From f684c4b1b43411a3ee38eb5b82b11a41c8ccc93d Mon Sep 17 00:00:00 2001 From: Nikola Mandic Date: Wed, 3 Apr 2024 16:03:25 +0200 Subject: [PATCH 1/6] Add collations support to split regex expression --- .../sql/catalyst/util/CollationSupport.java | 32 ++++++++++++++++ .../apache/spark/unsafe/types/UTF8String.java | 22 +++++++++-- .../expressions/regexpExpressions.scala | 18 ++++++--- .../sql/CollationRegexpExpressionsSuite.scala | 37 +++++++++++-------- 4 files changed, 84 insertions(+), 25 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java index fe1952921b7fb..bea46653c7896 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.catalyst.util; +import java.util.regex.Pattern; + import com.ibm.icu.text.StringSearch; import org.apache.spark.unsafe.types.UTF8String; @@ -143,6 +145,36 @@ public static boolean execICU(final UTF8String l, final UTF8String r, * Collation-aware regexp expressions. */ + public static class StringSplit { + public static UTF8String[] exec(final UTF8String l, final UTF8String r, final int limit, + final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + if (collation.supportsBinaryEquality) { + return execBinary(l, r, limit); + } else { + return execLowercase(l, r, limit); + } + } + public static String genCode(final String l, final String r, final String limit, + final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + String expr = "CollationSupport.StringSplit.exec"; + if (collation.supportsBinaryEquality) { + return String.format(expr + "Binary(%s, %s, %s)", l, r, limit); + } else { + return String.format(expr + "Lowercase(%s, %s, %s)", l, r, limit); + } + } + public static UTF8String[] execBinary(final UTF8String l, final UTF8String r, + final int limit) { + return l.split(r, limit); + } + public static UTF8String[] execLowercase(final UTF8String l, final UTF8String r, + final int limit) { + return l.split(r, limit, Pattern.UNICODE_CASE | Pattern.CASE_INSENSITIVE); + } + } + // TODO: Add more collation-aware regexp expressions. /** diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 2009f1d20442c..9fbdb0cfaf0c7 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -1028,7 +1028,7 @@ public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) { return fromBytes(result); } - public UTF8String[] split(UTF8String pattern, int limit) { + public UTF8String[] split(UTF8String pattern, int limit, int regexFlags) { // For the empty `pattern` a `split` function ignores trailing empty strings unless original // string is empty. if (numBytes() != 0 && pattern.numBytes() == 0) { @@ -1044,7 +1044,11 @@ public UTF8String[] split(UTF8String pattern, int limit) { } return result; } - return split(pattern.toString(), limit); + return split(pattern.toString(), limit, regexFlags); + } + + public UTF8String[] split(UTF8String pattern, int limit) { + return split(pattern, limit, 0); // Pattern without regex flags } public UTF8String[] splitSQL(UTF8String delimiter, int limit) { @@ -1061,14 +1065,20 @@ public UTF8String[] splitSQL(UTF8String delimiter, int limit) { } } - private UTF8String[] split(String delimiter, int limit) { + private UTF8String[] split(String delimiter, int limit, int regexFlags) { // Java String's split method supports "ignore empty string" behavior when the limit is 0 // whereas other languages do not. To avoid this java specific behavior, we fall back to // -1 when the limit is 0. if (limit == 0) { limit = -1; } - String[] splits = toString().split(delimiter, limit); + String[] splits; + if (regexFlags == 0) { + // Pattern without regex flags + splits = toString().split(delimiter, limit); + } else { + splits = Pattern.compile(delimiter, regexFlags).split(toString(), limit); + } UTF8String[] res = new UTF8String[splits.length]; for (int i = 0; i < res.length; i++) { res[i] = fromString(splits[i]); @@ -1076,6 +1086,10 @@ private UTF8String[] split(String delimiter, int limit) { return res; } + private UTF8String[] split(String delimiter, int limit) { + return split(delimiter, limit, 0); + } + public UTF8String replace(UTF8String search, UTF8String replace) { // This implementation is loosely based on commons-lang3's StringUtils.replace(). if (numBytes == 0 || search.numBytes == 0) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index b33de303b5d55..958f9424403c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -33,8 +33,9 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.BinaryLike import org.apache.spark.sql.catalyst.trees.TreePattern.{LIKE_FAMLIY, REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE, TreePattern} -import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils} +import org.apache.spark.sql.catalyst.util.{CollationSupport, GenericArrayData, StringUtils} import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.internal.types.{StringTypeAnyCollation, StringTypeBinaryLcase} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -543,17 +544,21 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress case class StringSplit(str: Expression, regex: Expression, limit: Expression) extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { - override def dataType: DataType = ArrayType(StringType, containsNull = false) - override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType) + override def dataType: DataType = ArrayType(str.dataType, containsNull = false) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeBinaryLcase, StringTypeAnyCollation, IntegerType) + override def first: Expression = str override def second: Expression = regex override def third: Expression = limit + final lazy val collationId: Int = str.dataType.asInstanceOf[StringType].collationId + def this(exp: Expression, regex: Expression) = this(exp, regex, Literal(-1)) override def nullSafeEval(string: Any, regex: Any, limit: Any): Any = { - val strings = string.asInstanceOf[UTF8String].split( - regex.asInstanceOf[UTF8String], limit.asInstanceOf[Int]) + val strings = CollationSupport.StringSplit.exec(string.asInstanceOf[UTF8String], + regex.asInstanceOf[UTF8String], limit.asInstanceOf[Int], collationId) new GenericArrayData(strings.asInstanceOf[Array[Any]]) } @@ -561,7 +566,8 @@ case class StringSplit(str: Expression, regex: Expression, limit: Expression) val arrayClass = classOf[GenericArrayData].getName nullSafeCodeGen(ctx, ev, (str, regex, limit) => { // Array in java is covariant, so we don't need to cast UTF8String[] to Object[]. - s"""${ev.value} = new $arrayClass($str.split($regex,$limit));""".stripMargin + val genCode = CollationSupport.StringSplit.genCode(str, regex, limit, collationId) + s"""${ev.value} = new $arrayClass($genCode);""".stripMargin }) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala index 0876425847bbb..799352ecf0b78 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala @@ -115,27 +115,34 @@ class CollationRegexpExpressionsSuite } test("Support StringSplit string expression with collation") { - // Supported collations - case class StringSplitTestCase[R](l: String, r: String, c: String, result: R) + case class StringSplitTestCase[R](l: String, r: String, c: String, result: R, limit: Int = -1) val testCases = Seq( - StringSplitTestCase("ABC", "[B]", "UTF8_BINARY", Seq("A", "C")) + StringSplitTestCase("ABC", "[B]", "UTF8_BINARY", Seq("A", "C")), + StringSplitTestCase("ABC", "[b]", "UTF8_BINARY", Seq("ABC")), + StringSplitTestCase("ABC", "[b]", "UTF8_BINARY_LCASE", Seq("A", "C")), + StringSplitTestCase("AAA", "[a]", "UTF8_BINARY_LCASE", Seq("", "", "", "")), + StringSplitTestCase("AAA", "[b]", "UTF8_BINARY_LCASE", Seq("AAA")), + StringSplitTestCase("aAbB", "[ab]", "UTF8_BINARY_LCASE", Seq("", "", "", "", "")), + StringSplitTestCase("", "", "UTF8_BINARY_LCASE", Seq("")), + StringSplitTestCase("", "[a]", "UTF8_BINARY_LCASE", Seq("")), + StringSplitTestCase("xAxBxaxbx", "[AB]", "UTF8_BINARY_LCASE", Seq("x", "x", "x", "x", "x")), + StringSplitTestCase("ABC", "", "UTF8_BINARY_LCASE", Seq("A", "B", "C")), + // test split with limit + StringSplitTestCase("ABC", "[b]", "UTF8_BINARY_LCASE", Seq("ABC"), 1), + StringSplitTestCase("ABC", "[b]", "UTF8_BINARY_LCASE", Seq("A", "C"), 2), + StringSplitTestCase("ABC", "[b]", "UTF8_BINARY_LCASE", Seq("A", "C"), 3), + StringSplitTestCase("ABC", "[B]", "UNICODE", Seq("A", "C")), + StringSplitTestCase("ABC", "[b]", "UNICODE", Seq("ABC")) ) testCases.foreach(t => { - val query = s"SELECT split(collate('${t.l}', '${t.c}'), collate('${t.r}', '${t.c}'))" - // Result & data type - checkAnswer(sql(query), Row(t.result)) + val query = s"SELECT split(collate('${t.l}', '${t.c}'), '${t.r}', ${t.limit})" + checkAnswer(sql(query), Seq(Row(t.result))) assert(sql(query).schema.fields.head.dataType.sameType(ArrayType(StringType(t.c)))) - // TODO: Implicit casting (not currently supported) }) - // Unsupported collations - case class StringSplitTestFail(l: String, r: String, c: String) - val failCases = Seq( - StringSplitTestFail("ABC", "[b]", "UTF8_BINARY_LCASE"), - StringSplitTestFail("ABC", "[B]", "UNICODE"), - StringSplitTestFail("ABC", "[b]", "UNICODE_CI") - ) + case class StringSplitFail(l: String, r: String, c: String) + val failCases = Seq(StringSplitFail("ABC", "[b]", "UNICODE_CI")) failCases.foreach(t => { - val query = s"SELECT split(collate('${t.l}', '${t.c}'), collate('${t.r}', '${t.c}'))" + val query = s"SELECT split(collate('${t.l}', '${t.c}'), '${t.r}')" val unsupportedCollation = intercept[AnalysisException] { sql(query) } assert(unsupportedCollation.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE") }) From 5b27a259fa7150bfe0fe9d5261738dc6875654ac Mon Sep 17 00:00:00 2001 From: Nikola Mandic Date: Mon, 15 Apr 2024 11:32:51 +0200 Subject: [PATCH 2/6] Fixes --- .../java/org/apache/spark/unsafe/types/UTF8String.java | 2 +- .../sql/catalyst/expressions/regexpExpressions.scala | 1 - .../spark/sql/CollationRegexpExpressionsSuite.scala | 8 ++++++-- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 9fbdb0cfaf0c7..caf64e2a23325 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -1087,7 +1087,7 @@ private UTF8String[] split(String delimiter, int limit, int regexFlags) { } private UTF8String[] split(String delimiter, int limit) { - return split(delimiter, limit, 0); + return split(delimiter, limit, 0); // Pattern without regex flags } public UTF8String replace(UTF8String search, UTF8String replace) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 958f9424403c2..40a6b03f5e1d6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -547,7 +547,6 @@ case class StringSplit(str: Expression, regex: Expression, limit: Expression) override def dataType: DataType = ArrayType(str.dataType, containsNull = false) override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeBinaryLcase, StringTypeAnyCollation, IntegerType) - override def first: Expression = str override def second: Expression = regex override def third: Expression = limit diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala index 799352ecf0b78..1897eca0cc242 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala @@ -115,6 +115,7 @@ class CollationRegexpExpressionsSuite } test("Support StringSplit string expression with collation") { + // Supported collations case class StringSplitTestCase[R](l: String, r: String, c: String, result: R, limit: Int = -1) val testCases = Seq( StringSplitTestCase("ABC", "[B]", "UTF8_BINARY", Seq("A", "C")), @@ -136,11 +137,14 @@ class CollationRegexpExpressionsSuite ) testCases.foreach(t => { val query = s"SELECT split(collate('${t.l}', '${t.c}'), '${t.r}', ${t.limit})" + // Result & data type checkAnswer(sql(query), Seq(Row(t.result))) assert(sql(query).schema.fields.head.dataType.sameType(ArrayType(StringType(t.c)))) + // TODO: Implicit casting (not currently supported) }) - case class StringSplitFail(l: String, r: String, c: String) - val failCases = Seq(StringSplitFail("ABC", "[b]", "UNICODE_CI")) + // Unsupported collations + case class StringSplitTestFail(l: String, r: String, c: String) + val failCases = Seq(StringSplitTestFail("ABC", "[b]", "UNICODE_CI")) failCases.foreach(t => { val query = s"SELECT split(collate('${t.l}', '${t.c}'), '${t.r}')" val unsupportedCollation = intercept[AnalysisException] { sql(query) } From 2bed330d457e22979f65790a228e497c4098a21d Mon Sep 17 00:00:00 2001 From: Nikola Mandic Date: Mon, 15 Apr 2024 11:34:32 +0200 Subject: [PATCH 3/6] Fixes --- .../org/apache/spark/sql/CollationRegexpExpressionsSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala index 1897eca0cc242..cf3d1f11ce4ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala @@ -138,7 +138,7 @@ class CollationRegexpExpressionsSuite testCases.foreach(t => { val query = s"SELECT split(collate('${t.l}', '${t.c}'), '${t.r}', ${t.limit})" // Result & data type - checkAnswer(sql(query), Seq(Row(t.result))) + checkAnswer(sql(query), Row(t.result)) assert(sql(query).schema.fields.head.dataType.sameType(ArrayType(StringType(t.c)))) // TODO: Implicit casting (not currently supported) }) From ecb5e309aceda0ea3646a5602d7832359979cf67 Mon Sep 17 00:00:00 2001 From: Nikola Mandic Date: Mon, 15 Apr 2024 14:23:52 +0200 Subject: [PATCH 4/6] Refactor --- .../sql/catalyst/util/CollationSupport.java | 29 ++++++++++--------- .../apache/spark/unsafe/types/UTF8String.java | 22 +++----------- .../expressions/regexpExpressions.scala | 5 ++-- .../sql/CollationRegexpExpressionsSuite.scala | 2 -- 4 files changed, 22 insertions(+), 36 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java index bea46653c7896..69a2132d4f515 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java @@ -16,8 +16,6 @@ */ package org.apache.spark.sql.catalyst.util; -import java.util.regex.Pattern; - import com.ibm.icu.text.StringSearch; import org.apache.spark.unsafe.types.UTF8String; @@ -146,32 +144,37 @@ public static boolean execICU(final UTF8String l, final UTF8String r, */ public static class StringSplit { - public static UTF8String[] exec(final UTF8String l, final UTF8String r, final int limit, - final int collationId) { + public static UTF8String[] exec(final UTF8String string, final UTF8String regex, + final int limit, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); if (collation.supportsBinaryEquality) { - return execBinary(l, r, limit); + return execBinary(string, regex, limit); } else { - return execLowercase(l, r, limit); + return execLowercase(string, regex, limit); } } - public static String genCode(final String l, final String r, final String limit, + public static String genCode(final String string, final String regex, final String limit, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringSplit.exec"; if (collation.supportsBinaryEquality) { - return String.format(expr + "Binary(%s, %s, %s)", l, r, limit); + return String.format(expr + "Binary(%s, %s, %s)", string, regex, limit); } else { - return String.format(expr + "Lowercase(%s, %s, %s)", l, r, limit); + return String.format(expr + "Lowercase(%s, %s, %s)", string, regex, limit); } } - public static UTF8String[] execBinary(final UTF8String l, final UTF8String r, + public static UTF8String[] execBinary(final UTF8String string, final UTF8String regex, final int limit) { - return l.split(r, limit); + return string.split(regex, limit); } - public static UTF8String[] execLowercase(final UTF8String l, final UTF8String r, + public static UTF8String[] execLowercase(final UTF8String string, final UTF8String regex, final int limit) { - return l.split(r, limit, Pattern.UNICODE_CASE | Pattern.CASE_INSENSITIVE); + if (string.numBytes() != 0 && regex.numBytes() == 0) { + return string.split(regex, limit); + } else { + // ui flags toggle unicode case-insensitive matching + return string.split(UTF8String.fromString("(?ui)" + regex.toString()), limit); + } } } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index caf64e2a23325..2009f1d20442c 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -1028,7 +1028,7 @@ public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) { return fromBytes(result); } - public UTF8String[] split(UTF8String pattern, int limit, int regexFlags) { + public UTF8String[] split(UTF8String pattern, int limit) { // For the empty `pattern` a `split` function ignores trailing empty strings unless original // string is empty. if (numBytes() != 0 && pattern.numBytes() == 0) { @@ -1044,11 +1044,7 @@ public UTF8String[] split(UTF8String pattern, int limit, int regexFlags) { } return result; } - return split(pattern.toString(), limit, regexFlags); - } - - public UTF8String[] split(UTF8String pattern, int limit) { - return split(pattern, limit, 0); // Pattern without regex flags + return split(pattern.toString(), limit); } public UTF8String[] splitSQL(UTF8String delimiter, int limit) { @@ -1065,20 +1061,14 @@ public UTF8String[] splitSQL(UTF8String delimiter, int limit) { } } - private UTF8String[] split(String delimiter, int limit, int regexFlags) { + private UTF8String[] split(String delimiter, int limit) { // Java String's split method supports "ignore empty string" behavior when the limit is 0 // whereas other languages do not. To avoid this java specific behavior, we fall back to // -1 when the limit is 0. if (limit == 0) { limit = -1; } - String[] splits; - if (regexFlags == 0) { - // Pattern without regex flags - splits = toString().split(delimiter, limit); - } else { - splits = Pattern.compile(delimiter, regexFlags).split(toString(), limit); - } + String[] splits = toString().split(delimiter, limit); UTF8String[] res = new UTF8String[splits.length]; for (int i = 0; i < res.length; i++) { res[i] = fromString(splits[i]); @@ -1086,10 +1076,6 @@ private UTF8String[] split(String delimiter, int limit, int regexFlags) { return res; } - private UTF8String[] split(String delimiter, int limit) { - return split(delimiter, limit, 0); // Pattern without regex flags - } - public UTF8String replace(UTF8String search, UTF8String replace) { // This implementation is loosely based on commons-lang3's StringUtils.replace(). if (numBytes == 0 || search.numBytes == 0) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 40a6b03f5e1d6..a91cc74276df9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -563,10 +563,9 @@ case class StringSplit(str: Expression, regex: Expression, limit: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val arrayClass = classOf[GenericArrayData].getName - nullSafeCodeGen(ctx, ev, (str, regex, limit) => { + defineCodeGen(ctx, ev, (str, regex, limit) => { // Array in java is covariant, so we don't need to cast UTF8String[] to Object[]. - val genCode = CollationSupport.StringSplit.genCode(str, regex, limit, collationId) - s"""${ev.value} = new $arrayClass($genCode);""".stripMargin + s"new $arrayClass(${CollationSupport.StringSplit.genCode(str, regex, limit, collationId)})" }) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala index cf3d1f11ce4ce..87bc3b5ea180e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala @@ -140,7 +140,6 @@ class CollationRegexpExpressionsSuite // Result & data type checkAnswer(sql(query), Row(t.result)) assert(sql(query).schema.fields.head.dataType.sameType(ArrayType(StringType(t.c)))) - // TODO: Implicit casting (not currently supported) }) // Unsupported collations case class StringSplitTestFail(l: String, r: String, c: String) @@ -150,7 +149,6 @@ class CollationRegexpExpressionsSuite val unsupportedCollation = intercept[AnalysisException] { sql(query) } assert(unsupportedCollation.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE") }) - // TODO: Collation mismatch (not currently supported) } test("Support RegExpReplace string expression with collation") { From 7f337fdb55cc3287e13a69a9ca5b79447c0a7aff Mon Sep 17 00:00:00 2001 From: Nikola Mandic Date: Mon, 15 Apr 2024 16:31:32 +0200 Subject: [PATCH 5/6] Optimize regex allocation --- .../spark/sql/catalyst/util/CollationSupport.java | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java index 69a2132d4f515..c3fc2835ccb21 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java @@ -172,8 +172,7 @@ public static UTF8String[] execLowercase(final UTF8String string, final UTF8Stri if (string.numBytes() != 0 && regex.numBytes() == 0) { return string.split(regex, limit); } else { - // ui flags toggle unicode case-insensitive matching - return string.split(UTF8String.fromString("(?ui)" + regex.toString()), limit); + return string.split(CollationAwareUTF8String.getLowercaseRegex(regex), limit); } } } @@ -204,6 +203,13 @@ private static boolean matchAt(final UTF8String target, final UTF8String pattern pos, pos + pattern.numChars()), pattern, collationId).last() == 0; } + // ui flags toggle unicode case-insensitive matching + private static final UTF8String lowercaseRegexPrefix = UTF8String.fromString("(?ui)"); + + private static UTF8String getLowercaseRegex(UTF8String regex) { + return UTF8String.concat(lowercaseRegexPrefix, regex); + } + } } From d336245bcb2f9de1a81863117a60739263fd108c Mon Sep 17 00:00:00 2001 From: Nikola Mandic Date: Mon, 15 Apr 2024 21:21:41 +0200 Subject: [PATCH 6/6] Updates --- .../sql/catalyst/util/CollationSupport.java | 1 + .../unsafe/types/CollationSupportSuite.java | 78 +++++++++++++++++++ .../sql/CollationRegexpExpressionsSuite.scala | 14 +--- 3 files changed, 82 insertions(+), 11 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java index c3fc2835ccb21..a09ccf7e23d93 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java @@ -150,6 +150,7 @@ public static UTF8String[] exec(final UTF8String string, final UTF8String regex, if (collation.supportsBinaryEquality) { return execBinary(string, regex, limit); } else { + assert(collation.supportsLowercaseEquality); return execLowercase(string, regex, limit); } } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java index bfb696c35fff6..a390b8410e314 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java @@ -16,6 +16,8 @@ */ package org.apache.spark.unsafe.types; +import java.util.Arrays; + import org.apache.spark.SparkException; import org.apache.spark.sql.catalyst.util.CollationFactory; import org.apache.spark.sql.catalyst.util.CollationSupport; @@ -255,6 +257,82 @@ public void testEndsWith() throws SparkException { * Collation-aware regexp expressions. */ + @Test + public void testStringSplit() throws SparkException { + // binary equality + assertStringSplit("ABC", "[B]", "UTF8_BINARY", new String[]{"A", "C"}); + assertStringSplit("ABC", "[b]", "UTF8_BINARY", new String[]{"ABC"}); + assertStringSplit("aaaa", "", "UTF8_BINARY", new String[]{"a", "a", "a", "a"}); + assertStringSplit("aaaa", "[a-z]", "UTF8_BINARY", new String[]{"", "", "", "", ""}); + assertStringSplit("aaaa", "[0-9]", "UTF8_BINARY", new String[]{"aaaa"}); + assertStringSplit("a1b2", "[a-z0-9]", "UTF8_BINARY", new String[]{"", "", "", "", ""}); + assertStringSplit("ABC", "[B]", "UNICODE", new String[]{"A", "C"}); + assertStringSplit("ABC", "[b]", "UNICODE", new String[]{"ABC"}); + assertStringSplit("aaaa", "", "UNICODE", new String[]{"a", "a", "a", "a"}); + assertStringSplit("aaaa", "[a-z]", "UNICODE", new String[]{"", "", "", "", ""}); + assertStringSplit("aaaa", "[0-9]", "UNICODE", new String[]{"aaaa"}); + assertStringSplit("a1b2", "[a-z0-9]", "UNICODE", new String[]{"", "", "", "", ""}); + // non-binary equality (lowercase) + assertStringSplit("ABC", "[B]", "UTF8_BINARY_LCASE", new String[]{"A", "C"}); + assertStringSplit("ABC", "[b]", "UTF8_BINARY_LCASE", new String[]{"A", "C"}); + assertStringSplit("aaaa", "", "UTF8_BINARY_LCASE", new String[]{"a", "a", "a", "a"}); + assertStringSplit("aaaa", "[a-z]", "UTF8_BINARY_LCASE", new String[]{"", "", "", "", ""}); + assertStringSplit("aaaa", "[0-9]", "UTF8_BINARY_LCASE", new String[]{"aaaa"}); + assertStringSplit("a1b2", "[a-z0-9]", "UTF8_BINARY_LCASE", new String[]{"", "", "", "", ""}); + assertStringSplit("AAA", "[a]", "UTF8_BINARY_LCASE", new String[]{"", "", "", ""}); + assertStringSplit("AAA", "[b]", "UTF8_BINARY_LCASE", new String[]{"AAA"}); + assertStringSplit("aAbB", "[ab]", "UTF8_BINARY_LCASE",new String[]{"", "", "", "", ""}); + assertStringSplit("", "", "UTF8_BINARY_LCASE", new String[]{""}); + assertStringSplit("", "[a]", "UTF8_BINARY_LCASE", new String[]{""}); + assertStringSplit("xAxBxaxbx", "[AB]", "UTF8_BINARY_LCASE", + new String[]{"x", "x", "x", "x", "x"}); + assertStringSplit("ABC", "", "UTF8_BINARY_LCASE", new String[]{"A", "B", "C"}); + // special characters + assertStringSplit("ä", "", "UTF8_BINARY", new String[]{"ä"}); + assertStringSplit("ääää", "", "UTF8_BINARY", new String[]{"ä", "ä", "ä", "ä"}); + assertStringSplit("äbćδ", "", "UTF8_BINARY", new String[]{"ä", "b", "ć", "δ"}); + assertStringSplit("äbćδ", "[äbćδ]", "UTF8_BINARY", new String[]{"", "", "", "", ""}); + assertStringSplit("ä", "", "UTF8_BINARY_LCASE", new String[]{"ä"}); + assertStringSplit("ääää", "", "UTF8_BINARY_LCASE", new String[]{"ä", "ä", "ä", "ä"}); + assertStringSplit("äbćδ", "", "UTF8_BINARY_LCASE", new String[]{"ä", "b", "ć", "δ"}); + assertStringSplit("äbćδ", "[äbćδ]", "UTF8_BINARY_LCASE", new String[]{"", "", "", "", ""}); + assertStringSplit("äbćδ", "[ÄBĆΔ]", "UTF8_BINARY_LCASE", new String[]{"", "", "", "", ""}); + assertStringSplit("äbćδ", "[äBćΔ]", "UTF8_BINARY_LCASE", new String[]{"", "", "", "", ""}); + assertStringSplit("ääää", "Ä", "UTF8_BINARY_LCASE", new String[]{"", "", "", "", ""}); + assertStringSplit("AäBÄCä", "Ä", "UTF8_BINARY_LCASE", new String[]{"A", "B", "C", ""}); + assertStringSplit("AäBÄCäD", "Ä", "UTF8_BINARY_LCASE", new String[]{"A", "B", "C", "D"}); + assertStringSplit("ä", "", "UNICODE", new String[]{"ä"}); + assertStringSplit("ääää", "", "UNICODE", new String[]{"ä", "ä", "ä", "ä"}); + assertStringSplit("äbćδ", "", "UNICODE", new String[]{"ä", "b", "ć", "δ"}); + assertStringSplit("äbćδ", "[äbćδ]", "UNICODE", new String[]{"", "", "", "", ""}); + // set limit + assertStringSplit("ABC", "[B]", 0, "UTF8_BINARY", new String[]{"A", "C"}); + assertStringSplit("ABC", "[B]", 1, "UTF8_BINARY", new String[]{"ABC"}); + assertStringSplit("ABC", "[B]", 2, "UTF8_BINARY", new String[]{"A", "C"}); + assertStringSplit("ABC", "[B]", 3, "UTF8_BINARY", new String[]{"A", "C"}); + assertStringSplit("ABC", "[b]", 0, "UTF8_BINARY_LCASE", new String[]{"A", "C"}); + assertStringSplit("ABC", "[b]", 1, "UTF8_BINARY_LCASE", new String[]{"ABC"}); + assertStringSplit("ABC", "[b]", 2, "UTF8_BINARY_LCASE", new String[]{"A", "C"}); + assertStringSplit("ABC", "[b]", 3, "UTF8_BINARY_LCASE", new String[]{"A", "C"}); + assertStringSplit("ABC", "[B]", 0, "UNICODE", new String[]{"A", "C"}); + assertStringSplit("ABC", "[B]", 1, "UNICODE", new String[]{"ABC"}); + assertStringSplit("ABC", "[B]", 2, "UNICODE", new String[]{"A", "C"}); + assertStringSplit("ABC", "[B]", 3, "UNICODE", new String[]{"A", "C"}); + } + + private void assertStringSplit(String string, String regex, int limit, String collationName, + String[] value) throws SparkException { + UTF8String[] result = CollationSupport.StringSplit.exec(UTF8String.fromString(string), + UTF8String.fromString(regex), limit, CollationFactory.collationNameToId(collationName)); + String[] actual = Arrays.stream(result).map(UTF8String::toString).toArray(String[]::new); + assertArrayEquals(value, actual); + } + + private void assertStringSplit(String string, String regex, String collationName, + String[] value) throws SparkException { + assertStringSplit(string, regex, -1, collationName, value); + } + // TODO: Test more collation-aware regexp expressions. /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala index 87bc3b5ea180e..0a5a5055e6fa3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala @@ -122,18 +122,10 @@ class CollationRegexpExpressionsSuite StringSplitTestCase("ABC", "[b]", "UTF8_BINARY", Seq("ABC")), StringSplitTestCase("ABC", "[b]", "UTF8_BINARY_LCASE", Seq("A", "C")), StringSplitTestCase("AAA", "[a]", "UTF8_BINARY_LCASE", Seq("", "", "", "")), - StringSplitTestCase("AAA", "[b]", "UTF8_BINARY_LCASE", Seq("AAA")), - StringSplitTestCase("aAbB", "[ab]", "UTF8_BINARY_LCASE", Seq("", "", "", "", "")), - StringSplitTestCase("", "", "UTF8_BINARY_LCASE", Seq("")), - StringSplitTestCase("", "[a]", "UTF8_BINARY_LCASE", Seq("")), - StringSplitTestCase("xAxBxaxbx", "[AB]", "UTF8_BINARY_LCASE", Seq("x", "x", "x", "x", "x")), - StringSplitTestCase("ABC", "", "UTF8_BINARY_LCASE", Seq("A", "B", "C")), - // test split with limit - StringSplitTestCase("ABC", "[b]", "UTF8_BINARY_LCASE", Seq("ABC"), 1), - StringSplitTestCase("ABC", "[b]", "UTF8_BINARY_LCASE", Seq("A", "C"), 2), - StringSplitTestCase("ABC", "[b]", "UTF8_BINARY_LCASE", Seq("A", "C"), 3), StringSplitTestCase("ABC", "[B]", "UNICODE", Seq("A", "C")), - StringSplitTestCase("ABC", "[b]", "UNICODE", Seq("ABC")) + StringSplitTestCase("ABC", "[b]", "UTF8_BINARY", Seq("ABC"), 1), + StringSplitTestCase("ABC", "[b]", "UTF8_BINARY_LCASE", Seq("ABC"), 1), + StringSplitTestCase("ABC", "[b]", "UTF8_BINARY_LCASE", Seq("A", "C"), 2) ) testCases.foreach(t => { val query = s"SELECT split(collate('${t.l}', '${t.c}'), '${t.r}', ${t.limit})"