Skip to content

Commit 72f7fbe

Browse files
alvinj15lidavidm
authored andcommitted
ARROW-14795: [C++] Fix issue on replace with mask for null values
ARROW-14795: [C++] Fix issue on vector replace with mask for null values, which weren't updated on null bitmaps Closes #11759 from AlvinJ15/achunga/14795-fix_vector_replace_with_mask Authored-by: alvinj15 <Alvin258461@> Signed-off-by: David Li <li.davidm96@gmail.com>
1 parent 236362a commit 72f7fbe

2 files changed

Lines changed: 54 additions & 11 deletions

File tree

cpp/src/arrow/compute/kernels/vector_replace.cc

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ struct CopyArrayBitmap {
8181

8282
void SetBit(uint8_t* out_bitmap, int64_t out_offset, int64_t offset) const {
8383
BitUtil::SetBitTo(out_bitmap, out_offset,
84-
BitUtil::GetBit(in_bitmap, in_offset + offset));
84+
in_bitmap ? BitUtil::GetBit(in_bitmap, in_offset + offset) : true);
8585
}
8686
};
8787

@@ -122,7 +122,7 @@ void ReplaceWithArrayMaskImpl(const ArrayData& array, const ArrayData& mask,
122122
if (replacements_bitmap) {
123123
copy_bitmap.CopyBitmap(out_bitmap, out_offset + write_offset, replacements_offset,
124124
block.length);
125-
} else if (!replacements_bitmap && out_bitmap) {
125+
} else if (out_bitmap) {
126126
BitUtil::SetBitsTo(out_bitmap, out_offset + write_offset, block.length, true);
127127
}
128128
replacements_offset += block.length;
@@ -133,10 +133,9 @@ void ReplaceWithArrayMaskImpl(const ArrayData& array, const ArrayData& mask,
133133
BitUtil::GetBit(mask_bitmap, write_offset + mask.offset + i))) {
134134
Functor::CopyData(*array.type, out_values, out_offset + write_offset + i,
135135
replacements, replacements_offset, /*length=*/1);
136-
if (replacements_bitmap) {
137-
copy_bitmap.SetBit(out_bitmap, out_offset + write_offset + i,
138-
replacements_offset);
139-
}
136+
copy_bitmap.SetBit(out_bitmap, out_offset + write_offset + i,
137+
138+
replacements_offset);
140139
replacements_offset++;
141140
}
142141
}
@@ -154,9 +153,8 @@ Status ReplaceWithArrayMask(KernelContext* ctx, const ArrayData& array,
154153
uint8_t* out_values = output->buffers[1]->mutable_data();
155154
const uint8_t* mask_bitmap = mask.MayHaveNulls() ? mask.buffers[0]->data() : nullptr;
156155
const uint8_t* mask_values = mask.buffers[1]->data();
157-
const bool replacements_bitmap = replacements.is_array()
158-
? replacements.array()->MayHaveNulls()
159-
: !replacements.scalar()->is_valid;
156+
const bool replacements_bitmap =
157+
replacements.is_array() ? replacements.array()->MayHaveNulls() : true;
160158
if (replacements.is_array()) {
161159
// Check that we have enough replacement values
162160
const int64_t replacements_length = replacements.array()->length;
@@ -189,7 +187,7 @@ Status ReplaceWithArrayMask(KernelContext* ctx, const ArrayData& array,
189187
const ArrayData& array_repl = *replacements.array();
190188
ReplaceWithArrayMaskImpl<Functor>(
191189
array, mask, array_repl, replacements_bitmap,
192-
CopyArrayBitmap{replacements_bitmap ? array_repl.buffers[0]->data() : nullptr,
190+
CopyArrayBitmap{(replacements_bitmap) ? array_repl.buffers[0]->data() : nullptr,
193191
array_repl.offset},
194192
mask_bitmap, mask_values, out_bitmap, out_values, out_offset);
195193
} else {
@@ -254,7 +252,9 @@ struct ReplaceWithMask<Type, enable_if_boolean<Type>> {
254252
}
255253
static void CopyData(const DataType&, uint8_t* out, const int64_t out_offset,
256254
const Scalar& in, const int64_t in_offset, const int64_t length) {
257-
BitUtil::SetBitsTo(out, out_offset, length, in.is_valid);
255+
BitUtil::SetBitsTo(
256+
out, out_offset, length,
257+
in.is_valid ? checked_cast<const BooleanScalar&>(in).value : false);
258258
}
259259

260260
static Status ExecScalarMask(KernelContext* ctx, const ArrayData& array,

cpp/src/arrow/compute/kernels/vector_replace_test.cc

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,38 @@ TYPED_TEST(TestReplaceNumeric, ReplaceWithMask) {
235235
this->array("[0, null, 10]"));
236236
}
237237

238+
TYPED_TEST(TestReplaceNumeric, ReplaceWithMaskForNullValuesAndMaskEnabled) {
239+
this->Assert(ReplaceWithMask, this->array("[1, null, 1]"),
240+
this->mask("[false, true, false]"), this->array("[7]"),
241+
this->array("[1, 7, 1]"));
242+
this->Assert(ReplaceWithMask, this->array("[1, null, 1, 7]"),
243+
this->mask("[false, true, false, true]"), this->array("[7, 20]"),
244+
this->array("[1, 7, 1, 20]"));
245+
this->Assert(ReplaceWithMask, this->array("[1, 2, 3, 4]"),
246+
this->mask("[false, true, false, true]"), this->array("[null, null]"),
247+
this->array("[1, null, 3, null]"));
248+
this->Assert(ReplaceWithMask, this->array("[null, 2, 3, 4]"),
249+
this->mask("[true, true, false, true]"), this->array("[1, null, null]"),
250+
this->array("[1, null, 3, null]"));
251+
this->Assert(ReplaceWithMask, this->array("[1, null, 1]"),
252+
this->mask("[false, true, false]"), this->scalar("null"),
253+
this->array("[1, null, 1]"));
254+
this->Assert(ReplaceWithMask, this->array("[1, null, 1]"),
255+
this->mask("[true, true, true]"), this->array("[7, 7, 7]"),
256+
this->array("[7, 7, 7]"));
257+
this->Assert(ReplaceWithMask, this->array("[1, null, 1]"),
258+
this->mask("[true, true, true]"), this->array("[null, null, null]"),
259+
this->array("[null, null, null]"));
260+
this->Assert(ReplaceWithMask, this->array("[1, null, 1]"),
261+
this->mask("[false, true, false]"), this->scalar("null"),
262+
this->array("[1, null, 1]"));
263+
this->Assert(ReplaceWithMask, this->array("[1, null, 1]"),
264+
this->mask("[true, true, true]"), this->scalar("null"),
265+
this->array("[null, null, null]"));
266+
this->Assert(ReplaceWithMask, this->array("[null, null]"), this->mask("[true, true]"),
267+
this->array("[1, 1]"), this->array("[1, 1]"));
268+
}
269+
238270
TYPED_TEST(TestReplaceNumeric, ReplaceWithMaskRandom) {
239271
using ArrayType = typename TypeTraits<TypeParam>::ArrayType;
240272
using CType = typename TypeTraits<TypeParam>::CType;
@@ -340,16 +372,24 @@ TEST_F(TestReplaceBoolean, ReplaceWithMask) {
340372
this->mask("[false, false, null, null, true, true]"),
341373
this->array("[false, null]"),
342374
this->array("[null, null, null, null, false, null]"));
375+
this->Assert(ReplaceWithMask, this->array("[true, null, true]"),
376+
this->mask("[false, true, false]"), this->array("[true]"),
377+
this->array("[true, true, true]"));
343378

344379
this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"), this->scalar("true"),
345380
this->array("[]"));
381+
this->Assert(ReplaceWithMask, this->array("[null, false, true]"),
382+
this->mask("[true, false, false]"), this->scalar("false"),
383+
this->array("[false, false, true]"));
346384
this->Assert(ReplaceWithMask, this->array("[false, false]"), this->mask("[true, true]"),
347385
this->scalar("true"), this->array("[true, true]"));
348386
this->Assert(ReplaceWithMask, this->array("[false, false]"), this->mask("[true, true]"),
349387
this->scalar("null"), this->array("[null, null]"));
350388
this->Assert(ReplaceWithMask, this->array("[false, false, false]"),
351389
this->mask("[false, null, true]"), this->scalar("true"),
352390
this->array("[false, null, true]"));
391+
this->Assert(ReplaceWithMask, this->array("[null, null]"), this->mask("[true, true]"),
392+
this->array("[true, true]"), this->array("[true, true]"));
353393
}
354394

355395
TEST_F(TestReplaceBoolean, ReplaceWithMaskErrors) {
@@ -427,6 +467,9 @@ TEST_F(TestReplaceFixedSizeBinary, ReplaceWithMask) {
427467
this->mask("[false, false, null, null, true, true]"),
428468
this->array(R"(["aaa", null])"),
429469
this->array(R"([null, null, null, null, "aaa", null])"));
470+
this->Assert(ReplaceWithMask, this->array(R"(["aaa", null, "bbb"])"),
471+
this->mask("[false, true, false]"), this->array(R"(["aba"])"),
472+
this->array(R"(["aaa", "aba", "bbb"])"));
430473

431474
this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"),
432475
this->scalar(R"("zzz")"), this->array("[]"));

0 commit comments

Comments
 (0)