From 8ff10ac16d018a4af5f8a36341ac405ac779ac2c Mon Sep 17 00:00:00 2001 From: "Maarten A. Breddels" Date: Wed, 25 Nov 2020 17:10:56 +0100 Subject: [PATCH 1/3] ARROW-10306: [C++] Add string replacement kernel --- cpp/src/arrow/compute/api_scalar.h | 12 ++ .../arrow/compute/kernels/scalar_string.cc | 204 ++++++++++++++++++ .../compute/kernels/scalar_string_test.cc | 37 ++++ docs/source/cpp/compute.rst | 51 +++-- python/pyarrow/_compute.pyx | 20 ++ python/pyarrow/compute.py | 1 + python/pyarrow/includes/libarrow.pxd | 8 + python/pyarrow/tests/test_compute.py | 12 ++ 8 files changed, 328 insertions(+), 17 deletions(-) diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index 0d95092c95b..13450d54a50 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -68,6 +68,18 @@ struct ARROW_EXPORT SplitPatternOptions : public SplitOptions { std::string pattern; }; +struct ARROW_EXPORT ReplaceSubstringOptions : public FunctionOptions { + explicit ReplaceSubstringOptions(std::string pattern, std::string replacement, + int64_t max_replacements = -1) + : pattern(pattern), replacement(replacement), max_replacements(max_replacements) {} + + /// Pattern to match, literal, or regular expression depending on which kernel is used + std::string pattern; + /// String to replace the pattern with + std::string replacement; + int64_t max_replacements; +}; + /// Options for IsIn and IndexIn functions struct ARROW_EXPORT SetLookupOptions : public FunctionOptions { explicit SetLookupOptions(Datum value_set, bool skip_nulls = false) diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc index 88c91a18818..02913b3b768 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string.cc @@ -23,6 +23,10 @@ #include #endif +#ifdef ARROW_WITH_RE2 +#include +#endif + #include "arrow/array/builder_binary.h" #include "arrow/array/builder_nested.h" #include "arrow/buffer_builder.h" @@ -1230,6 +1234,198 @@ void AddSplit(FunctionRegistry* registry) { #endif } +// ---------------------------------------------------------------------- +// replace substring + +template +struct ReplaceSubStringBase { + using ArrayType = typename TypeTraits::ArrayType; + using ScalarType = typename TypeTraits::ScalarType; + using BuilderType = typename TypeTraits::BuilderType; + using offset_type = typename Type::offset_type; + using ValueDataBuilder = TypedBufferBuilder; + using OffsetBuilder = TypedBufferBuilder; + using State = OptionsWrapper; + + static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + Derived derived(ctx, State::Get(ctx)); + if (ctx->status().ok()) { + derived.Replace(ctx, batch, out); + } + } + void Replace(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + std::shared_ptr value_data_builder = + std::make_shared(); + std::shared_ptr offset_builder = std::make_shared(); + + if (batch[0].kind() == Datum::ARRAY) { + // We already know how many strings we have, so we can use Reserve/UnsafeAppend + KERNEL_RETURN_IF_ERROR(ctx, offset_builder->Reserve(batch[0].array()->length)); + + const ArrayData& input = *batch[0].array(); + KERNEL_RETURN_IF_ERROR(ctx, offset_builder->Append(0)); // offsets start at 0 + KERNEL_RETURN_IF_ERROR( + ctx, VisitArrayDataInline( + input, + [&](util::string_view s) { + RETURN_NOT_OK(static_cast(*this).ReplaceString( + s, value_data_builder.get())); + offset_builder->UnsafeAppend( + static_cast(value_data_builder->length())); + return Status::OK(); + }, + [&]() { + // offset for null value + offset_builder->UnsafeAppend( + static_cast(value_data_builder->length())); + return Status::OK(); + })); + ArrayData* output = out->mutable_array(); + KERNEL_RETURN_IF_ERROR(ctx, value_data_builder->Finish(&output->buffers[2])); + KERNEL_RETURN_IF_ERROR(ctx, offset_builder->Finish(&output->buffers[1])); + } else { + const auto& input = checked_cast(*batch[0].scalar()); + auto result = std::make_shared(); + if (input.is_valid) { + util::string_view s = static_cast(*input.value); + KERNEL_RETURN_IF_ERROR( + ctx, static_cast(*this).ReplaceString(s, value_data_builder.get())); + KERNEL_RETURN_IF_ERROR(ctx, value_data_builder->Finish(&result->value)); + result->is_valid = true; + } + out->value = result; + } + } +}; + +template +struct ReplaceSubString : ReplaceSubStringBase> { + using Base = ReplaceSubStringBase>; + using ValueDataBuilder = typename Base::ValueDataBuilder; + using offset_type = typename Base::offset_type; + + ReplaceSubstringOptions options; + explicit ReplaceSubString(KernelContext* ctx, ReplaceSubstringOptions options) + : options(options) {} + + Status ReplaceString(util::string_view s, ValueDataBuilder* builder) { + const char* i = s.begin(); + const char* end = s.end(); + int64_t max_replacements = options.max_replacements; + while ((i < end) && (max_replacements != 0)) { + const char* pos = + std::search(i, end, options.pattern.begin(), options.pattern.end()); + if (pos == end) { + RETURN_NOT_OK(builder->Append(reinterpret_cast(i), + static_cast(end - i))); + i = end; + } else { + // the string before the pattern + RETURN_NOT_OK(builder->Append(reinterpret_cast(i), + static_cast(pos - i))); + // the replacement + RETURN_NOT_OK( + builder->Append(reinterpret_cast(options.replacement.data()), + options.replacement.length())); + // skip pattern + i = pos + options.pattern.length(); + max_replacements--; + } + } + // if we exited early due to max_replacements, add the trailing part + RETURN_NOT_OK(builder->Append(reinterpret_cast(i), + static_cast(end - i))); + return Status::OK(); + } +}; + +const FunctionDoc replace_substring_doc( + "Replace non-overlapping substrings that match pattern by replacement", + ("For each string in `strings`, replace non-overlapping substrings that match\n" + "`pattern` by `replacement`. If `max_replacements != -1`, it determines the\n" + "maximum amount of replacements made, counting from the left. Null values emit\n" + "null."), + {"strings"}, "ReplaceSubstringOptions"); + +#ifdef ARROW_WITH_RE2 +template +struct ReplaceSubStringRE2 : ReplaceSubStringBase> { + using Base = ReplaceSubStringBase>; + using ValueDataBuilder = typename Base::ValueDataBuilder; + using offset_type = typename Base::offset_type; + + ReplaceSubstringOptions options; + RE2 regex_find; + RE2 regex_replacement; + explicit ReplaceSubStringRE2(KernelContext* ctx, ReplaceSubstringOptions options) + : options(options), + regex_find("(" + options.pattern + ")"), + regex_replacement(options.pattern) { + // Using RE2::FindAndConsume we can only find the pattern if it is a group, therefore + // we have 2 regex, one with () around it, one without. + if (!(regex_find.ok() && regex_replacement.ok())) { + ctx->SetStatus(Status::Invalid("Regular expression error")); + return; + } + } + Status ReplaceString(util::string_view s, ValueDataBuilder* builder) { + re2::StringPiece replacement(options.replacement); + if (options.max_replacements == -1) { + std::string s_copy(s.to_string()); + re2::RE2::GlobalReplace(&s_copy, regex_replacement, replacement); + RETURN_NOT_OK(builder->Append(reinterpret_cast(s_copy.data()), + s_copy.length())); + return Status::OK(); + } + // Since RE2 does not have the concept of max_replacements, we have to do some work + // ourselves. + // We might do this faster similar to RE2::GlobalReplace using Match and Rewrite + const char* i = s.begin(); + const char* end = s.end(); + re2::StringPiece piece(s.data(), s.length()); + + int64_t max_replacements = options.max_replacements; + while ((i < end) && (max_replacements != 0)) { + std::string found; + if (!re2::RE2::FindAndConsume(&piece, regex_find, &found)) { + RETURN_NOT_OK(builder->Append(reinterpret_cast(i), + static_cast(end - i))); + i = end; + } else { + // wind back to the beginning of the match + const char* pos = piece.begin() - found.length(); + // the string before the pattern + RETURN_NOT_OK(builder->Append(reinterpret_cast(i), + static_cast(pos - i))); + // replace the pattern in what we found + if (!re2::RE2::Replace(&found, regex_replacement, replacement)) { + return Status::Invalid("Regex found, but replacement failed"); + } + RETURN_NOT_OK(builder->Append(reinterpret_cast(found.data()), + static_cast(found.length()))); + // skip pattern + i = piece.begin(); + max_replacements--; + } + } + // If we exited early due to max_replacements, add the trailing part + RETURN_NOT_OK(builder->Append(reinterpret_cast(i), + static_cast(end - i))); + return Status::OK(); + } +}; + +const FunctionDoc replace_substring_regex_doc( + "Replace non-overlapping substrings that match regex `pattern` by `replacement`", + ("For each string in `strings`, replace non-overlapping substrings that match the\n" + "regular expression `pattern` by `replacement` using the Google RE2 library.\n" + "If `max_replacements != -1`, it determines the maximum amount of replacements\n" + "made, counting from the left. Note that if the pattern contains groups,\n" + "backreferencing macan be used. Null values emit null."), + {"strings"}, "ReplaceSubstringOptions"); + +#endif + // ---------------------------------------------------------------------- // strptime string parsing @@ -1904,6 +2100,14 @@ void RegisterScalarStringAscii(FunctionRegistry* registry) { AddBinaryLength(registry); AddUtf8Length(registry); AddMatchSubstring(registry); + MakeUnaryStringBatchKernelWithState("replace_substring", registry, + &replace_substring_doc, + MemAllocation::NO_PREALLOCATE); +#ifdef ARROW_WITH_RE2 + MakeUnaryStringBatchKernelWithState( + "replace_substring_regex", registry, &replace_substring_regex_doc, + MemAllocation::NO_PREALLOCATE); +#endif AddStrptime(registry); } diff --git a/cpp/src/arrow/compute/kernels/scalar_string_test.cc b/cpp/src/arrow/compute/kernels/scalar_string_test.cc index 281fcb5c7aa..4262681ba4d 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_test.cc @@ -48,6 +48,14 @@ class BaseTestStringKernels : public ::testing::Test { CheckScalarUnary(func_name, type(), json_input, out_ty, json_expected, options); } + void CheckBinaryScalar(std::string func_name, std::string json_left_input, + std::string json_right_scalar, std::shared_ptr out_ty, + std::string json_expected, + const FunctionOptions* options = nullptr) { + CheckScalarBinaryScalar(func_name, type(), json_left_input, json_right_scalar, out_ty, + json_expected, options); + } + std::shared_ptr type() { return TypeTraits::type_singleton(); } std::shared_ptr offset_type() { @@ -422,6 +430,35 @@ TYPED_TEST(TestStringKernels, SplitWhitespaceUTF8Reverse) { &options_max); } +TYPED_TEST(TestStringKernels, ReplaceSubstringNormal) { + ReplaceSubstringOptions options{"foo", "bazz"}; + this->CheckUnary("replace_substring", R"(["foo", "this foo that foo", null])", + this->type(), R"(["bazz", "this bazz that bazz", null])", &options); +} + +#ifdef ARROW_WITH_RE2 +TYPED_TEST(TestStringKernels, ReplaceSubstringRegex) { + ReplaceSubstringOptions options_regex{"(fo+)\\s*", "\\1-bazz"}; + this->CheckUnary("replace_substring_regex", R"(["foo ", "this foo that foo", null])", + this->type(), R"(["foo-bazz", "this foo-bazzthat foo-bazz", null])", + &options_regex); + // make sure we match non-overlapping + ReplaceSubstringOptions options_regex2{"(a.a)", "aba\\1"}; + this->CheckUnary("replace_substring_regex", R"(["aaaaaa"])", this->type(), + R"(["abaaaaabaaaa"])", &options_regex2); +} + +TYPED_TEST(TestStringKernels, ReplaceSubstringMax) { + ReplaceSubstringOptions options1{"foo", "bazz", 1}; + this->CheckUnary("replace_substring", R"(["foo", "this foo that foo", null])", + this->type(), R"(["bazz", "this bazz that foo", null])", &options1); + ReplaceSubstringOptions options_regex1{"(fo+)\\s*", "\\1-bazz", 1}; + this->CheckUnary("replace_substring_regex", R"(["foo ", "this foo that foo", null])", + this->type(), R"(["foo-bazz", "this foo-bazzthat foo", null])", + &options_regex1); +} +#endif + TYPED_TEST(TestStringKernels, Strptime) { std::string input1 = R"(["5/1/2020", null, "12/11/1900"])"; std::string output1 = R"(["2020-05-01", null, "1900-12-11"])"; diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index e4eaa94bc59..065b80736aa 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -426,21 +426,25 @@ The third set of functions examines string elements on a byte-per-byte basis: String transforms ~~~~~~~~~~~~~~~~~ -+--------------------------+------------+-------------------------+---------------------+---------+ -| Function name | Arity | Input types | Output type | Notes | -+==========================+============+=========================+=====================+=========+ -| ascii_lower | Unary | String-like | String-like | \(1) | -+--------------------------+------------+-------------------------+---------------------+---------+ -| ascii_upper | Unary | String-like | String-like | \(1) | -+--------------------------+------------+-------------------------+---------------------+---------+ -| binary_length | Unary | Binary- or String-like | Int32 or Int64 | \(2) | -+--------------------------+------------+-------------------------+---------------------+---------+ -| utf8_length | Unary | String-like | Int32 or Int64 | \(3) | -+--------------------------+------------+-------------------------+---------------------+---------+ -| utf8_lower | Unary | String-like | String-like | \(4) | -+--------------------------+------------+-------------------------+---------------------+---------+ -| utf8_upper | Unary | String-like | String-like | \(4) | -+--------------------------+------------+-------------------------+---------------------+---------+ ++--------------------------+------------+-------------------------+---------------------+-------------------------------------------------+ +| Function name | Arity | Input types | Output type | Notes | Options class | ++==========================+============+=========================+=====================+=========+=======================================+ +| ascii_lower | Unary | String-like | String-like | \(1) | | ++--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+ +| ascii_upper | Unary | String-like | String-like | \(1) | | ++--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+ +| binary_length | Unary | Binary- or String-like | Int32 or Int64 | \(2) | | ++--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+ +| replace_substring | Unary | String-like | String-like | \(3) | :struct:`ReplaceSubstringOptions` | ++--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+ +| replace_substring_regex | Unary | String-like | String-like | \(4) | :struct:`ReplaceSubstringOptions` | ++--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+ +| utf8_length | Unary | String-like | Int32 or Int64 | \(5) | | ++--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+ +| utf8_lower | Unary | String-like | String-like | \(6) | | ++--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+ +| utf8_upper | Unary | String-like | String-like | \(6) | | ++--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+ * \(1) Each ASCII character in the input is converted to lowercase or @@ -449,10 +453,23 @@ String transforms * \(2) Output is the physical length in bytes of each input element. Output type is Int32 for Binary / String, Int64 for LargeBinary / LargeString. -* \(3) Output is the number of characters (not bytes) of each input element. +* \(3) Replace non-overlapping substrings that match to + :member:`ReplaceSubstringOptions::pattern` by + :member:`ReplaceSubstringOptions::replacement`. If + :member:`ReplaceSubstringOptions::max_replacements` != -1, it determines the + maximum number of replacements made, counting from the left. + +* \(4) Replace non-overlapping substrings that match to the regular expression + :member:`ReplaceSubstringOptions::pattern` by + :member:`ReplaceSubstringOptions::replacement`, using the Google RE2 library. If + :member:`ReplaceSubstringOptions::max_replacements` != -1, it determines the + maximum number of replacements made, counting from the left. Note that if the + pattern contains groups, backreferencing can be used. + +* \(5) Output is the number of characters (not bytes) of each input element. Output type is Int32 for String, Int64 for LargeString. -* \(4) Each UTF8-encoded character in the input is converted to lowercase or +* \(6) Each UTF8-encoded character in the input is converted to lowercase or uppercase. diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index f3a8eb860d4..1515bdcfd36 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -684,6 +684,26 @@ class TrimOptions(_TrimOptions): self._set_options(characters) +cdef class _ReplaceSubstringOptions(FunctionOptions): + cdef: + unique_ptr[CReplaceSubstringOptions] replace_substring_options + + cdef const CFunctionOptions* get_options(self) except NULL: + return self.replace_substring_options.get() + + def _set_options(self, pattern, replacement, max_replacements): + self.replace_substring_options.reset( + new CReplaceSubstringOptions(tobytes(pattern), + tobytes(replacement), + max_replacements) + ) + + +class ReplaceSubstringOptions(_ReplaceSubstringOptions): + def __init__(self, pattern, replacement, max_replacements=-1): + self._set_options(pattern, replacement, max_replacements) + + cdef class _FilterOptions(FunctionOptions): cdef: unique_ptr[CFilterOptions] filter_options diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index 2cdd843d81a..1b46a08c402 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -42,6 +42,7 @@ PartitionNthOptions, ProjectOptions, QuantileOptions, + ReplaceSubstringOptions, SetLookupOptions, SortOptions, StrptimeOptions, diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 61deb658b0c..ebdcd08334c 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -1815,6 +1815,14 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: c_bool reverse) c_string pattern + cdef cppclass CReplaceSubstringOptions \ + "arrow::compute::ReplaceSubstringOptions"(CFunctionOptions): + CReplaceSubstringOptions(c_string pattern, c_string replacement, + int64_t max_replacements) + c_string pattern + c_string replacement + int64_t max_replacements + cdef cppclass CCastOptions" arrow::compute::CastOptions"(CFunctionOptions): CCastOptions() CCastOptions(c_bool safe) diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index 112629fc702..160375f93bd 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -579,6 +579,18 @@ def test_string_py_compat_boolean(function_name, variant): assert arrow_func(ar)[0].as_py() == getattr(c, py_name)() +def test_replace_plain(): + ar = pa.array(['foo', 'food', None]) + ar = pc.replace_substring(ar, pattern='foo', replacement='bar') + assert ar.tolist() == ['bar', 'bard', None] + + +def test_replace_regex(): + ar = pa.array(['foo', 'mood', None]) + ar = pc.replace_substring_regex(ar, pattern='(.)oo', replacement=r'\100') + assert ar.tolist() == ['f00', 'm00d', None] + + @pytest.mark.parametrize(('ty', 'values'), all_array_types) def test_take(ty, values): arr = pa.array(values, type=ty) From 9d8f0eae8d9b7eef6a41926bcf73db2a6290770b Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Thu, 25 Mar 2021 18:35:03 +0100 Subject: [PATCH 2/3] Avoid parameterizing the core string_replace implementation --- cpp/src/arrow/compute/api_scalar.h | 1 + .../arrow/compute/kernels/scalar_string.cc | 163 +++++++++--------- .../compute/kernels/scalar_string_test.cc | 21 ++- 3 files changed, 101 insertions(+), 84 deletions(-) diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index 13450d54a50..730836bd118 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -77,6 +77,7 @@ struct ARROW_EXPORT ReplaceSubstringOptions : public FunctionOptions { std::string pattern; /// String to replace the pattern with std::string replacement; + /// Max number of substrings to replace (-1 means unbounded) int64_t max_replacements; }; diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc index 02913b3b768..39869879561 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string.cc @@ -1235,62 +1235,60 @@ void AddSplit(FunctionRegistry* registry) { } // ---------------------------------------------------------------------- -// replace substring +// Replace substring (plain, regex) -template -struct ReplaceSubStringBase { - using ArrayType = typename TypeTraits::ArrayType; +template +struct ReplaceSubString { using ScalarType = typename TypeTraits::ScalarType; - using BuilderType = typename TypeTraits::BuilderType; using offset_type = typename Type::offset_type; using ValueDataBuilder = TypedBufferBuilder; using OffsetBuilder = TypedBufferBuilder; using State = OptionsWrapper; static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - Derived derived(ctx, State::Get(ctx)); - if (ctx->status().ok()) { - derived.Replace(ctx, batch, out); + // TODO Cache replacer accross invocations (for regex compilation) + Replacer replacer{ctx, State::Get(ctx)}; + if (!ctx->HasError()) { + Replace(ctx, batch, &replacer, out); } } - void Replace(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - std::shared_ptr value_data_builder = - std::make_shared(); - std::shared_ptr offset_builder = std::make_shared(); + + static void Replace(KernelContext* ctx, const ExecBatch& batch, Replacer* replacer, + Datum* out) { + ValueDataBuilder value_data_builder(ctx->memory_pool()); + OffsetBuilder offset_builder(ctx->memory_pool()); if (batch[0].kind() == Datum::ARRAY) { // We already know how many strings we have, so we can use Reserve/UnsafeAppend - KERNEL_RETURN_IF_ERROR(ctx, offset_builder->Reserve(batch[0].array()->length)); + KERNEL_RETURN_IF_ERROR(ctx, offset_builder.Reserve(batch[0].array()->length)); + offset_builder.UnsafeAppend(0); // offsets start at 0 const ArrayData& input = *batch[0].array(); - KERNEL_RETURN_IF_ERROR(ctx, offset_builder->Append(0)); // offsets start at 0 KERNEL_RETURN_IF_ERROR( ctx, VisitArrayDataInline( input, [&](util::string_view s) { - RETURN_NOT_OK(static_cast(*this).ReplaceString( - s, value_data_builder.get())); - offset_builder->UnsafeAppend( - static_cast(value_data_builder->length())); + RETURN_NOT_OK(replacer->ReplaceString(s, &value_data_builder)); + offset_builder.UnsafeAppend( + static_cast(value_data_builder.length())); return Status::OK(); }, [&]() { // offset for null value - offset_builder->UnsafeAppend( - static_cast(value_data_builder->length())); + offset_builder.UnsafeAppend( + static_cast(value_data_builder.length())); return Status::OK(); })); ArrayData* output = out->mutable_array(); - KERNEL_RETURN_IF_ERROR(ctx, value_data_builder->Finish(&output->buffers[2])); - KERNEL_RETURN_IF_ERROR(ctx, offset_builder->Finish(&output->buffers[1])); + KERNEL_RETURN_IF_ERROR(ctx, value_data_builder.Finish(&output->buffers[2])); + KERNEL_RETURN_IF_ERROR(ctx, offset_builder.Finish(&output->buffers[1])); } else { const auto& input = checked_cast(*batch[0].scalar()); auto result = std::make_shared(); if (input.is_valid) { util::string_view s = static_cast(*input.value); - KERNEL_RETURN_IF_ERROR( - ctx, static_cast(*this).ReplaceString(s, value_data_builder.get())); - KERNEL_RETURN_IF_ERROR(ctx, value_data_builder->Finish(&result->value)); + KERNEL_RETURN_IF_ERROR(ctx, replacer->ReplaceString(s, &value_data_builder)); + KERNEL_RETURN_IF_ERROR(ctx, value_data_builder.Finish(&result->value)); result->is_valid = true; } out->value = result; @@ -1298,85 +1296,71 @@ struct ReplaceSubStringBase { } }; -template -struct ReplaceSubString : ReplaceSubStringBase> { - using Base = ReplaceSubStringBase>; - using ValueDataBuilder = typename Base::ValueDataBuilder; - using offset_type = typename Base::offset_type; +struct PlainSubStringReplacer { + const ReplaceSubstringOptions& options_; - ReplaceSubstringOptions options; - explicit ReplaceSubString(KernelContext* ctx, ReplaceSubstringOptions options) - : options(options) {} + PlainSubStringReplacer(KernelContext* ctx, const ReplaceSubstringOptions& options) + : options_(options) {} - Status ReplaceString(util::string_view s, ValueDataBuilder* builder) { + Status ReplaceString(util::string_view s, TypedBufferBuilder* builder) { const char* i = s.begin(); const char* end = s.end(); - int64_t max_replacements = options.max_replacements; + int64_t max_replacements = options_.max_replacements; while ((i < end) && (max_replacements != 0)) { const char* pos = - std::search(i, end, options.pattern.begin(), options.pattern.end()); + std::search(i, end, options_.pattern.begin(), options_.pattern.end()); if (pos == end) { RETURN_NOT_OK(builder->Append(reinterpret_cast(i), - static_cast(end - i))); + static_cast(end - i))); i = end; } else { // the string before the pattern RETURN_NOT_OK(builder->Append(reinterpret_cast(i), - static_cast(pos - i))); + static_cast(pos - i))); // the replacement RETURN_NOT_OK( - builder->Append(reinterpret_cast(options.replacement.data()), - options.replacement.length())); + builder->Append(reinterpret_cast(options_.replacement.data()), + options_.replacement.length())); // skip pattern - i = pos + options.pattern.length(); + i = pos + options_.pattern.length(); max_replacements--; } } // if we exited early due to max_replacements, add the trailing part RETURN_NOT_OK(builder->Append(reinterpret_cast(i), - static_cast(end - i))); + static_cast(end - i))); return Status::OK(); } }; -const FunctionDoc replace_substring_doc( - "Replace non-overlapping substrings that match pattern by replacement", - ("For each string in `strings`, replace non-overlapping substrings that match\n" - "`pattern` by `replacement`. If `max_replacements != -1`, it determines the\n" - "maximum amount of replacements made, counting from the left. Null values emit\n" - "null."), - {"strings"}, "ReplaceSubstringOptions"); - #ifdef ARROW_WITH_RE2 -template -struct ReplaceSubStringRE2 : ReplaceSubStringBase> { - using Base = ReplaceSubStringBase>; - using ValueDataBuilder = typename Base::ValueDataBuilder; - using offset_type = typename Base::offset_type; - - ReplaceSubstringOptions options; - RE2 regex_find; - RE2 regex_replacement; - explicit ReplaceSubStringRE2(KernelContext* ctx, ReplaceSubstringOptions options) - : options(options), - regex_find("(" + options.pattern + ")"), - regex_replacement(options.pattern) { - // Using RE2::FindAndConsume we can only find the pattern if it is a group, therefore - // we have 2 regex, one with () around it, one without. - if (!(regex_find.ok() && regex_replacement.ok())) { +struct RegexSubStringReplacer { + const ReplaceSubstringOptions& options_; + const RE2 regex_find_; + const RE2 regex_replacement_; + + // Using RE2::FindAndConsume we can only find the pattern if it is a group, therefore + // we have 2 regexes, one with () around it, one without. + RegexSubStringReplacer(KernelContext* ctx, const ReplaceSubstringOptions& options) + : options_(options), + regex_find_("(" + options_.pattern + ")"), + regex_replacement_(options_.pattern) { + if (!(regex_find_.ok() && regex_replacement_.ok())) { ctx->SetStatus(Status::Invalid("Regular expression error")); return; } } - Status ReplaceString(util::string_view s, ValueDataBuilder* builder) { - re2::StringPiece replacement(options.replacement); - if (options.max_replacements == -1) { + + Status ReplaceString(util::string_view s, TypedBufferBuilder* builder) { + re2::StringPiece replacement(options_.replacement); + if (options_.max_replacements == -1) { std::string s_copy(s.to_string()); - re2::RE2::GlobalReplace(&s_copy, regex_replacement, replacement); + re2::RE2::GlobalReplace(&s_copy, regex_replacement_, replacement); RETURN_NOT_OK(builder->Append(reinterpret_cast(s_copy.data()), s_copy.length())); return Status::OK(); } + // Since RE2 does not have the concept of max_replacements, we have to do some work // ourselves. // We might do this faster similar to RE2::GlobalReplace using Match and Rewrite @@ -1384,25 +1368,25 @@ struct ReplaceSubStringRE2 : ReplaceSubStringBaseAppend(reinterpret_cast(i), - static_cast(end - i))); + static_cast(end - i))); i = end; } else { // wind back to the beginning of the match const char* pos = piece.begin() - found.length(); // the string before the pattern RETURN_NOT_OK(builder->Append(reinterpret_cast(i), - static_cast(pos - i))); + static_cast(pos - i))); // replace the pattern in what we found - if (!re2::RE2::Replace(&found, regex_replacement, replacement)) { + if (!re2::RE2::Replace(&found, regex_replacement_, replacement)) { return Status::Invalid("Regex found, but replacement failed"); } RETURN_NOT_OK(builder->Append(reinterpret_cast(found.data()), - static_cast(found.length()))); + static_cast(found.length()))); // skip pattern i = piece.begin(); max_replacements--; @@ -1410,10 +1394,26 @@ struct ReplaceSubStringRE2 : ReplaceSubStringBaseAppend(reinterpret_cast(i), - static_cast(end - i))); + static_cast(end - i))); return Status::OK(); } }; +#endif + +template +using ReplaceSubStringPlain = ReplaceSubString; + +const FunctionDoc replace_substring_doc( + "Replace non-overlapping substrings that match pattern by replacement", + ("For each string in `strings`, replace non-overlapping substrings that match\n" + "`pattern` by `replacement`. If `max_replacements != -1`, it determines the\n" + "maximum amount of replacements made, counting from the left. Null values emit\n" + "null."), + {"strings"}, "ReplaceSubstringOptions"); + +#ifdef ARROW_WITH_RE2 +template +using ReplaceSubStringRegex = ReplaceSubString; const FunctionDoc replace_substring_regex_doc( "Replace non-overlapping substrings that match regex `pattern` by `replacement`", @@ -1423,7 +1423,6 @@ const FunctionDoc replace_substring_regex_doc( "made, counting from the left. Note that if the pattern contains groups,\n" "backreferencing macan be used. Null values emit null."), {"strings"}, "ReplaceSubstringOptions"); - #endif // ---------------------------------------------------------------------- @@ -2100,11 +2099,11 @@ void RegisterScalarStringAscii(FunctionRegistry* registry) { AddBinaryLength(registry); AddUtf8Length(registry); AddMatchSubstring(registry); - MakeUnaryStringBatchKernelWithState("replace_substring", registry, - &replace_substring_doc, - MemAllocation::NO_PREALLOCATE); + MakeUnaryStringBatchKernelWithState( + "replace_substring", registry, &replace_substring_doc, + MemAllocation::NO_PREALLOCATE); #ifdef ARROW_WITH_RE2 - MakeUnaryStringBatchKernelWithState( + MakeUnaryStringBatchKernelWithState( "replace_substring_regex", registry, &replace_substring_regex_doc, MemAllocation::NO_PREALLOCATE); #endif diff --git a/cpp/src/arrow/compute/kernels/scalar_string_test.cc b/cpp/src/arrow/compute/kernels/scalar_string_test.cc index 4262681ba4d..88622e842d1 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_test.cc @@ -430,12 +430,23 @@ TYPED_TEST(TestStringKernels, SplitWhitespaceUTF8Reverse) { &options_max); } -TYPED_TEST(TestStringKernels, ReplaceSubstringNormal) { +TYPED_TEST(TestStringKernels, ReplaceSubstring) { ReplaceSubstringOptions options{"foo", "bazz"}; this->CheckUnary("replace_substring", R"(["foo", "this foo that foo", null])", this->type(), R"(["bazz", "this bazz that bazz", null])", &options); } +TYPED_TEST(TestStringKernels, ReplaceSubstringLimited) { + ReplaceSubstringOptions options{"foo", "bazz", 1}; + this->CheckUnary("replace_substring", R"(["foo", "this foo that foo", null])", + this->type(), R"(["bazz", "this bazz that foo", null])", &options); +} + +TYPED_TEST(TestStringKernels, ReplaceSubstringNoOptions) { + Datum input = ArrayFromJSON(this->type(), "[]"); + ASSERT_RAISES(Invalid, CallFunction("replace_substring", {input})); +} + #ifdef ARROW_WITH_RE2 TYPED_TEST(TestStringKernels, ReplaceSubstringRegex) { ReplaceSubstringOptions options_regex{"(fo+)\\s*", "\\1-bazz"}; @@ -448,7 +459,8 @@ TYPED_TEST(TestStringKernels, ReplaceSubstringRegex) { R"(["abaaaaabaaaa"])", &options_regex2); } -TYPED_TEST(TestStringKernels, ReplaceSubstringMax) { +TYPED_TEST(TestStringKernels, ReplaceSubstringRegexLimited) { + // With a finite number of replacements ReplaceSubstringOptions options1{"foo", "bazz", 1}; this->CheckUnary("replace_substring", R"(["foo", "this foo that foo", null])", this->type(), R"(["bazz", "this bazz that foo", null])", &options1); @@ -457,6 +469,11 @@ TYPED_TEST(TestStringKernels, ReplaceSubstringMax) { this->type(), R"(["foo-bazz", "this foo-bazzthat foo", null])", &options_regex1); } + +TYPED_TEST(TestStringKernels, ReplaceSubstringRegexNoOptions) { + Datum input = ArrayFromJSON(this->type(), "[]"); + ASSERT_RAISES(Invalid, CallFunction("replace_substring_regex", {input})); +} #endif TYPED_TEST(TestStringKernels, Strptime) { From 82bb60d7c097bcb89aed458b3a59145cb9133538 Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Thu, 25 Mar 2021 15:17:19 -0700 Subject: [PATCH 3/3] Turn off re2 in the rtools35 build --- ci/scripts/PKGBUILD | 3 +++ r/configure.win | 6 +++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/ci/scripts/PKGBUILD b/ci/scripts/PKGBUILD index 1d9e41bba7a..c5b55eef42a 100644 --- a/ci/scripts/PKGBUILD +++ b/ci/scripts/PKGBUILD @@ -79,8 +79,10 @@ build() { export CPPFLAGS="${CPPFLAGS} -I${MINGW_PREFIX}/include" export LIBS="-L${MINGW_PREFIX}/libs" export ARROW_S3=OFF + export ARROW_WITH_RE2=OFF else export ARROW_S3=ON + export ARROW_WITH_RE2=ON fi MSYS2_ARG_CONV_EXCL="-DCMAKE_INSTALL_PREFIX=" \ @@ -105,6 +107,7 @@ build() { -DARROW_SNAPPY_USE_SHARED=OFF \ -DARROW_USE_GLOG=OFF \ -DARROW_WITH_LZ4=ON \ + -DARROW_WITH_RE2="${ARROW_WITH_RE2}" \ -DARROW_WITH_SNAPPY=ON \ -DARROW_WITH_ZLIB=ON \ -DARROW_WITH_ZSTD=ON \ diff --git a/r/configure.win b/r/configure.win index 88ac0e125e1..d645834fac8 100644 --- a/r/configure.win +++ b/r/configure.win @@ -50,13 +50,13 @@ AWS_LIBS="-laws-cpp-sdk-config -laws-cpp-sdk-transfer -laws-cpp-sdk-identity-man # NOTE: If you make changes to the libraries below, you should also change # ci/scripts/r_windows_build.sh and ci/scripts/PKGBUILD PKG_CFLAGS="-I${RWINLIB}/include -DARROW_STATIC -DPARQUET_STATIC -DARROW_DS_STATIC -DARROW_R_WITH_ARROW -DARROW_R_WITH_PARQUET -DARROW_R_WITH_DATASET" -PKG_LIBS="-L${RWINLIB}/lib"'$(subst gcc,,$(COMPILED_BY))$(R_ARCH) '"-L${RWINLIB}/lib"'$(R_ARCH) '"-lparquet -larrow_dataset -larrow -larrow_bundled_dependencies -lutf8proc -lre2 -lthrift -lsnappy -lz -lzstd -llz4 ${MIMALLOC_LIBS} ${OPENSSL_LIBS}" +PKG_LIBS="-L${RWINLIB}/lib"'$(subst gcc,,$(COMPILED_BY))$(R_ARCH) '"-L${RWINLIB}/lib"'$(R_ARCH) '"-lparquet -larrow_dataset -larrow -larrow_bundled_dependencies -lutf8proc -lthrift -lsnappy -lz -lzstd -llz4 ${MIMALLOC_LIBS} ${OPENSSL_LIBS}" -# S3 support only for Rtools40 (i.e. R >= 4.0) +# S3 and re2 support only for Rtools40 (i.e. R >= 4.0) "${R_HOME}/bin${R_ARCH_BIN}/Rscript.exe" -e 'R.version$major >= 4' | grep TRUE >/dev/null 2>&1 if [ $? -eq 0 ]; then PKG_CFLAGS="${PKG_CFLAGS} -DARROW_R_WITH_S3" - PKG_LIBS="${PKG_LIBS} ${AWS_LIBS}" + PKG_LIBS="${PKG_LIBS} -lre2 ${AWS_LIBS}" else # It seems that order matters PKG_LIBS="${PKG_LIBS} -lws2_32"