diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index d05113431df41..4b2d4195ee906 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -276,12 +276,12 @@ case class MapElementsExec( } override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { - val (funcClass, methodName) = func match { + val (funcClass, funcName) = func match { case m: MapFunction[_, _] => classOf[MapFunction[_, _]] -> "call" case _ => FunctionUtils.getFunctionOneName(outputObjectType, child.output(0).dataType) } val funcObj = Literal.create(func, ObjectType(funcClass)) - val callFunc = Invoke(funcObj, methodName, outputObjectType, child.output) + val callFunc = Invoke(funcObj, funcName, outputObjectType, child.output, propagateNull = false) val result = BindReferences.bindReference(callFunc, child.output).genCode(ctx) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index af65957691b37..06600c1e4b1d7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1916,6 +1916,16 @@ class DatasetSuite extends QueryTest assert(df1.semanticHash !== df3.semanticHash) assert(df3.semanticHash === df4.semanticHash) } + + test("SPARK-31854: Invoke in MapElementsExec should not propagate null") { + Seq("true", "false").foreach { wholeStage => + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> wholeStage) { + val ds = Seq(1.asInstanceOf[Integer], null.asInstanceOf[Integer]).toDS() + val expectedAnswer = Seq[(Integer, Integer)]((1, 1), (null, null)) + checkDataset(ds.map(v => (v, v)), expectedAnswer: _*) + } + } + } } object AssertExecutionId {