Skip to content

Commit be2a55d

Browse files
committed
Faster version of hash join
1 parent fa78edc commit be2a55d

24 files changed

Lines changed: 6039 additions & 888 deletions

cpp/src/arrow/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,9 +393,11 @@ if(ARROW_COMPUTE)
393393
compute/exec/key_map.cc
394394
compute/exec/options.cc
395395
compute/exec/order_by_impl.cc
396+
compute/exec/partition_util.cc
396397
compute/exec/project_node.cc
397398
compute/exec/sink_node.cc
398399
compute/exec/source_node.cc
400+
compute/exec/swiss_join.cc
399401
compute/exec/task_util.cc
400402
compute/exec/union_node.cc
401403
compute/exec/util.cc
@@ -445,6 +447,7 @@ if(ARROW_COMPUTE)
445447
append_avx2_src(compute/exec/key_encode_avx2.cc)
446448
append_avx2_src(compute/exec/key_hash_avx2.cc)
447449
append_avx2_src(compute/exec/key_map_avx2.cc)
450+
append_avx2_src(compute/exec/swiss_join_avx2.cc)
448451
append_avx2_src(compute/exec/util_avx2.cc)
449452

450453
list(APPEND ARROW_TESTING_SRCS compute/exec/test_util.cc)

cpp/src/arrow/compute/exec/hash_join.cc

Lines changed: 47 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,8 @@ class HashJoinBasicImpl : public HashJoinImpl {
8383
}
8484

8585
Status Init(ExecContext* ctx, JoinType join_type, bool use_sync_execution,
86-
size_t num_threads, HashJoinSchema* schema_mgr,
86+
size_t num_threads, const HashJoinProjectionMaps* proj_map_left,
87+
const HashJoinProjectionMaps* proj_map_right,
8788
std::vector<JoinKeyCmp> key_cmp, Expression filter,
8889
OutputBatchCallback output_batch_callback,
8990
FinishedCallback finished_callback,
@@ -98,7 +99,8 @@ class HashJoinBasicImpl : public HashJoinImpl {
9899
ctx_ = ctx;
99100
join_type_ = join_type;
100101
num_threads_ = num_threads;
101-
schema_mgr_ = schema_mgr;
102+
schema_[0] = proj_map_left;
103+
schema_[1] = proj_map_right;
102104
key_cmp_ = std::move(key_cmp);
103105
filter_ = std::move(filter);
104106
output_batch_callback_ = std::move(output_batch_callback);
@@ -139,12 +141,11 @@ class HashJoinBasicImpl : public HashJoinImpl {
139141
private:
140142
void InitEncoder(int side, HashJoinProjection projection_handle, RowEncoder* encoder) {
141143
std::vector<ValueDescr> data_types;
142-
int num_cols = schema_mgr_->proj_maps[side].num_cols(projection_handle);
144+
int num_cols = schema_[side]->num_cols(projection_handle);
143145
data_types.resize(num_cols);
144146
for (int icol = 0; icol < num_cols; ++icol) {
145-
data_types[icol] =
146-
ValueDescr(schema_mgr_->proj_maps[side].data_type(projection_handle, icol),
147-
ValueDescr::ARRAY);
147+
data_types[icol] = ValueDescr(schema_[side]->data_type(projection_handle, icol),
148+
ValueDescr::ARRAY);
148149
}
149150
encoder->Init(data_types, ctx_);
150151
encoder->Clear();
@@ -155,8 +156,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
155156
ThreadLocalState& local_state = local_states_[thread_index];
156157
if (!local_state.is_initialized) {
157158
InitEncoder(0, HashJoinProjection::KEY, &local_state.exec_batch_keys);
158-
bool has_payload =
159-
(schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::PAYLOAD) > 0);
159+
bool has_payload = (schema_[0]->num_cols(HashJoinProjection::PAYLOAD) > 0);
160160
if (has_payload) {
161161
InitEncoder(0, HashJoinProjection::PAYLOAD, &local_state.exec_batch_payloads);
162162
}
@@ -168,11 +168,10 @@ class HashJoinBasicImpl : public HashJoinImpl {
168168
Status EncodeBatch(int side, HashJoinProjection projection_handle, RowEncoder* encoder,
169169
const ExecBatch& batch, ExecBatch* opt_projected_batch = nullptr) {
170170
ExecBatch projected({}, batch.length);
171-
int num_cols = schema_mgr_->proj_maps[side].num_cols(projection_handle);
171+
int num_cols = schema_[side]->num_cols(projection_handle);
172172
projected.values.resize(num_cols);
173173

174-
auto to_input =
175-
schema_mgr_->proj_maps[side].map(projection_handle, HashJoinProjection::INPUT);
174+
auto to_input = schema_[side]->map(projection_handle, HashJoinProjection::INPUT);
176175
for (int icol = 0; icol < num_cols; ++icol) {
177176
projected.values[icol] = batch.values[to_input.get(icol)];
178177
}
@@ -235,16 +234,13 @@ class HashJoinBasicImpl : public HashJoinImpl {
235234
ExecBatch* opt_left_payload, ExecBatch* opt_right_key,
236235
ExecBatch* opt_right_payload) {
237236
ExecBatch result({}, batch_size_next);
238-
int num_out_cols_left =
239-
schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::OUTPUT);
240-
int num_out_cols_right =
241-
schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::OUTPUT);
237+
int num_out_cols_left = schema_[0]->num_cols(HashJoinProjection::OUTPUT);
238+
int num_out_cols_right = schema_[1]->num_cols(HashJoinProjection::OUTPUT);
242239

243240
result.values.resize(num_out_cols_left + num_out_cols_right);
244-
auto from_key = schema_mgr_->proj_maps[0].map(HashJoinProjection::OUTPUT,
245-
HashJoinProjection::KEY);
246-
auto from_payload = schema_mgr_->proj_maps[0].map(HashJoinProjection::OUTPUT,
247-
HashJoinProjection::PAYLOAD);
241+
auto from_key = schema_[0]->map(HashJoinProjection::OUTPUT, HashJoinProjection::KEY);
242+
auto from_payload =
243+
schema_[0]->map(HashJoinProjection::OUTPUT, HashJoinProjection::PAYLOAD);
248244
for (int icol = 0; icol < num_out_cols_left; ++icol) {
249245
bool is_from_key = (from_key.get(icol) != HashJoinSchema::kMissingField());
250246
bool is_from_payload = (from_payload.get(icol) != HashJoinSchema::kMissingField());
@@ -262,10 +258,9 @@ class HashJoinBasicImpl : public HashJoinImpl {
262258
? opt_left_key->values[from_key.get(icol)]
263259
: opt_left_payload->values[from_payload.get(icol)];
264260
}
265-
from_key = schema_mgr_->proj_maps[1].map(HashJoinProjection::OUTPUT,
266-
HashJoinProjection::KEY);
267-
from_payload = schema_mgr_->proj_maps[1].map(HashJoinProjection::OUTPUT,
268-
HashJoinProjection::PAYLOAD);
261+
from_key = schema_[1]->map(HashJoinProjection::OUTPUT, HashJoinProjection::KEY);
262+
from_payload =
263+
schema_[1]->map(HashJoinProjection::OUTPUT, HashJoinProjection::PAYLOAD);
269264
for (int icol = 0; icol < num_out_cols_right; ++icol) {
270265
bool is_from_key = (from_key.get(icol) != HashJoinSchema::kMissingField());
271266
bool is_from_payload = (from_payload.get(icol) != HashJoinSchema::kMissingField());
@@ -284,7 +279,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
284279
: opt_right_payload->values[from_payload.get(icol)];
285280
}
286281

287-
output_batch_callback_(std::move(result));
282+
output_batch_callback_(0, std::move(result));
288283

289284
// Update the counter of produced batches
290285
//
@@ -310,13 +305,13 @@ class HashJoinBasicImpl : public HashJoinImpl {
310305
hash_table_keys_.Decode(match_right.size(), match_right.data()));
311306

312307
ExecBatch left_payload;
313-
if (!schema_mgr_->LeftPayloadIsEmpty()) {
308+
if (!schema_[0]->is_empty(HashJoinProjection::PAYLOAD)) {
314309
ARROW_ASSIGN_OR_RAISE(left_payload, local_state.exec_batch_payloads.Decode(
315310
match_left.size(), match_left.data()));
316311
}
317312

318313
ExecBatch right_payload;
319-
if (!schema_mgr_->RightPayloadIsEmpty()) {
314+
if (!schema_[1]->is_empty(HashJoinProjection::PAYLOAD)) {
320315
ARROW_ASSIGN_OR_RAISE(right_payload, hash_table_payloads_.Decode(
321316
match_right.size(), match_right.data()));
322317
}
@@ -336,14 +331,14 @@ class HashJoinBasicImpl : public HashJoinImpl {
336331
}
337332
};
338333

339-
SchemaProjectionMap left_to_key = schema_mgr_->proj_maps[0].map(
340-
HashJoinProjection::FILTER, HashJoinProjection::KEY);
341-
SchemaProjectionMap left_to_pay = schema_mgr_->proj_maps[0].map(
342-
HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD);
343-
SchemaProjectionMap right_to_key = schema_mgr_->proj_maps[1].map(
344-
HashJoinProjection::FILTER, HashJoinProjection::KEY);
345-
SchemaProjectionMap right_to_pay = schema_mgr_->proj_maps[1].map(
346-
HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD);
334+
SchemaProjectionMap left_to_key =
335+
schema_[0]->map(HashJoinProjection::FILTER, HashJoinProjection::KEY);
336+
SchemaProjectionMap left_to_pay =
337+
schema_[0]->map(HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD);
338+
SchemaProjectionMap right_to_key =
339+
schema_[1]->map(HashJoinProjection::FILTER, HashJoinProjection::KEY);
340+
SchemaProjectionMap right_to_pay =
341+
schema_[1]->map(HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD);
347342

348343
AppendFields(left_to_key, left_to_pay, left_key, left_payload);
349344
AppendFields(right_to_key, right_to_pay, right_key, right_payload);
@@ -419,15 +414,14 @@ class HashJoinBasicImpl : public HashJoinImpl {
419414

420415
bool has_left =
421416
(join_type_ != JoinType::RIGHT_SEMI && join_type_ != JoinType::RIGHT_ANTI &&
422-
schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::OUTPUT) > 0);
417+
schema_[0]->num_cols(HashJoinProjection::OUTPUT) > 0);
423418
bool has_right =
424419
(join_type_ != JoinType::LEFT_SEMI && join_type_ != JoinType::LEFT_ANTI &&
425-
schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::OUTPUT) > 0);
420+
schema_[1]->num_cols(HashJoinProjection::OUTPUT) > 0);
426421
bool has_left_payload =
427-
has_left && (schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::PAYLOAD) > 0);
422+
has_left && (schema_[0]->num_cols(HashJoinProjection::PAYLOAD) > 0);
428423
bool has_right_payload =
429-
has_right &&
430-
(schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::PAYLOAD) > 0);
424+
has_right && (schema_[1]->num_cols(HashJoinProjection::PAYLOAD) > 0);
431425

432426
ThreadLocalState& local_state = local_states_[thread_index];
433427
InitLocalStateIfNeeded(thread_index);
@@ -450,7 +444,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
450444
ARROW_ASSIGN_OR_RAISE(right_key,
451445
hash_table_keys_.Decode(batch_size_next, opt_right_ids));
452446
// Post process build side keys that use dictionary
453-
RETURN_NOT_OK(dict_build_.PostDecode(schema_mgr_->proj_maps[1], &right_key, ctx_));
447+
RETURN_NOT_OK(dict_build_.PostDecode(*schema_[1], &right_key, ctx_));
454448
}
455449
if (has_right_payload) {
456450
ARROW_ASSIGN_OR_RAISE(right_payload,
@@ -550,8 +544,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
550544

551545
RETURN_NOT_OK(EncodeBatch(0, HashJoinProjection::KEY, &local_state.exec_batch_keys,
552546
batch, &batch_key_for_lookups));
553-
bool has_left_payload =
554-
(schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::PAYLOAD) > 0);
547+
bool has_left_payload = (schema_[0]->num_cols(HashJoinProjection::PAYLOAD) > 0);
555548
if (has_left_payload) {
556549
local_state.exec_batch_payloads.Clear();
557550
RETURN_NOT_OK(EncodeBatch(0, HashJoinProjection::PAYLOAD,
@@ -563,13 +556,13 @@ class HashJoinBasicImpl : public HashJoinImpl {
563556
local_state.match_left.clear();
564557
local_state.match_right.clear();
565558

566-
bool use_key_batch_for_dicts = dict_probe_.BatchRemapNeeded(
567-
thread_index, schema_mgr_->proj_maps[0], schema_mgr_->proj_maps[1], ctx_);
559+
bool use_key_batch_for_dicts =
560+
dict_probe_.BatchRemapNeeded(thread_index, *schema_[0], *schema_[1], ctx_);
568561
RowEncoder* row_encoder_for_lookups = &local_state.exec_batch_keys;
569562
if (use_key_batch_for_dicts) {
570-
RETURN_NOT_OK(dict_probe_.EncodeBatch(
571-
thread_index, schema_mgr_->proj_maps[0], schema_mgr_->proj_maps[1], dict_build_,
572-
batch, &row_encoder_for_lookups, &batch_key_for_lookups, ctx_));
563+
RETURN_NOT_OK(dict_probe_.EncodeBatch(thread_index, *schema_[0], *schema_[1],
564+
dict_build_, batch, &row_encoder_for_lookups,
565+
&batch_key_for_lookups, ctx_));
573566
}
574567

575568
// Collect information about all nulls in key columns.
@@ -609,9 +602,8 @@ class HashJoinBasicImpl : public HashJoinImpl {
609602
if (batches.empty()) {
610603
hash_table_empty_ = true;
611604
} else {
612-
dict_build_.InitEncoder(schema_mgr_->proj_maps[1], &hash_table_keys_, ctx_);
613-
bool has_payload =
614-
(schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::PAYLOAD) > 0);
605+
dict_build_.InitEncoder(*schema_[1], &hash_table_keys_, ctx_);
606+
bool has_payload = (schema_[1]->num_cols(HashJoinProjection::PAYLOAD) > 0);
615607
if (has_payload) {
616608
InitEncoder(1, HashJoinProjection::PAYLOAD, &hash_table_payloads_);
617609
}
@@ -626,11 +618,11 @@ class HashJoinBasicImpl : public HashJoinImpl {
626618
} else if (hash_table_empty_) {
627619
hash_table_empty_ = false;
628620

629-
RETURN_NOT_OK(dict_build_.Init(schema_mgr_->proj_maps[1], &batch, ctx_));
621+
RETURN_NOT_OK(dict_build_.Init(*schema_[1], &batch, ctx_));
630622
}
631623
int32_t num_rows_before = hash_table_keys_.num_rows();
632-
RETURN_NOT_OK(dict_build_.EncodeBatch(thread_index, schema_mgr_->proj_maps[1],
633-
batch, &hash_table_keys_, ctx_));
624+
RETURN_NOT_OK(dict_build_.EncodeBatch(thread_index, *schema_[1], batch,
625+
&hash_table_keys_, ctx_));
634626
if (has_payload) {
635627
RETURN_NOT_OK(
636628
EncodeBatch(1, HashJoinProjection::PAYLOAD, &hash_table_payloads_, batch));
@@ -643,7 +635,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
643635
}
644636

645637
if (hash_table_empty_) {
646-
RETURN_NOT_OK(dict_build_.Init(schema_mgr_->proj_maps[1], nullptr, ctx_));
638+
RETURN_NOT_OK(dict_build_.Init(*schema_[1], nullptr, ctx_));
647639
}
648640

649641
return Status::OK();
@@ -869,7 +861,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
869861
ExecContext* ctx_;
870862
JoinType join_type_;
871863
size_t num_threads_;
872-
HashJoinSchema* schema_mgr_;
864+
const HashJoinProjectionMaps* schema_[2];
873865
std::vector<JoinKeyCmp> key_cmp_;
874866
Expression filter_;
875867
std::unique_ptr<TaskScheduler> scheduler_;

cpp/src/arrow/compute/exec/hash_join.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ class ARROW_EXPORT HashJoinSchema {
5757
const std::string& left_field_name_prefix,
5858
const std::string& right_field_name_prefix);
5959

60+
bool HasDictionaries() const;
61+
62+
bool HasLargeBinary() const;
63+
6064
Result<Expression> BindFilter(Expression filter, const Schema& left_schema,
6165
const Schema& right_schema);
6266
std::shared_ptr<Schema> MakeOutputSchema(const std::string& left_field_name_suffix,
@@ -98,12 +102,13 @@ class ARROW_EXPORT HashJoinSchema {
98102

99103
class HashJoinImpl {
100104
public:
101-
using OutputBatchCallback = std::function<void(ExecBatch)>;
105+
using OutputBatchCallback = std::function<void(int64_t, ExecBatch)>;
102106
using FinishedCallback = std::function<void(int64_t)>;
103107

104108
virtual ~HashJoinImpl() = default;
105109
virtual Status Init(ExecContext* ctx, JoinType join_type, bool use_sync_execution,
106-
size_t num_threads, HashJoinSchema* schema_mgr,
110+
size_t num_threads, const HashJoinProjectionMaps* proj_map_left,
111+
const HashJoinProjectionMaps* proj_map_right,
107112
std::vector<JoinKeyCmp> key_cmp, Expression filter,
108113
OutputBatchCallback output_batch_callback,
109114
FinishedCallback finished_callback,
@@ -113,6 +118,7 @@ class HashJoinImpl {
113118
virtual void Abort(TaskScheduler::AbortContinuationImpl pos_abort_callback) = 0;
114119

115120
static Result<std::unique_ptr<HashJoinImpl>> MakeBasic();
121+
static Result<std::unique_ptr<HashJoinImpl>> MakeSwiss();
116122

117123
protected:
118124
util::tracing::Span span_;

cpp/src/arrow/compute/exec/hash_join_benchmark.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,9 @@ class JoinBenchmark {
135135

136136
DCHECK_OK(join_->Init(
137137
ctx_.get(), settings.join_type, !is_parallel, settings.num_threads,
138-
schema_mgr_.get(), {JoinKeyCmp::EQ}, std::move(filter), [](ExecBatch) {},
139-
[](int64_t x) {}, schedule_callback));
138+
&(schema_mgr_->proj_maps[0]), &(schema_mgr_->proj_maps[1]), {JoinKeyCmp::EQ},
139+
std::move(filter), [](int64_t, ExecBatch) {}, [](int64_t x) {},
140+
schedule_callback));
140141
}
141142

142143
void RunJoin() {

cpp/src/arrow/compute/exec/hash_join_node.cc

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,34 @@ Status HashJoinSchema::CollectFilterColumns(std::vector<FieldRef>& left_filter,
453453
return Status::OK();
454454
}
455455

456+
bool HashJoinSchema::HasDictionaries() const {
457+
for (int side = 0; side <= 1; ++side) {
458+
for (int icol = 0; icol < proj_maps[side].num_cols(HashJoinProjection::INPUT);
459+
++icol) {
460+
const std::shared_ptr<DataType>& column_type =
461+
proj_maps[side].data_type(HashJoinProjection::INPUT, icol);
462+
if (column_type->id() == Type::DICTIONARY) {
463+
return true;
464+
}
465+
}
466+
}
467+
return false;
468+
}
469+
470+
bool HashJoinSchema::HasLargeBinary() const {
471+
for (int side = 0; side <= 1; ++side) {
472+
for (int icol = 0; icol < proj_maps[side].num_cols(HashJoinProjection::INPUT);
473+
++icol) {
474+
const std::shared_ptr<DataType>& column_type =
475+
proj_maps[side].data_type(HashJoinProjection::INPUT, icol);
476+
if (is_large_binary_like(column_type->id())) {
477+
return true;
478+
}
479+
}
480+
}
481+
return false;
482+
}
483+
456484
class HashJoinNode : public ExecNode {
457485
public:
458486
HashJoinNode(ExecPlan* plan, NodeVector inputs, const HashJoinNodeOptions& join_options,
@@ -504,8 +532,26 @@ class HashJoinNode : public ExecNode {
504532
// Generate output schema
505533
std::shared_ptr<Schema> output_schema = schema_mgr->MakeOutputSchema(
506534
join_options.output_suffix_for_left, join_options.output_suffix_for_right);
535+
507536
// Create hash join implementation object
508-
ARROW_ASSIGN_OR_RAISE(std::unique_ptr<HashJoinImpl> impl, HashJoinImpl::MakeBasic());
537+
// SwissJoin does not support:
538+
// a) 64-bit string offsets
539+
// b) residual predicates
540+
// c) dictionaries
541+
//
542+
bool use_swiss_join;
543+
#if ARROW_LITTLE_ENDIAN
544+
use_swiss_join = (filter == literal(true)) && !schema_mgr->HasDictionaries() &&
545+
!schema_mgr->HasLargeBinary();
546+
#else
547+
use_swiss_join = false;
548+
#endif
549+
std::unique_ptr<HashJoinImpl> impl;
550+
if (use_swiss_join) {
551+
ARROW_ASSIGN_OR_RAISE(impl, HashJoinImpl::MakeSwiss());
552+
} else {
553+
ARROW_ASSIGN_OR_RAISE(impl, HashJoinImpl::MakeBasic());
554+
}
509555

510556
return plan->EmplaceNode<HashJoinNode>(
511557
plan, inputs, join_options, std::move(output_schema), std::move(schema_mgr),
@@ -584,8 +630,10 @@ class HashJoinNode : public ExecNode {
584630

585631
RETURN_NOT_OK(impl_->Init(
586632
plan_->exec_context(), join_type_, use_sync_execution, num_threads,
587-
schema_mgr_.get(), key_cmp_, filter_,
588-
[this](ExecBatch batch) { this->OutputBatchCallback(batch); },
633+
&(schema_mgr_->proj_maps[0]), &(schema_mgr_->proj_maps[1]), key_cmp_, filter_,
634+
[this](int64_t /*ignored*/, ExecBatch batch) {
635+
this->OutputBatchCallback(batch);
636+
},
589637
[this](int64_t total_num_batches) { this->FinishedCallback(total_num_batches); },
590638
[this](std::function<Status(size_t)> func) -> Status {
591639
return this->ScheduleTaskCallback(std::move(func));

0 commit comments

Comments
 (0)