Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.internal.SQLConf

/**
* Apply all of the GroupExpressions to every input row, hence we will get
Expand Down Expand Up @@ -152,40 +152,82 @@ case class ExpandExec(
// This column is the same across all output rows. Just generate code for it here.
BindReferences.bindReference(firstExpr, attributeSeq).genCode(ctx)
} else {
val isNull = ctx.freshName("isNull")
val value = ctx.freshName("value")
val code = code"""
|boolean $isNull = true;
|${CodeGenerator.javaType(firstExpr.dataType)} $value =
| ${CodeGenerator.defaultValue(firstExpr.dataType)};
""".stripMargin
val isNull = ctx.addMutableState(
CodeGenerator.JAVA_BOOLEAN,
"resultIsNull",
v => s"$v = true;")
val value = ctx.addMutableState(
CodeGenerator.javaType(firstExpr.dataType),
"resultValue",
v => s"$v = ${CodeGenerator.defaultValue(firstExpr.dataType)};")

ExprCode(
code,
JavaCode.isNullVariable(isNull),
JavaCode.variable(value, firstExpr.dataType))
}
}

// Part 2: switch/case statements
val cases = projections.zipWithIndex.map { case (exprs, row) =>
var updateCode = ""
for (col <- exprs.indices) {
val switchCaseExprs = projections.zipWithIndex.map { case (exprs, row) =>
val (exprCodesWithIndices, inputVarSets) = exprs.indices.flatMap { col =>
if (!sameOutput(col)) {
val ev = BindReferences.bindReference(exprs(col), attributeSeq).genCode(ctx)
updateCode +=
val boundExpr = BindReferences.bindReference(exprs(col), attributeSeq)
val exprCode = boundExpr.genCode(ctx)
val inputVars = CodeGenerator.getLocalInputVariableValues(ctx, boundExpr)._1
Some(((col, exprCode), inputVars))
} else {
None
}
}.unzip

val inputVars = inputVarSets.foldLeft(Set.empty[VariableValue])(_ ++ _)
(row, exprCodesWithIndices, inputVars.toSeq)
}

val updateCodes = switchCaseExprs.map { case (_, exprCodes, _) =>
exprCodes.map { case (col, ev) =>
s"""
|${ev.code}
|${outputColumns(col).isNull} = ${ev.isNull};
|${outputColumns(col).value} = ${ev.value};
""".stripMargin
}.mkString("\n")
}

val splitThreshold = SQLConf.get.methodSplitThreshold
val cases = if (switchCaseExprs.flatMap(_._2.map(_._2.code.length)).sum > splitThreshold) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we call generateUpdateCode only once before line 198?
IMHO, code in all three cases (line 203-, line 216, and line 225-) is generated by generateUpdateCode().

switchCaseExprs.zip(updateCodes).map { case ((row, _, inputVars), updateCode) =>
val paramLength = CodeGenerator.calculateParamLengthFromExprValues(inputVars)
val maybeSplitUpdateCode = if (CodeGenerator.isValidParamLength(paramLength)) {
val switchCaseFunc = ctx.freshName("switchCaseCode")
val argList = inputVars.map { v =>
Comment thread
kiszk marked this conversation as resolved.
s"${CodeGenerator.typeName(v.javaType)} ${v.variableName}"
}
ctx.addNewFunction(switchCaseFunc,
s"""
|${ev.code}
|${outputColumns(col).isNull} = ${ev.isNull};
|${outputColumns(col).value} = ${ev.value};
""".stripMargin
|private void $switchCaseFunc(${argList.mkString(", ")}) {
| $updateCode
|}
""".stripMargin)

s"$switchCaseFunc(${inputVars.map(_.variableName).mkString(", ")});"
} else {
updateCode
}
s"""
|case $row:
| $maybeSplitUpdateCode
| break;
""".stripMargin
}
} else {
switchCaseExprs.map(_._1).zip(updateCodes).map { case (row, updateCode) =>
s"""
|case $row:
| $updateCode
| break;
""".stripMargin
}

s"""
|case $row:
| ${updateCode.trim}
| break;
""".stripMargin
}

val numOutput = metricTerm(ctx, "numOutputRows")
Expand Down