From e95d7f48c1d4a0a9a1dac1ecb6d0b32f895c4154 Mon Sep 17 00:00:00 2001 From: Emil Ejbyfeldt Date: Fri, 4 Jun 2021 14:04:06 +0200 Subject: [PATCH 1/2] Fix encoder (interpreted path) for Map with case classes Used the key/valueLambdaFunction to convert the elements instead of using CatalystTypeConverters.createToScalaConverter. This is how it is done in MapObjects and that correctly handles Arrays with case classes. --- .../expressions/objects/objects.scala | 31 +++++++++---------- .../encoders/ExpressionEncoderSuite.scala | 5 +++ 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index e5726544f9ca0..455050e6fd5b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -29,7 +29,7 @@ import org.apache.commons.lang3.reflect.MethodUtils import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.serializer._ import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection} +import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection} import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -1167,11 +1167,6 @@ case class CatalystToExternalMap private( private lazy val inputMapType = inputData.dataType.asInstanceOf[MapType] - private lazy val keyConverter = - CatalystTypeConverters.createToScalaConverter(inputMapType.keyType) - private lazy val valueConverter = - CatalystTypeConverters.createToScalaConverter(inputMapType.valueType) - private lazy val (newMapBuilderMethod, moduleField) = { val clazz = Utils.classForName(collClass.getCanonicalName + "$") (clazz.getMethod("newBuilder"), clazz.getField("MODULE$").get(null)) @@ -1181,21 +1176,25 @@ case class CatalystToExternalMap private( newMapBuilderMethod.invoke(moduleField).asInstanceOf[Builder[AnyRef, AnyRef]] } + private def keyValueIterator(md: MapData): Iterator[AnyRef] = { + val keyArray = md.keyArray() + val valueArray = md.valueArray() + val row = new GenericInternalRow(1) + 0.until(md.numElements()).iterator.map { i => + row.update(0, keyArray.get(i, inputMapType.keyType)) + val key = keyLambdaFunction.eval(row) + row.update(0, valueArray.get(i, inputMapType.valueType)) + val value = valueLambdaFunction.eval(row) + Tuple2(key, value) + } + } + override def eval(input: InternalRow): Any = { val result = inputData.eval(input).asInstanceOf[MapData] if (result != null) { val builder = newMapBuilder() builder.sizeHint(result.numElements()) - val keyArray = result.keyArray() - val valueArray = result.valueArray() - var i = 0 - while (i < result.numElements()) { - val key = keyConverter(keyArray.get(i, inputMapType.keyType)) - val value = valueConverter(valueArray.get(i, inputMapType.valueType)) - builder += Tuple2(key, value) - i += 1 - } - builder.result() + keyValueIterator(result).foldLeft(builder)(_ += _).result } else { null } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 7faab4e7aa757..bf4afac2f8be6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -114,6 +114,7 @@ case class ReferenceValueClass(wrapped: ReferenceValueClass.Container) extends A object ReferenceValueClass { case class Container(data: Int) } +case class IntAndString(i: Int, s: String) class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTest { OuterScopes.addOuterScope(this) @@ -174,6 +175,10 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes encodeDecodeTest(Map(1 -> "a", 2 -> "b"), "map") encodeDecodeTest(Map(1 -> "a", 2 -> null), "map with null") encodeDecodeTest(Map(1 -> Map("a" -> 1), 2 -> Map("b" -> 2)), "map of map") + encodeDecodeTest(Map(1 -> IntAndString(1, "a")), "map with case class as value") + encodeDecodeTest(Map(IntAndString(1, "a") -> 1), "map with case class as key") + encodeDecodeTest(Map(IntAndString(1, "a") -> IntAndString(2, "b")), + "map with case class as key and value") encodeDecodeTest(Tuple1[Seq[Int]](null), "null seq in tuple") encodeDecodeTest(Tuple1[Map[String, String]](null), "null map in tuple") From 9cc2484600374022bb76a95039e22a8c232a4700 Mon Sep 17 00:00:00 2001 From: Emil Ejbyfeldt Date: Tue, 8 Jun 2021 08:28:53 +0200 Subject: [PATCH 2/2] Use while loop --- .../expressions/objects/objects.scala | 27 +++++++++---------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 455050e6fd5b8..e78d442a651a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1176,25 +1176,24 @@ case class CatalystToExternalMap private( newMapBuilderMethod.invoke(moduleField).asInstanceOf[Builder[AnyRef, AnyRef]] } - private def keyValueIterator(md: MapData): Iterator[AnyRef] = { - val keyArray = md.keyArray() - val valueArray = md.valueArray() - val row = new GenericInternalRow(1) - 0.until(md.numElements()).iterator.map { i => - row.update(0, keyArray.get(i, inputMapType.keyType)) - val key = keyLambdaFunction.eval(row) - row.update(0, valueArray.get(i, inputMapType.valueType)) - val value = valueLambdaFunction.eval(row) - Tuple2(key, value) - } - } - override def eval(input: InternalRow): Any = { val result = inputData.eval(input).asInstanceOf[MapData] if (result != null) { val builder = newMapBuilder() builder.sizeHint(result.numElements()) - keyValueIterator(result).foldLeft(builder)(_ += _).result + val keyArray = result.keyArray() + val valueArray = result.valueArray() + val row = new GenericInternalRow(1) + var i = 0 + while (i < result.numElements()) { + row.update(0, keyArray.get(i, inputMapType.keyType)) + val key = keyLambdaFunction.eval(row) + row.update(0, valueArray.get(i, inputMapType.valueType)) + val value = valueLambdaFunction.eval(row) + builder += Tuple2(key, value) + i += 1 + } + builder.result() } else { null }