From 98eaae9436adf63ec3023ee077f2fff8e23dfa35 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 14 Nov 2017 18:41:00 +0100 Subject: [PATCH 1/9] [SPARK-22520][SQL] Support code generation for large CaseWhen --- .../expressions/EquivalentExpressions.scala | 3 +- .../expressions/conditionalExpressions.scala | 152 +++++++----------- .../sql/catalyst/optimizer/Optimizer.scala | 2 - .../sql/catalyst/optimizer/expressions.scala | 15 -- .../apache/spark/sql/internal/SQLConf.scala | 8 - .../expressions/CodeGenerationSuite.scala | 2 +- .../optimizer/OptimizeCodegenSuite.scala | 101 ------------ .../FlatMapGroupsWithState_StateManager.scala | 2 +- .../spark/sql/internal/SQLConfSuite.scala | 29 ---- 9 files changed, 64 insertions(+), 250 deletions(-) delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index f8644c2cd672c..8d06804ce1e90 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -87,8 +87,7 @@ class EquivalentExpressions { def childrenToRecurse: Seq[Expression] = expr match { case _: CodegenFallback => Nil case i: If => i.predicate :: Nil - // `CaseWhen` implements `CodegenFallback`, we only need to handle `CaseWhenCodegen` here. - case c: CaseWhenCodegen => c.children.head :: Nil + case c: CaseWhen => c.children.head :: Nil case c: Coalesce => c.children.head :: Nil case other => other.children } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index d95b59d5ec423..b936f26581034 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -141,14 +141,34 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi } /** - * Abstract parent class for common logic in CaseWhen and CaseWhenCodegen. + * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END". + * When a = true, returns b; when c = true, returns d; else returns e. * * @param branches seq of (branch condition, branch value) * @param elseValue optional value for the else branch */ -abstract class CaseWhenBase( +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "CASE WHEN expr1 THEN expr2 [WHEN expr3 THEN expr4]* [ELSE expr5] END - When `expr1` = true, returns `expr2`; else when `expr3` = true, returns `expr4`; else returns `expr5`.", + arguments = """ + Arguments: + * expr1, expr3 - the branch condition expressions should all be boolean type. + * expr2, expr4, expr5 - the branch value expressions and else value expression should all be + same type or coercible to a common type. + """, + examples = """ + Examples: + > SELECT CASE WHEN 1 > 0 THEN 1 WHEN 2 > 0 THEN 2.0 ELSE 1.2 END; + 1 + > SELECT CASE WHEN 1 < 0 THEN 1 WHEN 2 > 0 THEN 2.0 ELSE 1.2 END; + 2 + > SELECT CASE WHEN 1 < 0 THEN 1 WHEN 2 < 0 THEN 2.0 ELSE null END; + NULL + """) +// scalastyle:on line.size.limit +case class CaseWhen( branches: Seq[(Expression, Expression)], - elseValue: Option[Expression]) + elseValue: Option[Expression] = None) extends Expression with Serializable { override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue @@ -211,111 +231,61 @@ abstract class CaseWhenBase( val elseCase = elseValue.map(" ELSE " + _.sql).getOrElse("") "CASE" + cases + elseCase + " END" } -} - - -/** - * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END". - * When a = true, returns b; when c = true, returns d; else returns e. - * - * @param branches seq of (branch condition, branch value) - * @param elseValue optional value for the else branch - */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "CASE WHEN expr1 THEN expr2 [WHEN expr3 THEN expr4]* [ELSE expr5] END - When `expr1` = true, returns `expr2`; else when `expr3` = true, returns `expr4`; else returns `expr5`.", - arguments = """ - Arguments: - * expr1, expr3 - the branch condition expressions should all be boolean type. - * expr2, expr4, expr5 - the branch value expressions and else value expression should all be - same type or coercible to a common type. - """, - examples = """ - Examples: - > SELECT CASE WHEN 1 > 0 THEN 1 WHEN 2 > 0 THEN 2.0 ELSE 1.2 END; - 1 - > SELECT CASE WHEN 1 < 0 THEN 1 WHEN 2 > 0 THEN 2.0 ELSE 1.2 END; - 2 - > SELECT CASE WHEN 1 < 0 THEN 1 WHEN 2 < 0 THEN 2.0 ELSE null END; - NULL - """) -// scalastyle:on line.size.limit -case class CaseWhen( - val branches: Seq[(Expression, Expression)], - val elseValue: Option[Expression] = None) - extends CaseWhenBase(branches, elseValue) with CodegenFallback with Serializable { - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - super[CodegenFallback].doGenCode(ctx, ev) - } - - def toCodegen(): CaseWhenCodegen = { - CaseWhenCodegen(branches, elseValue) - } -} - -/** - * CaseWhen expression used when code generation condition is satisfied. - * OptimizeCodegen optimizer replaces CaseWhen into CaseWhenCodegen. - * - * @param branches seq of (branch condition, branch value) - * @param elseValue optional value for the else branch - */ -case class CaseWhenCodegen( - val branches: Seq[(Expression, Expression)], - val elseValue: Option[Expression] = None) - extends CaseWhenBase(branches, elseValue) with Serializable { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - // Generate code that looks like: - // - // condA = ... - // if (condA) { - // valueA - // } else { - // condB = ... - // if (condB) { - // valueB - // } else { - // condC = ... - // if (condC) { - // valueC - // } else { - // elseValue - // } - // } - // } + val conditionMet = ctx.freshName("caseWhenConditionMet") + ctx.addMutableState("boolean", ev.isNull, "") + ctx.addMutableState(ctx.javaType(dataType), ev.value, "") val cases = branches.map { case (condExpr, valueExpr) => val cond = condExpr.genCode(ctx) val res = valueExpr.genCode(ctx) s""" - ${cond.code} - if (!${cond.isNull} && ${cond.value}) { - ${res.code} - ${ev.isNull} = ${res.isNull}; - ${ev.value} = ${res.value}; + if(!$conditionMet) { + ${cond.code} + if (!${cond.isNull} && ${cond.value}) { + ${res.code} + ${ev.isNull} = ${res.isNull}; + ${ev.value} = ${res.value}; + $conditionMet = true; + } } """ } - var generatedCode = cases.mkString("", "\nelse {\n", "\nelse {\n") - - elseValue.foreach { elseExpr => + val elseCode = elseValue.map { elseExpr => val res = elseExpr.genCode(ctx) - generatedCode += - s""" + s""" + if(!$conditionMet) { ${res.code} ${ev.isNull} = ${res.isNull}; ${ev.value} = ${res.value}; - """ - } + } + """ + }.getOrElse("") - generatedCode += "}\n" * cases.size + val casesCode = if (ctx.INPUT_ROW == null || ctx.currentVars != null) { + cases.mkString("\n") + } else { + ctx.splitExpressions(cases, "caseWhen", + ("InternalRow", ctx.INPUT_ROW) :: ("boolean", conditionMet) :: Nil, returnType = "boolean", + makeSplitFunction = { + func => + s""" + $func + return $conditionMet; + """ + }, + foldFunctions = { funcCalls => + funcCalls.map(funcCall => s"$conditionMet = $funcCall;").mkString("\n") + }) + } ev.copy(code = s""" - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - $generatedCode""") + ${ev.isNull} = true; + ${ev.value} = ${ctx.defaultValue(dataType)}; + boolean $conditionMet = false; + $casesCode + $elseCode""") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 3a3ccd5ff5e60..0d961bf2e6e5e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -138,8 +138,6 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) // The following batch should be executed after batch "Join Reorder" and "LocalRelation". Batch("Check Cartesian Products", Once, CheckCartesianProducts) :: - Batch("OptimizeCodegen", Once, - OptimizeCodegen) :: Batch("RewriteSubquery", Once, RewritePredicateSubquery, CollapseProject) :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 523b53b39d6b5..785e815b41185 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -552,21 +552,6 @@ object FoldablePropagation extends Rule[LogicalPlan] { } -/** - * Optimizes expressions by replacing according to CodeGen configuration. - */ -object OptimizeCodegen extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case e: CaseWhen if canCodegen(e) => e.toCodegen() - } - - private def canCodegen(e: CaseWhen): Boolean = { - val numBranches = e.branches.size + e.elseValue.size - numBranches <= SQLConf.get.maxCaseBranchesForCodegen - } -} - - /** * Removes [[Cast Casts]] that are unnecessary because the input is already the correct type. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index ede116e964a03..1c6f897716d01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -570,12 +570,6 @@ object SQLConf { .booleanConf .createWithDefault(true) - val MAX_CASES_BRANCHES = buildConf("spark.sql.codegen.maxCaseBranches") - .internal() - .doc("The maximum number of switches supported with codegen.") - .intConf - .createWithDefault(20) - val CODEGEN_LOGGING_MAX_LINES = buildConf("spark.sql.codegen.logging.maxLines") .internal() .doc("The maximum number of codegen lines to log when errors occur. Use -1 for unlimited.") @@ -1084,8 +1078,6 @@ class SQLConf extends Serializable with Logging { def codegenFallback: Boolean = getConf(CODEGEN_FALLBACK) - def maxCaseBranchesForCodegen: Int = getConf(MAX_CASES_BRANCHES) - def loggingMaxLinesForCodegen: Int = getConf(CODEGEN_LOGGING_MAX_LINES) def hugeMethodLimit: Int = getConf(WHOLESTAGE_HUGE_METHOD_LIMIT) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 1e6f7b65e7e72..2b61f884c7b5d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -77,7 +77,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { } test("SPARK-13242: case-when expression with large number of branches (or cases)") { - val cases = 50 + val cases = 500 val clauses = 20 // Generate an individual case diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala deleted file mode 100644 index b1157f3e3edd2..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala +++ /dev/null @@ -1,101 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.optimizer - -import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.Literal._ -import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.rules._ - - -class OptimizeCodegenSuite extends PlanTest { - - object Optimize extends RuleExecutor[LogicalPlan] { - val batches = Batch("OptimizeCodegen", Once, OptimizeCodegen) :: Nil - } - - protected def assertEquivalent(e1: Expression, e2: Expression): Unit = { - val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation()).analyze - val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation()).analyze) - comparePlans(actual, correctAnswer) - } - - test("Codegen only when the number of branches is small.") { - assertEquivalent( - CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)), - CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen()) - - assertEquivalent( - CaseWhen(List.fill(100)((TrueLiteral, Literal(1))), Literal(2)), - CaseWhen(List.fill(100)((TrueLiteral, Literal(1))), Literal(2))) - } - - test("Nested CaseWhen Codegen.") { - assertEquivalent( - CaseWhen( - Seq((CaseWhen(Seq((TrueLiteral, TrueLiteral)), FalseLiteral), Literal(3))), - CaseWhen(Seq((TrueLiteral, Literal(4))), Literal(5))), - CaseWhen( - Seq((CaseWhen(Seq((TrueLiteral, TrueLiteral)), FalseLiteral).toCodegen(), Literal(3))), - CaseWhen(Seq((TrueLiteral, Literal(4))), Literal(5)).toCodegen()).toCodegen()) - } - - test("Multiple CaseWhen in one operator.") { - val plan = OneRowRelation() - .select( - CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)), - CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)), - CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)), - CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6))).analyze - val correctAnswer = OneRowRelation() - .select( - CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen(), - CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)).toCodegen(), - CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)), - CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6)).toCodegen()).analyze - val optimized = Optimize.execute(plan) - comparePlans(optimized, correctAnswer) - } - - test("Multiple CaseWhen in different operators") { - val plan = OneRowRelation() - .select( - CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)), - CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)), - CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0))) - .where( - LessThan( - CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6)), - CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0))) - ).analyze - val correctAnswer = OneRowRelation() - .select( - CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen(), - CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)).toCodegen(), - CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0))) - .where( - LessThan( - CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6)).toCodegen(), - CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0))) - ).analyze - val optimized = Optimize.execute(plan) - comparePlans(optimized, correctAnswer) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala index d077836da847c..e49546830286b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala @@ -90,7 +90,7 @@ class FlatMapGroupsWithState_StateManager( val deser = stateEncoder.resolveAndBind().deserializer.transformUp { case BoundReference(ordinal, _, _) => GetStructField(boundRefToNestedState, ordinal) } - CaseWhen(Seq(IsNull(boundRefToNestedState) -> Literal(null)), elseValue = deser).toCodegen() + CaseWhen(Seq(IsNull(boundRefToNestedState) -> Literal(null)), elseValue = deser) } // Converters for translating state between rows and Java objects diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala index 205c303b6cc4b..7ad018eb7f908 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala @@ -221,35 +221,6 @@ class SQLConfSuite extends QueryTest with SharedSQLContext { .sessionState.conf.warehousePath.stripSuffix("/")) } - test("MAX_CASES_BRANCHES") { - withTable("tab1") { - spark.range(10).write.saveAsTable("tab1") - val sql_one_branch_caseWhen = "SELECT CASE WHEN id = 1 THEN 1 END FROM tab1" - val sql_two_branch_caseWhen = "SELECT CASE WHEN id = 1 THEN 1 ELSE 0 END FROM tab1" - - withSQLConf(SQLConf.MAX_CASES_BRANCHES.key -> "0") { - assert(!sql(sql_one_branch_caseWhen) - .queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec]) - assert(!sql(sql_two_branch_caseWhen) - .queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec]) - } - - withSQLConf(SQLConf.MAX_CASES_BRANCHES.key -> "1") { - assert(sql(sql_one_branch_caseWhen) - .queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec]) - assert(!sql(sql_two_branch_caseWhen) - .queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec]) - } - - withSQLConf(SQLConf.MAX_CASES_BRANCHES.key -> "2") { - assert(sql(sql_one_branch_caseWhen) - .queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec]) - assert(sql(sql_two_branch_caseWhen) - .queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec]) - } - } - } - test("static SQL conf comes from SparkConf") { val previousValue = sparkContext.conf.get(SCHEMA_STRING_LENGTH_THRESHOLD) try { From 6225c8ecb00bab4fa892e7847f5aa9bdee54409b Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 24 Nov 2017 16:15:18 +0100 Subject: [PATCH 2/9] adding test case --- .../expressions/conditionalExpressions.scala | 13 +++++++------ .../org/apache/spark/sql/DataFrameSuite.scala | 15 ++++++++++++++- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index b936f26581034..c629db6b73309 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -261,12 +261,14 @@ case class CaseWhen( ${ev.value} = ${res.value}; } """ - }.getOrElse("") + } + + val allConditions = cases ++ elseCode - val casesCode = if (ctx.INPUT_ROW == null || ctx.currentVars != null) { - cases.mkString("\n") + val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) { + allConditions.mkString("\n") } else { - ctx.splitExpressions(cases, "caseWhen", + ctx.splitExpressions(allConditions, "caseWhen", ("InternalRow", ctx.INPUT_ROW) :: ("boolean", conditionMet) :: Nil, returnType = "boolean", makeSplitFunction = { func => @@ -284,8 +286,7 @@ case class CaseWhen( ${ev.isNull} = true; ${ev.value} = ${ctx.defaultValue(dataType)}; boolean $conditionMet = false; - $casesCode - $elseCode""") + $code""") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 17c88b0690800..3c31f64693e6c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -29,7 +29,7 @@ import org.scalatest.Matchers._ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Union} -import org.apache.spark.sql.execution.{FilterExec, QueryExecution} +import org.apache.spark.sql.execution.{FilterExec, QueryExecution, WholeStageCodegenExec} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.functions._ @@ -2126,4 +2126,17 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val mean = result.select("DecimalCol").where($"summary" === "mean") assert(mean.collect().toSet === Set(Row("0.0345678900000000000000000000000000000"))) } + + test("SPARK-22520: support code generation for large CaseWhen") { + val N = 30 + var expr1 = when($"id" === lit(0), 0) + var expr2 = when($"id" === lit(0), 10) + (1 to N).foreach { i => + expr1 = expr1.when($"id" === lit(i), -i) + expr2 = expr2.when($"id" === lit(i + 10), i) + } + val df = spark.range(1).select(expr1, expr2.otherwise(0)) + df.show + assert(df.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec]) + } } From f9c20bea19e1e03394a976c90012fc8744267065 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sun, 26 Nov 2017 22:15:20 +0100 Subject: [PATCH 3/9] review comments --- .../expressions/conditionalExpressions.scala | 43 ++++++++++++------- .../org/apache/spark/sql/DataFrameSuite.scala | 2 +- 2 files changed, 28 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index c629db6b73309..67f856bda9416 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -233,8 +233,12 @@ case class CaseWhen( } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + // This variable represents whether the first successful condition is met or not. + // It is initialized to `false` and it is set to `true` when the first condition which + // evaluates to `true` is met and therefore is not needed to go on anymore on the computation + // of the following conditions. val conditionMet = ctx.freshName("caseWhenConditionMet") - ctx.addMutableState("boolean", ev.isNull, "") + ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull, "") ctx.addMutableState(ctx.javaType(dataType), ev.value, "") val cases = branches.map { case (condExpr, valueExpr) => val cond = condExpr.genCode(ctx) @@ -266,21 +270,28 @@ case class CaseWhen( val allConditions = cases ++ elseCode val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) { - allConditions.mkString("\n") - } else { - ctx.splitExpressions(allConditions, "caseWhen", - ("InternalRow", ctx.INPUT_ROW) :: ("boolean", conditionMet) :: Nil, returnType = "boolean", - makeSplitFunction = { - func => - s""" - $func - return $conditionMet; - """ - }, - foldFunctions = { funcCalls => - funcCalls.map(funcCall => s"$conditionMet = $funcCall;").mkString("\n") - }) - } + allConditions.mkString("\n") + } else { + ctx.splitExpressions(allConditions, "caseWhen", + ("InternalRow", ctx.INPUT_ROW) :: (ctx.JAVA_BOOLEAN, conditionMet) :: Nil, + returnType = ctx.JAVA_BOOLEAN, + makeSplitFunction = { + func => + s""" + $func + return $conditionMet; + """ + }, + foldFunctions = { funcCalls => + funcCalls.map { funcCall => + s""" + $conditionMet = $funcCall; + if ($conditionMet) { + continue; + }""" + }.mkString("do {", "", "\n} while (false);") + }) + } ev.copy(code = s""" ${ev.isNull} = true; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 3c31f64693e6c..1ab3294a23b7b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2136,7 +2136,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { expr2 = expr2.when($"id" === lit(i + 10), i) } val df = spark.range(1).select(expr1, expr2.otherwise(0)) - df.show + checkAnswer(df, Row(0, 10) :: Nil) assert(df.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec]) } } From 9063583bce77348b9da61abec6e9fb5ae7aef117 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sun, 26 Nov 2017 22:23:11 +0100 Subject: [PATCH 4/9] change description example --- .../spark/sql/catalyst/expressions/conditionalExpressions.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 67f856bda9416..7eade486bc3c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -162,7 +162,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi 1 > SELECT CASE WHEN 1 < 0 THEN 1 WHEN 2 > 0 THEN 2.0 ELSE 1.2 END; 2 - > SELECT CASE WHEN 1 < 0 THEN 1 WHEN 2 < 0 THEN 2.0 ELSE null END; + > SELECT CASE WHEN 1 < 0 THEN 1 WHEN 2 < 0 THEN 2.0 END; NULL """) // scalastyle:on line.size.limit From f4c78965a8cee34e3be8b9d8e264f2d6eb0d27f5 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 27 Nov 2017 08:09:36 +0100 Subject: [PATCH 5/9] nit: remove useless init empty string --- .../sql/catalyst/expressions/conditionalExpressions.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index f1bfa03af9d2f..5b30949eac9c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -185,8 +185,8 @@ case class CaseWhen( // evaluates to `true` is met and therefore is not needed to go on anymore on the computation // of the following conditions. val conditionMet = ctx.freshName("caseWhenConditionMet") - ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull, "") - ctx.addMutableState(ctx.javaType(dataType), ev.value, "") + ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) + ctx.addMutableState(ctx.javaType(dataType), ev.value) val cases = branches.map { case (condExpr, valueExpr) => val cond = condExpr.genCode(ctx) val res = valueExpr.genCode(ctx) From 5adb513ddba69e02b83aa47b69f45c73753a8457 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 27 Nov 2017 10:08:26 +0100 Subject: [PATCH 6/9] making conditionMet a local variable --- .../expressions/conditionalExpressions.scala | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 5b30949eac9c7..c104892a55a61 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -219,12 +219,25 @@ case class CaseWhen( val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) { allConditions.mkString("\n") } else { + // This generates code like: + // do { + // conditionMet = caseWhen_1(i); + // if(conditionMet) { + // continue; + // } + // conditionMet = caseWhen_2(i); + // if(conditionMet) { + // continue; + // } + // ... + // } while (false); ctx.splitExpressions(allConditions, "caseWhen", - ("InternalRow", ctx.INPUT_ROW) :: (ctx.JAVA_BOOLEAN, conditionMet) :: Nil, + ("InternalRow", ctx.INPUT_ROW) :: Nil, returnType = ctx.JAVA_BOOLEAN, makeSplitFunction = { func => s""" + ${ctx.JAVA_BOOLEAN} $conditionMet = false; $func return $conditionMet; """ @@ -243,7 +256,7 @@ case class CaseWhen( ev.copy(code = s""" ${ev.isNull} = true; ${ev.value} = ${ctx.defaultValue(dataType)}; - boolean $conditionMet = false; + ${ctx.JAVA_BOOLEAN} $conditionMet = false; $code""") } } From 6b280fd69fe53e11a9ab54874b9eda9d9340b63c Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 27 Nov 2017 17:23:53 +0100 Subject: [PATCH 7/9] implement do while optimization also inside the methods --- .../expressions/conditionalExpressions.scala | 24 +++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index c104892a55a61..3856efe559b10 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -179,6 +179,14 @@ case class CaseWhen( "CASE" + cases + elseCase + " END" } + private def wrapInDoWhileFalse(code: String): String = { + s""" + do { + $code + } while (false); + """ + } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // This variable represents whether the first successful condition is met or not. // It is initialized to `false` and it is set to `true` when the first condition which @@ -187,6 +195,12 @@ case class CaseWhen( val conditionMet = ctx.freshName("caseWhenConditionMet") ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) ctx.addMutableState(ctx.javaType(dataType), ev.value) + + // these blocks are meant to be inside a + // do { + // ... + // } while (false); + // loop val cases = branches.map { case (condExpr, valueExpr) => val cond = condExpr.genCode(ctx) val res = valueExpr.genCode(ctx) @@ -198,6 +212,7 @@ case class CaseWhen( ${ev.isNull} = ${res.isNull}; ${ev.value} = ${res.value}; $conditionMet = true; + continue; } } """ @@ -217,7 +232,7 @@ case class CaseWhen( val allConditions = cases ++ elseCode val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) { - allConditions.mkString("\n") + wrapInDoWhileFalse(allConditions.mkString("\n")) } else { // This generates code like: // do { @@ -238,18 +253,19 @@ case class CaseWhen( func => s""" ${ctx.JAVA_BOOLEAN} $conditionMet = false; - $func + ${wrapInDoWhileFalse(func)} return $conditionMet; """ }, foldFunctions = { funcCalls => - funcCalls.map { funcCall => + val loopBody = funcCalls.map { funcCall => s""" $conditionMet = $funcCall; if ($conditionMet) { continue; }""" - }.mkString("do {", "", "\n} while (false);") + }.mkString + wrapInDoWhileFalse(loopBody) }) } From c7347b1565ed2697c0327218bd2c61cff59a8f33 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 27 Nov 2017 17:25:29 +0100 Subject: [PATCH 8/9] minor: test style warn --- .../sql/catalyst/expressions/CodeGenerationSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 0343fa8d9970b..a4198f826cedb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -88,13 +88,13 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { (condition, Literal(n)) } - val expression = CaseWhen((1 to cases).map(generateCase(_))) + val expression = CaseWhen((1 to cases).map(generateCase)) val plan = GenerateMutableProjection.generate(Seq(expression)) - val input = new GenericInternalRow(Array[Any](UTF8String.fromString(s"${clauses}:${cases}"))) + val input = new GenericInternalRow(Array[Any](UTF8String.fromString(s"$clauses:$cases"))) val actual = plan(input).toSeq(Seq(expression.dataType)) - assert(actual(0) == cases) + assert(actual.head == cases) } test("SPARK-22543: split large if expressions into blocks due to JVM code size limit") { From dd5f455541babc3b594d071f1aae8591cd01f8de Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 27 Nov 2017 21:02:50 +0100 Subject: [PATCH 9/9] fix bug --- .../expressions/conditionalExpressions.scala | 31 ++++++++++--------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 3856efe559b10..a8629c15c7564 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -179,14 +179,6 @@ case class CaseWhen( "CASE" + cases + elseCase + " END" } - private def wrapInDoWhileFalse(code: String): String = { - s""" - do { - $code - } while (false); - """ - } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // This variable represents whether the first successful condition is met or not. // It is initialized to `false` and it is set to `true` when the first condition which @@ -232,10 +224,9 @@ case class CaseWhen( val allConditions = cases ++ elseCode val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) { - wrapInDoWhileFalse(allConditions.mkString("\n")) + allConditions.mkString("\n") } else { // This generates code like: - // do { // conditionMet = caseWhen_1(i); // if(conditionMet) { // continue; @@ -245,7 +236,14 @@ case class CaseWhen( // continue; // } // ... - // } while (false); + // and the declared methods are: + // private boolean caseWhen_1234() { + // boolean conditionMet = false; + // do { + // // here the evaluation of the conditions + // } while (false); + // return conditionMet; + // } ctx.splitExpressions(allConditions, "caseWhen", ("InternalRow", ctx.INPUT_ROW) :: Nil, returnType = ctx.JAVA_BOOLEAN, @@ -253,19 +251,20 @@ case class CaseWhen( func => s""" ${ctx.JAVA_BOOLEAN} $conditionMet = false; - ${wrapInDoWhileFalse(func)} + do { + $func + } while (false); return $conditionMet; """ }, foldFunctions = { funcCalls => - val loopBody = funcCalls.map { funcCall => + funcCalls.map { funcCall => s""" $conditionMet = $funcCall; if ($conditionMet) { continue; }""" }.mkString - wrapInDoWhileFalse(loopBody) }) } @@ -273,7 +272,9 @@ case class CaseWhen( ${ev.isNull} = true; ${ev.value} = ${ctx.defaultValue(dataType)}; ${ctx.JAVA_BOOLEAN} $conditionMet = false; - $code""") + do { + $code + } while (false);""") } }