diff --git a/datafusion/physical-expr-common/src/physical_expr.rs b/datafusion/physical-expr-common/src/physical_expr.rs index 9102f54cc7474..2358a21940912 100644 --- a/datafusion/physical-expr-common/src/physical_expr.rs +++ b/datafusion/physical-expr-common/src/physical_expr.rs @@ -579,6 +579,25 @@ pub fn fmt_sql(expr: &dyn PhysicalExpr) -> impl Display + '_ { pub fn snapshot_physical_expr( expr: Arc, ) -> Result> { + snapshot_physical_expr_opt(expr).data() +} + +/// Take a snapshot of the given `PhysicalExpr` if it is dynamic. +/// +/// Take a snapshot of this `PhysicalExpr` if it is dynamic. +/// This is used to capture the current state of `PhysicalExpr`s that may contain +/// dynamic references to other operators in order to serialize it over the wire +/// or treat it via downcast matching. +/// +/// See the documentation of [`PhysicalExpr::snapshot`] for more details. +/// +/// # Returns +/// +/// Returns a `[`Transformed`] indicating whether a snapshot was taken, +/// along with the resulting `PhysicalExpr`. +pub fn snapshot_physical_expr_opt( + expr: Arc, +) -> Result>> { expr.transform_up(|e| { if let Some(snapshot) = e.snapshot()? { Ok(Transformed::yes(snapshot)) @@ -586,7 +605,6 @@ pub fn snapshot_physical_expr( Ok(Transformed::no(Arc::clone(&e))) } }) - .data() } /// Check the generation of this `PhysicalExpr`. diff --git a/datafusion/pruning/src/pruning_predicate.rs b/datafusion/pruning/src/pruning_predicate.rs index 4110391514dfb..2de8116cfeaee 100644 --- a/datafusion/pruning/src/pruning_predicate.rs +++ b/datafusion/pruning/src/pruning_predicate.rs @@ -35,7 +35,7 @@ use datafusion_physical_plan::metrics::Count; use log::{debug, trace}; use datafusion_common::error::Result; -use datafusion_common::tree_node::TransformedResult; +use datafusion_common::tree_node::{TransformedResult, TreeNodeRecursion}; use datafusion_common::{assert_eq_or_internal_err, Column, DFSchema}; use datafusion_common::{ internal_datafusion_err, plan_datafusion_err, plan_err, @@ -44,9 +44,9 @@ use datafusion_common::{ }; use datafusion_expr_common::operator::Operator; use datafusion_physical_expr::expressions::CastColumnExpr; -use datafusion_physical_expr::utils::{collect_columns, Guarantee, LiteralGuarantee}; +use datafusion_physical_expr::utils::{Guarantee, LiteralGuarantee}; use datafusion_physical_expr::{expressions as phys_expr, PhysicalExprRef}; -use datafusion_physical_expr_common::physical_expr::snapshot_physical_expr; +use datafusion_physical_expr_common::physical_expr::snapshot_physical_expr_opt; use datafusion_physical_plan::{ColumnarValue, PhysicalExpr}; /// Used to prove that arbitrary predicates (boolean expression) can not @@ -456,10 +456,29 @@ impl PruningPredicate { /// /// See the struct level documentation on [`PruningPredicate`] for more /// details. - pub fn try_new(expr: Arc, schema: SchemaRef) -> Result { - // Get a (simpler) snapshot of the physical expr here to use with `PruningPredicate` - // which does not handle dynamic exprs in general - let expr = snapshot_physical_expr(expr)?; + /// + /// Note that `PruningPredicate` does not attempt to normalize or simplify + /// the input expression unless calling [`snapshot_physical_expr_opt`] + /// returns a new expression. + /// It is recommended that you pass the expressions through [`PhysicalExprSimplifier`] + /// before calling this method to make sure the expressions can be used for pruning. + pub fn try_new(mut expr: Arc, schema: SchemaRef) -> Result { + // Get a (simpler) snapshot of the physical expr here to use with `PruningPredicate`. + // In particular this unravels any `DynamicFilterPhysicalExpr`s by snapshotting them + // so that PruningPredicate can work with a static expression. + let tf = snapshot_physical_expr_opt(expr)?; + if tf.transformed { + // If we had an expression such as Dynamic(part_col < 5 and col < 10) + // (this could come from something like `select * from t order by part_col, col, limit 10`) + // after snapshotting and because `DynamicFilterPhysicalExpr` applies child replacements to its + // children after snapshotting and previously `replace_columns_with_literals` may have been called with partition values + // the expression we have now is `8 < 5 and col < 10`. + // Thus we need as simplifier pass to get `false and col < 10` => `false` here. + let simplifier = PhysicalExprSimplifier::new(&schema); + expr = simplifier.simplify(tf.data)?; + } else { + expr = tf.data; + } let unhandled_hook = Arc::new(ConstantUnhandledPredicateHook::default()) as _; // build predicate expression once @@ -960,24 +979,41 @@ impl<'a> PruningExpressionBuilder<'a> { fn try_new( left: &'a Arc, right: &'a Arc, + left_columns: ColumnReferenceCount, + right_columns: ColumnReferenceCount, op: Operator, schema: &'a SchemaRef, required_columns: &'a mut RequiredColumns, ) -> Result { // find column name; input could be a more complicated expression - let left_columns = collect_columns(left); - let right_columns = collect_columns(right); - let (column_expr, scalar_expr, columns, correct_operator) = - match (left_columns.len(), right_columns.len()) { - (1, 0) => (left, right, left_columns, op), - (0, 1) => (right, left, right_columns, reverse_operator(op)?), - _ => { - // if more than one column used in expression - not supported - return plan_err!( - "Multi-column expressions are not currently supported" + let (column_expr, scalar_expr, column, correct_operator) = match ( + left_columns, + right_columns, + ) { + (ColumnReferenceCount::One(column), ColumnReferenceCount::Zero) => { + (left, right, column, op) + } + (ColumnReferenceCount::Zero, ColumnReferenceCount::One(column)) => { + (right, left, column, reverse_operator(op)?) + } + (ColumnReferenceCount::One(_), ColumnReferenceCount::One(_)) => { + // both sides have one column - not supported + return plan_err!( + "Expression not supported for pruning: left has 1 column, right has 1 column" ); - } - }; + } + (ColumnReferenceCount::Zero, ColumnReferenceCount::Zero) => { + // both sides are literals - should be handled before calling try_new + return plan_err!( + "Pruning literal expressions is not supported, please call PhysicalExprSimplifier first" + ); + } + (ColumnReferenceCount::Many, _) | (_, ColumnReferenceCount::Many) => { + return plan_err!( + "Expression not supported for pruning: left or right has multiple columns" + ); + } + }; let df_schema = DFSchema::try_from(Arc::clone(schema))?; let (column_expr, correct_operator, scalar_expr) = rewrite_expr_to_prunable( @@ -986,7 +1022,6 @@ impl<'a> PruningExpressionBuilder<'a> { scalar_expr, df_schema, )?; - let column = columns.iter().next().unwrap().clone(); let field = match schema.column_with_name(column.name()) { Some((_, f)) => f, _ => { @@ -1529,8 +1564,17 @@ fn build_predicate_expression( return expr; } - let expr_builder = - PruningExpressionBuilder::try_new(&left, &right, op, schema, required_columns); + let left_columns = ColumnReferenceCount::from_expression(&left); + let right_columns = ColumnReferenceCount::from_expression(&right); + let expr_builder = PruningExpressionBuilder::try_new( + &left, + &right, + left_columns, + right_columns, + op, + schema, + required_columns, + ); let mut expr_builder = match expr_builder { Ok(builder) => builder, // allow partial failure in predicate expression generation @@ -1545,6 +1589,50 @@ fn build_predicate_expression( .unwrap_or_else(|_| unhandled_hook.handle(expr)) } +/// Count of distinct column references in an expression. +/// This is the same as [`collect_columns`] but optimized to stop counting +/// once more than one distinct column is found. +/// +/// For example, in expression `col1 + col2`, the count is `Many`. +/// In expression `col1 + 5`, the count is `One`. +/// In expression `5 + 10`, the count is `Zero`. +/// +/// [`collect_columns`]: datafusion_physical_expr::utils::collect_columns +#[derive(Debug, PartialEq, Eq)] +enum ColumnReferenceCount { + /// no column references + Zero, + /// Only one column reference + One(phys_expr::Column), + /// More than one column reference + Many, +} + +impl ColumnReferenceCount { + /// Count the number of distinct column references in an expression + fn from_expression(expr: &Arc) -> Self { + let mut seen = HashSet::::new(); + expr.apply(|expr| { + if let Some(column) = expr.as_any().downcast_ref::() { + seen.insert(column.clone()); + if seen.len() > 1 { + return Ok(TreeNodeRecursion::Stop); + } + } + Ok(TreeNodeRecursion::Continue) + }) + // pre_visit always returns OK, so this will always too + .expect("no way to return error during recursion"); + match seen.len() { + 0 => ColumnReferenceCount::Zero, + 1 => ColumnReferenceCount::One( + seen.into_iter().next().expect("just checked len==1"), + ), + _ => ColumnReferenceCount::Many, + } + } +} + fn build_statistics_expr( expr_builder: &mut PruningExpressionBuilder, ) -> Result> { @@ -1884,6 +1972,7 @@ mod tests { use super::*; use datafusion_common::test_util::batches_to_string; use datafusion_expr::{and, col, lit, or}; + use datafusion_physical_expr::utils::collect_columns; use insta::assert_snapshot; use arrow::array::Decimal128Array; @@ -1894,8 +1983,11 @@ mod tests { use datafusion_expr::expr::InList; use datafusion_expr::{cast, is_null, try_cast, Expr}; use datafusion_functions_nested::expr_fn::{array_has, make_array}; - use datafusion_physical_expr::expressions as phys_expr; + use datafusion_physical_expr::expressions::{ + self as phys_expr, DynamicFilterPhysicalExpr, + }; use datafusion_physical_expr::planner::logical2physical; + use itertools::Itertools; #[derive(Debug, Default)] /// Mock statistic provider for tests @@ -2774,6 +2866,164 @@ mod tests { Ok(()) } + /// Test that non-boolean literal expressions don't prune any containers and error gracefully by not pruning anything instead of e.g. panicking + #[test] + fn row_group_predicate_non_boolean() { + let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, true)])); + let statistics = TestStatistics::new() + .with("c1", ContainerStats::new_i32(vec![Some(0)], vec![Some(10)])); + let expected_ret = &[true]; + prune_with_expr(lit(1), &schema, &statistics, expected_ret); + } + + // Test that literal-to-literal comparisons are correctly evaluated. + // When both sides are constants, the expression should be evaluated directly + // and if it's false, all containers should be pruned. + #[test] + fn row_group_predicate_literal_false() { + // lit(1) = lit(2) is always false, so all containers should be pruned + let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, true)])); + let statistics = TestStatistics::new() + .with("c1", ContainerStats::new_i32(vec![Some(0)], vec![Some(10)])); + let expected_ret = &[false]; + prune_with_simplified_expr(lit(1).eq(lit(2)), &schema, &statistics, expected_ret); + } + + /// Test nested/complex literal expression trees. + /// This is an integration test that PhysicalExprSimplifier + PruningPredicate work together as expected. + #[test] + fn row_group_predicate_literal_true() { + // lit(1) = lit(1) is always true, so no containers should be pruned + let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, true)])); + let statistics = TestStatistics::new() + .with("c1", ContainerStats::new_i32(vec![Some(0)], vec![Some(10)])); + let expected_ret = &[true]; + prune_with_simplified_expr(lit(1).eq(lit(1)), &schema, &statistics, expected_ret); + } + + /// Test nested/complex literal expression trees. + /// This is an integration test that PhysicalExprSimplifier + PruningPredicate work together as expected. + #[test] + fn row_group_predicate_literal_null() { + // lit(1) = null is always null, so no containers should be pruned + let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, true)])); + let statistics = TestStatistics::new() + .with("c1", ContainerStats::new_i32(vec![Some(0)], vec![Some(10)])); + let expected_ret = &[true]; + prune_with_simplified_expr( + lit(1).eq(lit(ScalarValue::Null)), + &schema, + &statistics, + expected_ret, + ); + } + + /// Test nested/complex literal expression trees. + /// This is an integration test that PhysicalExprSimplifier + PruningPredicate work together as expected. + #[test] + fn row_group_predicate_complex_literals() { + let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, true)])); + let statistics = TestStatistics::new() + .with("c1", ContainerStats::new_i32(vec![Some(0)], vec![Some(10)])); + + // (1 + 2) > 0 is always true + prune_with_simplified_expr( + (lit(1) + lit(2)).gt(lit(0)), + &schema, + &statistics, + &[true], + ); + + // (1 + 2) < 0 is always false + prune_with_simplified_expr( + (lit(1) + lit(2)).lt(lit(0)), + &schema, + &statistics, + &[false], + ); + + // Nested AND of literals: true AND false = false + prune_with_simplified_expr( + lit(true).and(lit(false)), + &schema, + &statistics, + &[false], + ); + + // Nested OR of literals: true OR false = true + prune_with_simplified_expr( + lit(true).or(lit(false)), + &schema, + &statistics, + &[true], + ); + + // Complex nested: (1 < 2) AND (3 > 1) = true AND true = true + prune_with_simplified_expr( + lit(1).lt(lit(2)).and(lit(3).gt(lit(1))), + &schema, + &statistics, + &[true], + ); + + // Complex nested: (1 > 2) OR (3 < 1) = false OR false = false + prune_with_simplified_expr( + lit(1).gt(lit(2)).or(lit(3).lt(lit(1))), + &schema, + &statistics, + &[false], + ); + } + + /// Integration test demonstrating that a dynamic filter with replaced children as literals will be snapshotted, simplified and then pruned correctly. + #[test] + fn row_group_predicate_dynamic_filter_with_literals() { + let schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int32, true), + Field::new("part", DataType::Utf8, true), + ])); + let statistics = TestStatistics::new() + // Note that we have no stats, pruning can only happen via partition value pruning from the dynamic filter + .with_row_counts("c1", vec![Some(10)]); + let dynamic_filter_expr = col("c1").gt(lit(5)).and(col("part").eq(lit("B"))); + let phys_expr = logical2physical(&dynamic_filter_expr, &schema); + let children = collect_columns(&phys_expr) + .iter() + .map(|c| Arc::new(c.clone()) as Arc) + .collect_vec(); + let dynamic_phys_expr = + Arc::new(DynamicFilterPhysicalExpr::new(children, phys_expr)) + as Arc; + // Simulate the partition value substitution that would happen in ParquetOpener + let remapped_expr = dynamic_phys_expr + .children() + .into_iter() + .map(|child_expr| { + let Some(col_expr) = + child_expr.as_any().downcast_ref::() + else { + return Arc::clone(child_expr); + }; + if col_expr.name() == "part" { + // simulate dynamic filter replacement with literal "A" + Arc::new(phys_expr::Literal::new(ScalarValue::Utf8(Some( + "A".to_string(), + )))) as Arc + } else { + Arc::clone(child_expr) + } + }) + .collect_vec(); + let dynamic_filter_expr = + dynamic_phys_expr.with_new_children(remapped_expr).unwrap(); + // After substitution the expression is c1 > 5 AND part = "B" which should prune the file since the partition value is "A" + let expected = &[false]; + let p = + PruningPredicate::try_new(dynamic_filter_expr, Arc::clone(&schema)).unwrap(); + let result = p.prune(&statistics).unwrap(); + assert_eq!(result, expected); + } + #[test] fn row_group_predicate_lt_bool() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Boolean, false)]); @@ -5137,6 +5387,21 @@ mod tests { assert_eq!(result, expected); } + fn prune_with_simplified_expr( + expr: Expr, + schema: &SchemaRef, + statistics: &TestStatistics, + expected: &[bool], + ) { + println!("Pruning with expr: {expr}"); + let expr = logical2physical(&expr, schema); + let simplifier = PhysicalExprSimplifier::new(schema); + let expr = simplifier.simplify(expr).unwrap(); + let p = PruningPredicate::try_new(expr, Arc::::clone(schema)).unwrap(); + let result = p.prune(statistics).unwrap(); + assert_eq!(result, expected); + } + fn test_build_predicate_expression( expr: &Expr, schema: &Schema,