From f1bc30df4e2625b52f19b10b5bb0622d85012ea9 Mon Sep 17 00:00:00 2001 From: dsolow Date: Thu, 18 Mar 2021 19:41:48 -0400 Subject: [PATCH 1/5] Add AtomicInteger to make var names unique --- .../scala/org/apache/spark/sql/functions.scala | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 2c4b81c5df908..bdbb0c5cdb2e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import java.util.concurrent.atomic.AtomicInteger + import scala.collection.JavaConverters._ import scala.reflect.runtime.universe.TypeTag import scala.util.Try @@ -3799,23 +3801,26 @@ object functions { ArrayExcept(col1.expr, col2.expr) } + // counter to ensure lambdra variable names unique + private val lambdaVarNameCounter = new AtomicInteger(0) + private def createLambda(f: Column => Column) = { - val x = UnresolvedNamedLambdaVariable(Seq("x")) + val x = UnresolvedNamedLambdaVariable(Seq("x_" + lambdaVarNameCounter.incrementAndGet())) val function = f(Column(x)).expr LambdaFunction(function, Seq(x)) } private def createLambda(f: (Column, Column) => Column) = { - val x = UnresolvedNamedLambdaVariable(Seq("x")) - val y = UnresolvedNamedLambdaVariable(Seq("y")) + val x = UnresolvedNamedLambdaVariable(Seq("x_" + lambdaVarNameCounter.incrementAndGet())) + val y = UnresolvedNamedLambdaVariable(Seq("y_" + lambdaVarNameCounter.incrementAndGet())) val function = f(Column(x), Column(y)).expr LambdaFunction(function, Seq(x, y)) } private def createLambda(f: (Column, Column, Column) => Column) = { - val x = UnresolvedNamedLambdaVariable(Seq("x")) - val y = UnresolvedNamedLambdaVariable(Seq("y")) - val z = UnresolvedNamedLambdaVariable(Seq("z")) + val x = UnresolvedNamedLambdaVariable(Seq("x_" + lambdaVarNameCounter.incrementAndGet())) + val y = UnresolvedNamedLambdaVariable(Seq("y_" + lambdaVarNameCounter.incrementAndGet())) + val z = UnresolvedNamedLambdaVariable(Seq("z_" + lambdaVarNameCounter.incrementAndGet())) val function = f(Column(x), Column(y), Column(z)).expr LambdaFunction(function, Seq(x, y, z)) } From 2cd874ab1a6e97188b9e54a19f6658410312ed45 Mon Sep 17 00:00:00 2001 From: dsolow Date: Thu, 18 Mar 2021 23:23:07 -0400 Subject: [PATCH 2/5] added test to verify nested transform --- .../spark/sql/DataFrameFunctionsSuite.scala | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 70dc0d09bcad5..3642804f5a269 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2261,6 +2261,32 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { assert(ex3.getMessage.contains("cannot resolve 'a'")) } + test("nested transform (DSL)") { + val df = Seq( + (Seq(1, 2, 3), Seq("a", "b", "c")) + ).toDF("numbers", "letters") + + checkAnswer( + df.select( + flatten( + transform( + $"numbers", + (number: Column) => transform( + $"letters", + (letter: Column) => struct( + number.as("number"), + letter.as("letter") + ) + ) + ) + ).as("zipped") + ), + Seq(Row(Seq(Row(1, "a"), Row(1, "b"), Row(1, "c"), Row(2, "a"), Row(2, "b"), + Row(2, "c"), Row(3, "a"), Row(3, "b"), Row(3, "c") + ))) + ) + } + test("map_filter") { val dfInts = Seq( Map(1 -> 10, 2 -> 20, 3 -> 30), From e9a398bb87d0c23a4cf70a76898b8f4801ba8566 Mon Sep 17 00:00:00 2001 From: dmsolow Date: Fri, 19 Mar 2021 08:27:21 -0400 Subject: [PATCH 3/5] fixed typo --- sql/core/src/main/scala/org/apache/spark/sql/functions.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index bdbb0c5cdb2e9..cb3824ff67810 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3801,7 +3801,7 @@ object functions { ArrayExcept(col1.expr, col2.expr) } - // counter to ensure lambdra variable names unique + // counter to ensure lambda variable names are unique private val lambdaVarNameCounter = new AtomicInteger(0) private def createLambda(f: Column => Column) = { From 0164e0f9fccfc1854eb48c940c7e43cc4c1567d2 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Mon, 3 May 2021 23:53:06 +0900 Subject: [PATCH 4/5] Fix --- .../expressions/higherOrderFunctions.scala | 12 ++++- .../org/apache/spark/sql/functions.scala | 17 +++---- .../spark/sql/DataFrameFunctionsSuite.scala | 49 +++++++++---------- 3 files changed, 40 insertions(+), 38 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index a0f9dc2f58b20..6920e58d4e39f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Comparator -import java.util.concurrent.atomic.AtomicReference +import java.util.concurrent.atomic.{AtomicInteger, AtomicReference} import scala.collection.mutable @@ -54,6 +54,16 @@ case class UnresolvedNamedLambdaVariable(nameParts: Seq[String]) override def sql: String = name } +object UnresolvedNamedLambdaVariable { + + // Counter to ensure lambda variable names are unique + private val nextVarNameId = new AtomicInteger(0) + + def freshVarName(name: String): String = { + s"${name}_${nextVarNameId.getAndIncrement()}" + } +} + /** * A named lambda variable. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index cb3824ff67810..bd222d150d46a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql -import java.util.concurrent.atomic.AtomicInteger - import scala.collection.JavaConverters._ import scala.reflect.runtime.universe.TypeTag import scala.util.Try @@ -3801,26 +3799,23 @@ object functions { ArrayExcept(col1.expr, col2.expr) } - // counter to ensure lambda variable names are unique - private val lambdaVarNameCounter = new AtomicInteger(0) - private def createLambda(f: Column => Column) = { - val x = UnresolvedNamedLambdaVariable(Seq("x_" + lambdaVarNameCounter.incrementAndGet())) + val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x"))) val function = f(Column(x)).expr LambdaFunction(function, Seq(x)) } private def createLambda(f: (Column, Column) => Column) = { - val x = UnresolvedNamedLambdaVariable(Seq("x_" + lambdaVarNameCounter.incrementAndGet())) - val y = UnresolvedNamedLambdaVariable(Seq("y_" + lambdaVarNameCounter.incrementAndGet())) + val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x"))) + val y = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("y"))) val function = f(Column(x), Column(y)).expr LambdaFunction(function, Seq(x, y)) } private def createLambda(f: (Column, Column, Column) => Column) = { - val x = UnresolvedNamedLambdaVariable(Seq("x_" + lambdaVarNameCounter.incrementAndGet())) - val y = UnresolvedNamedLambdaVariable(Seq("y_" + lambdaVarNameCounter.incrementAndGet())) - val z = UnresolvedNamedLambdaVariable(Seq("z_" + lambdaVarNameCounter.incrementAndGet())) + val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x"))) + val y = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("y"))) + val z = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("z"))) val function = f(Column(x), Column(y), Column(z)).expr LambdaFunction(function, Seq(x, y, z)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 3642804f5a269..9aeafde381479 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2261,32 +2261,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { assert(ex3.getMessage.contains("cannot resolve 'a'")) } - test("nested transform (DSL)") { - val df = Seq( - (Seq(1, 2, 3), Seq("a", "b", "c")) - ).toDF("numbers", "letters") - - checkAnswer( - df.select( - flatten( - transform( - $"numbers", - (number: Column) => transform( - $"letters", - (letter: Column) => struct( - number.as("number"), - letter.as("letter") - ) - ) - ) - ).as("zipped") - ), - Seq(Row(Seq(Row(1, "a"), Row(1, "b"), Row(1, "c"), Row(2, "a"), Row(2, "b"), - Row(2, "c"), Row(3, "a"), Row(3, "b"), Row(3, "c") - ))) - ) - } - test("map_filter") { val dfInts = Seq( Map(1 -> 10, 2 -> 20, 3 -> 30), @@ -3655,6 +3629,29 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { df.select(map(map_entries($"m"), lit(1))), Row(Map(Seq(Row(1, "a")) -> 1))) } + + test("SPARK-34794: lambda variable name issues in nested functions") { + val df1 = Seq((Seq(1, 2), Seq("a", "b"))).toDF("numbers", "letters") + + checkAnswer(df1.select(flatten(transform($"numbers", (number: Column) => + transform($"letters", (letter: Column) => + struct(number, letter))))), + Seq(Row(Seq(Row(1, "a"), Row(1, "b"), Row(2, "a"), Row(2, "b")))) + ) + checkAnswer(df1.select(flatten(transform($"numbers", (number: Column, i: Column) => + transform($"letters", (letter: Column, j: Column) => + struct(number + j, concat(letter, i)))))), + Seq(Row(Seq(Row(1, "a0"), Row(2, "b0"), Row(2, "a1"), Row(3, "b1")))) + ) + + val df2 = Seq((Map("a" -> 1, "b" -> 2), Map("a" -> 2, "b" -> 3))).toDF("m1", "m2") + + checkAnswer(df2.select(map_zip_with($"m1", $"m2", (k1: Column, ov1: Column, ov2: Column) => + map_zip_with($"m1", $"m2", (k2: Column, iv1: Column, iv2: Column) => + ov1 + iv1 + ov2 + iv2))), + Seq(Row(Map("a" -> Map("a" -> 6, "b" -> 8), "b" -> Map("a" -> 8, "b" -> 10)))) + ) + } } object DataFrameFunctionsSuite { From ea961eeb32da606ddcedda9a22b0dff6d281587b Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 5 May 2021 09:14:28 +0900 Subject: [PATCH 5/5] review --- .../spark/sql/DataFrameFunctionsSuite.scala | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 9aeafde381479..eb71f5966a9c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -3634,22 +3634,22 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { val df1 = Seq((Seq(1, 2), Seq("a", "b"))).toDF("numbers", "letters") checkAnswer(df1.select(flatten(transform($"numbers", (number: Column) => - transform($"letters", (letter: Column) => - struct(number, letter))))), - Seq(Row(Seq(Row(1, "a"), Row(1, "b"), Row(2, "a"), Row(2, "b")))) + transform($"letters", (letter: Column) => + struct(number, letter))))), + Seq(Row(Seq(Row(1, "a"), Row(1, "b"), Row(2, "a"), Row(2, "b")))) ) checkAnswer(df1.select(flatten(transform($"numbers", (number: Column, i: Column) => - transform($"letters", (letter: Column, j: Column) => - struct(number + j, concat(letter, i)))))), - Seq(Row(Seq(Row(1, "a0"), Row(2, "b0"), Row(2, "a1"), Row(3, "b1")))) + transform($"letters", (letter: Column, j: Column) => + struct(number + j, concat(letter, i)))))), + Seq(Row(Seq(Row(1, "a0"), Row(2, "b0"), Row(2, "a1"), Row(3, "b1")))) ) val df2 = Seq((Map("a" -> 1, "b" -> 2), Map("a" -> 2, "b" -> 3))).toDF("m1", "m2") checkAnswer(df2.select(map_zip_with($"m1", $"m2", (k1: Column, ov1: Column, ov2: Column) => - map_zip_with($"m1", $"m2", (k2: Column, iv1: Column, iv2: Column) => - ov1 + iv1 + ov2 + iv2))), - Seq(Row(Map("a" -> Map("a" -> 6, "b" -> 8), "b" -> Map("a" -> 8, "b" -> 10)))) + map_zip_with($"m1", $"m2", (k2: Column, iv1: Column, iv2: Column) => + ov1 + iv1 + ov2 + iv2))), + Seq(Row(Map("a" -> Map("a" -> 6, "b" -> 8), "b" -> Map("a" -> 8, "b" -> 10)))) ) } }