diff --git a/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs b/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs index b91c1732260cf..090839e469825 100644 --- a/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs +++ b/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs @@ -278,7 +278,7 @@ async fn test_dynamic_filter_pushdown_through_hash_join_with_topk() { - SortExec: TopK(fetch=2), expr=[e@4 ASC], preserve_partitioning=[false], filter=[e@4 IS NULL OR e@4 < bb] - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, d@0)] - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, e, f], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ d@0 >= aa AND d@0 <= ab ] AND DynamicFilter [ e@1 IS NULL OR e@1 < bb ] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, e, f], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ d@0 >= aa AND d@0 <= ab AND d@0 IN BLOOM_FILTER ] AND DynamicFilter [ e@1 IS NULL OR e@1 < bb ] " ); } @@ -1078,7 +1078,7 @@ async fn test_hashjoin_dynamic_filter_pushdown() { @r" - HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 >= aa AND a@0 <= ab AND b@1 >= ba AND b@1 <= bb ] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 >= aa AND a@0 <= ab AND b@1 >= ba AND b@1 <= bb AND (a@0, b@1) IN BLOOM_FILTER ] " ); } @@ -1309,7 +1309,7 @@ async fn test_hashjoin_dynamic_filter_pushdown_partitioned() { - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - CoalesceBatchesExec: target_batch_size=8192 - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 >= ab AND a@0 <= ab AND b@1 >= bb AND b@1 <= bb OR a@0 >= aa AND a@0 <= aa AND b@1 >= ba AND b@1 <= ba ] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 >= ab AND a@0 <= ab AND b@1 >= bb AND b@1 <= bb AND (a@0, b@1) IN BLOOM_FILTER OR a@0 >= aa AND a@0 <= aa AND b@1 >= ba AND b@1 <= ba AND (a@0, b@1) IN BLOOM_FILTER ] " ); @@ -1326,7 +1326,7 @@ async fn test_hashjoin_dynamic_filter_pushdown_partitioned() { - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - CoalesceBatchesExec: target_batch_size=8192 - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 >= aa AND a@0 <= ab AND b@1 >= ba AND b@1 <= bb ] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 >= aa AND a@0 <= ab AND b@1 >= ba AND b@1 <= bb AND (a@0, b@1) IN BLOOM_FILTER ] " ); @@ -1503,7 +1503,7 @@ async fn test_hashjoin_dynamic_filter_pushdown_collect_left() { - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - CoalesceBatchesExec: target_batch_size=8192 - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 >= aa AND a@0 <= ab AND b@1 >= ba AND b@1 <= bb ] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 >= aa AND a@0 <= ab AND b@1 >= ba AND b@1 <= bb AND (a@0, b@1) IN BLOOM_FILTER ] " ); @@ -1671,8 +1671,8 @@ async fn test_nested_hashjoin_dynamic_filter_pushdown() { - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, b@0)] - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, x], file_type=test, pushdown_supported=true - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, d@0)] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[b, c, y], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ b@0 >= aa AND b@0 <= ab ] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, z], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ d@0 >= ca AND d@0 <= cb ] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[b, c, y], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ b@0 >= aa AND b@0 <= ab AND b@0 IN BLOOM_FILTER ] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, z], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ d@0 >= ca AND d@0 <= cb AND d@0 IN BLOOM_FILTER ] " ); } diff --git a/datafusion/physical-expr/src/bloom_filter.rs b/datafusion/physical-expr/src/bloom_filter.rs new file mode 100644 index 0000000000000..15a51fc0efa0d --- /dev/null +++ b/datafusion/physical-expr/src/bloom_filter.rs @@ -0,0 +1,361 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Bloom filter implementation for physical expressions +//! +//! This module contains a vendored copy of the Split Block Bloom Filter (SBBF) +//! implementation from the parquet crate. This avoids circular dependencies +//! while allowing physical expressions to use bloom filters for runtime pruning. +//! +//! TODO: If this bloom filter approach is successful, extract this into a shared +//! crate (e.g., `datafusion-bloom-filter`) that both parquet and physical-expr +//! can depend on. +//! +//! The implementation below is adapted from: +//! arrow-rs/parquet/src/bloom_filter/mod.rs +//! +//! One thing to consider is if we can make this implementation compatible with Parquet's +//! byte for byte (it currently is not) so that we can do binary intersection of bloom filters +//! between DataFusion and Parquet. + +use datafusion_common::{internal_err, Result}; +use std::mem::size_of; + +/// Salt values as defined in the Parquet specification +/// Although we don't *need* to follow the Parquet spec here, using the same +/// constants allows us to be compatible with Parquet bloom filters in the future +/// e.g. to do binary intersection of bloom filters. +const SALT: [u32; 8] = [ + 0x47b6137b_u32, + 0x44974d91_u32, + 0x8824ad5b_u32, + 0xa2b7289d_u32, + 0x705495c7_u32, + 0x2df1424b_u32, + 0x9efc4947_u32, + 0x5c6bfb31_u32, +]; + +/// Minimum bitset length in bytes +const BITSET_MIN_LENGTH: usize = 32; +/// Maximum bitset length in bytes +const BITSET_MAX_LENGTH: usize = 128 * 1024 * 1024; + +/// Each block is 256 bits, broken up into eight contiguous "words", each consisting of 32 bits. +/// Each word is thought of as an array of bits; each bit is either "set" or "not set". +#[derive(Debug, Copy, Clone)] +#[repr(transparent)] +struct Block([u32; 8]); + +impl Block { + const ZERO: Block = Block([0; 8]); + + /// Takes as its argument a single unsigned 32-bit integer and returns a block in which each + /// word has exactly one bit set. + fn mask(x: u32) -> Self { + let mut result = [0_u32; 8]; + for i in 0..8 { + // wrapping instead of checking for overflow + let y = x.wrapping_mul(SALT[i]); + let y = y >> 27; + result[i] = 1 << y; + } + Self(result) + } + + /// Setting every bit in the block that was also set in the result from mask + fn insert(&mut self, hash: u32) { + let mask = Self::mask(hash); + for i in 0..8 { + self[i] |= mask[i]; + } + } + + /// Returns true when every bit that is set in the result of mask is also set in the block. + fn check(&self, hash: u32) -> bool { + let mask = Self::mask(hash); + for i in 0..8 { + if self[i] & mask[i] == 0 { + return false; + } + } + true + } +} + +impl std::ops::Index for Block { + type Output = u32; + + #[inline] + fn index(&self, index: usize) -> &Self::Output { + self.0.index(index) + } +} + +impl std::ops::IndexMut for Block { + #[inline] + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + self.0.index_mut(index) + } +} + +/// A Split Block Bloom Filter (SBBF) +/// +/// This is a space-efficient probabilistic data structure used to test whether +/// an element is a member of a set. False positive matches are possible, but +/// false negatives are not. +#[derive(Debug, Clone)] +pub(crate) struct Sbbf(Vec); + +impl Sbbf { + /// Create a new Sbbf with given number of distinct values and false positive probability. + /// Will return an error if `fpp` is greater than or equal to 1.0 or less than 0.0. + pub fn new_with_ndv_fpp(ndv: u64, fpp: f64) -> Result { + if !(0.0..1.0).contains(&fpp) { + return internal_err!( + "False positive probability must be between 0.0 and 1.0, got {fpp}" + ); + } + let num_bits = num_of_bits_from_ndv_fpp(ndv, fpp); + Ok(Self::new_with_num_of_bytes(num_bits / 8)) + } + + /// Create a new Sbbf with given number of bytes, the exact number of bytes will be adjusted + /// to the next power of two bounded by BITSET_MIN_LENGTH and BITSET_MAX_LENGTH. + fn new_with_num_of_bytes(num_bytes: usize) -> Self { + let num_bytes = optimal_num_of_bytes(num_bytes); + assert_eq!(num_bytes % size_of::(), 0); + let num_blocks = num_bytes / size_of::(); + let bitset = vec![Block::ZERO; num_blocks]; + Self(bitset) + } + + #[inline] + fn hash_to_block_index(&self, hash: u64) -> usize { + // unchecked_mul is unstable, but in reality this is safe, we'd just use saturating mul + // but it will not saturate + (((hash >> 32).saturating_mul(self.0.len() as u64)) >> 32) as usize + } + + /// Insert a hash into the filter + pub(crate) fn insert_hash(&mut self, hash: u64) { + let block_index = self.hash_to_block_index(hash); + self.0[block_index].insert(hash as u32) + } + + /// Check if a hash is in the filter. May return + /// true for values that were never inserted ("false positive") + /// but will always return false if a hash has not been inserted. + pub(crate) fn check_hash(&self, hash: u64) -> bool { + let block_index = self.hash_to_block_index(hash); + self.0[block_index].check(hash as u32) + } +} + +/// Trait for types that can be converted to bytes for hashing +pub trait AsBytes { + /// Return a byte slice representation of this value + fn as_bytes(&self) -> &[u8]; +} + +impl AsBytes for str { + fn as_bytes(&self) -> &[u8] { + str::as_bytes(self) + } +} + +impl AsBytes for [u8] { + fn as_bytes(&self) -> &[u8] { + self + } +} + +impl AsBytes for bool { + fn as_bytes(&self) -> &[u8] { + if *self { + &[1u8] + } else { + &[0u8] + } + } +} + +impl AsBytes for i8 { + fn as_bytes(&self) -> &[u8] { + unsafe { + std::slice::from_raw_parts(self as *const i8 as *const u8, size_of::()) + } + } +} + +impl AsBytes for i16 { + fn as_bytes(&self) -> &[u8] { + unsafe { + std::slice::from_raw_parts(self as *const i16 as *const u8, size_of::()) + } + } +} + +impl AsBytes for i32 { + fn as_bytes(&self) -> &[u8] { + unsafe { + std::slice::from_raw_parts(self as *const i32 as *const u8, size_of::()) + } + } +} + +impl AsBytes for i64 { + fn as_bytes(&self) -> &[u8] { + unsafe { + std::slice::from_raw_parts(self as *const i64 as *const u8, size_of::()) + } + } +} + +impl AsBytes for u8 { + fn as_bytes(&self) -> &[u8] { + unsafe { std::slice::from_raw_parts(self as *const u8, size_of::()) } + } +} + +impl AsBytes for u16 { + fn as_bytes(&self) -> &[u8] { + unsafe { + std::slice::from_raw_parts(self as *const u16 as *const u8, size_of::()) + } + } +} + +impl AsBytes for u32 { + fn as_bytes(&self) -> &[u8] { + unsafe { + std::slice::from_raw_parts(self as *const u32 as *const u8, size_of::()) + } + } +} + +impl AsBytes for u64 { + fn as_bytes(&self) -> &[u8] { + unsafe { + std::slice::from_raw_parts(self as *const u64 as *const u8, size_of::()) + } + } +} + +impl AsBytes for f32 { + fn as_bytes(&self) -> &[u8] { + unsafe { + std::slice::from_raw_parts(self as *const f32 as *const u8, size_of::()) + } + } +} + +impl AsBytes for f64 { + fn as_bytes(&self) -> &[u8] { + unsafe { + std::slice::from_raw_parts(self as *const f64 as *const u8, size_of::()) + } + } +} + +impl AsBytes for i128 { + fn as_bytes(&self) -> &[u8] { + // Use big-endian for i128 to match Parquet's FIXED_LEN_BYTE_ARRAY representation + // This allows compatibility with Parquet bloom filters + unsafe { + std::slice::from_raw_parts( + self as *const i128 as *const u8, + size_of::(), + ) + } + } +} + +impl AsBytes for [u8; 32] { + fn as_bytes(&self) -> &[u8] { + self + } +} + +/// Calculate optimal number of bytes, bounded by min/max and rounded to power of 2 +#[inline] +fn optimal_num_of_bytes(num_bytes: usize) -> usize { + let num_bytes = num_bytes.min(BITSET_MAX_LENGTH); + let num_bytes = num_bytes.max(BITSET_MIN_LENGTH); + num_bytes.next_power_of_two() +} + +/// Calculate number of bits needed given NDV and FPP +/// Formula: m = -k * n / ln(1 - f^(1/k)) +/// where k=8 (number of hash functions), n=ndv, f=fpp, m=num_bits +#[inline] +fn num_of_bits_from_ndv_fpp(ndv: u64, fpp: f64) -> usize { + let num_bits = -8.0 * ndv as f64 / (1.0 - fpp.powf(1.0 / 8.0)).ln(); + num_bits as usize +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_mask_set_quick_check() { + for i in 0..1_000 { + let result = Block::mask(i); + assert!(result.0.iter().all(|&x| x.is_power_of_two())); + } + } + + #[test] + fn test_block_insert_and_check() { + for i in 0..1_000 { + let mut block = Block::ZERO; + block.insert(i); + assert!(block.check(i)); + } + } + + #[test] + fn test_optimal_num_of_bytes() { + for (input, expected) in &[ + (0, 32), + (9, 32), + (31, 32), + (32, 32), + (33, 64), + (99, 128), + (1024, 1024), + (999_000_000, 128 * 1024 * 1024), + ] { + assert_eq!(*expected, optimal_num_of_bytes(*input)); + } + } + + #[test] + fn test_num_of_bits_from_ndv_fpp() { + for (fpp, ndv, num_bits) in &[ + (0.1, 10, 57), + (0.01, 10, 96), + (0.001, 10, 146), + (0.1, 100, 577), + (0.01, 100, 968), + (0.001, 100, 1460), + ] { + assert_eq!(*num_bits, num_of_bits_from_ndv_fpp(*ndv, *fpp) as u64); + } + } +} diff --git a/datafusion/physical-expr/src/expressions/bloom_filter_expr.rs b/datafusion/physical-expr/src/expressions/bloom_filter_expr.rs new file mode 100644 index 0000000000000..404f7b277c3cd --- /dev/null +++ b/datafusion/physical-expr/src/expressions/bloom_filter_expr.rs @@ -0,0 +1,442 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Bloom filter physical expression + +use crate::bloom_filter::Sbbf; +use crate::PhysicalExpr; +use ahash::RandomState; +use arrow::array::{ArrayRef, BooleanArray}; +use arrow::datatypes::{DataType, Schema}; +use arrow::record_batch::RecordBatch; +use datafusion_common::hash_utils::create_hashes; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr_common::columnar_value::ColumnarValue; +use std::any::Any; +use std::fmt; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +/// A progressive builder for creating bloom filters +/// +/// This builder allows incremental insertion of values from record batches +/// and produces a static `BloomFilterExpr` when finished. +/// +/// # Example +/// ```ignore +/// let random_state = RandomState::with_seeds(0, 0, 0, 0); +/// let mut builder = BloomFilterBuilder::new(1000, 0.01, random_state)?; +/// builder.insert_hashes(&hashes)?; +/// let expr = builder.build(col_expr); // Consumes builder +/// ``` +#[derive(Debug)] +pub struct BloomFilterBuilder { + /// The underlying bloom filter + sbbf: Sbbf, + /// Random state for consistent hashing + random_state: RandomState, +} + +impl BloomFilterBuilder { + /// Create a new bloom filter builder + /// + /// # Arguments + /// * `ndv` - Expected number of distinct values + /// * `fpp` - Desired false positive probability (0.0 to 1.0) + /// * `random_state` - Random state for consistent hashing across build and probe phases + pub fn new(ndv: u64, fpp: f64, random_state: RandomState) -> Result { + let sbbf = Sbbf::new_with_ndv_fpp(ndv, fpp)?; + Ok(Self { sbbf, random_state }) + } + + /// Insert pre-computed hash values into the bloom filter + /// + /// This method allows reusing hash values that were already computed + /// for other purposes (e.g., hash table insertion), avoiding redundant + /// hash computation. + /// + /// # Arguments + /// * `hashes` - Pre-computed hash values to insert + pub fn insert_hashes(&mut self, hashes: &[u64]) { + for &hash in hashes { + self.sbbf.insert_hash(hash); + } + } + + /// Build a `BloomFilterExpr` from this builder, consuming the builder. + /// + /// This consumes the builder and moves the bloom filter data into the expression, + /// avoiding any clones of the (potentially large) bloom filter. + /// + /// # Arguments + /// * `exprs` - The expressions to evaluate and check against the bloom filter + pub fn build(self, exprs: Vec>) -> BloomFilterExpr { + BloomFilterExpr::new(exprs, self.sbbf, self.random_state) + } +} + +/// Physical expression that checks values against a bloom filter +/// +/// This is a static expression (similar to `InListExpr`) that evaluates +/// one or more child expressions and checks each value against a pre-built bloom filter. +/// When multiple expressions are provided, they are combined via hashing (similar to join key hashing). +/// Returns a boolean array indicating whether each value might be present +/// (true) or is definitely absent (false). +/// +/// Note: Bloom filters can produce false positives but never false negatives. +#[derive(Debug, Clone)] +pub struct BloomFilterExpr { + /// The expressions to evaluate (one or more columns) + exprs: Vec>, + /// The bloom filter to check against + bloom_filter: Arc, + /// Random state for consistent hashing + random_state: RandomState, +} + +impl BloomFilterExpr { + /// Create a new bloom filter expression (internal use only) + /// + /// Users should create bloom filter expressions through `BloomFilterBuilder::build()` + pub(crate) fn new( + exprs: Vec>, + bloom_filter: Sbbf, + random_state: RandomState, + ) -> Self { + Self { + exprs, + bloom_filter: Arc::new(bloom_filter), + random_state, + } + } + + /// Check scalar expressions against the bloom filter + fn check_scalar_batch(&self, batch: &RecordBatch) -> Result { + // Evaluate all expressions to get their scalar values + let arrays: Vec = self + .exprs + .iter() + .map(|expr| { + let value = expr.evaluate(batch)?; + match value { + ColumnarValue::Scalar(s) => s.to_array(), + ColumnarValue::Array(a) => Ok(a), + } + }) + .collect::>>()?; + + // Compute combined hash + let mut hashes = vec![0u64; 1]; + create_hashes(&arrays, &self.random_state, &mut hashes)?; + + Ok(self.bloom_filter.check_hash(hashes[0])) + } + + /// Check arrays against the bloom filter + fn check_arrays(&self, batch: &RecordBatch) -> Result { + // Evaluate all expressions to get their arrays + let arrays: Vec = self + .exprs + .iter() + .map(|expr| expr.evaluate(batch)?.into_array(batch.num_rows())) + .collect::>>()?; + + // Use create_hashes to compute combined hash values for all columns + // This matches how the build side computes hashes (combining all join columns) + let mut hashes = vec![0u64; batch.num_rows()]; + create_hashes(&arrays, &self.random_state, &mut hashes)?; + + // Check each hash against the bloom filter + let mut builder = BooleanArray::builder(batch.num_rows()); + for hash in hashes { + builder.append_value(self.bloom_filter.check_hash(hash)); + } + + Ok(builder.finish()) + } +} + +impl fmt::Display for BloomFilterExpr { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + if self.exprs.len() == 1 { + write!(f, "{} IN BLOOM_FILTER", self.exprs[0]) + } else { + write!(f, "(")?; + for (i, expr) in self.exprs.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}", expr)?; + } + write!(f, ") IN BLOOM_FILTER") + } + } +} + +impl PartialEq for BloomFilterExpr { + fn eq(&self, other: &Self) -> bool { + // Two bloom filter expressions are equal if they have the same child expressions + // We can't compare bloom filters directly, so we use pointer equality + self.exprs.eq(&other.exprs) + && Arc::ptr_eq(&self.bloom_filter, &other.bloom_filter) + } +} + +impl Eq for BloomFilterExpr {} + +impl Hash for BloomFilterExpr { + fn hash(&self, state: &mut H) { + for expr in &self.exprs { + expr.hash(state); + } + // Hash the pointer to the bloom filter + Arc::as_ptr(&self.bloom_filter).hash(state); + } +} + +impl PhysicalExpr for BloomFilterExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, _input_schema: &Schema) -> Result { + Ok(DataType::Boolean) + } + + fn nullable(&self, _input_schema: &Schema) -> Result { + Ok(false) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + // Check if all expressions return scalars + let all_scalars = self + .exprs + .iter() + .map(|expr| expr.evaluate(batch)) + .collect::>>()? + .iter() + .all(|v| matches!(v, ColumnarValue::Scalar(_))); + + if all_scalars { + // If all are scalars, check them and return a scalar + let result = self.check_scalar_batch(batch)?; + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(result)))) + } else { + // Otherwise, check arrays + let result = self.check_arrays(batch)?; + Ok(ColumnarValue::Array(Arc::new(result))) + } + } + + fn children(&self) -> Vec<&Arc> { + self.exprs.iter().collect() + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(BloomFilterExpr { + exprs: children, + bloom_filter: Arc::clone(&self.bloom_filter), + random_state: self.random_state.clone(), + })) + } + + fn fmt_sql(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.exprs.len() == 1 { + write!(f, "{} IN BLOOM_FILTER", self.exprs[0]) + } else { + write!(f, "(")?; + for (i, expr) in self.exprs.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}", expr)?; + } + write!(f, ") IN BLOOM_FILTER") + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::expressions::col; + use arrow::datatypes::{Field, Schema}; + + // Helper trait to add insert_scalar for tests + trait BloomFilterBuilderTestExt { + fn insert_scalar(&mut self, value: &ScalarValue) -> Result<()>; + } + + impl BloomFilterBuilderTestExt for BloomFilterBuilder { + /// Insert a single scalar value by converting to array and computing hashes + /// This is less efficient but sufficient for tests + fn insert_scalar(&mut self, value: &ScalarValue) -> Result<()> { + let array = value.to_array()?; + let mut hashes = vec![0u64; array.len()]; + create_hashes(&[array], &self.random_state, &mut hashes)?; + self.insert_hashes(&hashes); + Ok(()) + } + } + + #[test] + fn test_bloom_filter_builder() -> Result<()> { + let random_state = RandomState::with_seeds(0, 0, 0, 0); + let mut builder = BloomFilterBuilder::new(100, 0.01, random_state)?; + + // Insert some values + builder.insert_scalar(&ScalarValue::Int32(Some(1)))?; + builder.insert_scalar(&ScalarValue::Int32(Some(2)))?; + builder.insert_scalar(&ScalarValue::Int32(Some(3)))?; + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let expr = col("a", &schema)?; + let bloom_expr = Arc::new(builder.build(vec![expr])); + + // Check that inserted values are found by creating test batches + let test_array = Arc::new(arrow::array::Int32Array::from(vec![1])) as ArrayRef; + let batch = RecordBatch::try_new(Arc::clone(&schema), vec![test_array])?; + let result = bloom_expr.evaluate(&batch)?; + assert!(result + .into_array(1)? + .as_any() + .downcast_ref::() + .unwrap() + .value(0)); + + Ok(()) + } + + #[test] + fn test_bloom_filter_expr_evaluation() -> Result<()> { + use arrow::array::Int32Array; + + // Build a bloom filter with values 1, 2, 3 + let random_state = RandomState::with_seeds(0, 0, 0, 0); + let mut builder = BloomFilterBuilder::new(100, 0.01, random_state)?; + let training_array = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef; + let mut hashes = vec![0u64; training_array.len()]; + create_hashes( + &[Arc::clone(&training_array)], + &builder.random_state, + &mut hashes, + )?; + builder.insert_hashes(&hashes); + + // Create the expression + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let expr = col("a", &schema)?; + let bloom_expr = Arc::new(builder.build(vec![expr])); + + // Create a test batch with values [1, 2, 4, 5] + let test_array = Arc::new(Int32Array::from(vec![1, 2, 4, 5])) as ArrayRef; + let batch = RecordBatch::try_new(Arc::clone(&schema), vec![test_array])?; + + // Evaluate the expression + let result = bloom_expr.evaluate(&batch)?; + let result_array = result.into_array(4)?; + let result_bool = result_array + .as_any() + .downcast_ref::() + .unwrap(); + + // Values 1 and 2 should definitely be found + assert!(result_bool.value(0)); // 1 is in the filter + assert!(result_bool.value(1)); // 2 is in the filter + + // Values 4 and 5 were not inserted, but might be false positives + // We can't assert they're false without making the test flaky + + Ok(()) + } + + #[test] + fn test_bloom_filter_with_strings() -> Result<()> { + use arrow::array::StringArray; + + let random_state = RandomState::with_seeds(0, 0, 0, 0); + let mut builder = BloomFilterBuilder::new(100, 0.01, random_state)?; + builder.insert_scalar(&ScalarValue::Utf8(Some("hello".to_string())))?; + builder.insert_scalar(&ScalarValue::Utf8(Some("world".to_string())))?; + + let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, false)])); + let expr = col("s", &schema)?; + let bloom_expr = Arc::new(builder.build(vec![expr])); + + let test_array = + Arc::new(StringArray::from(vec!["hello", "world", "foo"])) as ArrayRef; + let batch = RecordBatch::try_new(Arc::clone(&schema), vec![test_array])?; + + let result = bloom_expr.evaluate(&batch)?; + let result_array = result.into_array(3)?; + let result_bool = result_array + .as_any() + .downcast_ref::() + .unwrap(); + + assert!(result_bool.value(0)); // "hello" is in the filter + assert!(result_bool.value(1)); // "world" is in the filter + + Ok(()) + } + + #[test] + fn test_bloom_filter_with_decimals() -> Result<()> { + use arrow::array::Decimal128Array; + + // Build a bloom filter with decimal values + let random_state = RandomState::with_seeds(0, 0, 0, 0); + let mut builder = BloomFilterBuilder::new(100, 0.01, random_state)?; + builder.insert_scalar(&ScalarValue::Decimal128(Some(12345), 10, 2))?; + builder.insert_scalar(&ScalarValue::Decimal128(Some(67890), 10, 2))?; + + let schema = Arc::new(Schema::new(vec![Field::new( + "d", + DataType::Decimal128(10, 2), + false, + )])); + let expr = col("d", &schema)?; + let bloom_expr = Arc::new(builder.build(vec![expr])); + + // Create test array with decimal values + let test_array = Arc::new( + Decimal128Array::from(vec![12345, 67890, 11111]) + .with_precision_and_scale(10, 2)?, + ) as ArrayRef; + let batch = RecordBatch::try_new(Arc::clone(&schema), vec![test_array])?; + + // Evaluate the expression + let result = bloom_expr.evaluate(&batch)?; + let result_array = result.into_array(3)?; + let result_bool = result_array + .as_any() + .downcast_ref::() + .unwrap(); + + // Values that were inserted should be found + assert!(result_bool.value(0)); // 12345 is in the filter + assert!(result_bool.value(1)); // 67890 is in the filter + + // Value 11111 was not inserted, but might be a false positive + // We can't assert it's false without making the test flaky + + Ok(()) + } +} diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 59d675753d985..846ccfccb8d7c 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -19,6 +19,7 @@ #[macro_use] mod binary; +mod bloom_filter_expr; mod case; mod cast; mod cast_column; @@ -40,6 +41,7 @@ pub use crate::aggregate::stats::StatsType; pub use crate::PhysicalSortExpr; pub use binary::{binary, similar_to, BinaryExpr}; +pub use bloom_filter_expr::{BloomFilterBuilder, BloomFilterExpr}; pub use case::{case, CaseExpr}; pub use cast::{cast, CastExpr}; pub use cast_column::CastColumnExpr; diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index aa8c9e50fd71e..b67b1db0e92b7 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -31,6 +31,7 @@ pub mod binary_map { pub use datafusion_physical_expr_common::binary_map::{ArrowBytesSet, OutputType}; } pub mod async_scalar_function; +pub mod bloom_filter; pub mod equivalence; pub mod expressions; pub mod intervals; diff --git a/datafusion/physical-plan/src/joins/hash_join/exec.rs b/datafusion/physical-plan/src/joins/hash_join/exec.rs index b5fe5ee5cda14..59f59447ad655 100644 --- a/datafusion/physical-plan/src/joins/hash_join/exec.rs +++ b/datafusion/physical-plan/src/joins/hash_join/exec.rs @@ -26,7 +26,9 @@ use crate::filter_pushdown::{ ChildPushdownResult, FilterDescription, FilterPushdownPhase, FilterPushdownPropagation, }; -use crate::joins::hash_join::shared_bounds::{ColumnBounds, SharedBoundsAccumulator}; +use crate::joins::hash_join::shared_bounds::{ + ColumnBounds, ColumnFilterData, SharedBoundsAccumulator, +}; use crate::joins::hash_join::stream::{ BuildSide, BuildSideInitialState, HashJoinStream, HashJoinStreamState, }; @@ -53,6 +55,7 @@ use crate::{ DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PlanProperties, SendableRecordBatchStream, Statistics, }; +use parking_lot::Mutex; use arrow::array::{ArrayRef, BooleanBufferBuilder}; use arrow::compute::concat_batches; @@ -72,13 +75,14 @@ use datafusion_functions_aggregate_common::min_max::{MaxAccumulator, MinAccumula use datafusion_physical_expr::equivalence::{ join_equivalence_properties, ProjectionMapping, }; -use datafusion_physical_expr::expressions::{lit, DynamicFilterPhysicalExpr}; +use datafusion_physical_expr::expressions::{ + lit, BloomFilterBuilder, DynamicFilterPhysicalExpr, +}; use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef}; use ahash::RandomState; use datafusion_physical_expr_common::physical_expr::fmt_sql; use futures::TryStreamExt; -use parking_lot::Mutex; /// Hard-coded seed to ensure hash values from the hash join differ from `RepartitionExec`, avoiding collisions. const HASH_JOIN_SEED: RandomState = @@ -102,8 +106,9 @@ pub(super) struct JoinLeftData { /// This could hide potential out-of-memory issues, especially when upstream operators increase their memory consumption. /// The MemoryReservation ensures proper tracking of memory resources throughout the join operation's lifecycle. _reservation: MemoryReservation, - /// Bounds computed from the build side for dynamic filter pushdown - pub(super) bounds: Option>, + /// Filter data (bounds + bloom filters) computed from the build side for dynamic filter pushdown. + /// Wrapped in Mutex> to allow taking ownership when reporting to the accumulator. + pub(super) column_filters: Mutex>>, } impl JoinLeftData { @@ -115,7 +120,7 @@ impl JoinLeftData { visited_indices_bitmap: SharedBitmapBuilder, probe_threads_counter: AtomicUsize, reservation: MemoryReservation, - bounds: Option>, + column_filters: Option>, ) -> Self { Self { hash_map, @@ -124,7 +129,7 @@ impl JoinLeftData { visited_indices_bitmap, probe_threads_counter, _reservation: reservation, - bounds, + column_filters: Mutex::new(column_filters), } } @@ -1207,14 +1212,14 @@ impl ExecutionPlan for HashJoinExec { } } -/// Accumulator for collecting min/max bounds from build-side data during hash join. +/// Accumulator for collecting min/max bounds and bloom filters from build-side data during hash join. /// /// This struct encapsulates the logic for progressively computing column bounds -/// (minimum and maximum values) for a specific join key expression as batches +/// (minimum and maximum values) and bloom filters for a specific join key expression as batches /// are processed during the build phase of a hash join. /// -/// The bounds are used for dynamic filter pushdown optimization, where filters -/// based on the actual data ranges can be pushed down to the probe side to +/// The bounds and bloom filters are used for dynamic filter pushdown optimization, where filters +/// based on the actual data ranges and membership can be pushed down to the probe side to /// eliminate unnecessary data early. struct CollectLeftAccumulator { /// The physical expression to evaluate for each batch @@ -1223,6 +1228,8 @@ struct CollectLeftAccumulator { min: MinAccumulator, /// Accumulator for tracking the maximum value across all batches max: MaxAccumulator, + /// Bloom filter builder for membership testing + bloom_filter: BloomFilterBuilder, } impl CollectLeftAccumulator { @@ -1231,10 +1238,15 @@ impl CollectLeftAccumulator { /// # Arguments /// * `expr` - The physical expression to track bounds for /// * `schema` - The schema of the input data + /// * `random_state` - Random state for consistent hashing /// /// # Returns /// A new `CollectLeftAccumulator` instance configured for the expression's data type - fn try_new(expr: Arc, schema: &SchemaRef) -> Result { + fn try_new( + expr: Arc, + schema: &SchemaRef, + random_state: RandomState, + ) -> Result { /// Recursively unwraps dictionary types to get the underlying value type. fn dictionary_value_type(data_type: &DataType) -> DataType { match data_type { @@ -1249,17 +1261,23 @@ impl CollectLeftAccumulator { .data_type(schema) // Min/Max can operate on dictionary data but expect to be initialized with the underlying value type .map(|dt| dictionary_value_type(&dt))?; + + // Create bloom filter with default parameters + // NDV (number of distinct values) = 1000, FPP (false positive probability) = 0.05 (5%) + let bloom_filter = BloomFilterBuilder::new(1000, 0.05, random_state)?; + Ok(Self { expr, min: MinAccumulator::try_new(&data_type)?, max: MaxAccumulator::try_new(&data_type)?, + bloom_filter, }) } /// Updates the accumulators with values from a new batch. /// - /// Evaluates the expression on the batch and updates both min and max - /// accumulators with the resulting values. + /// Evaluates the expression on the batch and updates min and max bounds. + /// Bloom filter population is deferred to Pass 2 to reuse hash computations. /// /// # Arguments /// * `batch` - The record batch to process @@ -1270,20 +1288,20 @@ impl CollectLeftAccumulator { let array = self.expr.evaluate(batch)?.into_array(batch.num_rows())?; self.min.update_batch(std::slice::from_ref(&array))?; self.max.update_batch(std::slice::from_ref(&array))?; + // Bloom filter population deferred to Pass 2 to reuse hash table hashes Ok(()) } - /// Finalizes the accumulation and returns the computed bounds. + /// Finalizes the accumulation and returns the computed filter data. /// - /// Consumes self to extract the final min and max values from the accumulators. + /// Consumes self to extract the final min and max values from the accumulators + /// and the bloom filter builder. /// /// # Returns - /// The `ColumnBounds` containing the minimum and maximum values observed - fn evaluate(mut self) -> Result { - Ok(ColumnBounds::new( - self.min.evaluate()?, - self.max.evaluate()?, - )) + /// `ColumnFilterData` containing the bounds and bloom filter builder + fn evaluate(mut self) -> Result { + let bounds = ColumnBounds::new(self.min.evaluate()?, self.max.evaluate()?); + Ok(ColumnFilterData::new(bounds, self.bloom_filter)) } } @@ -1304,6 +1322,7 @@ impl BuildSideState { on_left: Vec>, schema: &SchemaRef, should_compute_bounds: bool, + random_state: &RandomState, ) -> Result { Ok(Self { batches: Vec::new(), @@ -1315,7 +1334,11 @@ impl BuildSideState { on_left .iter() .map(|expr| { - CollectLeftAccumulator::try_new(Arc::clone(expr), schema) + CollectLeftAccumulator::try_new( + Arc::clone(expr), + schema, + random_state.clone(), + ) }) .collect::>>() }) @@ -1374,6 +1397,7 @@ async fn collect_left_input( on_left.clone(), &schema, should_compute_bounds, + &random_state, )?; let state = left_stream @@ -1407,7 +1431,7 @@ async fn collect_left_input( num_rows, metrics, mut reservation, - bounds_accumulators, + mut bounds_accumulators, } = state; // Estimation of memory size, required for hashtable, prior to allocation. @@ -1449,6 +1473,14 @@ async fn collect_left_input( 0, true, )?; + + // Populate bloom filters with computed hashes to avoid redundant hash computation + if let Some(ref mut accumulators) = bounds_accumulators { + for accumulator in accumulators.iter_mut() { + accumulator.bloom_filter.insert_hashes(&hashes_buffer); + } + } + offset += batch.num_rows(); } // Merge all batches into a single batch, so we can directly index into the arrays @@ -1475,14 +1507,14 @@ async fn collect_left_input( }) .collect::>>()?; - // Compute bounds for dynamic filter if enabled - let bounds = match bounds_accumulators { + // Compute filter data (bounds + bloom filters) for dynamic filter if enabled + let column_filters = match bounds_accumulators { Some(accumulators) if num_rows > 0 => { - let bounds = accumulators + let column_filters: Vec<_> = accumulators .into_iter() .map(CollectLeftAccumulator::evaluate) .collect::>>()?; - Some(bounds) + Some(column_filters) } _ => None, }; @@ -1494,7 +1526,7 @@ async fn collect_left_input( Mutex::new(visited_indices_bitmap), AtomicUsize::new(probe_threads_count), reservation, - bounds, + column_filters, ); Ok(data) diff --git a/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs b/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs index 25f7a0de31acd..d8a590f16a595 100644 --- a/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs +++ b/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs @@ -27,10 +27,11 @@ use crate::ExecutionPlanProperties; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::Operator; -use datafusion_physical_expr::expressions::{lit, BinaryExpr, DynamicFilterPhysicalExpr}; +use datafusion_physical_expr::expressions::{ + lit, BinaryExpr, BloomFilterBuilder, DynamicFilterPhysicalExpr, +}; use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef}; -use itertools::Itertools; use parking_lot::Mutex; use tokio::sync::Barrier; @@ -40,7 +41,7 @@ use tokio::sync::Barrier; pub(crate) struct ColumnBounds { /// The minimum value observed for this column min: ScalarValue, - /// The maximum value observed for this column + /// The maximum value observed for this column max: ScalarValue, } @@ -50,32 +51,43 @@ impl ColumnBounds { } } +/// Filter data for a single join key column, combining bounds and bloom filter. +/// Used in dynamic filter pushdown for comprehensive filtering. +#[derive(Debug)] +pub(crate) struct ColumnFilterData { + /// Min/max bounds for range filtering + pub(crate) bounds: ColumnBounds, + /// Bloom filter builder for membership testing + pub(crate) bloom_filter: BloomFilterBuilder, +} + +impl ColumnFilterData { + pub(crate) fn new(bounds: ColumnBounds, bloom_filter: BloomFilterBuilder) -> Self { + Self { + bounds, + bloom_filter, + } + } +} + /// Represents the bounds for all join key columns from a single partition. -/// This contains the min/max values computed from one partition's build-side data. -#[derive(Debug, Clone)] +/// This contains the filter data (min/max values and bloom filters) computed from one partition's build-side data. +#[derive(Debug)] pub(crate) struct PartitionBounds { /// Partition identifier for debugging and determinism (not strictly necessary) - partition: usize, - /// Min/max bounds for each join key column in this partition. + pub(crate) partition: usize, + /// Filter data (bounds + bloom filter) for each join key column in this partition. /// Index corresponds to the join key expression index. - column_bounds: Vec, + pub(crate) column_filters: Vec, } impl PartitionBounds { - pub(crate) fn new(partition: usize, column_bounds: Vec) -> Self { + pub(crate) fn new(partition: usize, column_filters: Vec) -> Self { Self { partition, - column_bounds, + column_filters, } } - - pub(crate) fn len(&self) -> usize { - self.column_bounds.len() - } - - pub(crate) fn get_column_bounds(&self, index: usize) -> Option<&ColumnBounds> { - self.column_bounds.get(index) - } } /// Coordinates dynamic filter bounds collection across multiple partitions @@ -102,8 +114,9 @@ impl PartitionBounds { /// All fields use a single mutex to ensure correct coordination between concurrent /// partition executions. pub(crate) struct SharedBoundsAccumulator { - /// Shared state protected by a single mutex to avoid ordering concerns - inner: Mutex, + /// Shared state protected by a single mutex to avoid ordering concerns. + /// After filter creation, this becomes None as the state is consumed. + inner: Mutex>, barrier: Barrier, /// Dynamic filter for pushdown to probe side dynamic_filter: Arc, @@ -166,68 +179,98 @@ impl SharedBoundsAccumulator { PartitionMode::Auto => unreachable!("PartitionMode::Auto should not be present at execution time. This is a bug in DataFusion, please report it!"), }; Self { - inner: Mutex::new(SharedBoundsState { + inner: Mutex::new(Some(SharedBoundsState { bounds: Vec::with_capacity(expected_calls), - }), + })), barrier: Barrier::new(expected_calls), dynamic_filter, on_right, } } - /// Create a filter expression from individual partition bounds using OR logic. + /// Create a filter expression from individual partition bounds and bloom filters using OR logic. /// - /// This creates a filter where each partition's bounds form a conjunction (AND) - /// of column range predicates, and all partitions are combined with OR. + /// This creates a filter where each partition's bounds and bloom filters form a conjunction (AND) + /// of column range predicates and bloom filter checks, and all partitions are combined with OR. /// /// For example, with 2 partitions and 2 columns: - /// ((col0 >= p0_min0 AND col0 <= p0_max0 AND col1 >= p0_min1 AND col1 <= p0_max1) + /// ((col0 >= p0_min0 AND col0 <= p0_max0 AND col0 IN BLOOM_FILTER_0 AND col1 >= p0_min1 AND col1 <= p0_max1 AND col1 IN BLOOM_FILTER_1) /// OR - /// (col0 >= p1_min0 AND col0 <= p1_max0 AND col1 >= p1_min1 AND col1 <= p1_max1)) + /// (col0 >= p1_min0 AND col0 <= p1_max0 AND col0 IN BLOOM_FILTER_0 AND col1 >= p1_min1 AND col1 <= p1_max1 AND col1 IN BLOOM_FILTER_1)) + /// + /// This method consumes the bounds to allow moving bloom filter builders. pub(crate) fn create_filter_from_partition_bounds( &self, - bounds: &[PartitionBounds], + mut bounds: Vec, ) -> Result> { if bounds.is_empty() { return Ok(lit(true)); } + // Sort by partition for determinism + bounds.sort_by_key(|b| b.partition); + // Create a predicate for each partition let mut partition_predicates = Vec::with_capacity(bounds.len()); - for partition_bounds in bounds.iter().sorted_by_key(|b| b.partition) { + for partition_bounds in bounds.into_iter() { // Create range predicates for each join key in this partition - let mut column_predicates = Vec::with_capacity(partition_bounds.len()); - - for (col_idx, right_expr) in self.on_right.iter().enumerate() { - if let Some(column_bounds) = partition_bounds.get_column_bounds(col_idx) { - // Create predicate: col >= min AND col <= max - let min_expr = Arc::new(BinaryExpr::new( - Arc::clone(right_expr), - Operator::GtEq, - lit(column_bounds.min.clone()), - )) as Arc; - let max_expr = Arc::new(BinaryExpr::new( - Arc::clone(right_expr), - Operator::LtEq, - lit(column_bounds.max.clone()), - )) as Arc; - let range_expr = - Arc::new(BinaryExpr::new(min_expr, Operator::And, max_expr)) - as Arc; - column_predicates.push(range_expr); + let mut range_predicates = + Vec::with_capacity(partition_bounds.column_filters.len()); + + // Get the first bloom filter (they all have the same data - combined hash of all columns) + let mut bloom_filter_builder = None; + + // Consume column_filters by taking ownership + for (col_idx, filter_data) in + partition_bounds.column_filters.into_iter().enumerate() + { + let right_expr = &self.on_right[col_idx]; + + // Create predicate: col >= min AND col <= max + let min_expr = Arc::new(BinaryExpr::new( + Arc::clone(right_expr), + Operator::GtEq, + lit(filter_data.bounds.min.clone()), + )) as Arc; + let max_expr = Arc::new(BinaryExpr::new( + Arc::clone(right_expr), + Operator::LtEq, + lit(filter_data.bounds.max.clone()), + )) as Arc; + let range_expr = + Arc::new(BinaryExpr::new(min_expr, Operator::And, max_expr)) + as Arc; + + range_predicates.push(range_expr); + + // Save the first bloom filter (all bloom filters have identical data) + if bloom_filter_builder.is_none() { + bloom_filter_builder = Some(filter_data.bloom_filter); } } - // Combine all column predicates for this partition with AND - if !column_predicates.is_empty() { - let partition_predicate = column_predicates + // Create a single bloom filter check for all columns combined + if let Some(builder) = bloom_filter_builder { + // Build bloom filter with all on_right expressions (matching how build side populated it) + let bloom_expr = Arc::new(builder.build(self.on_right.clone())) + as Arc; + + // Combine all range predicates with bloom filter + // First combine all range predicates with AND + let combined_ranges = range_predicates .into_iter() .reduce(|acc, pred| { Arc::new(BinaryExpr::new(acc, Operator::And, pred)) as Arc }) .unwrap(); + + // Then AND with the bloom filter + let partition_predicate = + Arc::new(BinaryExpr::new(combined_ranges, Operator::And, bloom_expr)) + as Arc; + partition_predicates.push(partition_predicate); } } @@ -247,8 +290,8 @@ impl SharedBoundsAccumulator { /// Report bounds from a completed partition and update dynamic filter if all partitions are done /// /// This method coordinates the dynamic filter updates across all partitions. It stores the - /// bounds from the current partition, increments the completion counter, and when all - /// partitions have reported, creates an OR'd filter from individual partition bounds. + /// filter data from the current partition, increments the completion counter, and when all + /// partitions have reported, creates an OR'd filter from individual partition filter data. /// /// This method is async and uses a [`tokio::sync::Barrier`] to wait for all partitions /// to report their bounds. Once that occurs, the method will resolve for all callers and the @@ -263,20 +306,23 @@ impl SharedBoundsAccumulator { /// /// # Arguments /// * `left_side_partition_id` - The identifier for the **left-side** partition reporting its bounds - /// * `partition_bounds` - The bounds computed by this partition (if any) + /// * `column_filters` - The filter data computed by this partition (if any) /// /// # Returns /// * `Result<()>` - Ok if successful, Err if filter update failed pub(crate) async fn report_partition_bounds( &self, left_side_partition_id: usize, - partition_bounds: Option>, + column_filters: Option>, ) -> Result<()> { - // Store bounds in the accumulator - this runs once per partition - if let Some(bounds) = partition_bounds { + // Store filter data in the accumulator - this runs once per partition + if let Some(filters) = column_filters { let mut guard = self.inner.lock(); + let state = guard + .as_mut() + .expect("SharedBoundsState should exist during partition reporting"); - let should_push = if let Some(last_bound) = guard.bounds.last() { + let should_push = if let Some(last_bound) = state.bounds.last() { // In `PartitionMode::CollectLeft`, all streams on the left side share the same partition id (0). // Since this function can be called multiple times for that same partition, we must deduplicate // by checking against the last recorded bound. @@ -286,18 +332,24 @@ impl SharedBoundsAccumulator { }; if should_push { - guard + state .bounds - .push(PartitionBounds::new(left_side_partition_id, bounds)); + .push(PartitionBounds::new(left_side_partition_id, filters)); } } if self.barrier.wait().await.is_leader() { // All partitions have reported, so we can update the filter - let inner = self.inner.lock(); - if !inner.bounds.is_empty() { + // Take ownership of the state (consuming it) + let state = self + .inner + .lock() + .take() + .expect("SharedBoundsState should exist when creating filter"); + + if !state.bounds.is_empty() { let filter_expr = - self.create_filter_from_partition_bounds(&inner.bounds)?; + self.create_filter_from_partition_bounds(state.bounds)?; self.dynamic_filter.update(filter_expr)?; } } diff --git a/datafusion/physical-plan/src/joins/hash_join/stream.rs b/datafusion/physical-plan/src/joins/hash_join/stream.rs index 88c50c2eb2cee..8fed06d0ffd26 100644 --- a/datafusion/physical-plan/src/joins/hash_join/stream.rs +++ b/datafusion/physical-plan/src/joins/hash_join/stream.rs @@ -419,10 +419,11 @@ impl HashJoinStream { PartitionMode::Auto => unreachable!("PartitionMode::Auto should not be present at execution time. This is a bug in DataFusion, please report it!"), }; - let left_data_bounds = left_data.bounds.clone(); + // Take ownership of column_filters to avoid cloning + let column_filters = left_data.column_filters.lock().take(); self.bounds_waiter = Some(OnceFut::new(async move { bounds_accumulator - .report_partition_bounds(left_side_partition_id, left_data_bounds) + .report_partition_bounds(left_side_partition_id, column_filters) .await })); self.state = HashJoinStreamState::WaitPartitionBoundsReport;