diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 119508a37e717..8533e107db831 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -175,12 +175,19 @@ public Collation( * Auxiliary methods for collation aware string operations. */ + /** + * Creates an instance of ICU's StringSearch with provided parameters. + * @param targetUTF8String UTF8String representation of the string to be searched. + * @param patternUTF8String UTF8String representation of the string to search for. + * @param collationId ID of the collation to use. + * @return Created instance of StringSearch. + */ public static StringSearch getStringSearch( - final UTF8String left, - final UTF8String right, + final UTF8String targetUTF8String, + final UTF8String patternUTF8String, final int collationId) { - String pattern = right.toString(); - CharacterIterator target = new StringCharacterIterator(left.toString()); + String pattern = patternUTF8String.toString(); + CharacterIterator target = new StringCharacterIterator(targetUTF8String.toString()); Collator collator = CollationFactory.fetchCollation(collationId).collator; return new StringSearch(pattern, target, (RuleBasedCollator) collator); } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index c5dfb91f06c63..42000c07aad43 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -585,6 +585,23 @@ public UTF8String trim() { return copyUTF8String(s, e); } + /** + * Trims space characters from both ends of this string - same as {@link UTF8String#trim()}. + * This variant of the method additionally applies provided collation to this string + * and space character before searching. + * + * @param collationId Id of the collation to use. + * @return this string with no spaces at the start or end. + */ + public UTF8String trim(int collationId) { + if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality + || CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID == collationId) { + return trim(); + } else { + return trim(UTF8String.fromString(" "), collationId); + } + } + /** * Trims whitespace ASCII characters from both ends of this string. * @@ -628,6 +645,27 @@ public UTF8String trim(UTF8String trimString) { } } + /** + * Trims characters of the given trim string from both ends of this string. + * This variant of the method additionally applies provided collation to this string + * and trim characters before searching. + * + * @param trimString The trim characters string. + * @param collationId Id of the collation to use. + * @return this string with no occurrences of the characters from trim string. + */ + public UTF8String trim(UTF8String trimString, int collationId) { + if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { + return trim(trimString); + } + + if (CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID == collationId) { + return lowercaseTrimLeft(trimString).lowercaseTrimRight(trimString); + } + + return trimLeft(trimString, collationId).trimRight(trimString, collationId); + } + /** * Trims space characters (ASCII 32) from the start of this string. * @@ -648,6 +686,23 @@ public UTF8String trimLeft() { return copyUTF8String(s, this.numBytes - 1); } + /** + * Trims space characters from the start of this string - same as {@link UTF8String#trimLeft()}. + * This variant of the method additionally applies provided collation to this string + * and space character before searching. + * + * @param collationId Id of the collation to use. + * @return this string with no spaces at the start. + */ + public UTF8String trimLeft(int collationId) { + if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality + || CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID == collationId) { + return trimLeft(); + } else { + return trimLeft(UTF8String.fromString(" "), collationId); + } + } + /** * Trims instances of the given trim string from the start of this string. * @@ -686,6 +741,109 @@ public UTF8String trimLeft(UTF8String trimString) { return copyUTF8String(trimIdx, numBytes - 1); } + /** + * Trims characters of the given trim string from the start of this string. + * This variant of the method additionally applies provided collation to this string + * and trim characters before searching. + * + * @param trimString The trim characters string. + * @param collationId Id of the collation to use. + * @return this string with no occurrences of the trim characters at the start. + */ + public UTF8String trimLeft(UTF8String trimString, int collationId) { + if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { + return trimLeft(trimString); + } + + if (CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID == collationId) { + return lowercaseTrimLeft(trimString); + } + + return collationAwareTrimLeft(trimString, collationId); + } + + private UTF8String lowercaseTrimLeft(UTF8String trimString) { + if (trimString == null) { + return null; + } + + // The searching byte position in the lowercase source string + int searchIdx = 0; + // The byte position of a first non-matching character in the lowercase source string + int trimByteIdx = 0; + + // Convert trimString to lowercase so it can be searched properly + trimString = trimString.toLowerCase(); + + while (searchIdx < numBytes) { + UTF8String searchChar = copyUTF8String( + searchIdx, + searchIdx + numBytesForFirstByte(getByte(searchIdx)) - 1); + int searchCharBytes = searchChar.numBytes; + + // Try to find the matching for the lowercase searchChar in the trimString + if (trimString.find(searchChar.toLowerCase(), 0) >= 0) { + trimByteIdx += searchCharBytes; + searchIdx += searchCharBytes; + } else { + // No matching, exit the search + break; + } + } + + if (searchIdx == 0) { + // Nothing trimmed - return original string (not converted to lowercase) + return this; + } + if (trimByteIdx >= numBytes) { + // Everything trimmed + return EMPTY_UTF8; + } + return copyUTF8String(trimByteIdx, numBytes - 1); + } + + private UTF8String collationAwareTrimLeft(UTF8String trimString, int collationId) { + if (trimString == null) { + return null; + } + + // The searching byte position in the source string + int searchIdx = 0; + // The byte position of a first non-matching character in the source string + int trimByteIdx = 0; + + while (searchIdx < numBytes) { + UTF8String searchChar = copyUTF8String( + searchIdx, + searchIdx + numBytesForFirstByte(getByte(searchIdx)) - 1); + int searchCharBytes = searchChar.numBytes; + + // Try to find the matching for the searchChar in the trimString + StringSearch stringSearch = CollationFactory.getStringSearch( + trimString, searchChar, collationId); + int searchCharIdx = stringSearch.next(); + + if (searchCharIdx != StringSearch.DONE + && stringSearch.getMatchLength() == stringSearch.getPattern().length()) { + trimByteIdx += searchCharBytes; + searchIdx += searchCharBytes; + } else { + // No matching, exit the search + break; + } + } + + if (searchIdx == 0) { + // Nothing trimmed - return original string (not converted to lowercase) + return this; + } + if (trimByteIdx >= numBytes) { + // Everything trimmed + return EMPTY_UTF8; + } + return copyUTF8String(trimByteIdx, numBytes - 1); + } + /** * Trims space characters (ASCII 32) from the end of this string. * @@ -706,6 +864,23 @@ public UTF8String trimRight() { return copyUTF8String(0, e); } + /** + * Trims space characters from the end of this string - same as {@link UTF8String#trimRight()}. + * This variant of the method additionally applies provided collation to this string + * and space character before searching. + * + * @param collationId Id of the collation to use. + * @return this string with no spaces at the end. + */ + public UTF8String trimRight(int collationId) { + if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality + || CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID == collationId) { + return trimRight(); + } else { + return trimRight(UTF8String.fromString(" "), collationId); + } + } + /** * Trims at most `numSpaces` space characters (ASCII 32) from the end of this string. */ @@ -767,6 +942,137 @@ public UTF8String trimRight(UTF8String trimString) { return copyUTF8String(0, trimEnd); } + /** + * Trims characters of the given trim string from the end of this string. + * This variant of the method additionally applies provided collation to this string + * and trim characters before searching. + * + * @param trimString The trim characters string. + * @param collationId Id of the collation to use. + * @return this string with no occurrences of the trim characters at the end. + */ + public UTF8String trimRight(UTF8String trimString, int collationId) { + if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { + return trimRight(trimString); + } + + if (CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID == collationId) { + return lowercaseTrimRight(trimString); + } + + return collationAwareTrimRight(trimString, collationId); + } + + private UTF8String lowercaseTrimRight(UTF8String trimString) { + if (trimString == null) { + return null; + } + + // Convert trimString to lowercase so it can be searched properly + trimString = trimString.toLowerCase(); + + // Number of bytes iterated from the source string + int byteIdx = 0; + // Number of characters iterated from the source string + int numChars = 0; + // Array of character length for the source string + int[] stringCharLen = new int[numBytes]; + // Array of the first byte position for each character in the source string + int[] stringCharPos = new int[numBytes]; + + // Build the position and length array + while (byteIdx < numBytes) { + stringCharPos[numChars] = byteIdx; + stringCharLen[numChars] = numBytesForFirstByte(getByte(byteIdx)); + byteIdx += stringCharLen[numChars]; + numChars++; + } + + // Index trimEnd points to the first no matching byte position from the right side of + // the source string. + int trimByteIdx = numBytes - 1; + + while (numChars > 0) { + UTF8String searchChar = copyUTF8String( + stringCharPos[numChars - 1], + stringCharPos[numChars - 1] + stringCharLen[numChars - 1] - 1); + + // Try to find the matching for the lowercase searchChar in the trimString + if (trimString.find(searchChar.toLowerCase(), 0) >= 0) { + trimByteIdx -= stringCharLen[numChars - 1]; + numChars--; + } else { + break; + } + } + + if (trimByteIdx == numBytes - 1) { + // Nothing trimmed + return this; + } + if (trimByteIdx < 0) { + // Everything trimmed + return EMPTY_UTF8; + } + return copyUTF8String(0, trimByteIdx); + } + + private UTF8String collationAwareTrimRight(UTF8String trimString, int collationId) { + if (trimString == null) { + return null; + } + + // Number of bytes iterated from the source string + int byteIdx = 0; + // Number of characters iterated from the source string + int numChars = 0; + // Array of character length for the source string + int[] stringCharLen = new int[numBytes]; + // Array of the first byte position for each character in the source string + int[] stringCharPos = new int[numBytes]; + + // Build the position and length array + while (byteIdx < numBytes) { + stringCharPos[numChars] = byteIdx; + stringCharLen[numChars] = numBytesForFirstByte(getByte(byteIdx)); + byteIdx += stringCharLen[numChars]; + numChars++; + } + + // Index trimEnd points to the first no matching byte position from the right side of + // the source string. + int trimByteIdx = numBytes - 1; + + while (numChars > 0) { + UTF8String searchChar = copyUTF8String( + stringCharPos[numChars - 1], + stringCharPos[numChars - 1] + stringCharLen[numChars - 1] - 1); + + // Try to find the matching for the searchChar in the trimString + StringSearch stringSearch = CollationFactory.getStringSearch( + trimString, searchChar, collationId); + int searchCharIdx = stringSearch.next(); + + if (searchCharIdx != StringSearch.DONE + && stringSearch.getMatchLength() == stringSearch.getPattern().length()) { + trimByteIdx -= stringCharLen[numChars - 1]; + numChars--; + } else { + break; + } + } + + if (trimByteIdx == numBytes - 1) { + // Nothing trimmed + return this; + } + if (trimByteIdx < 0) { + // Everything trimmed + return EMPTY_UTF8; + } + return copyUTF8String(0, trimByteIdx); + } + public UTF8String reverse() { byte[] result = new byte[this.numBytes]; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index e73dc5f2ee1b4..1415affaf9c86 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1028,8 +1028,11 @@ trait String2TrimExpression extends Expression with ImplicitCastInputTypes { protected def direction: String override def children: Seq[Expression] = srcStr +: trimStr.toSeq - override def dataType: DataType = StringType - override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType) + override def dataType: DataType = srcStr.dataType + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeAnyCollation, StringTypeAnyCollation) + + final lazy val collationId: Int = srcStr.dataType.asInstanceOf[StringType].collationId override def nullable: Boolean = children.exists(_.nullable) override def foldable: Boolean = children.forall(_.foldable) @@ -1037,6 +1040,19 @@ trait String2TrimExpression extends Expression with ImplicitCastInputTypes { protected def doEval(srcString: UTF8String): UTF8String protected def doEval(srcString: UTF8String, trimString: UTF8String): UTF8String + override def checkInputDataTypes(): TypeCheckResult = { + val defaultCheckResult = super.checkInputDataTypes() + if (defaultCheckResult.isFailure) { + return defaultCheckResult + } + + trimStr match { + case None => TypeCheckResult.TypeCheckSuccess + case Some(trimChars) => + CollationTypeConstraints.checkCollationCompatibility(collationId, Seq(trimChars.dataType)) + } + } + override def eval(input: InternalRow): Any = { val srcString = srcStr.eval(input).asInstanceOf[UTF8String] if (srcString == null) { @@ -1055,31 +1071,39 @@ trait String2TrimExpression extends Expression with ImplicitCastInputTypes { val srcString = evals(0) if (evals.length == 1) { + val collationIdStr = + if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) "" + else collationId + ev.copy(code = code""" - |${srcString.code} - |boolean ${ev.isNull} = false; - |UTF8String ${ev.value} = null; - |if (${srcString.isNull}) { - | ${ev.isNull} = true; - |} else { - | ${ev.value} = ${srcString.value}.$trimMethod(); - |}""".stripMargin) + |${srcString.code} + |boolean ${ev.isNull} = false; + |UTF8String ${ev.value} = null; + |if (${srcString.isNull}) { + | ${ev.isNull} = true; + |} else { + | ${ev.value} = ${srcString.value}.$trimMethod($collationIdStr); + |}""".stripMargin) } else { val trimString = evals(1) + val collationIdStr = + if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) "" + else ", " + collationId + ev.copy(code = code""" - |${srcString.code} - |boolean ${ev.isNull} = false; - |UTF8String ${ev.value} = null; - |if (${srcString.isNull}) { - | ${ev.isNull} = true; - |} else { - | ${trimString.code} - | if (${trimString.isNull}) { - | ${ev.isNull} = true; - | } else { - | ${ev.value} = ${srcString.value}.$trimMethod(${trimString.value}); - | } - |}""".stripMargin) + |${srcString.code} + |boolean ${ev.isNull} = false; + |UTF8String ${ev.value} = null; + |if (${srcString.isNull}) { + | ${ev.isNull} = true; + |} else { + | ${trimString.code} + | if (${trimString.isNull}) { + | ${ev.isNull} = true; + | } else { + | ${ev.value} = ${srcString.value}.$trimMethod(${trimString.value}$collationIdStr); + | } + |}""".stripMargin) } } @@ -1170,10 +1194,21 @@ case class StringTrim(srcStr: Expression, trimStr: Option[Expression] = None) override protected def direction: String = "BOTH" - override def doEval(srcString: UTF8String): UTF8String = srcString.trim() + override def doEval(srcString: UTF8String): UTF8String = { + if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { + srcString.trim() + } else { + srcString.trim(collationId) + } + } - override def doEval(srcString: UTF8String, trimString: UTF8String): UTF8String = - srcString.trim(trimString) + override def doEval(srcString: UTF8String, trimString: UTF8String): UTF8String = { + if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { + srcString.trim(trimString) + } else { + srcString.trim(trimString, collationId) + } + } override val trimMethod: String = "trim" @@ -1278,10 +1313,21 @@ case class StringTrimLeft(srcStr: Expression, trimStr: Option[Expression] = None override protected def direction: String = "LEADING" - override def doEval(srcString: UTF8String): UTF8String = srcString.trimLeft() + override def doEval(srcString: UTF8String): UTF8String = { + if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { + srcString.trimLeft() + } else { + srcString.trimLeft(collationId) + } + } - override def doEval(srcString: UTF8String, trimString: UTF8String): UTF8String = - srcString.trimLeft(trimString) + override def doEval(srcString: UTF8String, trimString: UTF8String): UTF8String = { + if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { + srcString.trimLeft(trimString) + } else { + srcString.trimLeft(trimString, collationId) + } + } override val trimMethod: String = "trimLeft" @@ -1339,10 +1385,21 @@ case class StringTrimRight(srcStr: Expression, trimStr: Option[Expression] = Non override protected def direction: String = "TRAILING" - override def doEval(srcString: UTF8String): UTF8String = srcString.trimRight() + override def doEval(srcString: UTF8String): UTF8String = { + if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { + srcString.trimRight() + } else { + srcString.trimRight(collationId) + } + } - override def doEval(srcString: UTF8String, trimString: UTF8String): UTF8String = - srcString.trimRight(trimString) + override def doEval(srcString: UTF8String, trimString: UTF8String): UTF8String = { + if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { + srcString.trimRight(trimString) + } else { + srcString.trimRight(trimString, collationId) + } + } override val trimMethod: String = "trimRight" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index ab2d768256c14..13459c24ce827 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -86,8 +86,326 @@ class CollationStringExpressionsSuite extends QueryTest testRepeat("UNICODE_CI", 3, "abc", 2) } - // TODO: Add more tests for other string expressions + case class StringTrimTestCase( + collation: String, + trimFunc: String, + sourceString: String, + trimString: String, + expectedResultString: String) + + test("string trim functions with collation - success") { + // scalastyle:off + + val testCases = Seq( + // Basic test cases + StringTrimTestCase("UTF8_BINARY", "TRIM", "asd", null, "asd"), + StringTrimTestCase("UTF8_BINARY", "TRIM", " asd ", null, "asd"), + StringTrimTestCase("UTF8_BINARY", "TRIM", " a世a ", null, "a世a"), + StringTrimTestCase("UTF8_BINARY", "TRIM", "asd", "x", "asd"), + StringTrimTestCase("UTF8_BINARY", "TRIM", "xxasdxx", "x", "asd"), + StringTrimTestCase("UTF8_BINARY", "TRIM", "xa世ax", "x", "a世a"), + + StringTrimTestCase("UTF8_BINARY", "BTRIM", "asd", null, "asd"), + StringTrimTestCase("UTF8_BINARY", "BTRIM", " asd ", null, "asd"), + StringTrimTestCase("UTF8_BINARY", "BTRIM", " a世a ", null, "a世a"), + StringTrimTestCase("UTF8_BINARY", "BTRIM", "asd", "x", "asd"), + StringTrimTestCase("UTF8_BINARY", "BTRIM", "xxasdxx", "x", "asd"), + StringTrimTestCase("UTF8_BINARY", "BTRIM", "xa世ax", "x", "a世a"), + + StringTrimTestCase("UTF8_BINARY", "LTRIM", "asd", null, "asd"), + StringTrimTestCase("UTF8_BINARY", "LTRIM", " asd ", null, "asd "), + StringTrimTestCase("UTF8_BINARY", "LTRIM", " a世a ", null, "a世a "), + StringTrimTestCase("UTF8_BINARY", "LTRIM", "asd", "x", "asd"), + StringTrimTestCase("UTF8_BINARY", "LTRIM", "xxasdxx", "x", "asdxx"), + StringTrimTestCase("UTF8_BINARY", "LTRIM", "xa世ax", "x", "a世ax"), + + StringTrimTestCase("UTF8_BINARY", "RTRIM", "asd", null, "asd"), + StringTrimTestCase("UTF8_BINARY", "RTRIM", " asd ", null, " asd"), + StringTrimTestCase("UTF8_BINARY", "RTRIM", " a世a ", null, " a世a"), + StringTrimTestCase("UTF8_BINARY", "RTRIM", "asd", "x", "asd"), + StringTrimTestCase("UTF8_BINARY", "RTRIM", "xxasdxx", "x", "xxasd"), + StringTrimTestCase("UTF8_BINARY", "RTRIM", "xa世ax", "x", "xa世a"), + + StringTrimTestCase("UTF8_BINARY_LCASE", "TRIM", "asd", null, "asd"), + StringTrimTestCase("UTF8_BINARY_LCASE", "TRIM", " asd ", null, "asd"), + StringTrimTestCase("UTF8_BINARY_LCASE", "TRIM", " a世a ", null, "a世a"), + StringTrimTestCase("UTF8_BINARY_LCASE", "TRIM", "asd", "x", "asd"), + StringTrimTestCase("UTF8_BINARY_LCASE", "TRIM", "xxasdxx", "x", "asd"), + StringTrimTestCase("UTF8_BINARY_LCASE", "TRIM", "xa世ax", "x", "a世a"), + + StringTrimTestCase("UTF8_BINARY_LCASE", "BTRIM", "asd", null, "asd"), + StringTrimTestCase("UTF8_BINARY_LCASE", "BTRIM", " asd ", null, "asd"), + StringTrimTestCase("UTF8_BINARY_LCASE", "BTRIM", " a世a ", null, "a世a"), + StringTrimTestCase("UTF8_BINARY_LCASE", "BTRIM", "asd", "x", "asd"), + StringTrimTestCase("UTF8_BINARY_LCASE", "BTRIM", "xxasdxx", "x", "asd"), + StringTrimTestCase("UTF8_BINARY_LCASE", "BTRIM", "xa世ax", "x", "a世a"), + + StringTrimTestCase("UTF8_BINARY_LCASE", "LTRIM", "asd", null, "asd"), + StringTrimTestCase("UTF8_BINARY_LCASE", "LTRIM", " asd ", null, "asd "), + StringTrimTestCase("UTF8_BINARY_LCASE", "LTRIM", " a世a ", null, "a世a "), + StringTrimTestCase("UTF8_BINARY_LCASE", "LTRIM", "asd", "x", "asd"), + StringTrimTestCase("UTF8_BINARY_LCASE", "LTRIM", "xxasdxx", "x", "asdxx"), + StringTrimTestCase("UTF8_BINARY_LCASE", "LTRIM", "xa世ax", "x", "a世ax"), + + StringTrimTestCase("UTF8_BINARY_LCASE", "RTRIM", "asd", null, "asd"), + StringTrimTestCase("UTF8_BINARY_LCASE", "RTRIM", " asd ", null, " asd"), + StringTrimTestCase("UTF8_BINARY_LCASE", "RTRIM", " a世a ", null, " a世a"), + StringTrimTestCase("UTF8_BINARY_LCASE", "RTRIM", "asd", "x", "asd"), + StringTrimTestCase("UTF8_BINARY_LCASE", "RTRIM", "xxasdxx", "x", "xxasd"), + StringTrimTestCase("UTF8_BINARY_LCASE", "RTRIM", "xa世ax", "x", "xa世a"), + + StringTrimTestCase("UTF8_BINARY_LCASE", "TRIM", "asd", null, "asd"), + StringTrimTestCase("UTF8_BINARY_LCASE", "TRIM", " asd ", null, "asd"), + StringTrimTestCase("UTF8_BINARY_LCASE", "TRIM", " a世a ", null, "a世a"), + StringTrimTestCase("UTF8_BINARY_LCASE", "TRIM", "asd", "x", "asd"), + StringTrimTestCase("UTF8_BINARY_LCASE", "TRIM", "xxasdxx", "x", "asd"), + StringTrimTestCase("UTF8_BINARY_LCASE", "TRIM", "xa世ax", "x", "a世a"), + + StringTrimTestCase("UNICODE", "BTRIM", "asd", null, "asd"), + StringTrimTestCase("UNICODE", "BTRIM", " asd ", null, "asd"), + StringTrimTestCase("UNICODE", "BTRIM", " a世a ", null, "a世a"), + StringTrimTestCase("UNICODE", "BTRIM", "asd", "x", "asd"), + StringTrimTestCase("UNICODE", "BTRIM", "xxasdxx", "x", "asd"), + StringTrimTestCase("UNICODE", "BTRIM", "xa世ax", "x", "a世a"), + + StringTrimTestCase("UNICODE", "LTRIM", "asd", null, "asd"), + StringTrimTestCase("UNICODE", "LTRIM", " asd ", null, "asd "), + StringTrimTestCase("UNICODE", "LTRIM", " a世a ", null, "a世a "), + StringTrimTestCase("UNICODE", "LTRIM", "asd", "x", "asd"), + StringTrimTestCase("UNICODE", "LTRIM", "xxasdxx", "x", "asdxx"), + StringTrimTestCase("UNICODE", "LTRIM", "xa世ax", "x", "a世ax"), + + StringTrimTestCase("UNICODE", "RTRIM", "asd", null, "asd"), + StringTrimTestCase("UNICODE", "RTRIM", " asd ", null, " asd"), + StringTrimTestCase("UNICODE", "RTRIM", " a世a ", null, " a世a"), + StringTrimTestCase("UNICODE", "RTRIM", "asd", "x", "asd"), + StringTrimTestCase("UNICODE", "RTRIM", "xxasdxx", "x", "xxasd"), + StringTrimTestCase("UNICODE", "RTRIM", "xa世ax", "x", "xa世a"), + + StringTrimTestCase("UNICODE_CI", "TRIM", "asd", null, "asd"), + StringTrimTestCase("UNICODE_CI", "TRIM", " asd ", null, "asd"), + StringTrimTestCase("UNICODE_CI", "TRIM", " a世a ", null, "a世a"), + StringTrimTestCase("UNICODE_CI", "TRIM", "asd", "x", "asd"), + StringTrimTestCase("UNICODE_CI", "TRIM", "xxasdxx", "x", "asd"), + StringTrimTestCase("UNICODE_CI", "TRIM", "xa世ax", "x", "a世a"), + + StringTrimTestCase("UNICODE_CI", "BTRIM", "asd", null, "asd"), + StringTrimTestCase("UNICODE_CI", "BTRIM", " asd ", null, "asd"), + StringTrimTestCase("UNICODE_CI", "BTRIM", " a世a ", null, "a世a"), + StringTrimTestCase("UNICODE_CI", "BTRIM", "asd", "x", "asd"), + StringTrimTestCase("UNICODE_CI", "BTRIM", "xxasdxx", "x", "asd"), + StringTrimTestCase("UNICODE_CI", "BTRIM", "xa世ax", "x", "a世a"), + + StringTrimTestCase("UNICODE_CI", "LTRIM", "asd", null, "asd"), + StringTrimTestCase("UNICODE_CI", "LTRIM", " asd ", null, "asd "), + StringTrimTestCase("UNICODE_CI", "LTRIM", " a世a ", null, "a世a "), + StringTrimTestCase("UNICODE_CI", "LTRIM", "asd", "x", "asd"), + StringTrimTestCase("UNICODE_CI", "LTRIM", "xxasdxx", "x", "asdxx"), + StringTrimTestCase("UNICODE_CI", "LTRIM", "xa世ax", "x", "a世ax"), + + StringTrimTestCase("UNICODE_CI", "RTRIM", "asd", null, "asd"), + StringTrimTestCase("UNICODE_CI", "RTRIM", " asd ", null, " asd"), + StringTrimTestCase("UNICODE_CI", "RTRIM", " a世a ", null, " a世a"), + StringTrimTestCase("UNICODE_CI", "RTRIM", "asd", "x", "asd"), + StringTrimTestCase("UNICODE_CI", "RTRIM", "xxasdxx", "x", "xxasd"), + StringTrimTestCase("UNICODE_CI", "RTRIM", "xa世ax", "x", "xa世a"), + + // Test cases where trimString has more than one character + StringTrimTestCase("UTF8_BINARY", "TRIM", "ddsXXXaa", "asd", "XXX"), + StringTrimTestCase("UTF8_BINARY", "BTRIM", "ddsXXXaa", "asd", "XXX"), + StringTrimTestCase("UTF8_BINARY", "LTRIM", "ddsXXXaa", "asd", "XXXaa"), + StringTrimTestCase("UTF8_BINARY", "RTRIM", "ddsXXXaa", "asd", "ddsXXX"), + + StringTrimTestCase("UTF8_BINARY_LCASE", "TRIM", "ddsXXXaa", "asd", "XXX"), + StringTrimTestCase("UTF8_BINARY_LCASE", "BTRIM", "ddsXXXaa", "asd", "XXX"), + StringTrimTestCase("UTF8_BINARY_LCASE", "LTRIM", "ddsXXXaa", "asd", "XXXaa"), + StringTrimTestCase("UTF8_BINARY_LCASE", "RTRIM", "ddsXXXaa", "asd", "ddsXXX"), + + StringTrimTestCase("UNICODE", "TRIM", "ddsXXXaa", "asd", "XXX"), + StringTrimTestCase("UNICODE", "BTRIM", "ddsXXXaa", "asd", "XXX"), + StringTrimTestCase("UNICODE", "LTRIM", "ddsXXXaa", "asd", "XXXaa"), + StringTrimTestCase("UNICODE", "RTRIM", "ddsXXXaa", "asd", "ddsXXX"), + + StringTrimTestCase("UNICODE_CI", "TRIM", "ddsXXXaa", "asd", "XXX"), + StringTrimTestCase("UNICODE_CI", "BTRIM", "ddsXXXaa", "asd", "XXX"), + StringTrimTestCase("UNICODE_CI", "LTRIM", "ddsXXXaa", "asd", "XXXaa"), + StringTrimTestCase("UNICODE_CI", "RTRIM", "ddsXXXaa", "asd", "ddsXXX"), + + // Test cases specific to collation type + // uppercase trim, lowercase src + StringTrimTestCase("UTF8_BINARY", "TRIM", "asd", "A", "asd"), + StringTrimTestCase("UTF8_BINARY_LCASE", "TRIM", "asd", "A", "sd"), + StringTrimTestCase("UNICODE", "TRIM", "asd", "A", "asd"), + StringTrimTestCase("UNICODE_CI", "TRIM", "asd", "A", "sd"), + + // lowercase trim, uppercase src + StringTrimTestCase("UTF8_BINARY", "TRIM", "ASD", "a", "ASD"), + StringTrimTestCase("UTF8_BINARY_LCASE", "TRIM", "ASD", "a", "SD"), + StringTrimTestCase("UNICODE", "TRIM", "ASD", "a", "ASD"), + StringTrimTestCase("UNICODE_CI", "TRIM", "ASD", "a", "SD"), + + // uppercase and lowercase chars of different byte-length (utf8) + StringTrimTestCase("UTF8_BINARY", "TRIM", "ẞaaaẞ", "ß", "ẞaaaẞ"), + StringTrimTestCase("UTF8_BINARY", "BTRIM", "ẞaaaẞ", "ß", "ẞaaaẞ"), + StringTrimTestCase("UTF8_BINARY", "LTRIM", "ẞaaaẞ", "ß", "ẞaaaẞ"), + StringTrimTestCase("UTF8_BINARY", "RTRIM", "ẞaaaẞ", "ß", "ẞaaaẞ"), + + StringTrimTestCase("UTF8_BINARY_LCASE", "TRIM", "ẞaaaẞ", "ß", "aaa"), + StringTrimTestCase("UTF8_BINARY_LCASE", "BTRIM", "ẞaaaẞ", "ß", "aaa"), + StringTrimTestCase("UTF8_BINARY_LCASE", "LTRIM", "ẞaaaẞ", "ß", "aaaẞ"), + StringTrimTestCase("UTF8_BINARY_LCASE", "RTRIM", "ẞaaaẞ", "ß", "ẞaaa"), + + StringTrimTestCase("UNICODE", "TRIM", "ẞaaaẞ", "ß", "ẞaaaẞ"), + StringTrimTestCase("UNICODE", "BTRIM", "ẞaaaẞ", "ß", "ẞaaaẞ"), + StringTrimTestCase("UNICODE", "LTRIM", "ẞaaaẞ", "ß", "ẞaaaẞ"), + StringTrimTestCase("UNICODE", "RTRIM", "ẞaaaẞ", "ß", "ẞaaaẞ"), + + StringTrimTestCase("UNICODE_CI", "TRIM", "ẞaaaẞ", "ß", "aaa"), + StringTrimTestCase("UNICODE_CI", "BTRIM", "ẞaaaẞ", "ß", "aaa"), + StringTrimTestCase("UNICODE_CI", "LTRIM", "ẞaaaẞ", "ß", "aaaẞ"), + StringTrimTestCase("UNICODE_CI", "RTRIM", "ẞaaaẞ", "ß", "ẞaaa"), + StringTrimTestCase("UTF8_BINARY", "TRIM", "ßaaaß", "ẞ", "ßaaaß"), + StringTrimTestCase("UTF8_BINARY", "BTRIM", "ßaaaß", "ẞ", "ßaaaß"), + StringTrimTestCase("UTF8_BINARY", "LTRIM", "ßaaaß", "ẞ", "ßaaaß"), + StringTrimTestCase("UTF8_BINARY", "RTRIM", "ßaaaß", "ẞ", "ßaaaß"), + + StringTrimTestCase("UTF8_BINARY_LCASE", "TRIM", "ßaaaß", "ẞ", "aaa"), + StringTrimTestCase("UTF8_BINARY_LCASE", "BTRIM", "ßaaaß", "ẞ", "aaa"), + StringTrimTestCase("UTF8_BINARY_LCASE", "LTRIM", "ßaaaß", "ẞ", "aaaß"), + StringTrimTestCase("UTF8_BINARY_LCASE", "RTRIM", "ßaaaß", "ẞ", "ßaaa"), + + StringTrimTestCase("UNICODE", "TRIM", "ßaaaß", "ẞ", "ßaaaß"), + StringTrimTestCase("UNICODE", "BTRIM", "ßaaaß", "ẞ", "ßaaaß"), + StringTrimTestCase("UNICODE", "LTRIM", "ßaaaß", "ẞ", "ßaaaß"), + StringTrimTestCase("UNICODE", "RTRIM", "ßaaaß", "ẞ", "ßaaaß"), + + StringTrimTestCase("UNICODE_CI", "TRIM", "ßaaaß", "ẞ", "aaa"), + StringTrimTestCase("UNICODE_CI", "BTRIM", "ßaaaß", "ẞ", "aaa"), + StringTrimTestCase("UNICODE_CI", "LTRIM", "ßaaaß", "ẞ", "aaaß"), + StringTrimTestCase("UNICODE_CI", "RTRIM", "ßaaaß", "ẞ", "ßaaa"), + + // different byte-length (utf8) chars trimmed + StringTrimTestCase("UTF8_BINARY", "TRIM", "Ëaaaẞ", "Ëẞ", "aaa"), + StringTrimTestCase("UTF8_BINARY", "BTRIM", "Ëaaaẞ", "Ëẞ", "aaa"), + StringTrimTestCase("UTF8_BINARY", "LTRIM", "Ëaaaẞ", "Ëẞ", "aaaẞ"), + StringTrimTestCase("UTF8_BINARY", "RTRIM", "Ëaaaẞ", "Ëẞ", "Ëaaa"), + + StringTrimTestCase("UTF8_BINARY_LCASE", "TRIM", "Ëaaaẞ", "Ëẞ", "aaa"), + StringTrimTestCase("UTF8_BINARY_LCASE", "BTRIM", "Ëaaaẞ", "Ëẞ", "aaa"), + StringTrimTestCase("UTF8_BINARY_LCASE", "LTRIM", "Ëaaaẞ", "Ëẞ", "aaaẞ"), + StringTrimTestCase("UTF8_BINARY_LCASE", "RTRIM", "Ëaaaẞ", "Ëẞ", "Ëaaa"), + + StringTrimTestCase("UNICODE", "TRIM", "Ëaaaẞ", "Ëẞ", "aaa"), + StringTrimTestCase("UNICODE", "BTRIM", "Ëaaaẞ", "Ëẞ", "aaa"), + StringTrimTestCase("UNICODE", "LTRIM", "Ëaaaẞ", "Ëẞ", "aaaẞ"), + StringTrimTestCase("UNICODE", "RTRIM", "Ëaaaẞ", "Ëẞ", "Ëaaa"), + + StringTrimTestCase("UNICODE_CI", "TRIM", "Ëaaaẞ", "Ëẞ", "aaa"), + StringTrimTestCase("UNICODE_CI", "BTRIM", "Ëaaaẞ", "Ëẞ", "aaa"), + StringTrimTestCase("UNICODE_CI", "LTRIM", "Ëaaaẞ", "Ëẞ", "aaaẞ"), + StringTrimTestCase("UNICODE_CI", "RTRIM", "Ëaaaẞ", "Ëẞ", "Ëaaa") + ) + + testCases.foreach(testCase => { + var df: DataFrame = null + + if (testCase.trimFunc.equalsIgnoreCase("BTRIM")) { + // BTRIM has arguments in (srcStr, trimStr) order + df = sql(s"SELECT ${testCase.trimFunc}(" + + s"COLLATE('${testCase.sourceString}', '${testCase.collation}')" + + (if (testCase.trimString == null) "" else s", COLLATE('${testCase.trimString}', '${testCase.collation}')") + + ")") + } + else { + // While other functions have arguments in (trimStr, srcStr) order + df = sql(s"SELECT ${testCase.trimFunc}(" + + (if (testCase.trimString == null) "" else s"COLLATE('${testCase.trimString}', '${testCase.collation}'), ") + + s"COLLATE('${testCase.sourceString}', '${testCase.collation}')" + + ")") + } + + checkAnswer(df = df, expectedAnswer = Row(testCase.expectedResultString)) + }) + + // scalastyle:on + } + + test("string trim functions with collation - exceptions") { + // scalastyle:off + + // TRIM + checkError( + exception = intercept[ExtendedAnalysisException] { + spark.sql(s"SELECT TRIM(COLLATE('trimStr', 'UTF8_BINARY_LCASE'), COLLATE('sourceStr', 'UNICODE'))") + }, + errorClass = "DATATYPE_MISMATCH.COLLATION_MISMATCH", + sqlState = "42K09", + parameters = Map( + "collationNameLeft" -> "UNICODE", + "collationNameRight" -> "UTF8_BINARY_LCASE", + "sqlExpr" -> "\"TRIM(BOTH collate(trimStr) FROM collate(sourceStr))\"" + ), + context = ExpectedContext(fragment = + "TRIM(COLLATE('trimStr', 'UTF8_BINARY_LCASE'), COLLATE('sourceStr', 'UNICODE'))", + start = 7, stop = 84) + ) + + // BTRIM + checkError( + exception = intercept[ExtendedAnalysisException] { + spark.sql(s"SELECT BTRIM(COLLATE('sourceStr', 'UTF8_BINARY_LCASE'), COLLATE('trimStr', 'UNICODE'))") + }, + errorClass = "DATATYPE_MISMATCH.COLLATION_MISMATCH", + sqlState = "42K09", + parameters = Map( + "collationNameLeft" -> "UTF8_BINARY_LCASE", + "collationNameRight" -> "UNICODE", + "sqlExpr" -> "\"TRIM(BOTH collate(trimStr) FROM collate(sourceStr))\"" + ), + context = ExpectedContext(fragment = + "BTRIM(COLLATE('sourceStr', 'UTF8_BINARY_LCASE'), COLLATE('trimStr', 'UNICODE'))", + start = 7, stop = 85) + ) + + // LTRIM + checkError( + exception = intercept[ExtendedAnalysisException] { + spark.sql(s"SELECT LTRIM(COLLATE('trimStr', 'UTF8_BINARY_LCASE'), COLLATE('sourceStr', 'UNICODE'))") + }, + errorClass = "DATATYPE_MISMATCH.COLLATION_MISMATCH", + sqlState = "42K09", + parameters = Map( + "collationNameLeft" -> "UNICODE", + "collationNameRight" -> "UTF8_BINARY_LCASE", + "sqlExpr" -> "\"TRIM(LEADING collate(trimStr) FROM collate(sourceStr))\"" + ), + context = ExpectedContext(fragment = + "LTRIM(COLLATE('trimStr', 'UTF8_BINARY_LCASE'), COLLATE('sourceStr', 'UNICODE'))", + start = 7, stop = 85) + ) + + // RTRIM + checkError( + exception = intercept[ExtendedAnalysisException] { + spark.sql(s"SELECT RTRIM(COLLATE('trimStr', 'UTF8_BINARY_LCASE'), COLLATE('sourceStr', 'UNICODE'))") + }, + errorClass = "DATATYPE_MISMATCH.COLLATION_MISMATCH", + sqlState = "42K09", + parameters = Map( + "collationNameLeft" -> "UNICODE", + "collationNameRight" -> "UTF8_BINARY_LCASE", + "sqlExpr" -> "\"TRIM(TRAILING collate(trimStr) FROM collate(sourceStr))\"" + ), + context = ExpectedContext(fragment = + "RTRIM(COLLATE('trimStr', 'UTF8_BINARY_LCASE'), COLLATE('sourceStr', 'UNICODE'))", + start = 7, stop = 85) + ) + + // scalastyle:on + } + + // TODO: Add more tests for other string expressions } class CollationStringExpressionsANSISuite extends CollationRegexpExpressionsSuite {