@@ -83,7 +83,8 @@ class HashJoinBasicImpl : public HashJoinImpl {
8383 }
8484
8585 Status Init (ExecContext* ctx, JoinType join_type, bool use_sync_execution,
86- size_t num_threads, HashJoinSchema* schema_mgr,
86+ size_t num_threads, const HashJoinProjectionMaps* proj_map_left,
87+ const HashJoinProjectionMaps* proj_map_right,
8788 std::vector<JoinKeyCmp> key_cmp, Expression filter,
8889 OutputBatchCallback output_batch_callback,
8990 FinishedCallback finished_callback,
@@ -98,7 +99,8 @@ class HashJoinBasicImpl : public HashJoinImpl {
9899 ctx_ = ctx;
99100 join_type_ = join_type;
100101 num_threads_ = num_threads;
101- schema_mgr_ = schema_mgr;
102+ schema_[0 ] = proj_map_left;
103+ schema_[1 ] = proj_map_right;
102104 key_cmp_ = std::move (key_cmp);
103105 filter_ = std::move (filter);
104106 output_batch_callback_ = std::move (output_batch_callback);
@@ -139,12 +141,11 @@ class HashJoinBasicImpl : public HashJoinImpl {
139141 private:
140142 void InitEncoder (int side, HashJoinProjection projection_handle, RowEncoder* encoder) {
141143 std::vector<ValueDescr> data_types;
142- int num_cols = schema_mgr_-> proj_maps [side]. num_cols (projection_handle);
144+ int num_cols = schema_ [side]-> num_cols (projection_handle);
143145 data_types.resize (num_cols);
144146 for (int icol = 0 ; icol < num_cols; ++icol) {
145- data_types[icol] =
146- ValueDescr (schema_mgr_->proj_maps [side].data_type (projection_handle, icol),
147- ValueDescr::ARRAY);
147+ data_types[icol] = ValueDescr (schema_[side]->data_type (projection_handle, icol),
148+ ValueDescr::ARRAY);
148149 }
149150 encoder->Init (data_types, ctx_);
150151 encoder->Clear ();
@@ -155,8 +156,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
155156 ThreadLocalState& local_state = local_states_[thread_index];
156157 if (!local_state.is_initialized ) {
157158 InitEncoder (0 , HashJoinProjection::KEY, &local_state.exec_batch_keys );
158- bool has_payload =
159- (schema_mgr_->proj_maps [0 ].num_cols (HashJoinProjection::PAYLOAD) > 0 );
159+ bool has_payload = (schema_[0 ]->num_cols (HashJoinProjection::PAYLOAD) > 0 );
160160 if (has_payload) {
161161 InitEncoder (0 , HashJoinProjection::PAYLOAD, &local_state.exec_batch_payloads );
162162 }
@@ -168,11 +168,10 @@ class HashJoinBasicImpl : public HashJoinImpl {
168168 Status EncodeBatch (int side, HashJoinProjection projection_handle, RowEncoder* encoder,
169169 const ExecBatch& batch, ExecBatch* opt_projected_batch = nullptr ) {
170170 ExecBatch projected ({}, batch.length );
171- int num_cols = schema_mgr_-> proj_maps [side]. num_cols (projection_handle);
171+ int num_cols = schema_ [side]-> num_cols (projection_handle);
172172 projected.values .resize (num_cols);
173173
174- auto to_input =
175- schema_mgr_->proj_maps [side].map (projection_handle, HashJoinProjection::INPUT);
174+ auto to_input = schema_[side]->map (projection_handle, HashJoinProjection::INPUT);
176175 for (int icol = 0 ; icol < num_cols; ++icol) {
177176 projected.values [icol] = batch.values [to_input.get (icol)];
178177 }
@@ -235,16 +234,13 @@ class HashJoinBasicImpl : public HashJoinImpl {
235234 ExecBatch* opt_left_payload, ExecBatch* opt_right_key,
236235 ExecBatch* opt_right_payload) {
237236 ExecBatch result ({}, batch_size_next);
238- int num_out_cols_left =
239- schema_mgr_->proj_maps [0 ].num_cols (HashJoinProjection::OUTPUT);
240- int num_out_cols_right =
241- schema_mgr_->proj_maps [1 ].num_cols (HashJoinProjection::OUTPUT);
237+ int num_out_cols_left = schema_[0 ]->num_cols (HashJoinProjection::OUTPUT);
238+ int num_out_cols_right = schema_[1 ]->num_cols (HashJoinProjection::OUTPUT);
242239
243240 result.values .resize (num_out_cols_left + num_out_cols_right);
244- auto from_key = schema_mgr_->proj_maps [0 ].map (HashJoinProjection::OUTPUT,
245- HashJoinProjection::KEY);
246- auto from_payload = schema_mgr_->proj_maps [0 ].map (HashJoinProjection::OUTPUT,
247- HashJoinProjection::PAYLOAD);
241+ auto from_key = schema_[0 ]->map (HashJoinProjection::OUTPUT, HashJoinProjection::KEY);
242+ auto from_payload =
243+ schema_[0 ]->map (HashJoinProjection::OUTPUT, HashJoinProjection::PAYLOAD);
248244 for (int icol = 0 ; icol < num_out_cols_left; ++icol) {
249245 bool is_from_key = (from_key.get (icol) != HashJoinSchema::kMissingField ());
250246 bool is_from_payload = (from_payload.get (icol) != HashJoinSchema::kMissingField ());
@@ -262,10 +258,9 @@ class HashJoinBasicImpl : public HashJoinImpl {
262258 ? opt_left_key->values [from_key.get (icol)]
263259 : opt_left_payload->values [from_payload.get (icol)];
264260 }
265- from_key = schema_mgr_->proj_maps [1 ].map (HashJoinProjection::OUTPUT,
266- HashJoinProjection::KEY);
267- from_payload = schema_mgr_->proj_maps [1 ].map (HashJoinProjection::OUTPUT,
268- HashJoinProjection::PAYLOAD);
261+ from_key = schema_[1 ]->map (HashJoinProjection::OUTPUT, HashJoinProjection::KEY);
262+ from_payload =
263+ schema_[1 ]->map (HashJoinProjection::OUTPUT, HashJoinProjection::PAYLOAD);
269264 for (int icol = 0 ; icol < num_out_cols_right; ++icol) {
270265 bool is_from_key = (from_key.get (icol) != HashJoinSchema::kMissingField ());
271266 bool is_from_payload = (from_payload.get (icol) != HashJoinSchema::kMissingField ());
@@ -284,7 +279,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
284279 : opt_right_payload->values [from_payload.get (icol)];
285280 }
286281
287- output_batch_callback_ (std::move (result));
282+ output_batch_callback_ (0 , std::move (result));
288283
289284 // Update the counter of produced batches
290285 //
@@ -310,13 +305,13 @@ class HashJoinBasicImpl : public HashJoinImpl {
310305 hash_table_keys_.Decode (match_right.size (), match_right.data ()));
311306
312307 ExecBatch left_payload;
313- if (!schema_mgr_-> LeftPayloadIsEmpty ( )) {
308+ if (!schema_[ 0 ]-> is_empty (HashJoinProjection::PAYLOAD )) {
314309 ARROW_ASSIGN_OR_RAISE (left_payload, local_state.exec_batch_payloads .Decode (
315310 match_left.size (), match_left.data ()));
316311 }
317312
318313 ExecBatch right_payload;
319- if (!schema_mgr_-> RightPayloadIsEmpty ( )) {
314+ if (!schema_[ 1 ]-> is_empty (HashJoinProjection::PAYLOAD )) {
320315 ARROW_ASSIGN_OR_RAISE (right_payload, hash_table_payloads_.Decode (
321316 match_right.size (), match_right.data ()));
322317 }
@@ -336,14 +331,14 @@ class HashJoinBasicImpl : public HashJoinImpl {
336331 }
337332 };
338333
339- SchemaProjectionMap left_to_key = schema_mgr_-> proj_maps [ 0 ]. map (
340- HashJoinProjection::FILTER, HashJoinProjection::KEY);
341- SchemaProjectionMap left_to_pay = schema_mgr_-> proj_maps [ 0 ]. map (
342- HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD);
343- SchemaProjectionMap right_to_key = schema_mgr_-> proj_maps [ 1 ]. map (
344- HashJoinProjection::FILTER, HashJoinProjection::KEY);
345- SchemaProjectionMap right_to_pay = schema_mgr_-> proj_maps [ 1 ]. map (
346- HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD);
334+ SchemaProjectionMap left_to_key =
335+ schema_[ 0 ]-> map ( HashJoinProjection::FILTER, HashJoinProjection::KEY);
336+ SchemaProjectionMap left_to_pay =
337+ schema_[ 0 ]-> map ( HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD);
338+ SchemaProjectionMap right_to_key =
339+ schema_[ 1 ]-> map ( HashJoinProjection::FILTER, HashJoinProjection::KEY);
340+ SchemaProjectionMap right_to_pay =
341+ schema_[ 1 ]-> map ( HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD);
347342
348343 AppendFields (left_to_key, left_to_pay, left_key, left_payload);
349344 AppendFields (right_to_key, right_to_pay, right_key, right_payload);
@@ -419,15 +414,14 @@ class HashJoinBasicImpl : public HashJoinImpl {
419414
420415 bool has_left =
421416 (join_type_ != JoinType::RIGHT_SEMI && join_type_ != JoinType::RIGHT_ANTI &&
422- schema_mgr_-> proj_maps [0 ]. num_cols (HashJoinProjection::OUTPUT) > 0 );
417+ schema_ [0 ]-> num_cols (HashJoinProjection::OUTPUT) > 0 );
423418 bool has_right =
424419 (join_type_ != JoinType::LEFT_SEMI && join_type_ != JoinType::LEFT_ANTI &&
425- schema_mgr_-> proj_maps [1 ]. num_cols (HashJoinProjection::OUTPUT) > 0 );
420+ schema_ [1 ]-> num_cols (HashJoinProjection::OUTPUT) > 0 );
426421 bool has_left_payload =
427- has_left && (schema_mgr_-> proj_maps [0 ]. num_cols (HashJoinProjection::PAYLOAD) > 0 );
422+ has_left && (schema_ [0 ]-> num_cols (HashJoinProjection::PAYLOAD) > 0 );
428423 bool has_right_payload =
429- has_right &&
430- (schema_mgr_->proj_maps [1 ].num_cols (HashJoinProjection::PAYLOAD) > 0 );
424+ has_right && (schema_[1 ]->num_cols (HashJoinProjection::PAYLOAD) > 0 );
431425
432426 ThreadLocalState& local_state = local_states_[thread_index];
433427 InitLocalStateIfNeeded (thread_index);
@@ -450,7 +444,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
450444 ARROW_ASSIGN_OR_RAISE (right_key,
451445 hash_table_keys_.Decode (batch_size_next, opt_right_ids));
452446 // Post process build side keys that use dictionary
453- RETURN_NOT_OK (dict_build_.PostDecode (schema_mgr_-> proj_maps [1 ], &right_key, ctx_));
447+ RETURN_NOT_OK (dict_build_.PostDecode (*schema_ [1 ], &right_key, ctx_));
454448 }
455449 if (has_right_payload) {
456450 ARROW_ASSIGN_OR_RAISE (right_payload,
@@ -550,8 +544,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
550544
551545 RETURN_NOT_OK (EncodeBatch (0 , HashJoinProjection::KEY, &local_state.exec_batch_keys ,
552546 batch, &batch_key_for_lookups));
553- bool has_left_payload =
554- (schema_mgr_->proj_maps [0 ].num_cols (HashJoinProjection::PAYLOAD) > 0 );
547+ bool has_left_payload = (schema_[0 ]->num_cols (HashJoinProjection::PAYLOAD) > 0 );
555548 if (has_left_payload) {
556549 local_state.exec_batch_payloads .Clear ();
557550 RETURN_NOT_OK (EncodeBatch (0 , HashJoinProjection::PAYLOAD,
@@ -563,13 +556,13 @@ class HashJoinBasicImpl : public HashJoinImpl {
563556 local_state.match_left .clear ();
564557 local_state.match_right .clear ();
565558
566- bool use_key_batch_for_dicts = dict_probe_. BatchRemapNeeded (
567- thread_index, schema_mgr_-> proj_maps [0 ], schema_mgr_-> proj_maps [1 ], ctx_);
559+ bool use_key_batch_for_dicts =
560+ dict_probe_. BatchRemapNeeded ( thread_index, *schema_ [0 ], *schema_ [1 ], ctx_);
568561 RowEncoder* row_encoder_for_lookups = &local_state.exec_batch_keys ;
569562 if (use_key_batch_for_dicts) {
570- RETURN_NOT_OK (dict_probe_.EncodeBatch (
571- thread_index, schema_mgr_-> proj_maps [ 0 ], schema_mgr_-> proj_maps [ 1 ], dict_build_ ,
572- batch, &row_encoder_for_lookups, &batch_key_for_lookups, ctx_));
563+ RETURN_NOT_OK (dict_probe_.EncodeBatch (thread_index, *schema_[ 0 ], *schema_[ 1 ],
564+ dict_build_, batch, &row_encoder_for_lookups ,
565+ &batch_key_for_lookups, ctx_));
573566 }
574567
575568 // Collect information about all nulls in key columns.
@@ -609,9 +602,8 @@ class HashJoinBasicImpl : public HashJoinImpl {
609602 if (batches.empty ()) {
610603 hash_table_empty_ = true ;
611604 } else {
612- dict_build_.InitEncoder (schema_mgr_->proj_maps [1 ], &hash_table_keys_, ctx_);
613- bool has_payload =
614- (schema_mgr_->proj_maps [1 ].num_cols (HashJoinProjection::PAYLOAD) > 0 );
605+ dict_build_.InitEncoder (*schema_[1 ], &hash_table_keys_, ctx_);
606+ bool has_payload = (schema_[1 ]->num_cols (HashJoinProjection::PAYLOAD) > 0 );
615607 if (has_payload) {
616608 InitEncoder (1 , HashJoinProjection::PAYLOAD, &hash_table_payloads_);
617609 }
@@ -626,11 +618,11 @@ class HashJoinBasicImpl : public HashJoinImpl {
626618 } else if (hash_table_empty_) {
627619 hash_table_empty_ = false ;
628620
629- RETURN_NOT_OK (dict_build_.Init (schema_mgr_-> proj_maps [1 ], &batch, ctx_));
621+ RETURN_NOT_OK (dict_build_.Init (*schema_ [1 ], &batch, ctx_));
630622 }
631623 int32_t num_rows_before = hash_table_keys_.num_rows ();
632- RETURN_NOT_OK (dict_build_.EncodeBatch (thread_index, schema_mgr_-> proj_maps [1 ],
633- batch, &hash_table_keys_, ctx_));
624+ RETURN_NOT_OK (dict_build_.EncodeBatch (thread_index, *schema_ [1 ], batch ,
625+ &hash_table_keys_, ctx_));
634626 if (has_payload) {
635627 RETURN_NOT_OK (
636628 EncodeBatch (1 , HashJoinProjection::PAYLOAD, &hash_table_payloads_, batch));
@@ -643,7 +635,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
643635 }
644636
645637 if (hash_table_empty_) {
646- RETURN_NOT_OK (dict_build_.Init (schema_mgr_-> proj_maps [1 ], nullptr , ctx_));
638+ RETURN_NOT_OK (dict_build_.Init (*schema_ [1 ], nullptr , ctx_));
647639 }
648640
649641 return Status::OK ();
@@ -869,7 +861,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
869861 ExecContext* ctx_;
870862 JoinType join_type_;
871863 size_t num_threads_;
872- HashJoinSchema* schema_mgr_ ;
864+ const HashJoinProjectionMaps* schema_[ 2 ] ;
873865 std::vector<JoinKeyCmp> key_cmp_;
874866 Expression filter_;
875867 std::unique_ptr<TaskScheduler> scheduler_;
0 commit comments