From 6dbfed64f298e1725cda65e5bc32a478cefb3fc1 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 27 Jan 2024 23:21:59 -0800 Subject: [PATCH 1/8] Support join filter for SortMergeJoin --- .../enforce_distribution.rs | 5 + .../physical_optimizer/projection_pushdown.rs | 1 + .../core/src/physical_optimizer/test_utils.rs | 1 + datafusion/core/src/physical_planner.rs | 26 ++-- datafusion/core/tests/fuzz_cases/join_fuzz.rs | 1 + .../src/joins/sort_merge_join.rs | 122 ++++++++++++++++-- datafusion/sqllogictest/test_files/join.slt | 43 +++++- .../join_disable_repartition_joins.slt | 50 ++++--- 8 files changed, 205 insertions(+), 44 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index fab26c49c2daa..4f8806a685923 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -342,6 +342,7 @@ fn adjust_input_keys_ordering( left, right, on, + filter, join_type, sort_options, null_equals_null, @@ -356,6 +357,7 @@ fn adjust_input_keys_ordering( left.clone(), right.clone(), new_conditions.0, + filter.clone(), *join_type, new_conditions.1, *null_equals_null, @@ -635,6 +637,7 @@ pub(crate) fn reorder_join_keys_to_inputs( left, right, on, + filter, join_type, sort_options, null_equals_null, @@ -664,6 +667,7 @@ pub(crate) fn reorder_join_keys_to_inputs( left.clone(), right.clone(), new_join_on, + filter.clone(), *join_type, new_sort_options, *null_equals_null, @@ -1642,6 +1646,7 @@ pub(crate) mod tests { left, right, join_on.clone(), + None, *join_type, vec![SortOptions::default(); join_on.len()], false, diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs index 301a97bba4c5b..b2014635b6342 100644 --- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -736,6 +736,7 @@ fn try_swapping_with_sort_merge_join( Arc::new(new_left), Arc::new(new_right), new_on, + sm_join.filter.clone(), sm_join.join_type, sm_join.sort_options.clone(), sm_join.null_equals_null, diff --git a/datafusion/core/src/physical_optimizer/test_utils.rs b/datafusion/core/src/physical_optimizer/test_utils.rs index 5de6cff0b4fad..ca7fb78d21b15 100644 --- a/datafusion/core/src/physical_optimizer/test_utils.rs +++ b/datafusion/core/src/physical_optimizer/test_utils.rs @@ -175,6 +175,7 @@ pub fn sort_merge_join_exec( left, right, join_on.clone(), + None, *join_type, vec![SortOptions::default(); join_on.len()], false, diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index d4ef40493df38..1a22ade0ba9a1 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -1111,6 +1111,7 @@ impl DefaultPhysicalPlanner { }; let prefer_hash_join = session_state.config_options().optimizer.prefer_hash_join; + if join_on.is_empty() { // there is no equal join condition, use the nested loop join // TODO optimize the plan, and use the config of `target_partitions` and `repartition_joins` @@ -1126,20 +1127,17 @@ impl DefaultPhysicalPlanner { { // Use SortMergeJoin if hash join is not preferred // Sort-Merge join support currently is experimental - if join_filter.is_some() { - // TODO SortMergeJoinExec need to support join filter - not_impl_err!("SortMergeJoinExec does not support join_filter now.") - } else { - let join_on_len = join_on.len(); - Ok(Arc::new(SortMergeJoinExec::try_new( - physical_left, - physical_right, - join_on, - *join_type, - vec![SortOptions::default(); join_on_len], - null_equals_null, - )?)) - } + + let join_on_len = join_on.len(); + Ok(Arc::new(SortMergeJoinExec::try_new( + physical_left, + physical_right, + join_on, + join_filter, + *join_type, + vec![SortOptions::default(); join_on_len], + null_equals_null, + )?)) } else if session_state.config().target_partitions() > 1 && session_state.config().repartition_joins() && prefer_hash_join { diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index 1c819ac466dfb..78f8ee7723fc4 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -130,6 +130,7 @@ async fn run_join_test( left, right, on_columns.clone(), + None, join_type, vec![SortOptions::default(), SortOptions::default()], false, diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 675e90fb63d7f..af7538b3c340c 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -33,7 +33,7 @@ use std::task::{Context, Poll}; use crate::expressions::PhysicalSortExpr; use crate::joins::utils::{ build_join_schema, calculate_join_output_ordering, check_join_is_valid, - estimate_join_statistics, partitioned_join_output_partitioning, JoinOn, + estimate_join_statistics, partitioned_join_output_partitioning, JoinFilter, JoinOn, }; use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; use crate::{ @@ -42,6 +42,7 @@ use crate::{ }; use arrow::array::*; +use arrow::compute; use arrow::compute::{concat_batches, take, SortOptions}; use arrow::datatypes::{DataType, SchemaRef, TimeUnit}; use arrow::error::ArrowError; @@ -68,6 +69,8 @@ pub struct SortMergeJoinExec { pub right: Arc, /// Set of common columns used to join on pub on: JoinOn, + /// Filters which are applied while finding matching rows + pub filter: Option, /// How the join is performed pub join_type: JoinType, /// The schema once the join is applied @@ -95,6 +98,7 @@ impl SortMergeJoinExec { left: Arc, right: Arc, on: JoinOn, + filter: Option, join_type: JoinType, sort_options: Vec, null_equals_null: bool, @@ -150,6 +154,7 @@ impl SortMergeJoinExec { left, right, on, + filter, join_type, schema, metrics: ExecutionPlanMetricsSet::new(), @@ -210,6 +215,11 @@ impl SortMergeJoinExec { impl DisplayAs for SortMergeJoinExec { fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + let display_filter = self.filter.as_ref().map_or_else( + || "".to_string(), + |f| format!(", filter={}", f.expression()), + ); + match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { let on = self @@ -220,8 +230,8 @@ impl DisplayAs for SortMergeJoinExec { .join(", "); write!( f, - "SortMergeJoin: join_type={:?}, on=[{}]", - self.join_type, on + "SortMergeJoin: join_type={:?}, on=[{}]{}", + self.join_type, on, display_filter ) } } @@ -300,6 +310,7 @@ impl ExecutionPlan for SortMergeJoinExec { left.clone(), right.clone(), self.on.clone(), + self.filter.clone(), self.join_type, self.sort_options.clone(), self.null_equals_null, @@ -349,6 +360,7 @@ impl ExecutionPlan for SortMergeJoinExec { buffered, on_streamed, on_buffered, + self.filter.clone(), self.join_type, batch_size, SortMergeJoinMetrics::new(partition, &self.metrics), @@ -456,8 +468,9 @@ enum BufferedState { Exhausted, } +/// Represents a chunk of joined data from streamed and buffered side struct StreamedJoinedChunk { - /// Index of batch buffered_data + /// Index of batch in buffered_data buffered_batch_idx: Option, /// Array builder for streamed indices streamed_indices: UInt64Builder, @@ -466,13 +479,17 @@ struct StreamedJoinedChunk { } struct StreamedBatch { + /// The streamed record batch pub batch: RecordBatch, + /// The index of row in the streamed batch to compare with buffered batches pub idx: usize, + /// The join key arrays of streamed batch which are used to compare with buffered batches + /// and to produce output. They are produced by evaluating `on` expressions. pub join_arrays: Vec, - // Chunks of indices from buffered side (may be nulls) joined to streamed + /// Chunks of indices from buffered side (may be nulls) joined to streamed pub output_indices: Vec, - // Index of currently scanned batch from buffered data + /// Index of currently scanned batch from buffered data pub buffered_batch_idx: Option, } @@ -505,6 +522,8 @@ impl StreamedBatch { buffered_batch_idx: Option, buffered_idx: Option, ) { + // If no current chunk exists or current chunk is not for current buffered batch, + // create a new chunk if self.output_indices.is_empty() || self.buffered_batch_idx != buffered_batch_idx { self.output_indices.push(StreamedJoinedChunk { @@ -516,6 +535,7 @@ impl StreamedBatch { }; let current_chunk = self.output_indices.last_mut().unwrap(); + // Append index of streamed batch and index of buffered batch into current chunk current_chunk.streamed_indices.append_value(self.idx as u64); if let Some(idx) = buffered_idx { current_chunk.buffered_indices.append_value(idx as u64); @@ -610,6 +630,8 @@ struct SMJStream { pub on_streamed: Vec, /// Join key columns of buffered pub on_buffered: Vec, + /// optional join filter + pub filter: Option, /// Staging output array builders pub output_record_batches: Vec, /// Staging output size, including output batches and staging joined results @@ -736,6 +758,7 @@ impl SMJStream { buffered: SendableRecordBatchStream, on_streamed: Vec>, on_buffered: Vec>, + filter: Option, join_type: JoinType, batch_size: usize, join_metrics: SortMergeJoinMetrics, @@ -761,6 +784,7 @@ impl SMJStream { current_ordering: Ordering::Equal, on_streamed, on_buffered, + filter, output_record_batches: vec![], output_size: 0, batch_size, @@ -943,7 +967,9 @@ impl SMJStream { /// Produce join and fill output buffer until reaching target batch size /// or the join is finished fn join_partial(&mut self) -> Result<()> { + // Whether to join streamed rows let mut join_streamed = false; + // Whether to join buffered rows let mut join_buffered = false; // determine whether we need to join streamed/buffered rows @@ -991,11 +1017,13 @@ impl SMJStream { { let scanning_idx = self.buffered_data.scanning_idx(); if join_streamed { + // Join streamed row and buffered row self.streamed_batch.append_output_pair( Some(self.buffered_data.scanning_batch_idx), Some(scanning_idx), ); } else { + // Join nulls and buffered row self.buffered_data .scanning_batch_mut() .null_joined @@ -1059,6 +1087,7 @@ impl SMJStream { } buffered_batch.null_joined.clear(); + // Take buffered (right) columns let buffered_columns = buffered_batch .batch .columns() @@ -1067,6 +1096,7 @@ impl SMJStream { .collect::, ArrowError>>() .map_err(Into::::into)?; + // Create null streamed (left) columns let mut streamed_columns = self .streamed_schema .fields() @@ -1074,11 +1104,32 @@ impl SMJStream { .map(|f| new_null_array(f.data_type(), buffered_indices.len())) .collect::>(); + let filter_columns = + get_filter_column(&self.filter, &streamed_columns, &buffered_columns); + streamed_columns.extend(buffered_columns); let columns = streamed_columns; - self.output_record_batches - .push(RecordBatch::try_new(self.schema.clone(), columns)?); + let output_batch = RecordBatch::try_new(self.schema.clone(), columns)?; + + // Apply join filter if any + let output_batch = if let Some(f) = &self.filter { + // Construct batch with only filter columns + let filter_batch = + RecordBatch::try_new(Arc::new(f.schema().clone()), filter_columns)?; + + let filter_result = f + .expression() + .evaluate(&filter_batch)? + .into_array(filter_batch.num_rows())?; + let mask = datafusion_common::cast::as_boolean_array(&filter_result)?; + + compute::filter_record_batch(&output_batch, mask)? + } else { + output_batch + }; + + self.output_record_batches.push(output_batch); } Ok(()) } @@ -1121,6 +1172,9 @@ impl SMJStream { .collect::>() }; + let filter_columns = + get_filter_column(&self.filter, &streamed_columns, &buffered_columns); + let columns = if matches!(self.join_type, JoinType::Right) { buffered_columns.extend(streamed_columns); buffered_columns @@ -1129,8 +1183,26 @@ impl SMJStream { streamed_columns }; - self.output_record_batches - .push(RecordBatch::try_new(self.schema.clone(), columns)?); + let output_batch = RecordBatch::try_new(self.schema.clone(), columns)?; + + // Apply join filter if any + let output_batch = if let Some(f) = &self.filter { + // Construct batch with only filter columns + let filter_batch = + RecordBatch::try_new(Arc::new(f.schema().clone()), filter_columns)?; + + let filter_result = f + .expression() + .evaluate(&filter_batch)? + .into_array(filter_batch.num_rows())?; + let mask = datafusion_common::cast::as_boolean_array(&filter_result)?; + + compute::filter_record_batch(&output_batch, mask)? + } else { + output_batch + }; + + self.output_record_batches.push(output_batch); } self.streamed_batch.output_indices.clear(); @@ -1148,6 +1220,36 @@ impl SMJStream { } } +/// Gets the arrays which join filters are applied on. +fn get_filter_column( + join_filter: &Option, + streamed_columns: &Vec, + buffered_columns: &Vec, +) -> Vec { + let mut filter_columns = vec![]; + + if let Some(f) = join_filter { + let left_columns = f + .column_indices() + .iter() + .filter(|col_index| (*col_index).side == JoinSide::Left) + .map(|i| streamed_columns[i.index].clone()) + .collect::>(); + + let right_columns = f + .column_indices() + .iter() + .filter(|col_index| (*col_index).side == JoinSide::Right) + .map(|i| buffered_columns[i.index].clone()) + .collect::>(); + + filter_columns.extend(left_columns); + filter_columns.extend(right_columns); + } + + filter_columns +} + /// Buffered data contains all buffered batches with one unique join key #[derive(Debug, Default)] struct BufferedData { diff --git a/datafusion/sqllogictest/test_files/join.slt b/datafusion/sqllogictest/test_files/join.slt index ca9b918ff3ee0..2e716be3d2152 100644 --- a/datafusion/sqllogictest/test_files/join.slt +++ b/datafusion/sqllogictest/test_files/join.slt @@ -655,8 +655,49 @@ CoalesceBatchesExec: target_batch_size=8192 statement ok set datafusion.execution.target_partitions = 4; +# equijoin and join filter (sort merge join) statement ok -set datafusion.optimizer.repartition_joins = false; +set datafusion.optimizer.prefer_hash_join = false; + +query TT +EXPLAIN SELECT t1.a, t1.b, t2.a, t2.b FROM t1 JOIN t2 ON t1.a = t2.a AND t2.b * 50 <= t1.b +---- +logical_plan +Inner Join: t1.a = t2.a Filter: CAST(t2.b AS Int64) * Int64(50) <= CAST(t1.b AS Int64) +--TableScan: t1 projection=[a, b] +--TableScan: t2 projection=[a, b] +physical_plan +SortMergeJoin: join_type=Inner, on=[(a@0, a@0)], filter=CAST(b@1 AS Int64) * 50 <= CAST(b@0 AS Int64) +--SortExec: expr=[a@0 ASC] +----CoalesceBatchesExec: target_batch_size=8192 +------RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1 +--------MemoryExec: partitions=1, partition_sizes=[1] +--SortExec: expr=[a@0 ASC] +----CoalesceBatchesExec: target_batch_size=8192 +------RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1 +--------MemoryExec: partitions=1, partition_sizes=[1] + +query TITI rowsort +SELECT t1.a, t1.b, t2.a, t2.b FROM t1 JOIN t2 ON t1.a = t2.a AND t2.b * 50 <= t1.b +---- +Alice 100 Alice 1 +Alice 100 Alice 2 +Alice 50 Alice 1 + +query TITI rowsort +SELECT t1.a, t1.b, t2.a, t2.b FROM t1 JOIN t2 ON t1.a = t2.a AND t2.b < t1.b +---- +Alice 100 Alice 1 +Alice 100 Alice 2 +Alice 50 Alice 1 +Alice 50 Alice 2 + +query TITI rowsort +SELECT t1.a, t1.b, t2.a, t2.b FROM t1 JOIN t2 ON t1.a = t2.a AND t2.b > t1.b +---- + +statement ok +set datafusion.optimizer.prefer_hash_join = true; statement ok DROP TABLE t1; diff --git a/datafusion/sqllogictest/test_files/join_disable_repartition_joins.slt b/datafusion/sqllogictest/test_files/join_disable_repartition_joins.slt index 1312f2916ed61..805c189ed6edd 100644 --- a/datafusion/sqllogictest/test_files/join_disable_repartition_joins.slt +++ b/datafusion/sqllogictest/test_files/join_disable_repartition_joins.slt @@ -57,12 +57,18 @@ Limit: skip=0, fetch=5 physical_plan GlobalLimitExec: skip=0, fetch=5 --SortPreservingMergeExec: [a@0 ASC NULLS LAST], fetch=5 -----ProjectionExec: expr=[a@1 as a] -------CoalesceBatchesExec: target_batch_size=8192 ---------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(c@0, c@1)] -----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c], has_header=true -----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c], output_ordering=[a@0 ASC NULLS LAST], has_header=true +----SortExec: TopK(fetch=5), expr=[a@0 ASC NULLS LAST] +------ProjectionExec: expr=[a@1 as a] +--------CoalesceBatchesExec: target_batch_size=8192 +----------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@0, c@1)] +------------CoalesceBatchesExec: target_batch_size=8192 +--------------RepartitionExec: partitioning=Hash([c@0], 4), input_partitions=4 +----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c], has_header=true +------------CoalesceBatchesExec: target_batch_size=8192 +--------------RepartitionExec: partitioning=Hash([c@1], 4), input_partitions=4 +----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c], output_ordering=[a@0 ASC NULLS LAST], has_header=true # preserve_inner_join query IIII nosort @@ -72,11 +78,11 @@ SELECT t1.a, t1.b, t1.c, t2.a as a2 ON t1.d = t2.d ORDER BY a2, t2.b LIMIT 5 ---- -0 0 0 0 -0 0 2 0 -0 0 3 0 -0 0 6 0 -0 0 20 0 +0 0 7 0 +0 0 11 0 +0 0 12 0 +0 0 14 0 +0 0 1 0 query TT EXPLAIN SELECT t2.a as a2, t2.b @@ -100,14 +106,20 @@ Limit: skip=0, fetch=10 physical_plan GlobalLimitExec: skip=0, fetch=10 --SortPreservingMergeExec: [a2@0 ASC NULLS LAST,b@1 ASC NULLS LAST], fetch=10 -----ProjectionExec: expr=[a@0 as a2, b@1 as b] -------CoalesceBatchesExec: target_batch_size=8192 ---------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(d@1, d@3), (c@0, c@2)] -----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c, d], has_header=true -----------CoalesceBatchesExec: target_batch_size=8192 -------------FilterExec: d@3 = 3 ---------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -----------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c, d], output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], has_header=true +----SortExec: TopK(fetch=10), expr=[a2@0 ASC NULLS LAST,b@1 ASC NULLS LAST] +------ProjectionExec: expr=[a@0 as a2, b@1 as b] +--------CoalesceBatchesExec: target_batch_size=8192 +----------HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(d@1, d@3), (c@0, c@2)] +------------CoalesceBatchesExec: target_batch_size=8192 +--------------RepartitionExec: partitioning=Hash([d@1, c@0], 4), input_partitions=4 +----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c, d], has_header=true +------------CoalesceBatchesExec: target_batch_size=8192 +--------------RepartitionExec: partitioning=Hash([d@3, c@2], 4), input_partitions=4 +----------------CoalesceBatchesExec: target_batch_size=8192 +------------------FilterExec: d@3 = 3 +--------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c, d], output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], has_header=true # preserve_right_semi_join query II nosort From 159c69337406d4ea958aa0034bbf130e7622d790 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 30 Jan 2024 22:35:55 -0800 Subject: [PATCH 2/8] Move test --- datafusion/sqllogictest/test_files/join.slt | 43 +---------- .../join_disable_repartition_joins.slt | 50 +++++------- .../test_files/sort_merge_join.slt | 77 +++++++++++++++++++ 3 files changed, 97 insertions(+), 73 deletions(-) create mode 100644 datafusion/sqllogictest/test_files/sort_merge_join.slt diff --git a/datafusion/sqllogictest/test_files/join.slt b/datafusion/sqllogictest/test_files/join.slt index 2e716be3d2152..ca9b918ff3ee0 100644 --- a/datafusion/sqllogictest/test_files/join.slt +++ b/datafusion/sqllogictest/test_files/join.slt @@ -655,49 +655,8 @@ CoalesceBatchesExec: target_batch_size=8192 statement ok set datafusion.execution.target_partitions = 4; -# equijoin and join filter (sort merge join) statement ok -set datafusion.optimizer.prefer_hash_join = false; - -query TT -EXPLAIN SELECT t1.a, t1.b, t2.a, t2.b FROM t1 JOIN t2 ON t1.a = t2.a AND t2.b * 50 <= t1.b ----- -logical_plan -Inner Join: t1.a = t2.a Filter: CAST(t2.b AS Int64) * Int64(50) <= CAST(t1.b AS Int64) ---TableScan: t1 projection=[a, b] ---TableScan: t2 projection=[a, b] -physical_plan -SortMergeJoin: join_type=Inner, on=[(a@0, a@0)], filter=CAST(b@1 AS Int64) * 50 <= CAST(b@0 AS Int64) ---SortExec: expr=[a@0 ASC] -----CoalesceBatchesExec: target_batch_size=8192 -------RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1 ---------MemoryExec: partitions=1, partition_sizes=[1] ---SortExec: expr=[a@0 ASC] -----CoalesceBatchesExec: target_batch_size=8192 -------RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1 ---------MemoryExec: partitions=1, partition_sizes=[1] - -query TITI rowsort -SELECT t1.a, t1.b, t2.a, t2.b FROM t1 JOIN t2 ON t1.a = t2.a AND t2.b * 50 <= t1.b ----- -Alice 100 Alice 1 -Alice 100 Alice 2 -Alice 50 Alice 1 - -query TITI rowsort -SELECT t1.a, t1.b, t2.a, t2.b FROM t1 JOIN t2 ON t1.a = t2.a AND t2.b < t1.b ----- -Alice 100 Alice 1 -Alice 100 Alice 2 -Alice 50 Alice 1 -Alice 50 Alice 2 - -query TITI rowsort -SELECT t1.a, t1.b, t2.a, t2.b FROM t1 JOIN t2 ON t1.a = t2.a AND t2.b > t1.b ----- - -statement ok -set datafusion.optimizer.prefer_hash_join = true; +set datafusion.optimizer.repartition_joins = false; statement ok DROP TABLE t1; diff --git a/datafusion/sqllogictest/test_files/join_disable_repartition_joins.slt b/datafusion/sqllogictest/test_files/join_disable_repartition_joins.slt index 805c189ed6edd..1312f2916ed61 100644 --- a/datafusion/sqllogictest/test_files/join_disable_repartition_joins.slt +++ b/datafusion/sqllogictest/test_files/join_disable_repartition_joins.slt @@ -57,18 +57,12 @@ Limit: skip=0, fetch=5 physical_plan GlobalLimitExec: skip=0, fetch=5 --SortPreservingMergeExec: [a@0 ASC NULLS LAST], fetch=5 -----SortExec: TopK(fetch=5), expr=[a@0 ASC NULLS LAST] -------ProjectionExec: expr=[a@1 as a] ---------CoalesceBatchesExec: target_batch_size=8192 -----------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@0, c@1)] -------------CoalesceBatchesExec: target_batch_size=8192 ---------------RepartitionExec: partitioning=Hash([c@0], 4), input_partitions=4 -----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c], has_header=true -------------CoalesceBatchesExec: target_batch_size=8192 ---------------RepartitionExec: partitioning=Hash([c@1], 4), input_partitions=4 -----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c], output_ordering=[a@0 ASC NULLS LAST], has_header=true +----ProjectionExec: expr=[a@1 as a] +------CoalesceBatchesExec: target_batch_size=8192 +--------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(c@0, c@1)] +----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c], has_header=true +----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c], output_ordering=[a@0 ASC NULLS LAST], has_header=true # preserve_inner_join query IIII nosort @@ -78,11 +72,11 @@ SELECT t1.a, t1.b, t1.c, t2.a as a2 ON t1.d = t2.d ORDER BY a2, t2.b LIMIT 5 ---- -0 0 7 0 -0 0 11 0 -0 0 12 0 -0 0 14 0 -0 0 1 0 +0 0 0 0 +0 0 2 0 +0 0 3 0 +0 0 6 0 +0 0 20 0 query TT EXPLAIN SELECT t2.a as a2, t2.b @@ -106,20 +100,14 @@ Limit: skip=0, fetch=10 physical_plan GlobalLimitExec: skip=0, fetch=10 --SortPreservingMergeExec: [a2@0 ASC NULLS LAST,b@1 ASC NULLS LAST], fetch=10 -----SortExec: TopK(fetch=10), expr=[a2@0 ASC NULLS LAST,b@1 ASC NULLS LAST] -------ProjectionExec: expr=[a@0 as a2, b@1 as b] ---------CoalesceBatchesExec: target_batch_size=8192 -----------HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(d@1, d@3), (c@0, c@2)] -------------CoalesceBatchesExec: target_batch_size=8192 ---------------RepartitionExec: partitioning=Hash([d@1, c@0], 4), input_partitions=4 -----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c, d], has_header=true -------------CoalesceBatchesExec: target_batch_size=8192 ---------------RepartitionExec: partitioning=Hash([d@3, c@2], 4), input_partitions=4 -----------------CoalesceBatchesExec: target_batch_size=8192 -------------------FilterExec: d@3 = 3 ---------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -----------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c, d], output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], has_header=true +----ProjectionExec: expr=[a@0 as a2, b@1 as b] +------CoalesceBatchesExec: target_batch_size=8192 +--------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(d@1, d@3), (c@0, c@2)] +----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c, d], has_header=true +----------CoalesceBatchesExec: target_batch_size=8192 +------------FilterExec: d@3 = 3 +--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c, d], output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], has_header=true # preserve_right_semi_join query II nosort diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt b/datafusion/sqllogictest/test_files/sort_merge_join.slt new file mode 100644 index 0000000000000..3dffc0a24dc23 --- /dev/null +++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt @@ -0,0 +1,77 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +## Sort Merge Join Tests +########## + +statement ok +set datafusion.optimizer.prefer_hash_join = false; + +statement ok +CREATE TABLE t1(a text, b int) AS VALUES ('Alice', 50), ('Alice', 100); + +statement ok +CREATE TABLE t2(a text, b int) AS VALUES ('Alice', 2), ('Alice', 1); + +# equijoin and join filter (sort merge join) + +query TT +EXPLAIN SELECT t1.a, t1.b, t2.a, t2.b FROM t1 JOIN t2 ON t1.a = t2.a AND t2.b * 50 <= t1.b +---- +logical_plan +Inner Join: t1.a = t2.a Filter: CAST(t2.b AS Int64) * Int64(50) <= CAST(t1.b AS Int64) +--TableScan: t1 projection=[a, b] +--TableScan: t2 projection=[a, b] +physical_plan +SortMergeJoin: join_type=Inner, on=[(a@0, a@0)], filter=CAST(b@1 AS Int64) * 50 <= CAST(b@0 AS Int64) +--SortExec: expr=[a@0 ASC] +----CoalesceBatchesExec: target_batch_size=8192 +------RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1 +--------MemoryExec: partitions=1, partition_sizes=[1] +--SortExec: expr=[a@0 ASC] +----CoalesceBatchesExec: target_batch_size=8192 +------RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1 +--------MemoryExec: partitions=1, partition_sizes=[1] + +query TITI rowsort +SELECT t1.a, t1.b, t2.a, t2.b FROM t1 JOIN t2 ON t1.a = t2.a AND t2.b * 50 <= t1.b +---- +Alice 100 Alice 1 +Alice 100 Alice 2 +Alice 50 Alice 1 + +query TITI rowsort +SELECT t1.a, t1.b, t2.a, t2.b FROM t1 JOIN t2 ON t1.a = t2.a AND t2.b < t1.b +---- +Alice 100 Alice 1 +Alice 100 Alice 2 +Alice 50 Alice 1 +Alice 50 Alice 2 + +query TITI rowsort +SELECT t1.a, t1.b, t2.a, t2.b FROM t1 JOIN t2 ON t1.a = t2.a AND t2.b > t1.b +---- + +statement ok +set datafusion.optimizer.prefer_hash_join = true; + +statement ok +DROP TABLE t1; + +statement ok +DROP TABLE t2; From d9e40afcd3a55837ffe55bc37b05c2f7758dcb8a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 31 Jan 2024 12:30:56 -0800 Subject: [PATCH 3/8] Fix test --- datafusion/physical-plan/src/joins/sort_merge_join.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index af7538b3c340c..0d245a3ae18a3 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -1600,7 +1600,7 @@ mod tests { join_type: JoinType, ) -> Result { let sort_options = vec![SortOptions::default(); on.len()]; - SortMergeJoinExec::try_new(left, right, on, join_type, sort_options, false) + SortMergeJoinExec::try_new(left, right, on, None, join_type, sort_options, false) } fn join_with_options( @@ -1615,6 +1615,7 @@ mod tests { left, right, on, + None, join_type, sort_options, null_equals_null, From 00d416013b06c7e35558861a5ae93c1bf73a92fc Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 31 Jan 2024 12:51:27 -0800 Subject: [PATCH 4/8] Fix clippy --- datafusion/physical-plan/src/joins/sort_merge_join.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 0d245a3ae18a3..504e3f011a194 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -1223,8 +1223,8 @@ impl SMJStream { /// Gets the arrays which join filters are applied on. fn get_filter_column( join_filter: &Option, - streamed_columns: &Vec, - buffered_columns: &Vec, + streamed_columns: &[ArrayRef], + buffered_columns: &[ArrayRef], ) -> Vec { let mut filter_columns = vec![]; @@ -1232,14 +1232,14 @@ fn get_filter_column( let left_columns = f .column_indices() .iter() - .filter(|col_index| (*col_index).side == JoinSide::Left) + .filter(|col_index| col_index.side == JoinSide::Left) .map(|i| streamed_columns[i.index].clone()) .collect::>(); let right_columns = f .column_indices() .iter() - .filter(|col_index| (*col_index).side == JoinSide::Right) + .filter(|col_index| col_index.side == JoinSide::Right) .map(|i| buffered_columns[i.index].clone()) .collect::>(); From 99940c2b71054563e3eff816825e371b6de1788d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 31 Jan 2024 23:43:56 -0800 Subject: [PATCH 5/8] Add outer join tests --- .../src/joins/sort_merge_join.rs | 7 +- .../test_files/sort_merge_join.slt | 79 ++++++++++++++++++- 2 files changed, 81 insertions(+), 5 deletions(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 504e3f011a194..2d196acdb31cf 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -1172,8 +1172,11 @@ impl SMJStream { .collect::>() }; - let filter_columns = - get_filter_column(&self.filter, &streamed_columns, &buffered_columns); + let filter_columns = if matches!(self.join_type, JoinType::Right) { + get_filter_column(&self.filter, &buffered_columns, &streamed_columns) + } else { + get_filter_column(&self.filter, &streamed_columns, &buffered_columns) + }; let columns = if matches!(self.join_type, JoinType::Right) { buffered_columns.extend(streamed_columns); diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt b/datafusion/sqllogictest/test_files/sort_merge_join.slt index 3dffc0a24dc23..d9266dea5ab1a 100644 --- a/datafusion/sqllogictest/test_files/sort_merge_join.slt +++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt @@ -23,13 +23,12 @@ statement ok set datafusion.optimizer.prefer_hash_join = false; statement ok -CREATE TABLE t1(a text, b int) AS VALUES ('Alice', 50), ('Alice', 100); +CREATE TABLE t1(a text, b int) AS VALUES ('Alice', 50), ('Alice', 100), ('Bob', 1); statement ok CREATE TABLE t2(a text, b int) AS VALUES ('Alice', 2), ('Alice', 1); -# equijoin and join filter (sort merge join) - +# inner join query plan with join filter query TT EXPLAIN SELECT t1.a, t1.b, t2.a, t2.b FROM t1 JOIN t2 ON t1.a = t2.a AND t2.b * 50 <= t1.b ---- @@ -48,6 +47,7 @@ SortMergeJoin: join_type=Inner, on=[(a@0, a@0)], filter=CAST(b@1 AS Int64) * 50 ------RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1 --------MemoryExec: partitions=1, partition_sizes=[1] +# inner join with join filter query TITI rowsort SELECT t1.a, t1.b, t2.a, t2.b FROM t1 JOIN t2 ON t1.a = t2.a AND t2.b * 50 <= t1.b ---- @@ -67,6 +67,79 @@ query TITI rowsort SELECT t1.a, t1.b, t2.a, t2.b FROM t1 JOIN t2 ON t1.a = t2.a AND t2.b > t1.b ---- +# left join without join filter +query TITI rowsort +SELECT * FROM t1 LEFT JOIN t2 ON t1.a = t2.a +---- +Alice 100 Alice 1 +Alice 100 Alice 2 +Alice 50 Alice 1 +Alice 50 Alice 2 +Bob 1 NULL NULL + +# left join with join filter +query TITI rowsort +SELECT * FROM t1 LEFT JOIN t2 ON t1.a = t2.a AND t2.b * 50 <= t1.b +---- +Alice 100 Alice 1 +Alice 100 Alice 2 +Alice 50 Alice 1 + +query TITI rowsort +SELECT * FROM t1 LEFT JOIN t2 ON t1.a = t2.a AND t2.b < t1.b +---- +Alice 100 Alice 1 +Alice 100 Alice 2 +Alice 50 Alice 1 +Alice 50 Alice 2 + +# right join without join filter +query TITI rowsort +SELECT * FROM t1 RIGHT JOIN t2 ON t1.a = t2.a +---- +Alice 100 Alice 1 +Alice 100 Alice 2 +Alice 50 Alice 1 +Alice 50 Alice 2 + +# right join with join filter +query TITI rowsort +SELECT * FROM t1 RIGHT JOIN t2 ON t1.a = t2.a AND t2.b * 50 <= t1.b +---- +Alice 100 Alice 1 +Alice 100 Alice 2 +Alice 50 Alice 1 + +query TITI rowsort +SELECT * FROM t1 RIGHT JOIN t2 ON t1.a = t2.a AND t1.b > t2.b +---- +Alice 100 Alice 1 +Alice 100 Alice 2 +Alice 50 Alice 1 +Alice 50 Alice 2 + +# full join without join filter +query TITI rowsort +SELECT * FROM t1 FULL JOIN t2 ON t1.a = t2.a +---- +Alice 100 Alice 1 +Alice 100 Alice 2 +Alice 50 Alice 1 +Alice 50 Alice 2 +Bob 1 NULL NULL + +# full join with join filter +query TITI rowsort +SELECT * FROM t1 FULL JOIN t2 ON t1.a = t2.a AND t2.b * 50 > t1.b +---- +Alice 50 Alice 2 + +query TITI rowsort +SELECT * FROM t1 FULL JOIN t2 ON t1.a = t2.a AND t1.b > t2.b + 50 +---- +Alice 100 Alice 1 +Alice 100 Alice 2 + statement ok set datafusion.optimizer.prefer_hash_join = true; From be35c906b01a452fb0e519050c88b7062a4fbb8f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 2 Feb 2024 17:56:37 -0800 Subject: [PATCH 6/8] Fix outer join --- .../src/joins/sort_merge_join.rs | 174 +++++++++++++----- datafusion/sqllogictest/test_files/join.slt | 21 +++ .../test_files/sort_merge_join.slt | 119 +++++++++++- 3 files changed, 268 insertions(+), 46 deletions(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 2d196acdb31cf..3f9c9f7393f25 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -1104,32 +1104,11 @@ impl SMJStream { .map(|f| new_null_array(f.data_type(), buffered_indices.len())) .collect::>(); - let filter_columns = - get_filter_column(&self.filter, &streamed_columns, &buffered_columns); - streamed_columns.extend(buffered_columns); let columns = streamed_columns; - let output_batch = RecordBatch::try_new(self.schema.clone(), columns)?; - - // Apply join filter if any - let output_batch = if let Some(f) = &self.filter { - // Construct batch with only filter columns - let filter_batch = - RecordBatch::try_new(Arc::new(f.schema().clone()), filter_columns)?; - - let filter_result = f - .expression() - .evaluate(&filter_batch)? - .into_array(filter_batch.num_rows())?; - let mask = datafusion_common::cast::as_boolean_array(&filter_result)?; - - compute::filter_record_batch(&output_batch, mask)? - } else { - output_batch - }; - - self.output_record_batches.push(output_batch); + self.output_record_batches + .push(RecordBatch::try_new(self.schema.clone(), columns)?); } Ok(()) } @@ -1172,40 +1151,138 @@ impl SMJStream { .collect::>() }; - let filter_columns = if matches!(self.join_type, JoinType::Right) { - get_filter_column(&self.filter, &buffered_columns, &streamed_columns) + let streamed_columns_length = streamed_columns.len(); + let buffered_columns_length = buffered_columns.len(); + + // Prepare the columns we apply join filter on later. + // Only for joined rows between streamed and buffered. + let filter_columns = if chunk.buffered_batch_idx.is_some() { + if matches!(self.join_type, JoinType::Right) { + get_filter_column(&self.filter, &buffered_columns, &streamed_columns) + } else { + get_filter_column(&self.filter, &streamed_columns, &buffered_columns) + } } else { - get_filter_column(&self.filter, &streamed_columns, &buffered_columns) + vec![] }; let columns = if matches!(self.join_type, JoinType::Right) { - buffered_columns.extend(streamed_columns); + buffered_columns.extend(streamed_columns.clone()); buffered_columns } else { streamed_columns.extend(buffered_columns); streamed_columns }; - let output_batch = RecordBatch::try_new(self.schema.clone(), columns)?; + let output_batch = + RecordBatch::try_new(self.schema.clone(), columns.clone())?; // Apply join filter if any - let output_batch = if let Some(f) = &self.filter { - // Construct batch with only filter columns - let filter_batch = - RecordBatch::try_new(Arc::new(f.schema().clone()), filter_columns)?; - - let filter_result = f - .expression() - .evaluate(&filter_batch)? - .into_array(filter_batch.num_rows())?; - let mask = datafusion_common::cast::as_boolean_array(&filter_result)?; - - compute::filter_record_batch(&output_batch, mask)? - } else { - output_batch - }; + if !filter_columns.is_empty() { + if let Some(f) = &self.filter { + // Construct batch with only filter columns + let filter_batch = RecordBatch::try_new( + Arc::new(f.schema().clone()), + filter_columns, + )?; + + let filter_result = f + .expression() + .evaluate(&filter_batch)? + .into_array(filter_batch.num_rows())?; + + // The selection mask of the filter + let mask = datafusion_common::cast::as_boolean_array(&filter_result)?; + + // Push the filtered batch to the output + let filtered_batch = + compute::filter_record_batch(&output_batch, mask)?; + self.output_record_batches.push(filtered_batch); + + // For outer joins, we need to push the null joined rows to the output. + if matches!( + self.join_type, + JoinType::Left | JoinType::Right | JoinType::Full + ) { + // The reverse of the selection mask, which is for null joined rows + let not_mask = compute::not(mask)?; + let null_joined_batch = + compute::filter_record_batch(&output_batch, ¬_mask)?; + + let mut buffered_columns = self + .buffered_schema + .fields() + .iter() + .map(|f| { + new_null_array( + f.data_type(), + null_joined_batch.num_rows(), + ) + }) + .collect::>(); + + let columns = if matches!(self.join_type, JoinType::Right) { + let streamed_columns = null_joined_batch + .columns() + .iter() + .skip(buffered_columns_length) + .cloned() + .collect::>(); + + buffered_columns.extend(streamed_columns); + buffered_columns + } else { + let mut streamed_columns = null_joined_batch + .columns() + .iter() + .take(streamed_columns_length) + .cloned() + .collect::>(); + + streamed_columns.extend(buffered_columns); + streamed_columns + }; - self.output_record_batches.push(output_batch); + let null_joined_streamed_batch = + RecordBatch::try_new(self.schema.clone(), columns.clone())?; + self.output_record_batches.push(null_joined_streamed_batch); + + // For full join, we also need to output the null joined rows from the buffered side + if matches!(self.join_type, JoinType::Full) { + let mut streamed_columns = self + .streamed_schema + .fields() + .iter() + .map(|f| { + new_null_array( + f.data_type(), + null_joined_batch.num_rows(), + ) + }) + .collect::>(); + + let buffered_columns = null_joined_batch + .columns() + .iter() + .skip(streamed_columns_length) + .cloned() + .collect::>(); + + streamed_columns.extend(buffered_columns); + + let null_joined_buffered_batch = RecordBatch::try_new( + self.schema.clone(), + streamed_columns, + )?; + self.output_record_batches.push(null_joined_buffered_batch); + } + } + } else { + self.output_record_batches.push(output_batch); + } + } else { + self.output_record_batches.push(output_batch); + } } self.streamed_batch.output_indices.clear(); @@ -1217,7 +1294,14 @@ impl SMJStream { let record_batch = concat_batches(&self.schema, &self.output_record_batches)?; self.join_metrics.output_batches.add(1); self.join_metrics.output_rows.add(record_batch.num_rows()); - self.output_size -= record_batch.num_rows(); + // If join filter exists, `self.output_size` is not accurate as we don't know the exact + // number of rows in the output record batch. If streamed row joined with buffered rows, + // once join filter is applied, the number of output rows may be more than 1. + if record_batch.num_rows() > self.output_size { + self.output_size = 0; + } else { + self.output_size -= record_batch.num_rows(); + } self.output_record_batches.clear(); Ok(record_batch) } diff --git a/datafusion/sqllogictest/test_files/join.slt b/datafusion/sqllogictest/test_files/join.slt index ca9b918ff3ee0..d287d11041eb9 100644 --- a/datafusion/sqllogictest/test_files/join.slt +++ b/datafusion/sqllogictest/test_files/join.slt @@ -238,6 +238,27 @@ SELECT t1_int, t2_int, t2_id FROM t1 RIGHT JOIN t2 ON t1_id = t2_id AND t2_int < NULL 3 11 NULL 3 55 +# equijoin_full +query ITIITI rowsort +SELECT * FROM t1 FULL JOIN t2 ON t1_id = t2_id +---- +11 a 1 11 z 3 +22 b 2 22 y 1 +33 c 3 NULL NULL NULL +44 d 4 44 x 3 +NULL NULL NULL 55 w 3 + +# equijoin_full_and_condition_from_both +query ITIITI rowsort +SELECT * FROM t1 FULL JOIN t2 ON t1_id = t2_id AND t2_int <= t1_int +---- +11 a 1 NULL NULL NULL +22 b 2 22 y 1 +33 c 3 NULL NULL NULL +44 d 4 44 x 3 +NULL NULL NULL 11 z 3 +NULL NULL NULL 55 w 3 + # left_join query ITT rowsort SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt b/datafusion/sqllogictest/test_files/sort_merge_join.slt index d9266dea5ab1a..426b9a3a52919 100644 --- a/datafusion/sqllogictest/test_files/sort_merge_join.slt +++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt @@ -84,6 +84,8 @@ SELECT * FROM t1 LEFT JOIN t2 ON t1.a = t2.a AND t2.b * 50 <= t1.b Alice 100 Alice 1 Alice 100 Alice 2 Alice 50 Alice 1 +Alice 50 NULL NULL +Bob 1 NULL NULL query TITI rowsort SELECT * FROM t1 LEFT JOIN t2 ON t1.a = t2.a AND t2.b < t1.b @@ -92,6 +94,7 @@ Alice 100 Alice 1 Alice 100 Alice 2 Alice 50 Alice 1 Alice 50 Alice 2 +Bob 1 NULL NULL # right join without join filter query TITI rowsort @@ -109,6 +112,7 @@ SELECT * FROM t1 RIGHT JOIN t2 ON t1.a = t2.a AND t2.b * 50 <= t1.b Alice 100 Alice 1 Alice 100 Alice 2 Alice 50 Alice 1 +NULL NULL Alice 2 query TITI rowsort SELECT * FROM t1 RIGHT JOIN t2 ON t1.a = t2.a AND t1.b > t2.b @@ -132,19 +136,132 @@ Bob 1 NULL NULL query TITI rowsort SELECT * FROM t1 FULL JOIN t2 ON t1.a = t2.a AND t2.b * 50 > t1.b ---- +Alice 100 NULL NULL +Alice 100 NULL NULL Alice 50 Alice 2 +Alice 50 NULL NULL +Bob 1 NULL NULL +NULL NULL Alice 1 +NULL NULL Alice 1 +NULL NULL Alice 2 query TITI rowsort SELECT * FROM t1 FULL JOIN t2 ON t1.a = t2.a AND t1.b > t2.b + 50 ---- Alice 100 Alice 1 Alice 100 Alice 2 +Alice 50 NULL NULL +Alice 50 NULL NULL +Bob 1 NULL NULL +NULL NULL Alice 1 +NULL NULL Alice 2 statement ok -set datafusion.optimizer.prefer_hash_join = true; +DROP TABLE t1; + +statement ok +DROP TABLE t2; + +statement ok +CREATE TABLE IF NOT EXISTS t1(t1_id INT, t1_name TEXT, t1_int INT) AS VALUES +(11, 'a', 1), +(22, 'b', 2), +(33, 'c', 3), +(44, 'd', 4); + +statement ok +CREATE TABLE IF NOT EXISTS t2(t2_id INT, t2_name TEXT, t2_int INT) AS VALUES +(11, 'z', 3), +(22, 'y', 1), +(44, 'x', 3), +(55, 'w', 3); + +# inner join with join filter +query III rowsort +SELECT t1_id, t1_int, t2_int FROM t1 JOIN t2 ON t1_id = t2_id AND t1_int >= t2_int +---- +22 2 1 +44 4 3 + +# equijoin_multiple_condition_ordering +query ITT rowsort +SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t1_id = t2_id AND t1_name <> t2_name +---- +11 a z +22 b y +44 d x + +# equijoin_right_and_condition_from_left +query ITT rowsort +SELECT t1_id, t1_name, t2_name FROM t1 RIGHT JOIN t2 ON t1_id = t2_id AND t1_id >= 22 +---- +22 b y +44 d x +NULL NULL w +NULL NULL z + +# equijoin_left_and_condition_from_left +query ITT rowsort +SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id AND t1_id >= 44 +---- +11 a NULL +22 b NULL +33 c NULL +44 d x + +# equijoin_left_and_condition_from_both +query III rowsort +SELECT t1_id, t1_int, t2_int FROM t1 LEFT JOIN t2 ON t1_id = t2_id AND t1_int >= t2_int +---- +11 1 NULL +22 2 1 +33 3 NULL +44 4 3 + +# equijoin_right_and_condition_from_right +query ITT rowsort +SELECT t1_id, t1_name, t2_name FROM t1 RIGHT JOIN t2 ON t1_id = t2_id AND t2_id >= 22 +---- +22 b y +44 d x +NULL NULL w +NULL NULL z + +# equijoin_right_and_condition_from_both +query III rowsort +SELECT t1_int, t2_int, t2_id FROM t1 RIGHT JOIN t2 ON t1_id = t2_id AND t2_int <= t1_int +---- +2 1 22 +4 3 44 +NULL 3 11 +NULL 3 55 + +# equijoin_full +query ITIITI rowsort +SELECT * FROM t1 FULL JOIN t2 ON t1_id = t2_id +---- +11 a 1 11 z 3 +22 b 2 22 y 1 +33 c 3 NULL NULL NULL +44 d 4 44 x 3 +NULL NULL NULL 55 w 3 + +# equijoin_full_and_condition_from_both +query ITIITI rowsort +SELECT * FROM t1 FULL JOIN t2 ON t1_id = t2_id AND t2_int <= t1_int +---- +11 a 1 NULL NULL NULL +22 b 2 22 y 1 +33 c 3 NULL NULL NULL +44 d 4 44 x 3 +NULL NULL NULL 11 z 3 +NULL NULL NULL 55 w 3 statement ok DROP TABLE t1; statement ok DROP TABLE t2; + +statement ok +set datafusion.optimizer.prefer_hash_join = true; From c7c25ce24051ed68d3c8fb675f80a125d1e88b82 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 5 Feb 2024 11:32:57 -0800 Subject: [PATCH 7/8] For review --- datafusion/physical-plan/src/joins/sort_merge_join.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 3f9c9f7393f25..2d06fd7902c64 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -1163,6 +1163,7 @@ impl SMJStream { get_filter_column(&self.filter, &streamed_columns, &buffered_columns) } } else { + // This chunk is for null joined rows (outer join), we don't need to apply join filter. vec![] }; @@ -1204,7 +1205,8 @@ impl SMJStream { self.join_type, JoinType::Left | JoinType::Right | JoinType::Full ) { - // The reverse of the selection mask, which is for null joined rows + // The reverse of the selection mask. For the rows not pass join filter above, + // we need to join them (left or right) with null rows for outer joins. let not_mask = compute::not(mask)?; let null_joined_batch = compute::filter_record_batch(&output_batch, ¬_mask)?; @@ -1232,6 +1234,7 @@ impl SMJStream { buffered_columns.extend(streamed_columns); buffered_columns } else { + // Left join or full outer join let mut streamed_columns = null_joined_batch .columns() .iter() From e75629d64eac839e309e120dfe7f2406d9478e55 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 6 Feb 2024 09:13:33 -0800 Subject: [PATCH 8/8] Update datafusion/physical-plan/src/joins/sort_merge_join.rs Co-authored-by: Andrew Lamb --- datafusion/physical-plan/src/joins/sort_merge_join.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 2d06fd7902c64..107fd7dde0f6e 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -634,7 +634,9 @@ struct SMJStream { pub filter: Option, /// Staging output array builders pub output_record_batches: Vec, - /// Staging output size, including output batches and staging joined results + /// Staging output size, including output batches and staging joined results. + /// Increased when we put rows into buffer and decreased after we actually output batches. + /// Used to trigger output when sufficient rows are ready pub output_size: usize, /// Target output batch size pub batch_size: usize,