From e94a255be1627164c15f9524bfeb96dfb14f5c77 Mon Sep 17 00:00:00 2001 From: Michal Senkyr Date: Sun, 12 Feb 2017 16:51:51 +0100 Subject: [PATCH 1/8] Arbitrary map support implementation --- .../spark/sql/catalyst/ScalaReflection.scala | 37 +-- .../expressions/objects/objects.scala | 249 ++++++++++++++++++ .../sql/catalyst/ScalaReflectionSuite.scala | 25 ++ .../org/apache/spark/sql/SQLImplicits.scala | 5 + .../spark/sql/DatasetPrimitiveSuite.scala | 81 ++++++ 5 files changed, 373 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 82710a2a183ab..3db113138cec0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -333,31 +333,20 @@ object ScalaReflection extends ScalaReflection { // TODO: add walked type path for map val TypeRef(_, _, Seq(keyType, valueType)) = t - val keyData = - Invoke( - MapObjects( - p => deserializerFor(keyType, Some(p), walkedTypePath), - Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType), - returnNullable = false), - schemaFor(keyType).dataType), - "array", - ObjectType(classOf[Array[Any]]), returnNullable = false) - - val valueData = - Invoke( - MapObjects( - p => deserializerFor(valueType, Some(p), walkedTypePath), - Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType), - returnNullable = false), - schemaFor(valueType).dataType), - "array", - ObjectType(classOf[Array[Any]]), returnNullable = false) + val cls = t.companion.member(TermName("newBuilder")) match { + case NoSymbol => classOf[Map[_, _]] + case _ => mirror.runtimeClass(t.typeSymbol.asClass) + } - StaticInvoke( - ArrayBasedMapData.getClass, - ObjectType(classOf[scala.collection.immutable.Map[_, _]]), - "toScalaMap", - keyData :: valueData :: Nil) + CollectObjectsToMap( + p => deserializerFor(keyType, Some(p), walkedTypePath), + Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)), + schemaFor(keyType).dataType, + p => deserializerFor(valueType, Some(p), walkedTypePath), + Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)), + schemaFor(valueType).dataType, + cls + ) case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) => val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index f446c3e4a75f6..e00e4ad833d7e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -649,6 +649,255 @@ case class MapObjects private( } } +object CollectObjectsToMap { + private val curId = new java.util.concurrent.atomic.AtomicInteger() + + /** + * Construct an instance of CollectObjects case class. + * + * @param keyFunction The function applied on the key collection elements. + * @param keyInputData An expression that when evaluated returns a key collection object. + * @param keyElementType The data type of key elements in the collection. + * @param valueFunction The function applied on the value collection elements. + * @param valueInputData An expression that when evaluated returns a value collection object. + * @param valueElementType The data type of value elements in the collection. + * @param collClass The type of the resulting collection. + */ + def apply( + keyFunction: Expression => Expression, + keyInputData: Expression, + keyElementType: DataType, + valueFunction: Expression => Expression, + valueInputData: Expression, + valueElementType: DataType, + collClass: Class[_]): CollectObjectsToMap = { + val keyLoopValue = "CollectObjectsToMap_loopValue" + curId.getAndIncrement() + val keyLoopIsNull = "CollectObjectsToMap_loopIsNull" + curId.getAndIncrement() + val keyLoopVar = LambdaVariable(keyLoopValue, keyLoopIsNull, keyElementType) + val valueLoopValue = "CollectObjectsToMap_loopValue" + curId.getAndIncrement() + val valueLoopIsNull = "CollectObjectsToMap_loopIsNull" + curId.getAndIncrement() + val valueLoopVar = LambdaVariable(valueLoopValue, valueLoopIsNull, valueElementType) + val tupleLoopVar = "CollectObjectsToMap_loopValue" + curId.getAndIncrement() + val builderValue = "CollectObjectsToMap_builderValue" + curId.getAndIncrement() + CollectObjectsToMap( + keyLoopValue, keyLoopIsNull, keyElementType, keyFunction(keyLoopVar), keyInputData, + valueLoopValue, valueLoopIsNull, valueElementType, valueFunction(valueLoopVar), + valueInputData, + tupleLoopVar, collClass, builderValue) + } +} + +/** + * An equivalent to the [[MapObjects]] case class but returning an ObjectType containing + * a Scala collection constructed using the associated builder, obtained by calling `newBuilder` + * on the collection's companion object. + * + * @param keyLoopValue the name of the loop variable that is used when iterating over the key + * collection, and which is used as input for the `keyLambdaFunction` + * @param keyLoopIsNull the nullability of the loop variable that is used when iterating over + * the key collection, and which is used as input for the `keyLambdaFunction` + * @param keyLoopVarDataType the data type of the loop variable that is used when iterating over + * the key collection, and which is used as input for the + * `keyLambdaFunction` + * @param keyLambdaFunction A function that takes the `keyLoopVar` as input, and is used as + * a lambda function to handle collection elements. + * @param keyInputData An expression that when evaluated returns a collection object. + * @param valueLoopValue the name of the loop variable that is used when iterating over the value + * collection, and which is used as input for the `valueLambdaFunction` + * @param valueLoopIsNull the nullability of the loop variable that is used when iterating over + * the value collection, and which is used as input for the + * `valueLambdaFunction` + * @param valueLoopVarDataType the data type of the loop variable that is used when iterating over + * the value collection, and which is used as input for the + * `valueLambdaFunction` + * @param valueLambdaFunction A function that takes the `valueLoopVar` as input, and is used as + * a lambda function to handle collection elements. + * @param valueInputData An expression that when evaluated returns a collection object. + * @param tupleLoopValue the name of the loop variable that holds the tuple to be added to the + * resulting map + * @param collClass The type of the resulting collection. + * @param builderValue The name of the builder variable used to construct the resulting collection. + */ +case class CollectObjectsToMap private( + keyLoopValue: String, + keyLoopIsNull: String, + keyLoopVarDataType: DataType, + keyLambdaFunction: Expression, + keyInputData: Expression, + valueLoopValue: String, + valueLoopIsNull: String, + valueLoopVarDataType: DataType, + valueLambdaFunction: Expression, + valueInputData: Expression, + tupleLoopValue: String, + collClass: Class[_], + builderValue: String) extends Expression with NonSQLExpression { + + override def nullable: Boolean = keyInputData.nullable + + override def children: Seq[Expression] = + keyLambdaFunction :: keyInputData :: valueLambdaFunction :: valueInputData :: Nil + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported") + + override def dataType: DataType = ObjectType(collClass) + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val collObjectName = s"${collClass.getName}$$.MODULE$$" + val getBuilderVar = s"$collObjectName.newBuilder()" + val keyElementJavaType = ctx.javaType(keyLoopVarDataType) + ctx.addMutableState("boolean", keyLoopIsNull, "") + ctx.addMutableState(keyElementJavaType, keyLoopValue, "") + val genKeyInputData = keyInputData.genCode(ctx) + val genKeyFunction = keyLambdaFunction.genCode(ctx) + val valueElementJavaType = ctx.javaType(valueLoopVarDataType) + ctx.addMutableState("boolean", valueLoopIsNull, "") + ctx.addMutableState(valueElementJavaType, valueLoopValue, "") + val genValueInputData = valueInputData.genCode(ctx) + val genValueFunction = valueLambdaFunction.genCode(ctx) + val dataLength = ctx.freshName("dataLength") + val loopIndex = ctx.freshName("loopIndex") + + // In RowEncoder, we use `Object` to represent Array or Seq, so we need to determine the type + // of input collection at runtime for this case. + val keySeq = ctx.freshName("keySeq") + val keyArray = ctx.freshName("keyArray") + val valueSeq = ctx.freshName("valueSeq") + val valueArray = ctx.freshName("valueArray") + def determineCollectionType(inputData: Expression, genInputData: ExprCode, + elementJavaType: String, seq: String, array: String) = + inputData.dataType match { + case ObjectType(cls) if cls == classOf[Object] => + val seqClass = classOf[Seq[_]].getName + s""" + $seqClass $seq = null; + $elementJavaType[] $array = null; + if (${genInputData.value}.getClass().isArray()) { + $array = ($elementJavaType[]) ${genInputData.value}; + } else { + $seq = ($seqClass) ${genInputData.value}; + } + """ + case _ => "" + } + val determineKeyCollectionType = determineCollectionType( + keyInputData, genKeyInputData, keyElementJavaType, keySeq, keyArray) + val determineValueCollectionType = determineCollectionType( + valueInputData, genValueInputData, valueElementJavaType, valueSeq, valueArray) + + // The data with PythonUserDefinedType are actually stored with the data type of its sqlType. + // When we want to apply MapObjects on it, we have to use it. + def inputDataType(inputData: Expression) = inputData.dataType match { + case p: PythonUserDefinedType => p.sqlType + case _ => inputData.dataType + } + val keyInputDataType = inputDataType(keyInputData) + val valueInputDataType = inputDataType(valueInputData) + + def lengthAndLoopVar(inputDataType: DataType, genInputData: ExprCode, + seq: String, array: String) = + inputDataType match { + case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) => + s"${genInputData.value}.size()" -> s"${genInputData.value}.apply($loopIndex)" + case ObjectType(cls) if cls.isArray => + s"${genInputData.value}.length" -> s"${genInputData.value}[$loopIndex]" + case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) => + s"${genInputData.value}.size()" -> s"${genInputData.value}.get($loopIndex)" + case ArrayType(et, _) => + s"${genInputData.value}.numElements()" -> ctx.getValue(genInputData.value, et, loopIndex) + case ObjectType(cls) if cls == classOf[Object] => + s"$seq == null ? $array.length : $seq.size()" -> + s"$seq == null ? $array[$loopIndex] : $seq.apply($loopIndex)" + } + val ((getKeyLength, getKeyLoopVar), (getValueLength, getValueLoopVar)) = ( + lengthAndLoopVar(inputDataType(keyInputData), genKeyInputData, keySeq, keyArray), + lengthAndLoopVar(inputDataType(valueInputData), genValueInputData, valueSeq, valueArray) + ) + + // Make a copy of the data if it's unsafe-backed + def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: String) = + s"$value instanceof ${clazz.getSimpleName}? $value.copy() : $value" + def genFunctionValue(lambdaFunction: Expression, genFunction: ExprCode) = + lambdaFunction.dataType match { + case StructType(_) => makeCopyIfInstanceOf(classOf[UnsafeRow], genFunction.value) + case ArrayType(_, _) => makeCopyIfInstanceOf(classOf[UnsafeArrayData], genFunction.value) + case MapType(_, _, _) => makeCopyIfInstanceOf(classOf[UnsafeMapData], genFunction.value) + case _ => genFunction.value + } + val genKeyFunctionValue = genFunctionValue(keyLambdaFunction, genKeyFunction) + val genValueFunctionValue = genFunctionValue(valueLambdaFunction, genValueFunction) + + def loopNullCheck(genInputData: ExprCode, inputDataType: DataType, + loopIsNull: String, loopValue: String) = + inputDataType match { + case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);" + // The element of primitive array will never be null. + case ObjectType(cls) if cls.isArray && cls.getComponentType.isPrimitive => + s"$loopIsNull = false" + case _ => s"$loopIsNull = $loopValue == null;" + } + val keyLoopNullCheck = + loopNullCheck(genKeyInputData, keyInputDataType, keyLoopIsNull, keyLoopValue) + val valueLoopNullCheck = + loopNullCheck(genValueInputData, valueInputDataType, valueLoopIsNull, valueLoopValue) + + val tupleClass = classOf[(_, _)].getName + + val code = s""" + ${genKeyInputData.code} + ${genValueInputData.code} + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + + if ((${genKeyInputData.isNull} && !${genValueInputData.isNull}) || + (!${genKeyInputData.isNull} && ${genValueInputData.isNull})) { + throw new RuntimeException("Invalid state: Inconsistent nullability of key-value"); + } + + if (!${genKeyInputData.isNull}) { + $determineKeyCollectionType + $determineValueCollectionType + if ($getKeyLength != $getValueLength) { + throw new RuntimeException("Invalid state: Inconsistent lengths of key-value arrays"); + } + int $dataLength = $getKeyLength; + ${classOf[Builder[_, _]].getName} $builderValue = $getBuilderVar; + $builderValue.sizeHint($dataLength); + + int $loopIndex = 0; + while ($loopIndex < $dataLength) { + $keyLoopValue = ($keyElementJavaType) ($getKeyLoopVar); + $valueLoopValue = ($valueElementJavaType) ($getValueLoopVar); + $keyLoopNullCheck + $valueLoopNullCheck + + ${genKeyFunction.code} + ${genValueFunction.code} + + $tupleClass $tupleLoopValue; + + if (${genKeyFunction.isNull}) { + throw new RuntimeException("Found null in map key!"); + } + + if (${genValueFunction.isNull}) { + $tupleLoopValue = new $tupleClass($genKeyFunctionValue, null); + } else { + $tupleLoopValue = new $tupleClass($genKeyFunctionValue, $genValueFunctionValue); + } + + $builderValue.$$plus$$eq($tupleLoopValue); + + $loopIndex += 1; + } + + ${ev.value} = (${collClass.getName}) $builderValue.result(); + } + """ + ev.copy(code = code, isNull = genKeyInputData.isNull) + } +} + object ExternalMapToCatalyst { private val curId = new java.util.concurrent.atomic.AtomicInteger() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 70ad064f93ebc..ff2414b174acb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -314,6 +314,31 @@ class ScalaReflectionSuite extends SparkFunSuite { assert(arrayBufferDeserializer.dataType == ObjectType(classOf[ArrayBuffer[_]])) } + test("serialize and deserialize arbitrary map types") { + val mapSerializer = serializerFor[Map[Int, Int]](BoundReference( + 0, ObjectType(classOf[Map[Int, Int]]), nullable = false)) + assert(mapSerializer.dataType.head.dataType == + MapType(IntegerType, IntegerType, valueContainsNull = false)) + val mapDeserializer = deserializerFor[Map[Int, Int]] + assert(mapDeserializer.dataType == ObjectType(classOf[Map[_, _]])) + + import scala.collection.immutable.HashMap + val hashMapSerializer = serializerFor[HashMap[Int, Int]](BoundReference( + 0, ObjectType(classOf[HashMap[Int, Int]]), nullable = false)) + assert(hashMapSerializer.dataType.head.dataType == + MapType(IntegerType, IntegerType, valueContainsNull = false)) + val hashMapDeserializer = deserializerFor[HashMap[Int, Int]] + assert(hashMapDeserializer.dataType == ObjectType(classOf[HashMap[_, _]])) + + import scala.collection.mutable.{LinkedHashMap => LHMap} + val linkedHashMapSerializer = serializerFor[LHMap[Long, String]](BoundReference( + 0, ObjectType(classOf[LHMap[Long, String]]), nullable = false)) + assert(linkedHashMapSerializer.dataType.head.dataType == + MapType(LongType, StringType, valueContainsNull = true)) + val linkedHashMapDeserializer = deserializerFor[LHMap[Long, String]] + assert(linkedHashMapDeserializer.dataType == ObjectType(classOf[LHMap[_, _]])) + } + private val dataTypeForComplexData = dataTypeFor[ComplexData] private val typeOfComplexData = typeOf[ComplexData] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 375df64d39734..f38a2e766d84b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import scala.collection.Map import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag @@ -199,6 +200,10 @@ abstract class SQLImplicits extends LowPrioritySQLImplicits { implicit def newProductSequenceEncoder[T <: Seq[Product] : TypeTag]: Encoder[T] = ExpressionEncoder() + // Maps + /** @since 2.2.0 */ + implicit def newMapEncoder[T <: Map[_, _] : TypeTag]: Encoder[T] = ExpressionEncoder() + // Arrays /** @since 1.6.1 */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index 541565344f758..8533afb624968 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import scala.collection.immutable.Queue +import scala.collection.mutable.{LinkedHashMap => LHMap} import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.test.SharedSQLContext @@ -30,8 +31,14 @@ case class ListClass(l: List[Int]) case class QueueClass(q: Queue[Int]) +case class MapClass(m: Map[Int, Int]) + +case class LHMapClass(m: LHMap[Int, Int]) + case class ComplexClass(seq: SeqClass, list: ListClass, queue: QueueClass) +case class ComplexMapClass(map: MapClass, lhmap: LHMapClass) + package object packageobject { case class PackageClass(value: Int) } @@ -258,6 +265,80 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { ListClass(List(1)) -> Queue("test" -> SeqClass(Seq(2)))) } + test("arbitrary maps") { + checkDataset(Seq(Map(1 -> 2)).toDS(), Map(1 -> 2)) + checkDataset(Seq(Map(1.toLong -> 2.toLong)).toDS(), Map(1.toLong -> 2.toLong)) + checkDataset(Seq(Map(1.toDouble -> 2.toDouble)).toDS(), Map(1.toDouble -> 2.toDouble)) + checkDataset(Seq(Map(1.toFloat -> 2.toFloat)).toDS(), Map(1.toFloat -> 2.toFloat)) + checkDataset(Seq(Map(1.toByte -> 2.toByte)).toDS(), Map(1.toByte -> 2.toByte)) + checkDataset(Seq(Map(1.toShort -> 2.toShort)).toDS(), Map(1.toShort -> 2.toShort)) + checkDataset(Seq(Map(true -> false)).toDS(), Map(true -> false)) + checkDataset(Seq(Map("test1" -> "test2")).toDS(), Map("test1" -> "test2")) + checkDataset(Seq(Map(Tuple1(1) -> Tuple1(2))).toDS(), Map(Tuple1(1) -> Tuple1(2))) + checkDataset(Seq(Map(1 -> Tuple1(2))).toDS(), Map(1 -> Tuple1(2))) + checkDataset(Seq(Map("test" -> 2.toLong)).toDS(), Map("test" -> 2.toLong)) + + checkDataset(Seq(LHMap(1 -> 2)).toDS(), LHMap(1 -> 2)) + checkDataset(Seq(LHMap(1.toLong -> 2.toLong)).toDS(), LHMap(1.toLong -> 2.toLong)) + checkDataset(Seq(LHMap(1.toDouble -> 2.toDouble)).toDS(), LHMap(1.toDouble -> 2.toDouble)) + checkDataset(Seq(LHMap(1.toFloat -> 2.toFloat)).toDS(), LHMap(1.toFloat -> 2.toFloat)) + checkDataset(Seq(LHMap(1.toByte -> 2.toByte)).toDS(), LHMap(1.toByte -> 2.toByte)) + checkDataset(Seq(LHMap(1.toShort -> 2.toShort)).toDS(), LHMap(1.toShort -> 2.toShort)) + checkDataset(Seq(LHMap(true -> false)).toDS(), LHMap(true -> false)) + checkDataset(Seq(LHMap("test1" -> "test2")).toDS(), LHMap("test1" -> "test2")) + checkDataset(Seq(LHMap(Tuple1(1) -> Tuple1(2))).toDS(), LHMap(Tuple1(1) -> Tuple1(2))) + checkDataset(Seq(LHMap(1 -> Tuple1(2))).toDS(), LHMap(1 -> Tuple1(2))) + checkDataset(Seq(LHMap("test" -> 2.toLong)).toDS(), LHMap("test" -> 2.toLong)) + } + + ignore("SPARK-19104: map and product combinations") { + // Case classes + checkDataset(Seq(MapClass(Map(1 -> 2))).toDS(), MapClass(Map(1 -> 2))) + checkDataset(Seq(Map(1 -> MapClass(Map(2 -> 3)))).toDS(), Map(1 -> MapClass(Map(2 -> 3)))) + checkDataset(Seq(Map(MapClass(Map(1 -> 2)) -> 3)).toDS(), Map(MapClass(Map(1 -> 2)) -> 3)) + checkDataset(Seq(Map(MapClass(Map(1 -> 2)) -> MapClass(Map(3 -> 4)))).toDS(), + Map(MapClass(Map(1 -> 2)) -> MapClass(Map(3 -> 4)))) + checkDataset(Seq(LHMap(1 -> MapClass(Map(2 -> 3)))).toDS(), LHMap(1 -> MapClass(Map(2 -> 3)))) + checkDataset(Seq(LHMap(MapClass(Map(1 -> 2)) -> 3)).toDS(), LHMap(MapClass(Map(1 -> 2)) -> 3)) + checkDataset(Seq(LHMap(MapClass(Map(1 -> 2)) -> MapClass(Map(3 -> 4)))).toDS(), + LHMap(MapClass(Map(1 -> 2)) -> MapClass(Map(3 -> 4)))) + + checkDataset(Seq(LHMapClass(LHMap(1 -> 2))).toDS(), LHMapClass(LHMap(1 -> 2))) + checkDataset(Seq(Map(1 -> LHMapClass(LHMap(2 -> 3)))).toDS(), + Map(1 -> LHMapClass(LHMap(2 -> 3)))) + checkDataset(Seq(Map(LHMapClass(LHMap(1 -> 2)) -> 3)).toDS(), + Map(LHMapClass(LHMap(1 -> 2)) -> 3)) + checkDataset(Seq(Map(LHMapClass(LHMap(1 -> 2)) -> LHMapClass(LHMap(3 -> 4)))).toDS(), + Map(LHMapClass(LHMap(1 -> 2)) -> LHMapClass(LHMap(3 -> 4)))) + checkDataset(Seq(LHMap(1 -> LHMapClass(LHMap(2 -> 3)))).toDS(), + LHMap(1 -> LHMapClass(LHMap(2 -> 3)))) + checkDataset(Seq(LHMap(LHMapClass(LHMap(1 -> 2)) -> 3)).toDS(), + LHMap(LHMapClass(LHMap(1 -> 2)) -> 3)) + checkDataset(Seq(LHMap(LHMapClass(LHMap(1 -> 2)) -> LHMapClass(LHMap(3 -> 4)))).toDS(), + LHMap(LHMapClass(LHMap(1 -> 2)) -> LHMapClass(LHMap(3 -> 4)))) + + val complex = ComplexMapClass(MapClass(Map(1 -> 2)), LHMapClass(LHMap(3 -> 4))) + checkDataset(Seq(complex).toDS(), complex) + checkDataset(Seq(Map(1 -> complex)).toDS(), Map(1 -> complex)) + checkDataset(Seq(Map(complex -> 5)).toDS(), Map(complex -> 5)) + checkDataset(Seq(Map(complex -> complex)).toDS(), Map(complex -> complex)) + checkDataset(Seq(LHMap(1 -> complex)).toDS(), LHMap(1 -> complex)) + checkDataset(Seq(LHMap(complex -> 5)).toDS(), LHMap(complex -> 5)) + checkDataset(Seq(LHMap(complex -> complex)).toDS(), LHMap(complex -> complex)) + + // Tuples + checkDataset(Seq(Map(1 -> 2) -> Map(3 -> 4)).toDS(), Map(1 -> 2) -> Map(3 -> 4)) + checkDataset(Seq(LHMap(1 -> 2) -> Map(3 -> 4)).toDS(), LHMap(1 -> 2) -> Map(3 -> 4)) + checkDataset(Seq(Map(1 -> 2) -> LHMap(3 -> 4)).toDS(), Map(1 -> 2) -> LHMap(3 -> 4)) + checkDataset(Seq(LHMap(1 -> 2) -> LHMap(3 -> 4)).toDS(), LHMap(1 -> 2) -> LHMap(3 -> 4)) + checkDataset(Seq(LHMap((Map("test1" -> 1) -> 2) -> (3 -> LHMap(4 -> "test2")))).toDS(), + LHMap((Map("test1" -> 1) -> 2) -> (3 -> LHMap(4 -> "test2")))) + + // Complex + checkDataset(Seq(LHMapClass(LHMap(1 -> 2)) -> LHMap("test" -> MapClass(Map(3 -> 4)))).toDS(), + LHMapClass(LHMap(1 -> 2)) -> LHMap("test" -> MapClass(Map(3 -> 4)))) + } + test("package objects") { import packageobject._ checkDataset(Seq(PackageClass(1)).toDS(), PackageClass(1)) From 8e1d924d6f554becce982698d24403113e3c3ba9 Mon Sep 17 00:00:00 2001 From: Michal Senkyr Date: Sun, 26 Feb 2017 15:33:37 +0100 Subject: [PATCH 2/8] Add support for Java Maps --- .../spark/sql/catalyst/ScalaReflection.scala | 13 +--- .../expressions/objects/objects.scala | 76 +++++++++++++++---- .../sql/catalyst/ScalaReflectionSuite.scala | 25 ++++++ .../org/apache/spark/sql/SQLImplicits.scala | 4 + .../spark/sql/DatasetPrimitiveSuite.scala | 49 ++++++++++++ 5 files changed, 142 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 3db113138cec0..56de59eefb8f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -329,15 +329,10 @@ object ScalaReflection extends ScalaReflection { } UnresolvedMapObjects(mapFunction, getPath, Some(cls)) - case t if t <:< localTypeOf[Map[_, _]] => + case t if t <:< localTypeOf[Map[_, _]] || t <:< localTypeOf[java.util.Map[_, _]] => // TODO: add walked type path for map val TypeRef(_, _, Seq(keyType, valueType)) = t - val cls = t.companion.member(TermName("newBuilder")) match { - case NoSymbol => classOf[Map[_, _]] - case _ => mirror.runtimeClass(t.typeSymbol.asClass) - } - CollectObjectsToMap( p => deserializerFor(keyType, Some(p), walkedTypePath), Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)), @@ -345,7 +340,7 @@ object ScalaReflection extends ScalaReflection { p => deserializerFor(valueType, Some(p), walkedTypePath), Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)), schemaFor(valueType).dataType, - cls + mirror.runtimeClass(t.typeSymbol.asClass) ) case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) => @@ -489,7 +484,7 @@ object ScalaReflection extends ScalaReflection { val TypeRef(_, _, Seq(elementType)) = t toCatalystArray(inputObject, elementType) - case t if t <:< localTypeOf[Map[_, _]] => + case t if t <:< localTypeOf[Map[_, _]] || t <:< localTypeOf[java.util.Map[_, _]] => val TypeRef(_, _, Seq(keyType, valueType)) = t val keyClsName = getClassNameFromType(keyType) val valueClsName = getClassNameFromType(valueType) @@ -703,7 +698,7 @@ object ScalaReflection extends ScalaReflection { val TypeRef(_, _, Seq(elementType)) = t val Schema(dataType, nullable) = schemaFor(elementType) Schema(ArrayType(dataType, containsNull = nullable), nullable = true) - case t if t <:< localTypeOf[Map[_, _]] => + case t if t <:< localTypeOf[Map[_, _]] || t <:< localTypeOf[java.util.Map[_, _]] => val TypeRef(_, _, Seq(keyType, valueType)) = t val Schema(valueDataType, valueNullable) = schemaFor(valueType) Schema(MapType(schemaFor(keyType).dataType, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index e00e4ad833d7e..35417db169daa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -22,6 +22,7 @@ import java.lang.reflect.Modifier import scala.collection.mutable.Builder import scala.language.existentials import scala.reflect.ClassTag +import scala.util.Try import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.serializer._ @@ -714,7 +715,7 @@ object CollectObjectsToMap { * a lambda function to handle collection elements. * @param valueInputData An expression that when evaluated returns a collection object. * @param tupleLoopValue the name of the loop variable that holds the tuple to be added to the - * resulting map + * resulting map (used only for Scala Map) * @param collClass The type of the resulting collection. * @param builderValue The name of the builder variable used to construct the resulting collection. */ @@ -744,8 +745,6 @@ case class CollectObjectsToMap private( override def dataType: DataType = ObjectType(collClass) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val collObjectName = s"${collClass.getName}$$.MODULE$$" - val getBuilderVar = s"$collObjectName.newBuilder()" val keyElementJavaType = ctx.javaType(keyLoopVarDataType) ctx.addMutableState("boolean", keyLoopIsNull, "") ctx.addMutableState(keyElementJavaType, keyLoopValue, "") @@ -842,7 +841,61 @@ case class CollectObjectsToMap private( val valueLoopNullCheck = loopNullCheck(genValueInputData, valueInputDataType, valueLoopIsNull, valueLoopValue) - val tupleClass = classOf[(_, _)].getName + val constructBuilder = collClass match { + // Scala Map + case cls if classOf[scala.collection.Map[_, _]].isAssignableFrom(cls) => + val builderClass = classOf[Builder[_, _]].getName + s""" + $builderClass $builderValue = ${collClass.getName}$$.MODULE$$.newBuilder(); + $builderValue.sizeHint($dataLength); + """ + // Java Map, AbstractMap => HashMap + case cls if classOf[java.util.Map[_, _]] == cls || + classOf[java.util.AbstractMap[_, _]] == cls => + val builderClass = classOf[java.util.HashMap[_, _]].getName + s"$builderClass $builderValue = new $builderClass($dataLength);" + // Java SortedMap, NavigableMap => TreeMap + case cls if classOf[java.util.SortedMap[_, _]] == cls || + classOf[java.util.NavigableMap[_, _]] == cls => + val builderClass = classOf[java.util.TreeMap[_, _]].getName + s"$builderClass $builderValue = new $builderClass();" + // Java ConcurrentMap => ConcurrentHashMap + case cls if classOf[java.util.concurrent.ConcurrentMap[_, _]] == cls => + val builderClass = classOf[java.util.concurrent.ConcurrentHashMap[_, _]].getName + s"$builderClass $builderValue = new $builderClass();" + // Java ConcurrentNavigableMap => ConcurrentSkipListMap + case cls if classOf[java.util.concurrent.ConcurrentNavigableMap[_, _]] == cls => + val builderClass = classOf[java.util.concurrent.ConcurrentSkipListMap[_, _]].getName + s"$builderClass $builderValue = new $builderClass();" + // Java concrete Map implementation + case cls => + val builderClass = classOf[java.util.Map[_, _]].getName + // Check for constructor with capacity specification + if (Try(cls.getConstructor(Integer.TYPE)).isSuccess) { + s"$builderClass $builderValue = new ${cls.getName}($dataLength);" + } else { + s"$builderClass $builderValue = new ${cls.getName}();" + } + } + + val (appendToBuilder, getBuilderResult) = + if (classOf[scala.collection.Map[_, _]].isAssignableFrom(collClass)) { + val tupleClass = classOf[(_, _)].getName + s""" + $tupleClass $tupleLoopValue; + + if (${genValueFunction.isNull}) { + $tupleLoopValue = new $tupleClass($genKeyFunctionValue, null); + } else { + $tupleLoopValue = new $tupleClass($genKeyFunctionValue, $genValueFunctionValue); + } + + $builderValue.$$plus$$eq($tupleLoopValue); + """ -> s"${ev.value} = (${collClass.getName}) $builderValue.result();" + } else { + s"$builderValue.put($genKeyFunctionValue, $genValueFunctionValue);" -> + s"${ev.value} = (${collClass.getName}) $builderValue;" + } val code = s""" ${genKeyInputData.code} @@ -861,8 +914,7 @@ case class CollectObjectsToMap private( throw new RuntimeException("Invalid state: Inconsistent lengths of key-value arrays"); } int $dataLength = $getKeyLength; - ${classOf[Builder[_, _]].getName} $builderValue = $getBuilderVar; - $builderValue.sizeHint($dataLength); + $constructBuilder int $loopIndex = 0; while ($loopIndex < $dataLength) { @@ -874,24 +926,16 @@ case class CollectObjectsToMap private( ${genKeyFunction.code} ${genValueFunction.code} - $tupleClass $tupleLoopValue; - if (${genKeyFunction.isNull}) { throw new RuntimeException("Found null in map key!"); } - if (${genValueFunction.isNull}) { - $tupleLoopValue = new $tupleClass($genKeyFunctionValue, null); - } else { - $tupleLoopValue = new $tupleClass($genKeyFunctionValue, $genValueFunctionValue); - } - - $builderValue.$$plus$$eq($tupleLoopValue); + $appendToBuilder $loopIndex += 1; } - ${ev.value} = (${collClass.getName}) $builderValue.result(); + $getBuilderResult } """ ev.copy(code = code, isNull = genKeyInputData.isNull) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index ff2414b174acb..a1e786da0abba 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -337,6 +337,31 @@ class ScalaReflectionSuite extends SparkFunSuite { MapType(LongType, StringType, valueContainsNull = true)) val linkedHashMapDeserializer = deserializerFor[LHMap[Long, String]] assert(linkedHashMapDeserializer.dataType == ObjectType(classOf[LHMap[_, _]])) + + import java.util.{Map => JMap} + + val jmapSerializer = serializerFor[JMap[Int, Int]](BoundReference( + 0, ObjectType(classOf[JMap[Int, Int]]), nullable = false)) + assert(jmapSerializer.dataType.head.dataType == + MapType(IntegerType, IntegerType, valueContainsNull = false)) + val jmapDeserializer = deserializerFor[JMap[Int, Int]] + assert(jmapDeserializer.dataType == ObjectType(classOf[JMap[_, _]])) + + import java.util.{LinkedHashMap => JLHMap} + val jLHMapSerializer = serializerFor[JLHMap[Int, Int]](BoundReference( + 0, ObjectType(classOf[JLHMap[Int, Int]]), nullable = false)) + assert(jLHMapSerializer.dataType.head.dataType == + MapType(IntegerType, IntegerType, valueContainsNull = false)) + val jLHMapDeserializer = deserializerFor[JLHMap[Int, Int]] + assert(jLHMapDeserializer.dataType == ObjectType(classOf[JLHMap[_, _]])) + + import java.util.{TreeMap => JTreeMap} + val jTreeMapSerializer = serializerFor[JTreeMap[Long, String]](BoundReference( + 0, ObjectType(classOf[JTreeMap[Long, String]]), nullable = false)) + assert(jTreeMapSerializer.dataType.head.dataType == + MapType(LongType, StringType, valueContainsNull = true)) + val jTreeMapDeserializer = deserializerFor[JTreeMap[Long, String]] + assert(jTreeMapDeserializer.dataType == ObjectType(classOf[JTreeMap[_, _]])) } private val dataTypeForComplexData = dataTypeFor[ComplexData] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index f38a2e766d84b..8b20aced1f641 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -204,6 +204,10 @@ abstract class SQLImplicits extends LowPrioritySQLImplicits { /** @since 2.2.0 */ implicit def newMapEncoder[T <: Map[_, _] : TypeTag]: Encoder[T] = ExpressionEncoder() + /** @since 2.2.0 */ + implicit def newJavaMapEncoder[T <: java.util.Map[_, _] : TypeTag]: Encoder[T] = + ExpressionEncoder() + // Arrays /** @since 1.6.1 */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index 8533afb624968..ca5f4e06a9816 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -289,6 +289,55 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { checkDataset(Seq(LHMap(Tuple1(1) -> Tuple1(2))).toDS(), LHMap(Tuple1(1) -> Tuple1(2))) checkDataset(Seq(LHMap(1 -> Tuple1(2))).toDS(), LHMap(1 -> Tuple1(2))) checkDataset(Seq(LHMap("test" -> 2.toLong)).toDS(), LHMap("test" -> 2.toLong)) + + def JMap[K, V](tuples: (K, V)*): java.util.Map[K, V] = { + import scala.collection.JavaConverters._ + scala.collection.mutable.Map(tuples: _*).asJava + } + + checkDataset(Seq(JMap(1 -> 2)).toDS(), JMap(1 -> 2)) + checkDataset(Seq(JMap(1.toLong -> 2.toLong)).toDS(), JMap(1.toLong -> 2.toLong)) + checkDataset(Seq(JMap(1.toDouble -> 2.toDouble)).toDS(), JMap(1.toDouble -> 2.toDouble)) + checkDataset(Seq(JMap(1.toFloat -> 2.toFloat)).toDS(), JMap(1.toFloat -> 2.toFloat)) + checkDataset(Seq(JMap(1.toByte -> 2.toByte)).toDS(), JMap(1.toByte -> 2.toByte)) + checkDataset(Seq(JMap(1.toShort -> 2.toShort)).toDS(), JMap(1.toShort -> 2.toShort)) + checkDataset(Seq(JMap(true -> false)).toDS(), JMap(true -> false)) + checkDataset(Seq(JMap("test1" -> "test2")).toDS(), JMap("test1" -> "test2")) + checkDataset(Seq(JMap(Tuple1(1) -> Tuple1(2))).toDS(), JMap(Tuple1(1) -> Tuple1(2))) + checkDataset(Seq(JMap(1 -> Tuple1(2))).toDS(), JMap(1 -> Tuple1(2))) + checkDataset(Seq(JMap("test" -> 2.toLong)).toDS(), JMap("test" -> 2.toLong)) + + def JLHMap[K, V](tuples: (K, V)*): java.util.LinkedHashMap[K, V] = { + new java.util.LinkedHashMap[K, V](JMap(tuples: _*)) + } + + checkDataset(Seq(JLHMap(1 -> 2)).toDS(), JLHMap(1 -> 2)) + checkDataset(Seq(JLHMap(1.toLong -> 2.toLong)).toDS(), JLHMap(1.toLong -> 2.toLong)) + checkDataset(Seq(JLHMap(1.toDouble -> 2.toDouble)).toDS(), JLHMap(1.toDouble -> 2.toDouble)) + checkDataset(Seq(JLHMap(1.toFloat -> 2.toFloat)).toDS(), JLHMap(1.toFloat -> 2.toFloat)) + checkDataset(Seq(JLHMap(1.toByte -> 2.toByte)).toDS(), JLHMap(1.toByte -> 2.toByte)) + checkDataset(Seq(JLHMap(1.toShort -> 2.toShort)).toDS(), JLHMap(1.toShort -> 2.toShort)) + checkDataset(Seq(JLHMap(true -> false)).toDS(), JLHMap(true -> false)) + checkDataset(Seq(JLHMap("test1" -> "test2")).toDS(), JLHMap("test1" -> "test2")) + checkDataset(Seq(JLHMap(Tuple1(1) -> Tuple1(2))).toDS(), JLHMap(Tuple1(1) -> Tuple1(2))) + checkDataset(Seq(JLHMap(1 -> Tuple1(2))).toDS(), JLHMap(1 -> Tuple1(2))) + checkDataset(Seq(JLHMap("test" -> 2.toLong)).toDS(), JLHMap("test" -> 2.toLong)) + + def JTreeMap[K, V](tuples: (K, V)*): java.util.TreeMap[K, V] = { + new java.util.TreeMap[K, V](JMap(tuples: _*)) + } + + checkDataset(Seq(JTreeMap(1 -> 2)).toDS(), JTreeMap(1 -> 2)) + checkDataset(Seq(JTreeMap(1.toLong -> 2.toLong)).toDS(), JTreeMap(1.toLong -> 2.toLong)) + checkDataset(Seq(JTreeMap(1.toDouble -> 2.toDouble)).toDS(), + JTreeMap(1.toDouble -> 2.toDouble)) + checkDataset(Seq(JTreeMap(1.toFloat -> 2.toFloat)).toDS(), JTreeMap(1.toFloat -> 2.toFloat)) + checkDataset(Seq(JTreeMap(1.toByte -> 2.toByte)).toDS(), JTreeMap(1.toByte -> 2.toByte)) + checkDataset(Seq(JTreeMap(1.toShort -> 2.toShort)).toDS(), JTreeMap(1.toShort -> 2.toShort)) + checkDataset(Seq(JTreeMap(true -> false)).toDS(), JTreeMap(true -> false)) + checkDataset(Seq(JTreeMap("test1" -> "test2")).toDS(), JTreeMap("test1" -> "test2")) + checkDataset(Seq(JTreeMap(1 -> Tuple1(2))).toDS(), JTreeMap(1 -> Tuple1(2))) + checkDataset(Seq(JTreeMap("test" -> 2.toLong)).toDS(), JTreeMap("test" -> 2.toLong)) } ignore("SPARK-19104: map and product combinations") { From b65f6ce70d6388dda3e1ae047326be9b55124f81 Mon Sep 17 00:00:00 2001 From: Michal Senkyr Date: Sun, 16 Apr 2017 16:05:42 +0200 Subject: [PATCH 3/8] Fix code style Add shared curId --- .../expressions/objects/objects.scala | 53 ++++++++++--------- 1 file changed, 27 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 35417db169daa..c5575221d6310 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -665,21 +665,22 @@ object CollectObjectsToMap { * @param collClass The type of the resulting collection. */ def apply( - keyFunction: Expression => Expression, - keyInputData: Expression, - keyElementType: DataType, - valueFunction: Expression => Expression, - valueInputData: Expression, - valueElementType: DataType, - collClass: Class[_]): CollectObjectsToMap = { - val keyLoopValue = "CollectObjectsToMap_loopValue" + curId.getAndIncrement() - val keyLoopIsNull = "CollectObjectsToMap_loopIsNull" + curId.getAndIncrement() + keyFunction: Expression => Expression, + keyInputData: Expression, + keyElementType: DataType, + valueFunction: Expression => Expression, + valueInputData: Expression, + valueElementType: DataType, + collClass: Class[_]): CollectObjectsToMap = { + val id = curId.getAndIncrement() + val keyLoopValue = s"CollectObjectsToMap_keyLoopValue$id" + val keyLoopIsNull = s"CollectObjectsToMap_keyLoopIsNull$id" val keyLoopVar = LambdaVariable(keyLoopValue, keyLoopIsNull, keyElementType) - val valueLoopValue = "CollectObjectsToMap_loopValue" + curId.getAndIncrement() - val valueLoopIsNull = "CollectObjectsToMap_loopIsNull" + curId.getAndIncrement() + val valueLoopValue = s"CollectObjectsToMap_valueLoopValue$id" + val valueLoopIsNull = s"CollectObjectsToMap_valueLoopIsNull$id" val valueLoopVar = LambdaVariable(valueLoopValue, valueLoopIsNull, valueElementType) - val tupleLoopVar = "CollectObjectsToMap_loopValue" + curId.getAndIncrement() - val builderValue = "CollectObjectsToMap_builderValue" + curId.getAndIncrement() + val tupleLoopVar = s"CollectObjectsToMap_tupleLoopValue$id" + val builderValue = s"CollectObjectsToMap_builderValue$id" CollectObjectsToMap( keyLoopValue, keyLoopIsNull, keyElementType, keyFunction(keyLoopVar), keyInputData, valueLoopValue, valueLoopIsNull, valueElementType, valueFunction(valueLoopVar), @@ -720,19 +721,19 @@ object CollectObjectsToMap { * @param builderValue The name of the builder variable used to construct the resulting collection. */ case class CollectObjectsToMap private( - keyLoopValue: String, - keyLoopIsNull: String, - keyLoopVarDataType: DataType, - keyLambdaFunction: Expression, - keyInputData: Expression, - valueLoopValue: String, - valueLoopIsNull: String, - valueLoopVarDataType: DataType, - valueLambdaFunction: Expression, - valueInputData: Expression, - tupleLoopValue: String, - collClass: Class[_], - builderValue: String) extends Expression with NonSQLExpression { + keyLoopValue: String, + keyLoopIsNull: String, + keyLoopVarDataType: DataType, + keyLambdaFunction: Expression, + keyInputData: Expression, + valueLoopValue: String, + valueLoopIsNull: String, + valueLoopVarDataType: DataType, + valueLambdaFunction: Expression, + valueInputData: Expression, + tupleLoopValue: String, + collClass: Class[_], + builderValue: String) extends Expression with NonSQLExpression { override def nullable: Boolean = keyInputData.nullable From bea90d561ffddac9338aaae0ef582b5bb623b339 Mon Sep 17 00:00:00 2001 From: Michal Senkyr Date: Fri, 19 May 2017 12:05:31 +0200 Subject: [PATCH 4/8] Remove Java Map support --- .../spark/sql/catalyst/ScalaReflection.scala | 6 +-- .../sql/catalyst/ScalaReflectionSuite.scala | 25 ---------- .../org/apache/spark/sql/SQLImplicits.scala | 4 -- .../spark/sql/DatasetPrimitiveSuite.scala | 49 ------------------- 4 files changed, 3 insertions(+), 81 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 56de59eefb8f8..f12fecaaea734 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -329,7 +329,7 @@ object ScalaReflection extends ScalaReflection { } UnresolvedMapObjects(mapFunction, getPath, Some(cls)) - case t if t <:< localTypeOf[Map[_, _]] || t <:< localTypeOf[java.util.Map[_, _]] => + case t if t <:< localTypeOf[Map[_, _]] => // TODO: add walked type path for map val TypeRef(_, _, Seq(keyType, valueType)) = t @@ -484,7 +484,7 @@ object ScalaReflection extends ScalaReflection { val TypeRef(_, _, Seq(elementType)) = t toCatalystArray(inputObject, elementType) - case t if t <:< localTypeOf[Map[_, _]] || t <:< localTypeOf[java.util.Map[_, _]] => + case t if t <:< localTypeOf[Map[_, _]] => val TypeRef(_, _, Seq(keyType, valueType)) = t val keyClsName = getClassNameFromType(keyType) val valueClsName = getClassNameFromType(valueType) @@ -698,7 +698,7 @@ object ScalaReflection extends ScalaReflection { val TypeRef(_, _, Seq(elementType)) = t val Schema(dataType, nullable) = schemaFor(elementType) Schema(ArrayType(dataType, containsNull = nullable), nullable = true) - case t if t <:< localTypeOf[Map[_, _]] || t <:< localTypeOf[java.util.Map[_, _]] => + case t if t <:< localTypeOf[Map[_, _]] => val TypeRef(_, _, Seq(keyType, valueType)) = t val Schema(valueDataType, valueNullable) = schemaFor(valueType) Schema(MapType(schemaFor(keyType).dataType, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index a1e786da0abba..ff2414b174acb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -337,31 +337,6 @@ class ScalaReflectionSuite extends SparkFunSuite { MapType(LongType, StringType, valueContainsNull = true)) val linkedHashMapDeserializer = deserializerFor[LHMap[Long, String]] assert(linkedHashMapDeserializer.dataType == ObjectType(classOf[LHMap[_, _]])) - - import java.util.{Map => JMap} - - val jmapSerializer = serializerFor[JMap[Int, Int]](BoundReference( - 0, ObjectType(classOf[JMap[Int, Int]]), nullable = false)) - assert(jmapSerializer.dataType.head.dataType == - MapType(IntegerType, IntegerType, valueContainsNull = false)) - val jmapDeserializer = deserializerFor[JMap[Int, Int]] - assert(jmapDeserializer.dataType == ObjectType(classOf[JMap[_, _]])) - - import java.util.{LinkedHashMap => JLHMap} - val jLHMapSerializer = serializerFor[JLHMap[Int, Int]](BoundReference( - 0, ObjectType(classOf[JLHMap[Int, Int]]), nullable = false)) - assert(jLHMapSerializer.dataType.head.dataType == - MapType(IntegerType, IntegerType, valueContainsNull = false)) - val jLHMapDeserializer = deserializerFor[JLHMap[Int, Int]] - assert(jLHMapDeserializer.dataType == ObjectType(classOf[JLHMap[_, _]])) - - import java.util.{TreeMap => JTreeMap} - val jTreeMapSerializer = serializerFor[JTreeMap[Long, String]](BoundReference( - 0, ObjectType(classOf[JTreeMap[Long, String]]), nullable = false)) - assert(jTreeMapSerializer.dataType.head.dataType == - MapType(LongType, StringType, valueContainsNull = true)) - val jTreeMapDeserializer = deserializerFor[JTreeMap[Long, String]] - assert(jTreeMapDeserializer.dataType == ObjectType(classOf[JTreeMap[_, _]])) } private val dataTypeForComplexData = dataTypeFor[ComplexData] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 8b20aced1f641..f38a2e766d84b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -204,10 +204,6 @@ abstract class SQLImplicits extends LowPrioritySQLImplicits { /** @since 2.2.0 */ implicit def newMapEncoder[T <: Map[_, _] : TypeTag]: Encoder[T] = ExpressionEncoder() - /** @since 2.2.0 */ - implicit def newJavaMapEncoder[T <: java.util.Map[_, _] : TypeTag]: Encoder[T] = - ExpressionEncoder() - // Arrays /** @since 1.6.1 */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index ca5f4e06a9816..8533afb624968 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -289,55 +289,6 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { checkDataset(Seq(LHMap(Tuple1(1) -> Tuple1(2))).toDS(), LHMap(Tuple1(1) -> Tuple1(2))) checkDataset(Seq(LHMap(1 -> Tuple1(2))).toDS(), LHMap(1 -> Tuple1(2))) checkDataset(Seq(LHMap("test" -> 2.toLong)).toDS(), LHMap("test" -> 2.toLong)) - - def JMap[K, V](tuples: (K, V)*): java.util.Map[K, V] = { - import scala.collection.JavaConverters._ - scala.collection.mutable.Map(tuples: _*).asJava - } - - checkDataset(Seq(JMap(1 -> 2)).toDS(), JMap(1 -> 2)) - checkDataset(Seq(JMap(1.toLong -> 2.toLong)).toDS(), JMap(1.toLong -> 2.toLong)) - checkDataset(Seq(JMap(1.toDouble -> 2.toDouble)).toDS(), JMap(1.toDouble -> 2.toDouble)) - checkDataset(Seq(JMap(1.toFloat -> 2.toFloat)).toDS(), JMap(1.toFloat -> 2.toFloat)) - checkDataset(Seq(JMap(1.toByte -> 2.toByte)).toDS(), JMap(1.toByte -> 2.toByte)) - checkDataset(Seq(JMap(1.toShort -> 2.toShort)).toDS(), JMap(1.toShort -> 2.toShort)) - checkDataset(Seq(JMap(true -> false)).toDS(), JMap(true -> false)) - checkDataset(Seq(JMap("test1" -> "test2")).toDS(), JMap("test1" -> "test2")) - checkDataset(Seq(JMap(Tuple1(1) -> Tuple1(2))).toDS(), JMap(Tuple1(1) -> Tuple1(2))) - checkDataset(Seq(JMap(1 -> Tuple1(2))).toDS(), JMap(1 -> Tuple1(2))) - checkDataset(Seq(JMap("test" -> 2.toLong)).toDS(), JMap("test" -> 2.toLong)) - - def JLHMap[K, V](tuples: (K, V)*): java.util.LinkedHashMap[K, V] = { - new java.util.LinkedHashMap[K, V](JMap(tuples: _*)) - } - - checkDataset(Seq(JLHMap(1 -> 2)).toDS(), JLHMap(1 -> 2)) - checkDataset(Seq(JLHMap(1.toLong -> 2.toLong)).toDS(), JLHMap(1.toLong -> 2.toLong)) - checkDataset(Seq(JLHMap(1.toDouble -> 2.toDouble)).toDS(), JLHMap(1.toDouble -> 2.toDouble)) - checkDataset(Seq(JLHMap(1.toFloat -> 2.toFloat)).toDS(), JLHMap(1.toFloat -> 2.toFloat)) - checkDataset(Seq(JLHMap(1.toByte -> 2.toByte)).toDS(), JLHMap(1.toByte -> 2.toByte)) - checkDataset(Seq(JLHMap(1.toShort -> 2.toShort)).toDS(), JLHMap(1.toShort -> 2.toShort)) - checkDataset(Seq(JLHMap(true -> false)).toDS(), JLHMap(true -> false)) - checkDataset(Seq(JLHMap("test1" -> "test2")).toDS(), JLHMap("test1" -> "test2")) - checkDataset(Seq(JLHMap(Tuple1(1) -> Tuple1(2))).toDS(), JLHMap(Tuple1(1) -> Tuple1(2))) - checkDataset(Seq(JLHMap(1 -> Tuple1(2))).toDS(), JLHMap(1 -> Tuple1(2))) - checkDataset(Seq(JLHMap("test" -> 2.toLong)).toDS(), JLHMap("test" -> 2.toLong)) - - def JTreeMap[K, V](tuples: (K, V)*): java.util.TreeMap[K, V] = { - new java.util.TreeMap[K, V](JMap(tuples: _*)) - } - - checkDataset(Seq(JTreeMap(1 -> 2)).toDS(), JTreeMap(1 -> 2)) - checkDataset(Seq(JTreeMap(1.toLong -> 2.toLong)).toDS(), JTreeMap(1.toLong -> 2.toLong)) - checkDataset(Seq(JTreeMap(1.toDouble -> 2.toDouble)).toDS(), - JTreeMap(1.toDouble -> 2.toDouble)) - checkDataset(Seq(JTreeMap(1.toFloat -> 2.toFloat)).toDS(), JTreeMap(1.toFloat -> 2.toFloat)) - checkDataset(Seq(JTreeMap(1.toByte -> 2.toByte)).toDS(), JTreeMap(1.toByte -> 2.toByte)) - checkDataset(Seq(JTreeMap(1.toShort -> 2.toShort)).toDS(), JTreeMap(1.toShort -> 2.toShort)) - checkDataset(Seq(JTreeMap(true -> false)).toDS(), JTreeMap(true -> false)) - checkDataset(Seq(JTreeMap("test1" -> "test2")).toDS(), JTreeMap("test1" -> "test2")) - checkDataset(Seq(JTreeMap(1 -> Tuple1(2))).toDS(), JTreeMap(1 -> Tuple1(2))) - checkDataset(Seq(JTreeMap("test" -> 2.toLong)).toDS(), JTreeMap("test" -> 2.toLong)) } ignore("SPARK-19104: map and product combinations") { From e47abc663eb70ec0247a00d5433eff79bf311046 Mon Sep 17 00:00:00 2001 From: Michal Senkyr Date: Sun, 4 Jun 2017 18:31:30 +0200 Subject: [PATCH 5/8] Set returnNullable to false on map key function Use customer map type for Java Map deserialization --- .../spark/sql/catalyst/ScalaReflection.scala | 3 ++- .../sql/catalyst/expressions/objects/objects.scala | 14 +++++++------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index e196203b062f0..3edb4080ea135 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -335,7 +335,8 @@ object ScalaReflection extends ScalaReflection { CollectObjectsToMap( p => deserializerFor(keyType, Some(p), walkedTypePath), - Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)), + Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType), + returnNullable = false), schemaFor(keyType).dataType, p => deserializerFor(valueType, Some(p), walkedTypePath), Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index e03d56724e230..a419e820d6fc2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -857,28 +857,28 @@ case class CollectObjectsToMap private( case cls if classOf[java.util.Map[_, _]] == cls || classOf[java.util.AbstractMap[_, _]] == cls => val builderClass = classOf[java.util.HashMap[_, _]].getName - s"$builderClass $builderValue = new $builderClass($dataLength);" + s"${collClass.getName} $builderValue = new $builderClass($dataLength);" // Java SortedMap, NavigableMap => TreeMap case cls if classOf[java.util.SortedMap[_, _]] == cls || classOf[java.util.NavigableMap[_, _]] == cls => val builderClass = classOf[java.util.TreeMap[_, _]].getName - s"$builderClass $builderValue = new $builderClass();" + s"${collClass.getName} $builderValue = new $builderClass();" // Java ConcurrentMap => ConcurrentHashMap case cls if classOf[java.util.concurrent.ConcurrentMap[_, _]] == cls => val builderClass = classOf[java.util.concurrent.ConcurrentHashMap[_, _]].getName - s"$builderClass $builderValue = new $builderClass();" + s"${collClass.getName} $builderValue = new $builderClass();" // Java ConcurrentNavigableMap => ConcurrentSkipListMap case cls if classOf[java.util.concurrent.ConcurrentNavigableMap[_, _]] == cls => val builderClass = classOf[java.util.concurrent.ConcurrentSkipListMap[_, _]].getName - s"$builderClass $builderValue = new $builderClass();" + s"${collClass.getName} $builderValue = new $builderClass();" // Java concrete Map implementation case cls => val builderClass = classOf[java.util.Map[_, _]].getName // Check for constructor with capacity specification if (Try(cls.getConstructor(Integer.TYPE)).isSuccess) { - s"$builderClass $builderValue = new ${cls.getName}($dataLength);" + s"${collClass.getName} $builderValue = new ${cls.getName}($dataLength);" } else { - s"$builderClass $builderValue = new ${cls.getName}();" + s"${collClass.getName} $builderValue = new ${cls.getName}();" } } @@ -898,7 +898,7 @@ case class CollectObjectsToMap private( """ -> s"${ev.value} = (${collClass.getName}) $builderValue.result();" } else { s"$builderValue.put($genKeyFunctionValue, $genValueFunctionValue);" -> - s"${ev.value} = (${collClass.getName}) $builderValue;" + s"${ev.value} = $builderValue;" } val code = s""" From 25ec2f0ca09f63d214e932af29371ebd2f81f840 Mon Sep 17 00:00:00 2001 From: Michal Senkyr Date: Sun, 4 Jun 2017 21:33:03 +0200 Subject: [PATCH 6/8] Remove parameters containing element data types Replace key/value input data with map input data Remove null check for keys Remove unneeded code for sequences --- .../spark/sql/catalyst/ScalaReflection.scala | 8 +- .../expressions/objects/objects.scala | 141 ++++-------------- 2 files changed, 32 insertions(+), 117 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 3edb4080ea135..019db906f14fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -335,12 +335,8 @@ object ScalaReflection extends ScalaReflection { CollectObjectsToMap( p => deserializerFor(keyType, Some(p), walkedTypePath), - Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType), - returnNullable = false), - schemaFor(keyType).dataType, p => deserializerFor(valueType, Some(p), walkedTypePath), - Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)), - schemaFor(valueType).dataType, + getPath, mirror.runtimeClass(t.typeSymbol.asClass) ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index a419e820d6fc2..d21108b1a06a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -660,34 +660,28 @@ object CollectObjectsToMap { * Construct an instance of CollectObjects case class. * * @param keyFunction The function applied on the key collection elements. - * @param keyInputData An expression that when evaluated returns a key collection object. - * @param keyElementType The data type of key elements in the collection. * @param valueFunction The function applied on the value collection elements. - * @param valueInputData An expression that when evaluated returns a value collection object. - * @param valueElementType The data type of value elements in the collection. + * @param inputData An expression that when evaluated returns a map object. * @param collClass The type of the resulting collection. */ def apply( keyFunction: Expression => Expression, - keyInputData: Expression, - keyElementType: DataType, valueFunction: Expression => Expression, - valueInputData: Expression, - valueElementType: DataType, + inputData: Expression, collClass: Class[_]): CollectObjectsToMap = { val id = curId.getAndIncrement() val keyLoopValue = s"CollectObjectsToMap_keyLoopValue$id" - val keyLoopIsNull = s"CollectObjectsToMap_keyLoopIsNull$id" - val keyLoopVar = LambdaVariable(keyLoopValue, keyLoopIsNull, keyElementType) + val mapType = inputData.dataType.asInstanceOf[MapType] + val keyLoopVar = LambdaVariable(keyLoopValue, "", mapType.keyType, nullable = false) val valueLoopValue = s"CollectObjectsToMap_valueLoopValue$id" val valueLoopIsNull = s"CollectObjectsToMap_valueLoopIsNull$id" - val valueLoopVar = LambdaVariable(valueLoopValue, valueLoopIsNull, valueElementType) + val valueLoopVar = LambdaVariable(valueLoopValue, valueLoopIsNull, mapType.valueType) val tupleLoopVar = s"CollectObjectsToMap_tupleLoopValue$id" val builderValue = s"CollectObjectsToMap_builderValue$id" CollectObjectsToMap( - keyLoopValue, keyLoopIsNull, keyElementType, keyFunction(keyLoopVar), keyInputData, - valueLoopValue, valueLoopIsNull, valueElementType, valueFunction(valueLoopVar), - valueInputData, + keyLoopValue, keyFunction(keyLoopVar), + valueLoopValue, valueLoopIsNull, valueFunction(valueLoopVar), + inputData, tupleLoopVar, collClass, builderValue) } } @@ -699,25 +693,16 @@ object CollectObjectsToMap { * * @param keyLoopValue the name of the loop variable that is used when iterating over the key * collection, and which is used as input for the `keyLambdaFunction` - * @param keyLoopIsNull the nullability of the loop variable that is used when iterating over - * the key collection, and which is used as input for the `keyLambdaFunction` - * @param keyLoopVarDataType the data type of the loop variable that is used when iterating over - * the key collection, and which is used as input for the - * `keyLambdaFunction` * @param keyLambdaFunction A function that takes the `keyLoopVar` as input, and is used as * a lambda function to handle collection elements. - * @param keyInputData An expression that when evaluated returns a collection object. * @param valueLoopValue the name of the loop variable that is used when iterating over the value * collection, and which is used as input for the `valueLambdaFunction` * @param valueLoopIsNull the nullability of the loop variable that is used when iterating over * the value collection, and which is used as input for the * `valueLambdaFunction` - * @param valueLoopVarDataType the data type of the loop variable that is used when iterating over - * the value collection, and which is used as input for the - * `valueLambdaFunction` * @param valueLambdaFunction A function that takes the `valueLoopVar` as input, and is used as * a lambda function to handle collection elements. - * @param valueInputData An expression that when evaluated returns a collection object. + * @param inputData An expression that when evaluated returns a map object. * @param tupleLoopValue the name of the loop variable that holds the tuple to be added to the * resulting map (used only for Scala Map) * @param collClass The type of the resulting collection. @@ -725,23 +710,19 @@ object CollectObjectsToMap { */ case class CollectObjectsToMap private( keyLoopValue: String, - keyLoopIsNull: String, - keyLoopVarDataType: DataType, keyLambdaFunction: Expression, - keyInputData: Expression, valueLoopValue: String, valueLoopIsNull: String, - valueLoopVarDataType: DataType, valueLambdaFunction: Expression, - valueInputData: Expression, + inputData: Expression, tupleLoopValue: String, collClass: Class[_], builderValue: String) extends Expression with NonSQLExpression { - override def nullable: Boolean = keyInputData.nullable + override def nullable: Boolean = inputData.nullable override def children: Seq[Expression] = - keyLambdaFunction :: keyInputData :: valueLambdaFunction :: valueInputData :: Nil + keyLambdaFunction :: valueLambdaFunction :: inputData :: Nil override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported") @@ -749,73 +730,36 @@ case class CollectObjectsToMap private( override def dataType: DataType = ObjectType(collClass) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val keyElementJavaType = ctx.javaType(keyLoopVarDataType) - ctx.addMutableState("boolean", keyLoopIsNull, "") + val mapType = inputData.dataType.asInstanceOf[MapType] + val keyElementJavaType = ctx.javaType(mapType.keyType) ctx.addMutableState(keyElementJavaType, keyLoopValue, "") - val genKeyInputData = keyInputData.genCode(ctx) val genKeyFunction = keyLambdaFunction.genCode(ctx) - val valueElementJavaType = ctx.javaType(valueLoopVarDataType) + val valueElementJavaType = ctx.javaType(mapType.valueType) ctx.addMutableState("boolean", valueLoopIsNull, "") ctx.addMutableState(valueElementJavaType, valueLoopValue, "") - val genValueInputData = valueInputData.genCode(ctx) val genValueFunction = valueLambdaFunction.genCode(ctx) + val genInputData = inputData.genCode(ctx) val dataLength = ctx.freshName("dataLength") val loopIndex = ctx.freshName("loopIndex") - // In RowEncoder, we use `Object` to represent Array or Seq, so we need to determine the type - // of input collection at runtime for this case. - val keySeq = ctx.freshName("keySeq") val keyArray = ctx.freshName("keyArray") - val valueSeq = ctx.freshName("valueSeq") val valueArray = ctx.freshName("valueArray") - def determineCollectionType(inputData: Expression, genInputData: ExprCode, - elementJavaType: String, seq: String, array: String) = - inputData.dataType match { - case ObjectType(cls) if cls == classOf[Object] => - val seqClass = classOf[Seq[_]].getName - s""" - $seqClass $seq = null; - $elementJavaType[] $array = null; - if (${genInputData.value}.getClass().isArray()) { - $array = ($elementJavaType[]) ${genInputData.value}; - } else { - $seq = ($seqClass) ${genInputData.value}; - } - """ - case _ => "" - } - val determineKeyCollectionType = determineCollectionType( - keyInputData, genKeyInputData, keyElementJavaType, keySeq, keyArray) - val determineValueCollectionType = determineCollectionType( - valueInputData, genValueInputData, valueElementJavaType, valueSeq, valueArray) // The data with PythonUserDefinedType are actually stored with the data type of its sqlType. // When we want to apply MapObjects on it, we have to use it. - def inputDataType(inputData: Expression) = inputData.dataType match { + def inputDataType(dataType: DataType) = dataType match { case p: PythonUserDefinedType => p.sqlType - case _ => inputData.dataType + case _ => dataType } - val keyInputDataType = inputDataType(keyInputData) - val valueInputDataType = inputDataType(valueInputData) - - def lengthAndLoopVar(inputDataType: DataType, genInputData: ExprCode, - seq: String, array: String) = - inputDataType match { - case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) => - s"${genInputData.value}.size()" -> s"${genInputData.value}.apply($loopIndex)" - case ObjectType(cls) if cls.isArray => - s"${genInputData.value}.length" -> s"${genInputData.value}[$loopIndex]" - case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) => - s"${genInputData.value}.size()" -> s"${genInputData.value}.get($loopIndex)" - case ArrayType(et, _) => - s"${genInputData.value}.numElements()" -> ctx.getValue(genInputData.value, et, loopIndex) - case ObjectType(cls) if cls == classOf[Object] => - s"$seq == null ? $array.length : $seq.size()" -> - s"$seq == null ? $array[$loopIndex] : $seq.apply($loopIndex)" - } + + def lengthAndLoopVar(elementType: DataType, genInputData: ExprCode, method: String, + array: String) = + s"${genInputData.value}.$method().numElements()" -> + ctx.getValue(s"${genInputData.value}.$method()", elementType, loopIndex) + val ((getKeyLength, getKeyLoopVar), (getValueLength, getValueLoopVar)) = ( - lengthAndLoopVar(inputDataType(keyInputData), genKeyInputData, keySeq, keyArray), - lengthAndLoopVar(inputDataType(valueInputData), genValueInputData, valueSeq, valueArray) + lengthAndLoopVar(inputDataType(mapType.keyType), genInputData, "keyArray", keyArray), + lengthAndLoopVar(inputDataType(mapType.valueType), genInputData, "valueArray", valueArray) ) // Make a copy of the data if it's unsafe-backed @@ -831,19 +775,8 @@ case class CollectObjectsToMap private( val genKeyFunctionValue = genFunctionValue(keyLambdaFunction, genKeyFunction) val genValueFunctionValue = genFunctionValue(valueLambdaFunction, genValueFunction) - def loopNullCheck(genInputData: ExprCode, inputDataType: DataType, - loopIsNull: String, loopValue: String) = - inputDataType match { - case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);" - // The element of primitive array will never be null. - case ObjectType(cls) if cls.isArray && cls.getComponentType.isPrimitive => - s"$loopIsNull = false" - case _ => s"$loopIsNull = $loopValue == null;" - } - val keyLoopNullCheck = - loopNullCheck(genKeyInputData, keyInputDataType, keyLoopIsNull, keyLoopValue) val valueLoopNullCheck = - loopNullCheck(genValueInputData, valueInputDataType, valueLoopIsNull, valueLoopValue) + s"$valueLoopIsNull = ${genInputData.value}.valueArray().isNullAt($loopIndex);" val constructBuilder = collClass match { // Scala Map @@ -873,7 +806,6 @@ case class CollectObjectsToMap private( s"${collClass.getName} $builderValue = new $builderClass();" // Java concrete Map implementation case cls => - val builderClass = classOf[java.util.Map[_, _]].getName // Check for constructor with capacity specification if (Try(cls.getConstructor(Integer.TYPE)).isSuccess) { s"${collClass.getName} $builderValue = new ${cls.getName}($dataLength);" @@ -902,18 +834,10 @@ case class CollectObjectsToMap private( } val code = s""" - ${genKeyInputData.code} - ${genValueInputData.code} + ${genInputData.code} ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if ((${genKeyInputData.isNull} && !${genValueInputData.isNull}) || - (!${genKeyInputData.isNull} && ${genValueInputData.isNull})) { - throw new RuntimeException("Invalid state: Inconsistent nullability of key-value"); - } - - if (!${genKeyInputData.isNull}) { - $determineKeyCollectionType - $determineValueCollectionType + if (!${genInputData.isNull}) { if ($getKeyLength != $getValueLength) { throw new RuntimeException("Invalid state: Inconsistent lengths of key-value arrays"); } @@ -924,16 +848,11 @@ case class CollectObjectsToMap private( while ($loopIndex < $dataLength) { $keyLoopValue = ($keyElementJavaType) ($getKeyLoopVar); $valueLoopValue = ($valueElementJavaType) ($getValueLoopVar); - $keyLoopNullCheck $valueLoopNullCheck ${genKeyFunction.code} ${genValueFunction.code} - if (${genKeyFunction.isNull}) { - throw new RuntimeException("Found null in map key!"); - } - $appendToBuilder $loopIndex += 1; @@ -942,7 +861,7 @@ case class CollectObjectsToMap private( $getBuilderResult } """ - ev.copy(code = code, isNull = genKeyInputData.isNull) + ev.copy(code = code, isNull = genInputData.isNull) } } From dbdcb9c70fb2a2d503cc0c15a9f168b886c33a50 Mon Sep 17 00:00:00 2001 From: Michal Senkyr Date: Sat, 10 Jun 2017 09:37:28 +0200 Subject: [PATCH 7/8] Remove Java Map specific code Bump version in scaladoc Minor alterations based on code review --- .../expressions/objects/objects.scala | 82 +++++-------------- .../org/apache/spark/sql/SQLImplicits.scala | 2 +- 2 files changed, 21 insertions(+), 63 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index d21108b1a06a7..a3f32efd23459 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -22,7 +22,6 @@ import java.lang.reflect.Modifier import scala.collection.mutable.Builder import scala.language.existentials import scala.reflect.ClassTag -import scala.util.Try import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.serializer._ @@ -657,7 +656,7 @@ object CollectObjectsToMap { private val curId = new java.util.concurrent.atomic.AtomicInteger() /** - * Construct an instance of CollectObjects case class. + * Construct an instance of CollectObjectsToMap case class. * * @param keyFunction The function applied on the key collection elements. * @param valueFunction The function applied on the value collection elements. @@ -676,13 +675,10 @@ object CollectObjectsToMap { val valueLoopValue = s"CollectObjectsToMap_valueLoopValue$id" val valueLoopIsNull = s"CollectObjectsToMap_valueLoopIsNull$id" val valueLoopVar = LambdaVariable(valueLoopValue, valueLoopIsNull, mapType.valueType) - val tupleLoopVar = s"CollectObjectsToMap_tupleLoopValue$id" - val builderValue = s"CollectObjectsToMap_builderValue$id" CollectObjectsToMap( keyLoopValue, keyFunction(keyLoopVar), valueLoopValue, valueLoopIsNull, valueFunction(valueLoopVar), - inputData, - tupleLoopVar, collClass, builderValue) + inputData, collClass) } } @@ -703,10 +699,7 @@ object CollectObjectsToMap { * @param valueLambdaFunction A function that takes the `valueLoopVar` as input, and is used as * a lambda function to handle collection elements. * @param inputData An expression that when evaluated returns a map object. - * @param tupleLoopValue the name of the loop variable that holds the tuple to be added to the - * resulting map (used only for Scala Map) * @param collClass The type of the resulting collection. - * @param builderValue The name of the builder variable used to construct the resulting collection. */ case class CollectObjectsToMap private( keyLoopValue: String, @@ -715,9 +708,7 @@ case class CollectObjectsToMap private( valueLoopIsNull: String, valueLambdaFunction: Expression, inputData: Expression, - tupleLoopValue: String, - collClass: Class[_], - builderValue: String) extends Expression with NonSQLExpression { + collClass: Class[_]) extends Expression with NonSQLExpression { override def nullable: Boolean = inputData.nullable @@ -741,6 +732,8 @@ case class CollectObjectsToMap private( val genInputData = inputData.genCode(ctx) val dataLength = ctx.freshName("dataLength") val loopIndex = ctx.freshName("loopIndex") + val tupleLoopValue = ctx.freshName("tupleLoopValue") + val builderValue = ctx.freshName("builderValue") val keyArray = ctx.freshName("keyArray") val valueArray = ctx.freshName("valueArray") @@ -778,61 +771,26 @@ case class CollectObjectsToMap private( val valueLoopNullCheck = s"$valueLoopIsNull = ${genInputData.value}.valueArray().isNullAt($loopIndex);" - val constructBuilder = collClass match { - // Scala Map - case cls if classOf[scala.collection.Map[_, _]].isAssignableFrom(cls) => - val builderClass = classOf[Builder[_, _]].getName - s""" - $builderClass $builderValue = ${collClass.getName}$$.MODULE$$.newBuilder(); - $builderValue.sizeHint($dataLength); - """ - // Java Map, AbstractMap => HashMap - case cls if classOf[java.util.Map[_, _]] == cls || - classOf[java.util.AbstractMap[_, _]] == cls => - val builderClass = classOf[java.util.HashMap[_, _]].getName - s"${collClass.getName} $builderValue = new $builderClass($dataLength);" - // Java SortedMap, NavigableMap => TreeMap - case cls if classOf[java.util.SortedMap[_, _]] == cls || - classOf[java.util.NavigableMap[_, _]] == cls => - val builderClass = classOf[java.util.TreeMap[_, _]].getName - s"${collClass.getName} $builderValue = new $builderClass();" - // Java ConcurrentMap => ConcurrentHashMap - case cls if classOf[java.util.concurrent.ConcurrentMap[_, _]] == cls => - val builderClass = classOf[java.util.concurrent.ConcurrentHashMap[_, _]].getName - s"${collClass.getName} $builderValue = new $builderClass();" - // Java ConcurrentNavigableMap => ConcurrentSkipListMap - case cls if classOf[java.util.concurrent.ConcurrentNavigableMap[_, _]] == cls => - val builderClass = classOf[java.util.concurrent.ConcurrentSkipListMap[_, _]].getName - s"${collClass.getName} $builderValue = new $builderClass();" - // Java concrete Map implementation - case cls => - // Check for constructor with capacity specification - if (Try(cls.getConstructor(Integer.TYPE)).isSuccess) { - s"${collClass.getName} $builderValue = new ${cls.getName}($dataLength);" - } else { - s"${collClass.getName} $builderValue = new ${cls.getName}();" - } - } - - val (appendToBuilder, getBuilderResult) = - if (classOf[scala.collection.Map[_, _]].isAssignableFrom(collClass)) { - val tupleClass = classOf[(_, _)].getName - s""" - $tupleClass $tupleLoopValue; + val builderClass = classOf[Builder[_, _]].getName + val constructBuilder = s""" + $builderClass $builderValue = ${collClass.getName}$$.MODULE$$.newBuilder(); + $builderValue.sizeHint($dataLength); + """ - if (${genValueFunction.isNull}) { - $tupleLoopValue = new $tupleClass($genKeyFunctionValue, null); - } else { - $tupleLoopValue = new $tupleClass($genKeyFunctionValue, $genValueFunctionValue); - } + val tupleClass = classOf[(_, _)].getName + val appendToBuilder = s""" + $tupleClass $tupleLoopValue; - $builderValue.$$plus$$eq($tupleLoopValue); - """ -> s"${ev.value} = (${collClass.getName}) $builderValue.result();" + if (${genValueFunction.isNull}) { + $tupleLoopValue = new $tupleClass($genKeyFunctionValue, null); } else { - s"$builderValue.put($genKeyFunctionValue, $genValueFunctionValue);" -> - s"${ev.value} = $builderValue;" + $tupleLoopValue = new $tupleClass($genKeyFunctionValue, $genValueFunctionValue); } + $builderValue.$$plus$$eq($tupleLoopValue); + """ + val getBuilderResult = s"${ev.value} = (${collClass.getName}) $builderValue.result();" + val code = s""" ${genInputData.code} ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 0bd77f065a6c5..86574e2f71d92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -168,7 +168,7 @@ abstract class SQLImplicits extends LowPrioritySQLImplicits { implicit def newSequenceEncoder[T <: Seq[_] : TypeTag]: Encoder[T] = ExpressionEncoder() // Maps - /** @since 2.2.0 */ + /** @since 2.3.0 */ implicit def newMapEncoder[T <: Map[_, _] : TypeTag]: Encoder[T] = ExpressionEncoder() // Arrays From e37e0cac53810d6c4e694302fb08d2219bc4bccb Mon Sep 17 00:00:00 2001 From: Michal Senkyr Date: Sat, 10 Jun 2017 23:38:07 +0200 Subject: [PATCH 8/8] Store key/value arrays in local vars Use size of map instead of length of key/value arrays Add Python UDT resolution to map type Add nested map tests Update scaladoc --- .../expressions/objects/objects.scala | 50 +++++++++---------- .../spark/sql/DatasetPrimitiveSuite.scala | 5 ++ 2 files changed, 28 insertions(+), 27 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index a3f32efd23459..79b7b9f3d0e16 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} import org.apache.spark.sql.types._ /** @@ -683,8 +683,8 @@ object CollectObjectsToMap { } /** - * An equivalent to the [[MapObjects]] case class but returning an ObjectType containing - * a Scala collection constructed using the associated builder, obtained by calling `newBuilder` + * Expression used to convert a Catalyst Map to an external Scala Map. + * The collection is constructed using the associated builder, obtained by calling `newBuilder` * on the collection's companion object. * * @param keyLoopValue the name of the loop variable that is used when iterating over the key @@ -721,7 +721,14 @@ case class CollectObjectsToMap private( override def dataType: DataType = ObjectType(collClass) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val mapType = inputData.dataType.asInstanceOf[MapType] + // The data with PythonUserDefinedType are actually stored with the data type of its sqlType. + // When we want to apply MapObjects on it, we have to use it. + def inputDataType(dataType: DataType) = dataType match { + case p: PythonUserDefinedType => p.sqlType + case _ => dataType + } + + val mapType = inputDataType(inputData.dataType).asInstanceOf[MapType] val keyElementJavaType = ctx.javaType(mapType.keyType) ctx.addMutableState(keyElementJavaType, keyLoopValue, "") val genKeyFunction = keyLambdaFunction.genCode(ctx) @@ -735,25 +742,16 @@ case class CollectObjectsToMap private( val tupleLoopValue = ctx.freshName("tupleLoopValue") val builderValue = ctx.freshName("builderValue") + val getLength = s"${genInputData.value}.numElements()" + val keyArray = ctx.freshName("keyArray") val valueArray = ctx.freshName("valueArray") - - // The data with PythonUserDefinedType are actually stored with the data type of its sqlType. - // When we want to apply MapObjects on it, we have to use it. - def inputDataType(dataType: DataType) = dataType match { - case p: PythonUserDefinedType => p.sqlType - case _ => dataType - } - - def lengthAndLoopVar(elementType: DataType, genInputData: ExprCode, method: String, - array: String) = - s"${genInputData.value}.$method().numElements()" -> - ctx.getValue(s"${genInputData.value}.$method()", elementType, loopIndex) - - val ((getKeyLength, getKeyLoopVar), (getValueLength, getValueLoopVar)) = ( - lengthAndLoopVar(inputDataType(mapType.keyType), genInputData, "keyArray", keyArray), - lengthAndLoopVar(inputDataType(mapType.valueType), genInputData, "valueArray", valueArray) - ) + val getKeyArray = + s"${classOf[ArrayData].getName} $keyArray = ${genInputData.value}.keyArray();" + val getKeyLoopVar = ctx.getValue(keyArray, inputDataType(mapType.keyType), loopIndex) + val getValueArray = + s"${classOf[ArrayData].getName} $valueArray = ${genInputData.value}.valueArray();" + val getValueLoopVar = ctx.getValue(valueArray, inputDataType(mapType.valueType), loopIndex) // Make a copy of the data if it's unsafe-backed def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: String) = @@ -768,8 +766,7 @@ case class CollectObjectsToMap private( val genKeyFunctionValue = genFunctionValue(keyLambdaFunction, genKeyFunction) val genValueFunctionValue = genFunctionValue(valueLambdaFunction, genValueFunction) - val valueLoopNullCheck = - s"$valueLoopIsNull = ${genInputData.value}.valueArray().isNullAt($loopIndex);" + val valueLoopNullCheck = s"$valueLoopIsNull = $valueArray.isNullAt($loopIndex);" val builderClass = classOf[Builder[_, _]].getName val constructBuilder = s""" @@ -796,11 +793,10 @@ case class CollectObjectsToMap private( ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${genInputData.isNull}) { - if ($getKeyLength != $getValueLength) { - throw new RuntimeException("Invalid state: Inconsistent lengths of key-value arrays"); - } - int $dataLength = $getKeyLength; + int $dataLength = $getLength; $constructBuilder + $getKeyArray + $getValueArray int $loopIndex = 0; while ($loopIndex < $dataLength) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index beefe135f19a5..4126660b5d102 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -344,6 +344,11 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { checkDataset(Seq(List(Queue(1))).toDS(), List(Queue(1))) } + test("nested maps") { + checkDataset(Seq(Map(1 -> LHMap(2 -> 3))).toDS(), Map(1 -> LHMap(2 -> 3))) + checkDataset(Seq(LHMap(Map(1 -> 2) -> 3)).toDS(), LHMap(Map(1 -> 2) -> 3)) + } + test("package objects") { import packageobject._ checkDataset(Seq(PackageClass(1)).toDS(), PackageClass(1))