diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 554722f37ba2a..8e088e7a0b567 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -22,68 +22,25 @@ use std::sync::Arc; use crate::Result; -/// This macro is used to control continuation behaviors during tree traversals -/// based on the specified direction. Depending on `$DIRECTION` and the value of -/// the given expression (`$EXPR`), which should be a variant of [`TreeNodeRecursion`], -/// the macro results in the following behavior: -/// -/// - If the expression returns [`TreeNodeRecursion::Continue`], normal execution -/// continues. -/// - If it returns [`TreeNodeRecursion::Stop`], recursion halts and propagates -/// [`TreeNodeRecursion::Stop`]. -/// - If it returns [`TreeNodeRecursion::Jump`], the continuation behavior depends -/// on the traversal direction: -/// - For `UP` direction, the function returns with [`TreeNodeRecursion::Jump`], -/// bypassing further bottom-up closures until the next top-down closure. -/// - For `DOWN` direction, the function returns with [`TreeNodeRecursion::Continue`], -/// skipping further exploration. -/// - If no direction is specified, `Jump` is treated like `Continue`. -#[macro_export] -macro_rules! handle_visit_recursion { - // Internal helper macro for handling the `Jump` case based on the direction: - (@handle_jump UP) => { - return Ok(TreeNodeRecursion::Jump) - }; - (@handle_jump DOWN) => { - return Ok(TreeNodeRecursion::Continue) - }; - (@handle_jump) => { - {} // Treat `Jump` like `Continue`, do nothing and continue execution. - }; +/// These macros are used to determine continuation during transforming traversals. +macro_rules! handle_transform_recursion { + ($F_DOWN:expr, $F_CHILD:expr, $F_UP:expr) => {{ + #[allow(clippy::redundant_closure_call)] + $F_DOWN? + .transform_children(|n| n.map_children($F_CHILD))? + .transform_parent(|n| $F_UP(n)) + }}; +} - // Main macro logic with variables to handle directionality. - ($EXPR:expr $(, $DIRECTION:ident)?) => { - match $EXPR { - TreeNodeRecursion::Continue => {} - TreeNodeRecursion::Jump => handle_visit_recursion!(@handle_jump $($DIRECTION)?), - TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop), - } - }; +macro_rules! handle_transform_recursion_down { + ($F_DOWN:expr, $F_CHILD:expr) => {{ + $F_DOWN?.transform_children(|n| n.map_children($F_CHILD)) + }}; } -/// This macro is used to determine continuation during combined transforming -/// traversals. -/// -/// Depending on the [`TreeNodeRecursion`] the bottom-up closure returns, -/// [`Transformed::try_transform_node_with()`] decides recursion continuation -/// and if state propagation is necessary. Then, the same procedure recursively -/// applies to the children of the node in question. -macro_rules! handle_transform_recursion { - ($F_DOWN:expr, $F_SELF:expr, $F_UP:expr) => {{ - let pre_visited = $F_DOWN?; - match pre_visited.tnr { - TreeNodeRecursion::Continue => pre_visited - .data - .map_children($F_SELF)? - .try_transform_node_with($F_UP, TreeNodeRecursion::Jump), - #[allow(clippy::redundant_closure_call)] - TreeNodeRecursion::Jump => $F_UP(pre_visited.data), - TreeNodeRecursion::Stop => return Ok(pre_visited), - } - .map(|mut post_visited| { - post_visited.transformed |= pre_visited.transformed; - post_visited - }) +macro_rules! handle_transform_recursion_up { + ($SELF:expr, $F_CHILD:expr, $F_UP:expr) => {{ + $SELF.map_children($F_CHILD)?.transform_parent(|n| $F_UP(n)) }}; } @@ -128,17 +85,10 @@ pub trait TreeNode: Sized { &self, visitor: &mut V, ) -> Result { - match visitor.f_down(self)? { - TreeNodeRecursion::Continue => { - handle_visit_recursion!( - self.apply_children(&mut |n| n.visit(visitor))?, - UP - ); - visitor.f_up(self) - } - TreeNodeRecursion::Jump => visitor.f_up(self), - TreeNodeRecursion::Stop => Ok(TreeNodeRecursion::Stop), - } + visitor + .f_down(self)? + .visit_children(|| self.apply_children(|c| c.visit(visitor)))? + .visit_parent(|| visitor.f_up(self)) } /// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for @@ -184,8 +134,7 @@ pub trait TreeNode: Sized { &self, f: &mut F, ) -> Result { - handle_visit_recursion!(f(self)?, DOWN); - self.apply_children(&mut |n| n.apply(f)) + f(self)?.visit_children(|| self.apply_children(|c| c.apply(f))) } /// Convenience utility for writing optimizer rules: Recursively apply the @@ -205,10 +154,7 @@ pub trait TreeNode: Sized { self, f: &F, ) -> Result> { - f(self)?.try_transform_node_with( - |n| n.map_children(|c| c.transform_down(f)), - TreeNodeRecursion::Continue, - ) + handle_transform_recursion_down!(f(self), |c| c.transform_down(f)) } /// Convenience utility for writing optimizer rules: Recursively apply the @@ -218,10 +164,7 @@ pub trait TreeNode: Sized { self, f: &mut F, ) -> Result> { - f(self)?.try_transform_node_with( - |n| n.map_children(|c| c.transform_down_mut(f)), - TreeNodeRecursion::Continue, - ) + handle_transform_recursion_down!(f(self), |c| c.transform_down_mut(f)) } /// Convenience utility for writing optimizer rules: Recursively apply the @@ -232,8 +175,7 @@ pub trait TreeNode: Sized { self, f: &F, ) -> Result> { - self.map_children(|c| c.transform_up(f))? - .try_transform_node_with(f, TreeNodeRecursion::Jump) + handle_transform_recursion_up!(self, |c| c.transform_up(f), f) } /// Convenience utility for writing optimizer rules: Recursively apply the @@ -244,8 +186,7 @@ pub trait TreeNode: Sized { self, f: &mut F, ) -> Result> { - self.map_children(|c| c.transform_up_mut(f))? - .try_transform_node_with(f, TreeNodeRecursion::Jump) + handle_transform_recursion_up!(self, |c| c.transform_up_mut(f), f) } /// Transforms the tree using `f_down` while traversing the tree top-down @@ -355,7 +296,7 @@ pub trait TreeNode: Sized { /// Apply the closure `F` to the node's children. fn apply_children Result>( &self, - f: &mut F, + f: F, ) -> Result; /// Apply transform `F` to the node's children. Note that the transform `F` @@ -432,6 +373,45 @@ pub enum TreeNodeRecursion { Stop, } +impl TreeNodeRecursion { + /// Continues visiting nodes with `f` depending on the current [`TreeNodeRecursion`] + /// value and the fact that `f` is visiting the current node's children. + pub fn visit_children Result>( + self, + f: F, + ) -> Result { + match self { + TreeNodeRecursion::Continue => f(), + TreeNodeRecursion::Jump => Ok(TreeNodeRecursion::Continue), + TreeNodeRecursion::Stop => Ok(self), + } + } + + /// Continues visiting nodes with `f` depending on the current [`TreeNodeRecursion`] + /// value and the fact that `f` is visiting the current node's sibling. + pub fn visit_sibling Result>( + self, + f: F, + ) -> Result { + match self { + TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => f(), + TreeNodeRecursion::Stop => Ok(self), + } + } + + /// Continues visiting nodes with `f` depending on the current [`TreeNodeRecursion`] + /// value and the fact that `f` is visiting the current node's parent. + pub fn visit_parent Result>( + self, + f: F, + ) -> Result { + match self { + TreeNodeRecursion::Continue => f(), + TreeNodeRecursion::Jump | TreeNodeRecursion::Stop => Ok(self), + } + } +} + /// This struct is used by tree transformation APIs such as /// - [`TreeNode::rewrite`], /// - [`TreeNode::transform_down`], @@ -489,15 +469,23 @@ impl Transformed { f(self.data).map(|data| Transformed::new(data, self.transformed, self.tnr)) } - /// Handling [`TreeNodeRecursion::Continue`] and [`TreeNodeRecursion::Stop`] - /// is straightforward, but [`TreeNodeRecursion::Jump`] can behave differently - /// when we are traversing down or up on a tree. If [`TreeNodeRecursion`] of - /// the node is [`TreeNodeRecursion::Jump`], recursion stops with the given - /// `return_if_jump` value. - fn try_transform_node_with Result>>( + /// Maps the [`Transformed`] object to the result of the given `f`. + pub fn transform_data Result>>( + self, + f: F, + ) -> Result> { + f(self.data).map(|mut t| { + t.transformed |= self.transformed; + t + }) + } + + /// Maps the [`Transformed`] object to the result of the given `f` depending on the + /// current [`TreeNodeRecursion`] value and the fact that `f` is changing the current + /// node's children. + pub fn transform_children Result>>( mut self, f: F, - return_if_jump: TreeNodeRecursion, ) -> Result> { match self.tnr { TreeNodeRecursion::Continue => { @@ -507,37 +495,67 @@ impl Transformed { }); } TreeNodeRecursion::Jump => { - self.tnr = return_if_jump; + self.tnr = TreeNodeRecursion::Continue; } TreeNodeRecursion::Stop => {} } Ok(self) } - /// If [`TreeNodeRecursion`] of the node is [`TreeNodeRecursion::Continue`] or - /// [`TreeNodeRecursion::Jump`], transformation is applied to the node. - /// Otherwise, it remains as it is. - pub fn try_transform_node Result>>( + /// Maps the [`Transformed`] object to the result of the given `f` depending on the + /// current [`TreeNodeRecursion`] value and the fact that `f` is changing the current + /// node's sibling. + pub fn transform_sibling Result>>( self, f: F, ) -> Result> { - if self.tnr == TreeNodeRecursion::Stop { - Ok(self) - } else { - f(self.data).map(|mut t| { + match self.tnr { + TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => { + f(self.data).map(|mut t| { + t.transformed |= self.transformed; + t + }) + } + TreeNodeRecursion::Stop => Ok(self), + } + } + + /// Maps the [`Transformed`] object to the result of the given `f` depending on the + /// current [`TreeNodeRecursion`] value and the fact that `f` is changing the current + /// node's parent. + pub fn transform_parent Result>>( + self, + f: F, + ) -> Result> { + match self.tnr { + TreeNodeRecursion::Continue => f(self.data).map(|mut t| { t.transformed |= self.transformed; t - }) + }), + TreeNodeRecursion::Jump | TreeNodeRecursion::Stop => Ok(self), } } } /// Transformation helper to process a sequence of iterable tree nodes that are siblings. -pub trait TransformedIterator: Iterator { +pub trait TreeNodeIterator: Iterator { /// Apples `f` to each item in this iterator /// /// Visits all items in the iterator unless - /// `f` returns an error or `f` returns TreeNodeRecursion::stop. + /// `f` returns an error or `f` returns `TreeNodeRecursion::Stop`. + /// + /// # Returns + /// Error if `f` returns an error or `Ok(TreeNodeRecursion)` from the last invocation + /// of `f` or `Continue` if the iterator is empty + fn apply_until_stop Result>( + self, + f: F, + ) -> Result; + + /// Apples `f` to each item in this iterator + /// + /// Visits all items in the iterator unless + /// `f` returns an error or `f` returns `TreeNodeRecursion::Stop`. /// /// # Returns /// Error if `f` returns an error @@ -554,7 +572,22 @@ pub trait TransformedIterator: Iterator { ) -> Result>>; } -impl TransformedIterator for I { +impl TreeNodeIterator for I { + fn apply_until_stop Result>( + self, + mut f: F, + ) -> Result { + let mut tnr = TreeNodeRecursion::Continue; + for i in self { + tnr = f(i)?; + match tnr { + TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {} + TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop), + } + } + Ok(tnr) + } + fn map_until_stop_and_collect< F: FnMut(Self::Item) -> Result>, >( @@ -580,7 +613,7 @@ impl TransformedIterator for I { /// Transformation helper to process a heterogeneous sequence of tree node containing /// expressions. -/// This macro is very similar to [TransformedIterator::map_until_stop_and_collect] to +/// This macro is very similar to [TreeNodeIterator::map_until_stop_and_collect] to /// process nodes that are siblings, but it accepts an initial transformation (`F0`) and /// a sequence of pairs. Each pair is made of an expression (`EXPR`) and its /// transformation (`F`). @@ -664,14 +697,9 @@ pub trait DynTreeNode { impl TreeNode for Arc { fn apply_children Result>( &self, - f: &mut F, + f: F, ) -> Result { - let mut tnr = TreeNodeRecursion::Continue; - for child in self.arc_children() { - tnr = f(&child)?; - handle_visit_recursion!(tnr) - } - Ok(tnr) + self.arc_children().iter().apply_until_stop(f) } fn map_children Result>>( @@ -714,14 +742,9 @@ pub trait ConcreteTreeNode: Sized { impl TreeNode for T { fn apply_children Result>( &self, - f: &mut F, + f: F, ) -> Result { - let mut tnr = TreeNodeRecursion::Continue; - for child in self.children() { - tnr = f(child)?; - handle_visit_recursion!(tnr) - } - Ok(tnr) + self.children().into_iter().apply_until_stop(f) } fn map_children Result>>( @@ -745,7 +768,7 @@ mod tests { use std::fmt::Display; use crate::tree_node::{ - Transformed, TransformedIterator, TreeNode, TreeNodeRecursion, TreeNodeRewriter, + Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, }; use crate::Result; @@ -763,22 +786,17 @@ mod tests { } impl TreeNode for TestTreeNode { - fn apply_children(&self, f: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { - let mut tnr = TreeNodeRecursion::Continue; - for child in &self.children { - tnr = f(child)?; - handle_visit_recursion!(tnr); - } - Ok(tnr) + fn apply_children Result>( + &self, + f: F, + ) -> Result { + self.children.iter().apply_until_stop(f) } - fn map_children(self, f: F) -> Result> - where - F: FnMut(Self) -> Result>, - { + fn map_children Result>>( + self, + f: F, + ) -> Result> { Ok(self .children .into_iter() diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index df1585e5a5985..97331720ce7d0 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -25,16 +25,14 @@ use crate::expr::{ use crate::{Expr, GetFieldAccess}; use datafusion_common::tree_node::{ - Transformed, TransformedIterator, TreeNode, TreeNodeRecursion, -}; -use datafusion_common::{ - handle_visit_recursion, internal_err, map_until_stop_and_collect, Result, + Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, }; +use datafusion_common::{internal_err, map_until_stop_and_collect, Result}; impl TreeNode for Expr { fn apply_children Result>( &self, - f: &mut F, + f: F, ) -> Result { let children = match self { Expr::Alias(Alias{expr,..}) @@ -133,19 +131,13 @@ impl TreeNode for Expr { } }; - let mut tnr = TreeNodeRecursion::Continue; - for child in children { - tnr = f(child)?; - handle_visit_recursion!(tnr, DOWN); - } - - Ok(tnr) + children.into_iter().apply_until_stop(f) } - fn map_children(self, mut f: F) -> Result> - where - F: FnMut(Self) -> Result>, - { + fn map_children Result>>( + self, + mut f: F, + ) -> Result> { Ok(match self { Expr::Column(_) | Expr::Wildcard { .. } diff --git a/datafusion/expr/src/tree_node/plan.rs b/datafusion/expr/src/tree_node/plan.rs index 02d5d18512890..7a6b1005fedec 100644 --- a/datafusion/expr/src/tree_node/plan.rs +++ b/datafusion/expr/src/tree_node/plan.rs @@ -20,9 +20,9 @@ use crate::LogicalPlan; use datafusion_common::tree_node::{ - Transformed, TransformedIterator, TreeNode, TreeNodeRecursion, TreeNodeVisitor, + Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, TreeNodeVisitor, }; -use datafusion_common::{handle_visit_recursion, Result}; +use datafusion_common::Result; impl TreeNode for LogicalPlan { fn apply Result>( @@ -31,9 +31,10 @@ impl TreeNode for LogicalPlan { ) -> Result { // Compared to the default implementation, we need to invoke // [`Self::apply_subqueries`] before visiting its children - handle_visit_recursion!(f(self)?, DOWN); - self.apply_subqueries(f)?; - self.apply_children(&mut |n| n.apply(f)) + f(self)?.visit_children(|| { + self.apply_subqueries(f)?; + self.apply_children(|n| n.apply(f)) + }) } /// To use, define a struct that implements the trait [`TreeNodeVisitor`] and then invoke @@ -62,39 +63,26 @@ impl TreeNode for LogicalPlan { ) -> Result { // Compared to the default implementation, we need to invoke // [`Self::visit_subqueries`] before visiting its children - match visitor.f_down(self)? { - TreeNodeRecursion::Continue => { + visitor + .f_down(self)? + .visit_children(|| { self.visit_subqueries(visitor)?; - handle_visit_recursion!( - self.apply_children(&mut |n| n.visit(visitor))?, - UP - ); - visitor.f_up(self) - } - TreeNodeRecursion::Jump => { - self.visit_subqueries(visitor)?; - visitor.f_up(self) - } - TreeNodeRecursion::Stop => Ok(TreeNodeRecursion::Stop), - } + self.apply_children(|n| n.visit(visitor)) + })? + .visit_parent(|| visitor.f_up(self)) } fn apply_children Result>( &self, - f: &mut F, + f: F, ) -> Result { - let mut tnr = TreeNodeRecursion::Continue; - for child in self.inputs() { - tnr = f(child)?; - handle_visit_recursion!(tnr, DOWN) - } - Ok(tnr) + self.inputs().into_iter().apply_until_stop(f) } - fn map_children(self, f: F) -> Result> - where - F: FnMut(Self) -> Result>, - { + fn map_children Result>>( + self, + f: F, + ) -> Result> { let new_children = self .inputs() .iter() diff --git a/datafusion/optimizer/src/analyzer/subquery.rs b/datafusion/optimizer/src/analyzer/subquery.rs index b7f513727d39e..038361c3ee8c3 100644 --- a/datafusion/optimizer/src/analyzer/subquery.rs +++ b/datafusion/optimizer/src/analyzer/subquery.rs @@ -146,7 +146,7 @@ fn check_inner_plan( // We want to support as many operators as possible inside the correlated subquery match inner_plan { LogicalPlan::Aggregate(_) => { - inner_plan.apply_children(&mut |plan| { + inner_plan.apply_children(|plan| { check_inner_plan(plan, is_scalar, true, can_contain_outer_ref)?; Ok(TreeNodeRecursion::Continue) })?; @@ -171,7 +171,7 @@ fn check_inner_plan( } LogicalPlan::Window(window) => { check_mixed_out_refer_in_window(window)?; - inner_plan.apply_children(&mut |plan| { + inner_plan.apply_children(|plan| { check_inner_plan(plan, is_scalar, is_aggregate, can_contain_outer_ref)?; Ok(TreeNodeRecursion::Continue) })?; @@ -188,7 +188,7 @@ fn check_inner_plan( | LogicalPlan::Values(_) | LogicalPlan::Subquery(_) | LogicalPlan::SubqueryAlias(_) => { - inner_plan.apply_children(&mut |plan| { + inner_plan.apply_children(|plan| { check_inner_plan(plan, is_scalar, is_aggregate, can_contain_outer_ref)?; Ok(TreeNodeRecursion::Continue) })?; @@ -201,7 +201,7 @@ fn check_inner_plan( .. }) => match join_type { JoinType::Inner => { - inner_plan.apply_children(&mut |plan| { + inner_plan.apply_children(|plan| { check_inner_plan( plan, is_scalar, @@ -221,7 +221,7 @@ fn check_inner_plan( check_inner_plan(right, is_scalar, is_aggregate, can_contain_outer_ref) } JoinType::Full => { - inner_plan.apply_children(&mut |plan| { + inner_plan.apply_children(|plan| { check_inner_plan(plan, is_scalar, is_aggregate, false)?; Ok(TreeNodeRecursion::Continue) })?;