Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
2a5fce7
Update StringReplace class
miland-db Mar 25, 2024
e0ce699
Add UTF8_BINARY_LCASE collation support using custom function
miland-db Mar 25, 2024
0711242
Merge branch 'apache:master' into miland-db/string-replace
miland-db Mar 25, 2024
d2e90f8
Improve testReplace signature
miland-db Mar 26, 2024
1e41ebd
Merge branch 'master' into miland-db/string-replace
miland-db Mar 26, 2024
93c6eb7
Resolve merge problems with master
miland-db Mar 26, 2024
7a1b240
Improve scala style
miland-db Mar 26, 2024
c59d71e
Solve whitespace scala style problem
miland-db Mar 27, 2024
a5c75b3
Add lowercase StringSearch and remove lowercaseReplace
miland-db Apr 1, 2024
76878b9
Remove repeated code
miland-db Apr 1, 2024
572bd54
Improve naming of collation aware methods
miland-db Apr 2, 2024
839b39a
Merge branch 'master' into string-replace
miland-db Apr 2, 2024
e2bea13
Improve java style
miland-db Apr 3, 2024
d719fe2
Merge branch 'master' into string-replace
miland-db Apr 3, 2024
a194292
Remove unnecessary check for mathced length
miland-db Apr 3, 2024
7b6720b
Improve style in CollationFactory
miland-db Apr 3, 2024
3042d7e
Add doc comment
miland-db Apr 3, 2024
84e41a3
Improve comment style
miland-db Apr 3, 2024
cc940cb
Improve naming in getStringSearch
miland-db Apr 3, 2024
ec960b8
Merge branch 'master' into string-replace
miland-db Apr 4, 2024
4e93874
Remove type checks for collation missmatch
miland-db Apr 4, 2024
9cb0944
Remove checkInputDataTypes
miland-db Apr 4, 2024
41c3872
Add empty lines between imports
miland-db Apr 4, 2024
74f69b9
Handle all collationIds in getStringSearch
miland-db Apr 4, 2024
ea3730c
Improve Java style
miland-db Apr 5, 2024
5b2a9d3
Merge branch 'master' into string-replace
miland-db Apr 12, 2024
bc5c256
Refactor StringReplace
miland-db Apr 12, 2024
8a81536
Break lines to 100 characters
miland-db Apr 12, 2024
c456325
Refactor tests
miland-db Apr 15, 2024
68d55f2
Merge branch 'master' into string-replace
miland-db Apr 16, 2024
09f13d8
Sync with the latest master
miland-db Apr 16, 2024
67ecb47
Merge branch 'master' into string-replace
miland-db Apr 17, 2024
a67fc9b
Merge branch 'master' into string-replace
miland-db Apr 17, 2024
08d1462
Merge branch 'master' into string-replace
miland-db Apr 18, 2024
d9f56d6
Added new tests (2 failing)
miland-db Apr 18, 2024
ade12fc
Merge branch 'master' into string-replace
miland-db Apr 23, 2024
f6b4413
Merge branch 'master' into string-replace
miland-db Apr 24, 2024
0c725f9
Fix bug with case-variable lenght characters
miland-db Apr 24, 2024
816a49a
Fix java linter errors
miland-db Apr 24, 2024
feda2b9
Merge branch 'master' into string-replace
miland-db Apr 25, 2024
0ef49d0
Fix import scalastyle
miland-db Apr 25, 2024
a4747f1
Merge branch 'master' into miland-db/string-replace
uros-db Apr 26, 2024
91b32f2
Merge branch 'master' into miland-db/string-replace
uros-db Apr 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import com.ibm.icu.text.StringSearch;
import com.ibm.icu.util.ULocale;

import org.apache.spark.unsafe.UTF8StringBuilder;
import org.apache.spark.unsafe.types.UTF8String;

import java.util.ArrayList;
Expand Down Expand Up @@ -364,6 +365,44 @@ public static int execICU(final UTF8String string, final UTF8String substring,
}
}

public static class StringReplace {
public static UTF8String exec(final UTF8String src, final UTF8String search,
final UTF8String replace, final int collationId) {
CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId);
if (collation.supportsBinaryEquality) {
return execBinary(src, search, replace);
} else if (collation.supportsLowercaseEquality) {
return execLowercase(src, search, replace);
} else {
return execICU(src, search, replace, collationId);
}
}
public static String genCode(final String src, final String search, final String replace,
final int collationId) {
CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId);
String expr = "CollationSupport.StringReplace.exec";
if (collation.supportsBinaryEquality) {
return String.format(expr + "Binary(%s, %s, %s)", src, search, replace);
} else if (collation.supportsLowercaseEquality) {
return String.format(expr + "Lowercase(%s, %s, %s)", src, search, replace);
} else {
return String.format(expr + "ICU(%s, %s, %s, %d)", src, search, replace, collationId);
}
}
public static UTF8String execBinary(final UTF8String src, final UTF8String search,
final UTF8String replace) {
return src.replace(search, replace);
}
public static UTF8String execLowercase(final UTF8String src, final UTF8String search,
final UTF8String replace) {
return CollationAwareUTF8String.lowercaseReplace(src, search, replace);
}
public static UTF8String execICU(final UTF8String src, final UTF8String search,
final UTF8String replace, final int collationId) {
return CollationAwareUTF8String.replace(src, search, replace, collationId);
}
}

// TODO: Add more collation-aware string expressions.

/**
Expand Down Expand Up @@ -401,6 +440,107 @@ public static UTF8String collationAwareRegex(final UTF8String regex, final int c

private static class CollationAwareUTF8String {

private static UTF8String replace(final UTF8String src, final UTF8String search,
final UTF8String replace, final int collationId) {
// This collation aware implementation is based on existing implementation on UTF8String
if (src.numBytes() == 0 || search.numBytes() == 0) {
return src;
}

StringSearch stringSearch = CollationFactory.getStringSearch(src, search, collationId);

// Find the first occurrence of the search string.
int end = stringSearch.next();
if (end == StringSearch.DONE) {
// Search string was not found, so string is unchanged.
return src;
}

// Initialize byte positions
int c = 0;
int byteStart = 0; // position in byte
int byteEnd = 0; // position in byte
while (byteEnd < src.numBytes() && c < end) {
byteEnd += UTF8String.numBytesForFirstByte(src.getByte(byteEnd));
c += 1;
}

// At least one match was found. Estimate space needed for result.
// The 16x multiplier here is chosen to match commons-lang3's implementation.
int increase = Math.max(0, Math.abs(replace.numBytes() - search.numBytes())) * 16;
final UTF8StringBuilder buf = new UTF8StringBuilder(src.numBytes() + increase);
while (end != StringSearch.DONE) {
buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart, byteEnd - byteStart);
buf.append(replace);

// Move byteStart to the beginning of the current match
byteStart = byteEnd;
int cs = c;
// Move cs to the end of the current match
// This is necessary because the search string may contain 'multi-character' characters
while (byteStart < src.numBytes() && cs < c + stringSearch.getMatchLength()) {
byteStart += UTF8String.numBytesForFirstByte(src.getByte(byteStart));
cs += 1;
}
// Go to next match
end = stringSearch.next();
// Update byte positions
while (byteEnd < src.numBytes() && c < end) {
byteEnd += UTF8String.numBytesForFirstByte(src.getByte(byteEnd));
c += 1;
}
}
buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart,
src.numBytes() - byteStart);
return buf.build();
}

private static UTF8String lowercaseReplace(final UTF8String src, final UTF8String search,
final UTF8String replace) {
if (src.numBytes() == 0 || search.numBytes() == 0) {
return src;
}
UTF8String lowercaseString = src.toLowerCase();
UTF8String lowercaseSearch = search.toLowerCase();

int start = 0;
int end = lowercaseString.indexOf(lowercaseSearch, 0);
if (end == -1) {
// Search string was not found, so string is unchanged.
return src;
}

// Initialize byte positions
int c = 0;
int byteStart = 0; // position in byte
int byteEnd = 0; // position in byte
while (byteEnd < src.numBytes() && c < end) {
byteEnd += UTF8String.numBytesForFirstByte(src.getByte(byteEnd));
c += 1;
}

// At least one match was found. Estimate space needed for result.
// The 16x multiplier here is chosen to match commons-lang3's implementation.
int increase = Math.max(0, replace.numBytes() - search.numBytes()) * 16;
final UTF8StringBuilder buf = new UTF8StringBuilder(src.numBytes() + increase);
while (end != -1) {
buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart, byteEnd - byteStart);
buf.append(replace);
// Update character positions
start = end + lowercaseSearch.numChars();
end = lowercaseString.indexOf(lowercaseSearch, start);
// Update byte positions
byteStart = byteEnd + search.numBytes();
while (byteEnd < src.numBytes() && c < end) {
byteEnd += UTF8String.numBytesForFirstByte(src.getByte(byteEnd));
c += 1;
}
}
buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart,
src.numBytes() - byteStart);
return buf.build();
}

private static String toUpperCase(final String target, final int collationId) {
ULocale locale = CollationFactory.fetchCollation(collationId)
.collator.getLocale(ULocale.ACTUAL_LOCALE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ public void writeTo(OutputStream out) throws IOException {
* Returns the number of bytes for a code point with the first byte as `b`
* @param b The first byte of a code point
*/
private static int numBytesForFirstByte(final byte b) {
public static int numBytesForFirstByte(final byte b) {
final int offset = b & 0xFF;
byte numBytes = bytesOfCodePointInUTF8[offset];
return (numBytes == 0) ? 1: numBytes; // Skip the first byte disallowed in UTF-8
Expand Down Expand Up @@ -382,7 +382,7 @@ public boolean containsInLowerCase(final UTF8String substring) {
/**
* Returns the byte at position `i`.
*/
private byte getByte(int i) {
public byte getByte(int i) {
return Platform.getByte(base, offset + i);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,44 @@ public void testFindInSet() throws SparkException {
assertFindInSet("İo", "ab,i̇o,12", "UNICODE_CI", 2);
}

private void assertReplace(String source, String search, String replace, String collationName,
String expected) throws SparkException {
UTF8String src = UTF8String.fromString(source);
UTF8String sear = UTF8String.fromString(search);
UTF8String repl = UTF8String.fromString(replace);
int collationId = CollationFactory.collationNameToId(collationName);
assertEquals(expected, CollationSupport.StringReplace
.exec(src, sear, repl, collationId).toString());
}

@Test
public void testReplace() throws SparkException {
assertReplace("r世eplace", "pl", "123", "UTF8_BINARY", "r世e123ace");
assertReplace("replace", "pl", "", "UTF8_BINARY", "reace");
assertReplace("repl世ace", "Pl", "", "UTF8_BINARY", "repl世ace");
assertReplace("replace", "", "123", "UTF8_BINARY", "replace");
assertReplace("abcabc", "b", "12", "UTF8_BINARY", "a12ca12c");
assertReplace("abcdabcd", "bc", "", "UTF8_BINARY", "adad");
assertReplace("r世eplace", "pl", "xx", "UTF8_BINARY_LCASE", "r世exxace");
assertReplace("repl世ace", "PL", "AB", "UTF8_BINARY_LCASE", "reAB世ace");
assertReplace("Replace", "", "123", "UTF8_BINARY_LCASE", "Replace");
assertReplace("re世place", "世", "x", "UTF8_BINARY_LCASE", "rexplace");
assertReplace("abcaBc", "B", "12", "UTF8_BINARY_LCASE", "a12ca12c");
assertReplace("AbcdabCd", "Bc", "", "UTF8_BINARY_LCASE", "Adad");
assertReplace("re世place", "plx", "123", "UNICODE", "re世place");
assertReplace("世Replace", "re", "", "UNICODE", "世Replace");
assertReplace("replace世", "", "123", "UNICODE", "replace世");
assertReplace("aBc世abc", "b", "12", "UNICODE", "aBc世a12c");
assertReplace("abcdabcd", "bc", "", "UNICODE", "adad");
assertReplace("replace", "plx", "123", "UNICODE_CI", "replace");
assertReplace("Replace", "re", "", "UNICODE_CI", "place");
assertReplace("replace", "", "123", "UNICODE_CI", "replace");
assertReplace("aBc世abc", "b", "12", "UNICODE_CI", "a12c世a12c");
assertReplace("a世Bcdabcd", "bC", "", "UNICODE_CI", "a世dad");
assertReplace("abİo12i̇o", "i̇o", "xx", "UNICODE_CI", "abxx12xx");
assertReplace("abi̇o12i̇o", "İo", "yy", "UNICODE_CI", "abyy12yy");
}

// TODO: Test more collation-aware string expressions.

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ object CollationTypeCasts extends TypeCoercionRule {

case otherExpr @ (
_: In | _: InSubquery | _: CreateArray | _: ArrayJoin | _: Concat | _: Greatest | _: Least |
_: Coalesce | _: BinaryExpression | _: ConcatWs | _: Mask) =>
_: Coalesce | _: BinaryExpression | _: ConcatWs | _: Mask | _: StringReplace) =>
val newChildren = collateToSingleType(otherExpr.children)
otherExpr.withNewChildren(newChildren)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -710,23 +710,25 @@ case class EndsWith(left: Expression, right: Expression) extends StringPredicate
case class StringReplace(srcExpr: Expression, searchExpr: Expression, replaceExpr: Expression)
extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {

final lazy val collationId: Int = first.dataType.asInstanceOf[StringType].collationId

def this(srcExpr: Expression, searchExpr: Expression) = {
this(srcExpr, searchExpr, Literal(""))
}

override def nullSafeEval(srcEval: Any, searchEval: Any, replaceEval: Any): Any = {
srcEval.asInstanceOf[UTF8String].replace(
searchEval.asInstanceOf[UTF8String], replaceEval.asInstanceOf[UTF8String])
CollationSupport.StringReplace.exec(srcEval.asInstanceOf[UTF8String],
searchEval.asInstanceOf[UTF8String], replaceEval.asInstanceOf[UTF8String], collationId);
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (src, search, replace) => {
s"""${ev.value} = $src.replace($search, $replace);"""
})
defineCodeGen(ctx, ev, (src, search, replace) =>
CollationSupport.StringReplace.genCode(src, search, replace, collationId))
}

override def dataType: DataType = StringType
override def inputTypes: Seq[DataType] = Seq(StringType, StringType, StringType)
override def dataType: DataType = srcExpr.dataType
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeAnyCollation, StringTypeAnyCollation, StringTypeAnyCollation)
override def first: Expression = srcExpr
override def second: Expression = searchExpr
override def third: Expression = replaceExpr
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql

import org.apache.spark.SparkConf
import org.apache.spark.sql.catalyst.util.CollationFactory
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, DataType, IntegerType, StringType}
Expand Down Expand Up @@ -217,6 +218,41 @@ class CollationStringExpressionsSuite
assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT")
}

test("Support Replace string expression with collation") {
case class ReplaceTestCase[R](source: String, search: String, replace: String,
c: String, result: R)
val testCases = Seq(
// scalastyle:off
ReplaceTestCase("r世eplace", "pl", "123", "UTF8_BINARY", "r世e123ace"),
ReplaceTestCase("repl世ace", "PL", "AB", "UTF8_BINARY_LCASE", "reAB世ace"),
ReplaceTestCase("abcdabcd", "bc", "", "UNICODE", "adad"),
ReplaceTestCase("aBc世abc", "b", "12", "UNICODE_CI", "a12c世a12c"),
ReplaceTestCase("abi̇o12i̇o", "İo", "yy", "UNICODE_CI", "abyy12yy"),
ReplaceTestCase("abİo12i̇o", "i̇o", "xx", "UNICODE_CI", "abxx12xx")
// scalastyle:on
)
testCases.foreach(t => {
val query = s"SELECT replace(collate('${t.source}','${t.c}'),collate('${t.search}'," +
s"'${t.c}'),collate('${t.replace}','${t.c}'))"
// Result & data type
checkAnswer(sql(query), Row(t.result))
assert(sql(query).schema.fields.head.dataType.sameType(
StringType(CollationFactory.collationNameToId(t.c))))
// Implicit casting
checkAnswer(sql(s"SELECT replace(collate('${t.source}','${t.c}'),'${t.search}'," +
s"'${t.replace}')"), Row(t.result))
checkAnswer(sql(s"SELECT replace('${t.source}',collate('${t.search}','${t.c}')," +
s"'${t.replace}')"), Row(t.result))
checkAnswer(sql(s"SELECT replace('${t.source}','${t.search}'," +
s"collate('${t.replace}','${t.c}'))"), Row(t.result))
})
// Collation mismatch
val collationMismatch = intercept[AnalysisException] {
sql("SELECT startswith(collate('abcde', 'UTF8_BINARY_LCASE'),collate('C', 'UNICODE_CI'))")
}
assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT")
}

test("Support EndsWith string expression with collation") {
// Supported collations
case class EndsWithTestCase[R](l: String, r: String, c: String, result: R)
Expand Down