From 5c534727b0d72015104c242e369d7edc5b0fe910 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 7 Nov 2016 11:34:28 +0900 Subject: [PATCH 1/2] Make to_json expression/function null safe --- .../sql/catalyst/expressions/jsonExpressions.scala | 14 +++++--------- .../expressions/JsonExpressionsSuite.scala | 13 +++++++++++-- .../org/apache/spark/sql/JsonFunctionsSuite.scala | 12 ++++++++++++ 3 files changed, 28 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 89fe7c48c000e..b61583d0dafb6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -484,7 +484,7 @@ case class JsonTuple(children: Seq[Expression]) * Converts an json input string to a [[StructType]] with the specified schema. */ case class JsonToStruct(schema: StructType, options: Map[String, String], child: Expression) - extends Expression with CodegenFallback with ExpectsInputTypes { + extends UnaryExpression with CodegenFallback with ExpectsInputTypes { override def nullable: Boolean = true @transient @@ -495,11 +495,8 @@ case class JsonToStruct(schema: StructType, options: Map[String, String], child: new JSONOptions(options ++ Map("mode" -> ParseModes.FAIL_FAST_MODE))) override def dataType: DataType = schema - override def children: Seq[Expression] = child :: Nil - override def eval(input: InternalRow): Any = { - val json = child.eval(input) - if (json == null) return null + override def nullSafeEval(json: Any): Any = { try parser.parse(json.toString).head catch { case _: SparkSQLJsonProcessingException => null } @@ -512,7 +509,7 @@ case class JsonToStruct(schema: StructType, options: Map[String, String], child: * Converts a [[StructType]] to a json output string. */ case class StructToJson(options: Map[String, String], child: Expression) - extends Expression with CodegenFallback with ExpectsInputTypes { + extends UnaryExpression with CodegenFallback with ExpectsInputTypes { override def nullable: Boolean = true @transient @@ -523,7 +520,6 @@ case class StructToJson(options: Map[String, String], child: Expression) new JacksonGenerator(child.dataType.asInstanceOf[StructType], writer) override def dataType: DataType = StringType - override def children: Seq[Expression] = child :: Nil override def checkInputDataTypes(): TypeCheckResult = { if (StructType.acceptsType(child.dataType)) { @@ -540,8 +536,8 @@ case class StructToJson(options: Map[String, String], child: Expression) } } - override def eval(input: InternalRow): Any = { - gen.write(child.eval(input).asInstanceOf[InternalRow]) + override def nullSafeEval(row: Any): Any = { + gen.write(row.asInstanceOf[InternalRow]) gen.flush() val json = writer.toString writer.reset() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 3bfa0bfda6209..3b0e90824b766 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.ParseModes -import org.apache.spark.sql.types.{IntegerType, StructField, StructType} +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} import org.apache.spark.unsafe.types.UTF8String class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -347,7 +347,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("from_json null input column") { val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStruct(schema, Map.empty, Literal(null)), + JsonToStruct(schema, Map.empty, Literal.create(null, StringType)), null ) } @@ -360,4 +360,13 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { """{"a":1}""" ) } + + test("to_json null input column") { + val schema = StructType(StructField("a", IntegerType) :: Nil) + val struct = Literal.create(null, schema) + checkEvaluation( + StructToJson(Map.empty, struct), + null + ) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 59ae889cf3b92..ba048c4f4558b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -141,4 +141,16 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { assert(e.getMessage.contains( "Unable to convert column a of type calendarinterval to JSON.")) } + + test("to_json and from_json roundtrip") { + val dfOne = Seq(Some(Tuple1(Tuple1(1))), None).toDF("a") + val readBackOne = dfOne.select(to_json($"a").as("b")) + .select(from_json($"b", dfOne.schema.head.dataType.asInstanceOf[StructType])) + checkAnswer(dfOne, readBackOne) + + val dfTwo = Seq(Some("""{"a":1}"""), None).toDF("value") + val schema = new StructType().add("a", IntegerType) + val readBackTwo = dfTwo.select(from_json($"value", schema).as("b")).select(to_json($"b")) + checkAnswer(dfTwo, readBackTwo) + } } From ce0eddae4ee03002642c60cd21cc858ab4ae12a2 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 7 Nov 2016 11:53:51 +0900 Subject: [PATCH 2/2] Clean up the test --- .../apache/spark/sql/JsonFunctionsSuite.scala | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index ba048c4f4558b..7d63d31d9b979 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -142,15 +142,17 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { "Unable to convert column a of type calendarinterval to JSON.")) } - test("to_json and from_json roundtrip") { - val dfOne = Seq(Some(Tuple1(Tuple1(1))), None).toDF("a") - val readBackOne = dfOne.select(to_json($"a").as("b")) - .select(from_json($"b", dfOne.schema.head.dataType.asInstanceOf[StructType])) + test("roundtrip in to_json and from_json") { + val dfOne = Seq(Some(Tuple1(Tuple1(1))), None).toDF("struct") + val schemaOne = dfOne.schema(0).dataType.asInstanceOf[StructType] + val readBackOne = dfOne.select(to_json($"struct").as("json")) + .select(from_json($"json", schemaOne).as("struct")) checkAnswer(dfOne, readBackOne) - val dfTwo = Seq(Some("""{"a":1}"""), None).toDF("value") - val schema = new StructType().add("a", IntegerType) - val readBackTwo = dfTwo.select(from_json($"value", schema).as("b")).select(to_json($"b")) + val dfTwo = Seq(Some("""{"a":1}"""), None).toDF("json") + val schemaTwo = new StructType().add("a", IntegerType) + val readBackTwo = dfTwo.select(from_json($"json", schemaTwo).as("struct")) + .select(to_json($"struct").as("json")) checkAnswer(dfTwo, readBackTwo) } }