diff --git a/datafusion/physical-expr/src/equivalence.rs b/datafusion/physical-expr/src/equivalence.rs index d8aa09b904605..84291653fb4f9 100644 --- a/datafusion/physical-expr/src/equivalence.rs +++ b/datafusion/physical-expr/src/equivalence.rs @@ -20,11 +20,11 @@ use std::hash::Hash; use std::sync::Arc; use crate::expressions::Column; -use crate::physical_expr::{deduplicate_physical_exprs, have_common_entries}; use crate::sort_properties::{ExprOrdering, SortProperties}; use crate::{ - physical_exprs_contains, LexOrdering, LexOrderingRef, LexRequirement, - LexRequirementRef, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement, + physical_exprs_bag_equal, physical_exprs_contains, LexOrdering, LexOrderingRef, + LexRequirement, LexRequirementRef, PhysicalExpr, PhysicalSortExpr, + PhysicalSortRequirement, }; use arrow::datatypes::SchemaRef; @@ -32,14 +32,110 @@ use arrow_schema::SortOptions; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{JoinSide, JoinType, Result}; +use crate::physical_expr::deduplicate_physical_exprs; use indexmap::map::Entry; use indexmap::IndexMap; /// An `EquivalenceClass` is a set of [`Arc`]s that are known /// to have the same value for all tuples in a relation. These are generated by -/// equality predicates, typically equi-join conditions and equality conditions -/// in filters. -pub type EquivalenceClass = Vec>; +/// equality predicates (e.g. `a = b`), typically equi-join conditions and +/// equality conditions in filters. +/// +/// Two `EquivalenceClass`es are equal if they contains the same expressions in +/// without any ordering. +#[derive(Debug, Clone)] +pub struct EquivalenceClass { + /// The expressions in this equivalence class. The order doesn't + /// matter for equivalence purposes + /// + /// TODO: use a HashSet for this instead of a Vec + exprs: Vec>, +} + +impl PartialEq for EquivalenceClass { + /// Returns true if other is equal in the sense + /// of bags (multi-sets), disregarding their orderings. + fn eq(&self, other: &Self) -> bool { + physical_exprs_bag_equal(&self.exprs, &other.exprs) + } +} + +impl EquivalenceClass { + /// Create a new empty equivalence class + pub fn new_empty() -> Self { + Self { exprs: vec![] } + } + + // Create a new equivalence class from a pre-existing `Vec` + pub fn new(mut exprs: Vec>) -> Self { + deduplicate_physical_exprs(&mut exprs); + Self { exprs } + } + + /// Return the inner vector of expressions + pub fn into_vec(self) -> Vec> { + self.exprs + } + + /// Return the "canonical" expression for this class (the first element) + /// if any + fn canonical_expr(&self) -> Option> { + self.exprs.first().cloned() + } + + /// Insert the expression into this class, meaning it is known to be equal to + /// all other expressions in this class + pub fn push(&mut self, expr: Arc) { + if !self.contains(&expr) { + self.exprs.push(expr); + } + } + + /// Inserts all the expressions from other into this class + pub fn extend(&mut self, other: Self) { + for expr in other.exprs { + // use push so entries are deduplicated + self.push(expr); + } + } + + /// Returns true if this equivalence class contains t expression + pub fn contains(&self, expr: &Arc) -> bool { + physical_exprs_contains(&self.exprs, expr) + } + + /// Returns true if this equivalence class has any entries in common with `other` + pub fn contains_any(&self, other: &Self) -> bool { + self.exprs.iter().any(|e| other.contains(e)) + } + + /// return the number of items in this class + pub fn len(&self) -> usize { + self.exprs.len() + } + + /// return true if this class is empty + pub fn is_empty(&self) -> bool { + self.exprs.is_empty() + } + + /// Iterate over all elements in this class, in some arbitrary order + pub fn iter(&self) -> impl Iterator> { + self.exprs.iter() + } + + /// Return a new equivalence class that have the specified offset added to + /// each expression (used when schemas are appended such as in joins) + pub fn with_offset(&self, offset: usize) -> Self { + let new_exprs = self + .exprs + .iter() + .cloned() + .map(|e| add_offset_to_expr(e, offset)) + .collect(); + Self::new(new_exprs) + } +} /// Stores the mapping between source expressions and target expressions for a /// projection. @@ -148,10 +244,10 @@ impl EquivalenceGroup { let mut first_class = None; let mut second_class = None; for (idx, cls) in self.classes.iter().enumerate() { - if physical_exprs_contains(cls, left) { + if cls.contains(left) { first_class = Some(idx); } - if physical_exprs_contains(cls, right) { + if cls.contains(right) { second_class = Some(idx); } } @@ -181,7 +277,8 @@ impl EquivalenceGroup { (None, None) => { // None of the expressions is among existing classes. // Create a new equivalence class and extend the group. - self.classes.push(vec![left.clone(), right.clone()]); + self.classes + .push(EquivalenceClass::new(vec![left.clone(), right.clone()])); } } } @@ -192,7 +289,6 @@ impl EquivalenceGroup { self.classes.retain_mut(|cls| { // Keep groups that have at least two entries as singleton class is // meaningless (i.e. it contains no non-trivial information): - deduplicate_physical_exprs(cls); cls.len() > 1 }); // Unify/bridge groups that have common expressions: @@ -209,7 +305,7 @@ impl EquivalenceGroup { let mut next_idx = idx + 1; let start_size = self.classes[idx].len(); while next_idx < self.classes.len() { - if have_common_entries(&self.classes[idx], &self.classes[next_idx]) { + if self.classes[idx].contains_any(&self.classes[next_idx]) { let extension = self.classes.swap_remove(next_idx); self.classes[idx].extend(extension); } else { @@ -217,10 +313,7 @@ impl EquivalenceGroup { } } if self.classes[idx].len() > start_size { - deduplicate_physical_exprs(&mut self.classes[idx]); - if self.classes[idx].len() > start_size { - continue; - } + continue; } idx += 1; } @@ -239,8 +332,8 @@ impl EquivalenceGroup { expr.clone() .transform(&|expr| { for cls in self.iter() { - if physical_exprs_contains(cls, &expr) { - return Ok(Transformed::Yes(cls[0].clone())); + if cls.contains(&expr) { + return Ok(Transformed::Yes(cls.canonical_expr().unwrap())); } } Ok(Transformed::No(expr)) @@ -330,7 +423,7 @@ impl EquivalenceGroup { if source.eq(expr) || self .get_equivalence_class(source) - .map_or(false, |group| physical_exprs_contains(group, expr)) + .map_or(false, |group| group.contains(expr)) { return Some(target.clone()); } @@ -380,7 +473,7 @@ impl EquivalenceGroup { .iter() .filter_map(|expr| self.project_expr(mapping, expr)) .collect::>(); - (new_class.len() > 1).then_some(new_class) + (new_class.len() > 1).then_some(EquivalenceClass::new(new_class)) }); // TODO: Convert the algorithm below to a version that uses `HashMap`. // once `Arc` can be stored in `HashMap`. @@ -402,7 +495,9 @@ impl EquivalenceGroup { // equivalence classes are meaningless. let new_classes = new_classes .into_iter() - .filter_map(|(_, values)| (values.len() > 1).then_some(values)); + .filter_map(|(_, values)| (values.len() > 1).then_some(values)) + .map(EquivalenceClass::new); + let classes = projected_classes.chain(new_classes).collect(); Self::new(classes) } @@ -412,10 +507,8 @@ impl EquivalenceGroup { fn get_equivalence_class( &self, expr: &Arc, - ) -> Option<&[Arc]> { - self.iter() - .map(|cls| cls.as_slice()) - .find(|cls| physical_exprs_contains(cls, expr)) + ) -> Option<&EquivalenceClass> { + self.iter().find(|cls| cls.contains(expr)) } /// Combine equivalence groups of the given join children. @@ -431,12 +524,11 @@ impl EquivalenceGroup { let mut result = Self::new( self.iter() .cloned() - .chain(right_equivalences.iter().map(|item| { - item.iter() - .cloned() - .map(|expr| add_offset_to_expr(expr, left_size)) - .collect() - })) + .chain( + right_equivalences + .iter() + .map(|cls| cls.with_offset(left_size)), + ) .collect(), ); // In we have an inner join, expressions in the "on" condition @@ -1246,14 +1338,13 @@ mod tests { use std::sync::Arc; use super::*; - use crate::expressions::{col, lit, BinaryExpr, Column}; - use crate::physical_expr::{physical_exprs_bag_equal, physical_exprs_equal}; + use crate::expressions::{col, lit, BinaryExpr, Column, Literal}; use arrow::compute::{lexsort_to_indices, SortColumn}; use arrow::datatypes::{DataType, Field, Schema}; use arrow_array::{ArrayRef, RecordBatch, UInt32Array, UInt64Array}; use arrow_schema::{Fields, SortOptions}; - use datafusion_common::Result; + use datafusion_common::{Result, ScalarValue}; use datafusion_expr::Operator; use itertools::{izip, Itertools}; @@ -1440,8 +1531,8 @@ mod tests { assert_eq!(eq_properties.eq_group().len(), 1); let eq_groups = &eq_properties.eq_group().classes[0]; assert_eq!(eq_groups.len(), 2); - assert!(physical_exprs_contains(eq_groups, &col_a_expr)); - assert!(physical_exprs_contains(eq_groups, &col_b_expr)); + assert!(eq_groups.contains(&col_a_expr)); + assert!(eq_groups.contains(&col_b_expr)); // b and c are aliases. Exising equivalence class should expand, // however there shouldn't be any new equivalence class @@ -1449,9 +1540,9 @@ mod tests { assert_eq!(eq_properties.eq_group().len(), 1); let eq_groups = &eq_properties.eq_group().classes[0]; assert_eq!(eq_groups.len(), 3); - assert!(physical_exprs_contains(eq_groups, &col_a_expr)); - assert!(physical_exprs_contains(eq_groups, &col_b_expr)); - assert!(physical_exprs_contains(eq_groups, &col_c_expr)); + assert!(eq_groups.contains(&col_a_expr)); + assert!(eq_groups.contains(&col_b_expr)); + assert!(eq_groups.contains(&col_c_expr)); // This is a new set of equality. Hence equivalent class count should be 2. eq_properties.add_equal_conditions(&col_x_expr, &col_y_expr); @@ -1463,11 +1554,11 @@ mod tests { assert_eq!(eq_properties.eq_group().len(), 1); let eq_groups = &eq_properties.eq_group().classes[0]; assert_eq!(eq_groups.len(), 5); - assert!(physical_exprs_contains(eq_groups, &col_a_expr)); - assert!(physical_exprs_contains(eq_groups, &col_b_expr)); - assert!(physical_exprs_contains(eq_groups, &col_c_expr)); - assert!(physical_exprs_contains(eq_groups, &col_x_expr)); - assert!(physical_exprs_contains(eq_groups, &col_y_expr)); + assert!(eq_groups.contains(&col_a_expr)); + assert!(eq_groups.contains(&col_b_expr)); + assert!(eq_groups.contains(&col_c_expr)); + assert!(eq_groups.contains(&col_x_expr)); + assert!(eq_groups.contains(&col_y_expr)); Ok(()) } @@ -1509,10 +1600,10 @@ mod tests { assert_eq!(out_properties.eq_group().len(), 1); let eq_class = &out_properties.eq_group().classes[0]; assert_eq!(eq_class.len(), 4); - assert!(physical_exprs_contains(eq_class, col_a1)); - assert!(physical_exprs_contains(eq_class, col_a2)); - assert!(physical_exprs_contains(eq_class, col_a3)); - assert!(physical_exprs_contains(eq_class, col_a4)); + assert!(eq_class.contains(col_a1)); + assert!(eq_class.contains(col_a2)); + assert!(eq_class.contains(col_a3)); + assert!(eq_class.contains(col_a4)); Ok(()) } @@ -1852,10 +1943,12 @@ mod tests { let entries = entries .into_iter() .map(|entry| entry.into_iter().map(lit).collect::>()) + .map(EquivalenceClass::new) .collect::>(); let expected = expected .into_iter() .map(|entry| entry.into_iter().map(lit).collect::>()) + .map(EquivalenceClass::new) .collect::>(); let mut eq_groups = EquivalenceGroup::new(entries.clone()); eq_groups.bridge_classes(); @@ -1866,11 +1959,7 @@ mod tests { ); assert_eq!(eq_groups.len(), expected.len(), "{}", err_msg); for idx in 0..eq_groups.len() { - assert!( - physical_exprs_bag_equal(&eq_groups[idx], &expected[idx]), - "{}", - err_msg - ); + assert_eq!(&eq_groups[idx], &expected[idx], "{}", err_msg); } } Ok(()) @@ -1879,14 +1968,17 @@ mod tests { #[test] fn test_remove_redundant_entries_eq_group() -> Result<()> { let entries = vec![ - vec![lit(1), lit(1), lit(2)], + EquivalenceClass::new(vec![lit(1), lit(1), lit(2)]), // This group is meaningless should be removed - vec![lit(3), lit(3)], - vec![lit(4), lit(5), lit(6)], + EquivalenceClass::new(vec![lit(3), lit(3)]), + EquivalenceClass::new(vec![lit(4), lit(5), lit(6)]), ]; // Given equivalences classes are not in succinct form. // Expected form is the most plain representation that is functionally same. - let expected = vec![vec![lit(1), lit(2)], vec![lit(4), lit(5), lit(6)]]; + let expected = vec![ + EquivalenceClass::new(vec![lit(1), lit(2)]), + EquivalenceClass::new(vec![lit(4), lit(5), lit(6)]), + ]; let mut eq_groups = EquivalenceGroup::new(entries); eq_groups.remove_redundant_entries(); @@ -1894,8 +1986,8 @@ mod tests { assert_eq!(eq_groups.len(), expected.len()); assert_eq!(eq_groups.len(), 2); - assert!(physical_exprs_equal(&eq_groups[0], &expected[0])); - assert!(physical_exprs_equal(&eq_groups[1], &expected[1])); + assert_eq!(eq_groups[0], expected[0]); + assert_eq!(eq_groups[1], expected[1]); Ok(()) } @@ -2151,7 +2243,7 @@ mod tests { // expressions in the equivalence classes. For other expressions in the same // equivalence class use same result. This util gets already calculated result, when available. fn get_representative_arr( - eq_group: &[Arc], + eq_group: &EquivalenceClass, existing_vec: &[Option], schema: SchemaRef, ) -> Option { @@ -2224,7 +2316,7 @@ mod tests { get_representative_arr(eq_group, &schema_vec, schema.clone()) .unwrap_or_else(|| generate_random_array(n_elem, n_distinct)); - for expr in eq_group { + for expr in eq_group.iter() { let col = expr.as_any().downcast_ref::().unwrap(); let (idx, _field) = schema.column_with_name(col.name()).unwrap(); schema_vec[idx] = Some(representative_array.clone()); @@ -2626,6 +2718,29 @@ mod tests { Ok(()) } + #[test] + fn test_contains_any() { + let lit_true = Arc::new(Literal::new(ScalarValue::Boolean(Some(true)))) + as Arc; + let lit_false = Arc::new(Literal::new(ScalarValue::Boolean(Some(false)))) + as Arc; + let lit2 = + Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc; + let lit1 = + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc; + let col_b_expr = Arc::new(Column::new("b", 1)) as Arc; + + let cls1 = EquivalenceClass::new(vec![lit_true.clone(), lit_false.clone()]); + let cls2 = EquivalenceClass::new(vec![lit_true.clone(), col_b_expr.clone()]); + let cls3 = EquivalenceClass::new(vec![lit2.clone(), lit1.clone()]); + + // lit_true is common + assert!(cls1.contains_any(&cls2)); + // there is no common entry + assert!(!cls1.contains_any(&cls3)); + assert!(!cls2.contains_any(&cls3)); + } + #[test] fn test_get_indices_of_matching_sort_exprs_with_order_eq() -> Result<()> { let sort_options = SortOptions::default(); diff --git a/datafusion/physical-expr/src/physical_expr.rs b/datafusion/physical-expr/src/physical_expr.rs index 79cbe6828b64b..455ca84a792f5 100644 --- a/datafusion/physical-expr/src/physical_expr.rs +++ b/datafusion/physical-expr/src/physical_expr.rs @@ -228,14 +228,6 @@ pub fn physical_exprs_contains( .any(|physical_expr| physical_expr.eq(expr)) } -/// Checks whether the given slices have any common entries. -pub fn have_common_entries( - lhs: &[Arc], - rhs: &[Arc], -) -> bool { - lhs.iter().any(|expr| physical_exprs_contains(rhs, expr)) -} - /// Checks whether the given physical expression slices are equal. pub fn physical_exprs_equal( lhs: &[Arc], @@ -293,8 +285,8 @@ mod tests { use crate::expressions::{Column, Literal}; use crate::physical_expr::{ - deduplicate_physical_exprs, have_common_entries, physical_exprs_bag_equal, - physical_exprs_contains, physical_exprs_equal, PhysicalExpr, + deduplicate_physical_exprs, physical_exprs_bag_equal, physical_exprs_contains, + physical_exprs_equal, PhysicalExpr, }; use datafusion_common::ScalarValue; @@ -334,29 +326,6 @@ mod tests { assert!(!physical_exprs_contains(&physical_exprs, &lit1)); } - #[test] - fn test_have_common_entries() { - let lit_true = Arc::new(Literal::new(ScalarValue::Boolean(Some(true)))) - as Arc; - let lit_false = Arc::new(Literal::new(ScalarValue::Boolean(Some(false)))) - as Arc; - let lit2 = - Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc; - let lit1 = - Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc; - let col_b_expr = Arc::new(Column::new("b", 1)) as Arc; - - let vec1 = vec![lit_true.clone(), lit_false.clone()]; - let vec2 = vec![lit_true.clone(), col_b_expr.clone()]; - let vec3 = vec![lit2.clone(), lit1.clone()]; - - // lit_true is common - assert!(have_common_entries(&vec1, &vec2)); - // there is no common entry - assert!(!have_common_entries(&vec1, &vec3)); - assert!(!have_common_entries(&vec2, &vec3)); - } - #[test] fn test_physical_exprs_equal() { let lit_true = Arc::new(Literal::new(ScalarValue::Boolean(Some(true))))