Skip to content

Commit faeab94

Browse files
zanmato1984raulcd
authored andcommitted
GH-39778: [C++] Fix tail-byte access cross buffer boundary in key hash avx2 (#39800)
### Rationale for this change Issue #39778 seems caused by a careless (but hard to spot) bug in key hash avx2. ### What changes are included in this PR? Fix the careless bug. ### Are these changes tested? UT included. ### Are there any user-facing changes? No. * Closes: #39778 Authored-by: Ruoxi Sun <zanmato1984@gmail.com> Signed-off-by: Antoine Pitrou <antoine@python.org>
1 parent 5e8e20d commit faeab94

4 files changed

Lines changed: 145 additions & 80 deletions

File tree

cpp/src/arrow/compute/key_hash.cc

Lines changed: 74 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -105,23 +105,23 @@ inline void Hashing32::StripeMask(int i, uint32_t* mask1, uint32_t* mask2,
105105
}
106106

107107
template <bool T_COMBINE_HASHES>
108-
void Hashing32::HashFixedLenImp(uint32_t num_rows, uint64_t length, const uint8_t* keys,
109-
uint32_t* hashes) {
108+
void Hashing32::HashFixedLenImp(uint32_t num_rows, uint64_t key_length,
109+
const uint8_t* keys, uint32_t* hashes) {
110110
// Calculate the number of rows that skip the last 16 bytes
111111
//
112112
uint32_t num_rows_safe = num_rows;
113-
while (num_rows_safe > 0 && (num_rows - num_rows_safe) * length < kStripeSize) {
113+
while (num_rows_safe > 0 && (num_rows - num_rows_safe) * key_length < kStripeSize) {
114114
--num_rows_safe;
115115
}
116116

117117
// Compute masks for the last 16 byte stripe
118118
//
119-
uint64_t num_stripes = bit_util::CeilDiv(length, kStripeSize);
119+
uint64_t num_stripes = bit_util::CeilDiv(key_length, kStripeSize);
120120
uint32_t mask1, mask2, mask3, mask4;
121-
StripeMask(((length - 1) & (kStripeSize - 1)) + 1, &mask1, &mask2, &mask3, &mask4);
121+
StripeMask(((key_length - 1) & (kStripeSize - 1)) + 1, &mask1, &mask2, &mask3, &mask4);
122122

123123
for (uint32_t i = 0; i < num_rows_safe; ++i) {
124-
const uint8_t* key = keys + static_cast<uint64_t>(i) * length;
124+
const uint8_t* key = keys + static_cast<uint64_t>(i) * key_length;
125125
uint32_t acc1, acc2, acc3, acc4;
126126
ProcessFullStripes(num_stripes, key, &acc1, &acc2, &acc3, &acc4);
127127
ProcessLastStripe(mask1, mask2, mask3, mask4, key + (num_stripes - 1) * kStripeSize,
@@ -138,11 +138,11 @@ void Hashing32::HashFixedLenImp(uint32_t num_rows, uint64_t length, const uint8_
138138

139139
uint32_t last_stripe_copy[4];
140140
for (uint32_t i = num_rows_safe; i < num_rows; ++i) {
141-
const uint8_t* key = keys + static_cast<uint64_t>(i) * length;
141+
const uint8_t* key = keys + static_cast<uint64_t>(i) * key_length;
142142
uint32_t acc1, acc2, acc3, acc4;
143143
ProcessFullStripes(num_stripes, key, &acc1, &acc2, &acc3, &acc4);
144144
memcpy(last_stripe_copy, key + (num_stripes - 1) * kStripeSize,
145-
length - (num_stripes - 1) * kStripeSize);
145+
key_length - (num_stripes - 1) * kStripeSize);
146146
ProcessLastStripe(mask1, mask2, mask3, mask4,
147147
reinterpret_cast<const uint8_t*>(last_stripe_copy), &acc1, &acc2,
148148
&acc3, &acc4);
@@ -168,15 +168,16 @@ void Hashing32::HashVarLenImp(uint32_t num_rows, const T* offsets,
168168
}
169169

170170
for (uint32_t i = 0; i < num_rows_safe; ++i) {
171-
uint64_t length = offsets[i + 1] - offsets[i];
171+
uint64_t key_length = offsets[i + 1] - offsets[i];
172172

173173
// Compute masks for the last 16 byte stripe.
174174
// For an empty string set number of stripes to 1 but mask to all zeroes.
175175
//
176-
int is_non_empty = length == 0 ? 0 : 1;
177-
uint64_t num_stripes = bit_util::CeilDiv(length, kStripeSize) + (1 - is_non_empty);
176+
int is_non_empty = key_length == 0 ? 0 : 1;
177+
uint64_t num_stripes =
178+
bit_util::CeilDiv(key_length, kStripeSize) + (1 - is_non_empty);
178179
uint32_t mask1, mask2, mask3, mask4;
179-
StripeMask(((length - is_non_empty) & (kStripeSize - 1)) + is_non_empty, &mask1,
180+
StripeMask(((key_length - is_non_empty) & (kStripeSize - 1)) + is_non_empty, &mask1,
180181
&mask2, &mask3, &mask4);
181182

182183
const uint8_t* key = concatenated_keys + offsets[i];
@@ -198,23 +199,24 @@ void Hashing32::HashVarLenImp(uint32_t num_rows, const T* offsets,
198199

199200
uint32_t last_stripe_copy[4];
200201
for (uint32_t i = num_rows_safe; i < num_rows; ++i) {
201-
uint64_t length = offsets[i + 1] - offsets[i];
202+
uint64_t key_length = offsets[i + 1] - offsets[i];
202203

203204
// Compute masks for the last 16 byte stripe.
204205
// For an empty string set number of stripes to 1 but mask to all zeroes.
205206
//
206-
int is_non_empty = length == 0 ? 0 : 1;
207-
uint64_t num_stripes = bit_util::CeilDiv(length, kStripeSize) + (1 - is_non_empty);
207+
int is_non_empty = key_length == 0 ? 0 : 1;
208+
uint64_t num_stripes =
209+
bit_util::CeilDiv(key_length, kStripeSize) + (1 - is_non_empty);
208210
uint32_t mask1, mask2, mask3, mask4;
209-
StripeMask(((length - is_non_empty) & (kStripeSize - 1)) + is_non_empty, &mask1,
211+
StripeMask(((key_length - is_non_empty) & (kStripeSize - 1)) + is_non_empty, &mask1,
210212
&mask2, &mask3, &mask4);
211213

212214
const uint8_t* key = concatenated_keys + offsets[i];
213215
uint32_t acc1, acc2, acc3, acc4;
214216
ProcessFullStripes(num_stripes, key, &acc1, &acc2, &acc3, &acc4);
215-
if (length > 0) {
217+
if (key_length > 0) {
216218
memcpy(last_stripe_copy, key + (num_stripes - 1) * kStripeSize,
217-
length - (num_stripes - 1) * kStripeSize);
219+
key_length - (num_stripes - 1) * kStripeSize);
218220
}
219221
if (num_stripes > 0) {
220222
ProcessLastStripe(mask1, mask2, mask3, mask4,
@@ -309,9 +311,9 @@ void Hashing32::HashIntImp(uint32_t num_keys, const T* keys, uint32_t* hashes) {
309311
}
310312
}
311313

312-
void Hashing32::HashInt(bool combine_hashes, uint32_t num_keys, uint64_t length_key,
314+
void Hashing32::HashInt(bool combine_hashes, uint32_t num_keys, uint64_t key_length,
313315
const uint8_t* keys, uint32_t* hashes) {
314-
switch (length_key) {
316+
switch (key_length) {
315317
case sizeof(uint8_t):
316318
if (combine_hashes) {
317319
HashIntImp<true, uint8_t>(num_keys, keys, hashes);
@@ -352,27 +354,27 @@ void Hashing32::HashInt(bool combine_hashes, uint32_t num_keys, uint64_t length_
352354
}
353355
}
354356

355-
void Hashing32::HashFixed(int64_t hardware_flags, bool combine_hashes, uint32_t num_rows,
356-
uint64_t length, const uint8_t* keys, uint32_t* hashes,
357-
uint32_t* hashes_temp_for_combine) {
358-
if (ARROW_POPCOUNT64(length) == 1 && length <= sizeof(uint64_t)) {
359-
HashInt(combine_hashes, num_rows, length, keys, hashes);
357+
void Hashing32::HashFixed(int64_t hardware_flags, bool combine_hashes, uint32_t num_keys,
358+
uint64_t key_length, const uint8_t* keys, uint32_t* hashes,
359+
uint32_t* temp_hashes_for_combine) {
360+
if (ARROW_POPCOUNT64(key_length) == 1 && key_length <= sizeof(uint64_t)) {
361+
HashInt(combine_hashes, num_keys, key_length, keys, hashes);
360362
return;
361363
}
362364

363365
uint32_t num_processed = 0;
364366
#if defined(ARROW_HAVE_RUNTIME_AVX2)
365367
if (hardware_flags & arrow::internal::CpuInfo::AVX2) {
366-
num_processed = HashFixedLen_avx2(combine_hashes, num_rows, length, keys, hashes,
367-
hashes_temp_for_combine);
368+
num_processed = HashFixedLen_avx2(combine_hashes, num_keys, key_length, keys, hashes,
369+
temp_hashes_for_combine);
368370
}
369371
#endif
370372
if (combine_hashes) {
371-
HashFixedLenImp<true>(num_rows - num_processed, length, keys + length * num_processed,
372-
hashes + num_processed);
373+
HashFixedLenImp<true>(num_keys - num_processed, key_length,
374+
keys + key_length * num_processed, hashes + num_processed);
373375
} else {
374-
HashFixedLenImp<false>(num_rows - num_processed, length,
375-
keys + length * num_processed, hashes + num_processed);
376+
HashFixedLenImp<false>(num_keys - num_processed, key_length,
377+
keys + key_length * num_processed, hashes + num_processed);
376378
}
377379
}
378380

@@ -423,13 +425,13 @@ void Hashing32::HashMultiColumn(const std::vector<KeyColumnArray>& cols,
423425
}
424426

425427
if (cols[icol].metadata().is_fixed_length) {
426-
uint32_t col_width = cols[icol].metadata().fixed_length;
427-
if (col_width == 0) {
428+
uint32_t key_length = cols[icol].metadata().fixed_length;
429+
if (key_length == 0) {
428430
HashBit(icol > 0, cols[icol].bit_offset(1), batch_size_next,
429431
cols[icol].data(1) + first_row / 8, hashes + first_row);
430432
} else {
431-
HashFixed(ctx->hardware_flags, icol > 0, batch_size_next, col_width,
432-
cols[icol].data(1) + first_row * col_width, hashes + first_row,
433+
HashFixed(ctx->hardware_flags, icol > 0, batch_size_next, key_length,
434+
cols[icol].data(1) + first_row * key_length, hashes + first_row,
433435
hash_temp);
434436
}
435437
} else if (cols[icol].metadata().fixed_length == sizeof(uint32_t)) {
@@ -463,8 +465,9 @@ void Hashing32::HashMultiColumn(const std::vector<KeyColumnArray>& cols,
463465
Status Hashing32::HashBatch(const ExecBatch& key_batch, uint32_t* hashes,
464466
std::vector<KeyColumnArray>& column_arrays,
465467
int64_t hardware_flags, util::TempVectorStack* temp_stack,
466-
int64_t offset, int64_t length) {
467-
RETURN_NOT_OK(ColumnArraysFromExecBatch(key_batch, offset, length, &column_arrays));
468+
int64_t start_rows, int64_t num_rows) {
469+
RETURN_NOT_OK(
470+
ColumnArraysFromExecBatch(key_batch, start_rows, num_rows, &column_arrays));
468471

469472
LightContext ctx;
470473
ctx.hardware_flags = hardware_flags;
@@ -574,23 +577,23 @@ inline void Hashing64::StripeMask(int i, uint64_t* mask1, uint64_t* mask2,
574577
}
575578

576579
template <bool T_COMBINE_HASHES>
577-
void Hashing64::HashFixedLenImp(uint32_t num_rows, uint64_t length, const uint8_t* keys,
578-
uint64_t* hashes) {
580+
void Hashing64::HashFixedLenImp(uint32_t num_rows, uint64_t key_length,
581+
const uint8_t* keys, uint64_t* hashes) {
579582
// Calculate the number of rows that skip the last 32 bytes
580583
//
581584
uint32_t num_rows_safe = num_rows;
582-
while (num_rows_safe > 0 && (num_rows - num_rows_safe) * length < kStripeSize) {
585+
while (num_rows_safe > 0 && (num_rows - num_rows_safe) * key_length < kStripeSize) {
583586
--num_rows_safe;
584587
}
585588

586589
// Compute masks for the last 32 byte stripe
587590
//
588-
uint64_t num_stripes = bit_util::CeilDiv(length, kStripeSize);
591+
uint64_t num_stripes = bit_util::CeilDiv(key_length, kStripeSize);
589592
uint64_t mask1, mask2, mask3, mask4;
590-
StripeMask(((length - 1) & (kStripeSize - 1)) + 1, &mask1, &mask2, &mask3, &mask4);
593+
StripeMask(((key_length - 1) & (kStripeSize - 1)) + 1, &mask1, &mask2, &mask3, &mask4);
591594

592595
for (uint32_t i = 0; i < num_rows_safe; ++i) {
593-
const uint8_t* key = keys + static_cast<uint64_t>(i) * length;
596+
const uint8_t* key = keys + static_cast<uint64_t>(i) * key_length;
594597
uint64_t acc1, acc2, acc3, acc4;
595598
ProcessFullStripes(num_stripes, key, &acc1, &acc2, &acc3, &acc4);
596599
ProcessLastStripe(mask1, mask2, mask3, mask4, key + (num_stripes - 1) * kStripeSize,
@@ -607,11 +610,11 @@ void Hashing64::HashFixedLenImp(uint32_t num_rows, uint64_t length, const uint8_
607610

608611
uint64_t last_stripe_copy[4];
609612
for (uint32_t i = num_rows_safe; i < num_rows; ++i) {
610-
const uint8_t* key = keys + static_cast<uint64_t>(i) * length;
613+
const uint8_t* key = keys + static_cast<uint64_t>(i) * key_length;
611614
uint64_t acc1, acc2, acc3, acc4;
612615
ProcessFullStripes(num_stripes, key, &acc1, &acc2, &acc3, &acc4);
613616
memcpy(last_stripe_copy, key + (num_stripes - 1) * kStripeSize,
614-
length - (num_stripes - 1) * kStripeSize);
617+
key_length - (num_stripes - 1) * kStripeSize);
615618
ProcessLastStripe(mask1, mask2, mask3, mask4,
616619
reinterpret_cast<const uint8_t*>(last_stripe_copy), &acc1, &acc2,
617620
&acc3, &acc4);
@@ -637,15 +640,16 @@ void Hashing64::HashVarLenImp(uint32_t num_rows, const T* offsets,
637640
}
638641

639642
for (uint32_t i = 0; i < num_rows_safe; ++i) {
640-
uint64_t length = offsets[i + 1] - offsets[i];
643+
uint64_t key_length = offsets[i + 1] - offsets[i];
641644

642645
// Compute masks for the last 32 byte stripe.
643646
// For an empty string set number of stripes to 1 but mask to all zeroes.
644647
//
645-
int is_non_empty = length == 0 ? 0 : 1;
646-
uint64_t num_stripes = bit_util::CeilDiv(length, kStripeSize) + (1 - is_non_empty);
648+
int is_non_empty = key_length == 0 ? 0 : 1;
649+
uint64_t num_stripes =
650+
bit_util::CeilDiv(key_length, kStripeSize) + (1 - is_non_empty);
647651
uint64_t mask1, mask2, mask3, mask4;
648-
StripeMask(((length - is_non_empty) & (kStripeSize - 1)) + is_non_empty, &mask1,
652+
StripeMask(((key_length - is_non_empty) & (kStripeSize - 1)) + is_non_empty, &mask1,
649653
&mask2, &mask3, &mask4);
650654

651655
const uint8_t* key = concatenated_keys + offsets[i];
@@ -667,22 +671,23 @@ void Hashing64::HashVarLenImp(uint32_t num_rows, const T* offsets,
667671

668672
uint64_t last_stripe_copy[4];
669673
for (uint32_t i = num_rows_safe; i < num_rows; ++i) {
670-
uint64_t length = offsets[i + 1] - offsets[i];
674+
uint64_t key_length = offsets[i + 1] - offsets[i];
671675

672676
// Compute masks for the last 32 byte stripe
673677
//
674-
int is_non_empty = length == 0 ? 0 : 1;
675-
uint64_t num_stripes = bit_util::CeilDiv(length, kStripeSize) + (1 - is_non_empty);
678+
int is_non_empty = key_length == 0 ? 0 : 1;
679+
uint64_t num_stripes =
680+
bit_util::CeilDiv(key_length, kStripeSize) + (1 - is_non_empty);
676681
uint64_t mask1, mask2, mask3, mask4;
677-
StripeMask(((length - is_non_empty) & (kStripeSize - 1)) + is_non_empty, &mask1,
682+
StripeMask(((key_length - is_non_empty) & (kStripeSize - 1)) + is_non_empty, &mask1,
678683
&mask2, &mask3, &mask4);
679684

680685
const uint8_t* key = concatenated_keys + offsets[i];
681686
uint64_t acc1, acc2, acc3, acc4;
682687
ProcessFullStripes(num_stripes, key, &acc1, &acc2, &acc3, &acc4);
683-
if (length > 0) {
688+
if (key_length > 0) {
684689
memcpy(last_stripe_copy, key + (num_stripes - 1) * kStripeSize,
685-
length - (num_stripes - 1) * kStripeSize);
690+
key_length - (num_stripes - 1) * kStripeSize);
686691
}
687692
if (num_stripes > 0) {
688693
ProcessLastStripe(mask1, mask2, mask3, mask4,
@@ -759,9 +764,9 @@ void Hashing64::HashIntImp(uint32_t num_keys, const T* keys, uint64_t* hashes) {
759764
}
760765
}
761766

762-
void Hashing64::HashInt(bool combine_hashes, uint32_t num_keys, uint64_t length_key,
767+
void Hashing64::HashInt(bool combine_hashes, uint32_t num_keys, uint64_t key_length,
763768
const uint8_t* keys, uint64_t* hashes) {
764-
switch (length_key) {
769+
switch (key_length) {
765770
case sizeof(uint8_t):
766771
if (combine_hashes) {
767772
HashIntImp<true, uint8_t>(num_keys, keys, hashes);
@@ -802,17 +807,17 @@ void Hashing64::HashInt(bool combine_hashes, uint32_t num_keys, uint64_t length_
802807
}
803808
}
804809

805-
void Hashing64::HashFixed(bool combine_hashes, uint32_t num_rows, uint64_t length,
810+
void Hashing64::HashFixed(bool combine_hashes, uint32_t num_keys, uint64_t key_length,
806811
const uint8_t* keys, uint64_t* hashes) {
807-
if (ARROW_POPCOUNT64(length) == 1 && length <= sizeof(uint64_t)) {
808-
HashInt(combine_hashes, num_rows, length, keys, hashes);
812+
if (ARROW_POPCOUNT64(key_length) == 1 && key_length <= sizeof(uint64_t)) {
813+
HashInt(combine_hashes, num_keys, key_length, keys, hashes);
809814
return;
810815
}
811816

812817
if (combine_hashes) {
813-
HashFixedLenImp<true>(num_rows, length, keys, hashes);
818+
HashFixedLenImp<true>(num_keys, key_length, keys, hashes);
814819
} else {
815-
HashFixedLenImp<false>(num_rows, length, keys, hashes);
820+
HashFixedLenImp<false>(num_keys, key_length, keys, hashes);
816821
}
817822
}
818823

@@ -860,13 +865,13 @@ void Hashing64::HashMultiColumn(const std::vector<KeyColumnArray>& cols,
860865
}
861866

862867
if (cols[icol].metadata().is_fixed_length) {
863-
uint64_t col_width = cols[icol].metadata().fixed_length;
864-
if (col_width == 0) {
868+
uint64_t key_length = cols[icol].metadata().fixed_length;
869+
if (key_length == 0) {
865870
HashBit(icol > 0, cols[icol].bit_offset(1), batch_size_next,
866871
cols[icol].data(1) + first_row / 8, hashes + first_row);
867872
} else {
868-
HashFixed(icol > 0, batch_size_next, col_width,
869-
cols[icol].data(1) + first_row * col_width, hashes + first_row);
873+
HashFixed(icol > 0, batch_size_next, key_length,
874+
cols[icol].data(1) + first_row * key_length, hashes + first_row);
870875
}
871876
} else if (cols[icol].metadata().fixed_length == sizeof(uint32_t)) {
872877
HashVarLen(icol > 0, batch_size_next, cols[icol].offsets() + first_row,
@@ -897,8 +902,9 @@ void Hashing64::HashMultiColumn(const std::vector<KeyColumnArray>& cols,
897902
Status Hashing64::HashBatch(const ExecBatch& key_batch, uint64_t* hashes,
898903
std::vector<KeyColumnArray>& column_arrays,
899904
int64_t hardware_flags, util::TempVectorStack* temp_stack,
900-
int64_t offset, int64_t length) {
901-
RETURN_NOT_OK(ColumnArraysFromExecBatch(key_batch, offset, length, &column_arrays));
905+
int64_t start_row, int64_t num_rows) {
906+
RETURN_NOT_OK(
907+
ColumnArraysFromExecBatch(key_batch, start_row, num_rows, &column_arrays));
902908

903909
LightContext ctx;
904910
ctx.hardware_flags = hardware_flags;

0 commit comments

Comments
 (0)