Skip to content

Commit 391606e

Browse files
alambkszucs
authored andcommitted
ARROW-11330: [Rust][DataFusion] add ExpressionVisitor to encode expression walking
## Problem: * There are several places in the DataFusion codebase where a walk of an Expression tree is needed * The logic of how to walk the tree is replicated * Adding new expression types often require many mechanically different but semantically the same changes in many places where no special treatment of such types is needed This PR introduces a `ExpressionVisitor` trait and the `Expr::accept` function to consolidate this walking of the expression tree. It does not intend to change any functionality. If folks like this pattern, I have ideas for a similar type of trait `ExpressionRewriter` which can be used to rewrite expressions (much like `clone_with_replacement`) as a subsquent PR. I think this was mentioned by @Dandandan in the [Rust roadmap](https://docs.google.com/document/d/1qspsOM_dknOxJKdGvKbC1aoVoO0M3i6x1CIo58mmN2Y/edit#heading=h.kstb571j5g5j) cc @jorgecarleitao @Dandandan and @andygrove Closes #9278 from alamb/alamb/expression_visitor Authored-by: Andrew Lamb <andrew@nerdnetworks.org> Signed-off-by: Andrew Lamb <andrew@nerdnetworks.org>
1 parent 0c3eb74 commit 391606e

4 files changed

Lines changed: 219 additions & 165 deletions

File tree

rust/datafusion/src/logical_plan/expr.rs

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,136 @@ impl Expr {
422422
nulls_first,
423423
}
424424
}
425+
426+
/// Performs a depth first walk of an expression and
427+
/// its children, calling [`ExpressionVisitor::pre_visit`] and
428+
/// `visitor.post_visit`.
429+
///
430+
/// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) to
431+
/// separate expression algorithms from the structure of the
432+
/// `Expr` tree and make it easier to add new types of expressions
433+
/// and algorithms that walk the tree.
434+
///
435+
/// For an expression tree such as
436+
/// BinaryExpr (GT)
437+
/// left: Column("foo")
438+
/// right: Column("bar")
439+
///
440+
/// The nodes are visited using the following order
441+
/// ```text
442+
/// pre_visit(BinaryExpr(GT))
443+
/// pre_visit(Column("foo"))
444+
/// pre_visit(Column("bar"))
445+
/// post_visit(Column("bar"))
446+
/// post_visit(Column("bar"))
447+
/// post_visit(BinaryExpr(GT))
448+
/// ```
449+
///
450+
/// If an Err result is returned, recursion is stopped immediately
451+
///
452+
/// If `Recursion::Stop` is returned on a call to pre_visit, no
453+
/// children of that expression are visited, nor is post_visit
454+
/// called on that expression
455+
///
456+
pub fn accept<V: ExpressionVisitor>(&self, visitor: V) -> Result<V> {
457+
let visitor = match visitor.pre_visit(self)? {
458+
Recursion::Continue(visitor) => visitor,
459+
// If the recursion should stop, do not visit children
460+
Recursion::Stop(visitor) => return Ok(visitor),
461+
};
462+
463+
// recurse (and cover all expression types)
464+
let visitor = match self {
465+
Expr::Alias(expr, _) => expr.accept(visitor),
466+
Expr::Column(..) => Ok(visitor),
467+
Expr::ScalarVariable(..) => Ok(visitor),
468+
Expr::Literal(..) => Ok(visitor),
469+
Expr::BinaryExpr { left, right, .. } => {
470+
let visitor = left.accept(visitor)?;
471+
right.accept(visitor)
472+
}
473+
Expr::Not(expr) => expr.accept(visitor),
474+
Expr::IsNotNull(expr) => expr.accept(visitor),
475+
Expr::IsNull(expr) => expr.accept(visitor),
476+
Expr::Negative(expr) => expr.accept(visitor),
477+
Expr::Between {
478+
expr, low, high, ..
479+
} => {
480+
let visitor = expr.accept(visitor)?;
481+
let visitor = low.accept(visitor)?;
482+
high.accept(visitor)
483+
}
484+
Expr::Case {
485+
expr,
486+
when_then_expr,
487+
else_expr,
488+
} => {
489+
let visitor = if let Some(expr) = expr.as_ref() {
490+
expr.accept(visitor)
491+
} else {
492+
Ok(visitor)
493+
}?;
494+
let visitor = when_then_expr.iter().try_fold(
495+
visitor,
496+
|visitor, (when, then)| {
497+
let visitor = when.accept(visitor)?;
498+
then.accept(visitor)
499+
},
500+
)?;
501+
if let Some(else_expr) = else_expr.as_ref() {
502+
else_expr.accept(visitor)
503+
} else {
504+
Ok(visitor)
505+
}
506+
}
507+
Expr::Cast { expr, .. } => expr.accept(visitor),
508+
Expr::Sort { expr, .. } => expr.accept(visitor),
509+
Expr::ScalarFunction { args, .. } => args
510+
.iter()
511+
.try_fold(visitor, |visitor, arg| arg.accept(visitor)),
512+
Expr::ScalarUDF { args, .. } => args
513+
.iter()
514+
.try_fold(visitor, |visitor, arg| arg.accept(visitor)),
515+
Expr::AggregateFunction { args, .. } => args
516+
.iter()
517+
.try_fold(visitor, |visitor, arg| arg.accept(visitor)),
518+
Expr::AggregateUDF { args, .. } => args
519+
.iter()
520+
.try_fold(visitor, |visitor, arg| arg.accept(visitor)),
521+
Expr::InList { expr, list, .. } => {
522+
let visitor = expr.accept(visitor)?;
523+
list.iter()
524+
.try_fold(visitor, |visitor, arg| arg.accept(visitor))
525+
}
526+
Expr::Wildcard => Ok(visitor),
527+
}?;
528+
529+
visitor.post_visit(self)
530+
}
531+
}
532+
533+
/// Controls how the visitor recursion should proceed.
534+
pub enum Recursion<V: ExpressionVisitor> {
535+
/// Attempt to visit all the children, recursively, of this expression.
536+
Continue(V),
537+
/// Do not visit the children of this expression, though the walk
538+
/// of parents of this expression will not be affected
539+
Stop(V),
540+
}
541+
542+
/// Encode the traversal of an expression tree. When passed to
543+
/// `Expr::accept`, `ExpressionVisitor::visit` is invoked
544+
/// recursively on all nodes of an expression tree. See the comments
545+
/// on `Expr::accept` for details on its use
546+
pub trait ExpressionVisitor: Sized {
547+
/// Invoked before any children of `expr` are visisted.
548+
fn pre_visit(self, expr: &Expr) -> Result<Recursion<Self>>;
549+
550+
/// Invoked after all children of `expr` are visited. Default
551+
/// implementation does nothing.
552+
fn post_visit(self, _expr: &Expr) -> Result<Self> {
553+
Ok(self)
554+
}
425555
}
426556

427557
pub struct CaseBuilder {

rust/datafusion/src/logical_plan/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ pub use expr::{
3838
count, count_distinct, create_udaf, create_udf, exp, exprlist_to_fields, floor,
3939
in_list, length, lit, ln, log10, log2, lower, ltrim, max, md5, min, or, round, rtrim,
4040
sha224, sha256, sha384, sha512, signum, sin, sqrt, sum, tan, trim, trunc, upper,
41-
when, Expr, Literal,
41+
when, Expr, ExpressionVisitor, Literal, Recursion,
4242
};
4343
pub use extension::UserDefinedLogicalNode;
4444
pub use operators::Operator;

rust/datafusion/src/optimizer/utils.rs

Lines changed: 42 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,16 @@ use std::{collections::HashSet, sync::Arc};
2222
use arrow::datatypes::Schema;
2323

2424
use super::optimizer::OptimizerRule;
25-
use crate::error::{DataFusionError, Result};
2625
use crate::logical_plan::{
27-
Expr, LogicalPlan, Operator, Partitioning, PlanType, StringifiedPlan, ToDFSchema,
26+
Expr, LogicalPlan, Operator, Partitioning, PlanType, Recursion, StringifiedPlan,
27+
ToDFSchema,
2828
};
2929
use crate::prelude::{col, lit};
3030
use crate::scalar::ScalarValue;
31+
use crate::{
32+
error::{DataFusionError, Result},
33+
logical_plan::ExpressionVisitor,
34+
};
3135

3236
const CASE_EXPR_MARKER: &str = "__DATAFUSION_CASE_EXPR__";
3337
const CASE_ELSE_MARKER: &str = "__DATAFUSION_CASE_ELSE__";
@@ -46,75 +50,48 @@ pub fn exprlist_to_column_names(
4650

4751
/// Recursively walk an expression tree, collecting the unique set of column names
4852
/// referenced in the expression
49-
pub fn expr_to_column_names(expr: &Expr, accum: &mut HashSet<String>) -> Result<()> {
50-
match expr {
51-
Expr::Alias(expr, _) => expr_to_column_names(expr, accum),
52-
Expr::Column(name) => {
53-
accum.insert(name.clone());
54-
Ok(())
55-
}
56-
Expr::ScalarVariable(var_names) => {
57-
accum.insert(var_names.join("."));
58-
Ok(())
59-
}
60-
Expr::Literal(_) => {
61-
// not needed
62-
Ok(())
63-
}
64-
Expr::Not(e) => expr_to_column_names(e, accum),
65-
Expr::Negative(e) => expr_to_column_names(e, accum),
66-
Expr::IsNull(e) => expr_to_column_names(e, accum),
67-
Expr::IsNotNull(e) => expr_to_column_names(e, accum),
68-
Expr::BinaryExpr { left, right, .. } => {
69-
expr_to_column_names(left, accum)?;
70-
expr_to_column_names(right, accum)?;
71-
Ok(())
72-
}
73-
Expr::Case {
74-
expr,
75-
when_then_expr,
76-
else_expr,
77-
..
78-
} => {
79-
if let Some(e) = expr {
80-
expr_to_column_names(e, accum)?;
81-
}
82-
for (w, t) in when_then_expr {
83-
expr_to_column_names(w, accum)?;
84-
expr_to_column_names(t, accum)?;
85-
}
86-
if let Some(e) = else_expr {
87-
expr_to_column_names(e, accum)?
53+
struct ColumnNameVisitor<'a> {
54+
accum: &'a mut HashSet<String>,
55+
}
56+
57+
impl ExpressionVisitor for ColumnNameVisitor<'_> {
58+
fn pre_visit(self, expr: &Expr) -> Result<Recursion<Self>> {
59+
match expr {
60+
Expr::Column(name) => {
61+
self.accum.insert(name.clone());
8862
}
89-
Ok(())
90-
}
91-
Expr::Cast { expr, .. } => expr_to_column_names(expr, accum),
92-
Expr::Sort { expr, .. } => expr_to_column_names(expr, accum),
93-
Expr::AggregateFunction { args, .. } => exprlist_to_column_names(args, accum),
94-
Expr::AggregateUDF { args, .. } => exprlist_to_column_names(args, accum),
95-
Expr::ScalarFunction { args, .. } => exprlist_to_column_names(args, accum),
96-
Expr::ScalarUDF { args, .. } => exprlist_to_column_names(args, accum),
97-
Expr::Between {
98-
expr, low, high, ..
99-
} => {
100-
expr_to_column_names(expr, accum)?;
101-
expr_to_column_names(low, accum)?;
102-
expr_to_column_names(high, accum)?;
103-
Ok(())
104-
}
105-
Expr::InList { expr, list, .. } => {
106-
expr_to_column_names(expr, accum)?;
107-
for list_expr in list {
108-
expr_to_column_names(list_expr, accum)?;
63+
Expr::ScalarVariable(var_names) => {
64+
self.accum.insert(var_names.join("."));
10965
}
110-
Ok(())
66+
Expr::Alias(_, _) => {}
67+
Expr::Literal(_) => {}
68+
Expr::BinaryExpr { .. } => {}
69+
Expr::Not(_) => {}
70+
Expr::IsNotNull(_) => {}
71+
Expr::IsNull(_) => {}
72+
Expr::Negative(_) => {}
73+
Expr::Between { .. } => {}
74+
Expr::Case { .. } => {}
75+
Expr::Cast { .. } => {}
76+
Expr::Sort { .. } => {}
77+
Expr::ScalarFunction { .. } => {}
78+
Expr::ScalarUDF { .. } => {}
79+
Expr::AggregateFunction { .. } => {}
80+
Expr::AggregateUDF { .. } => {}
81+
Expr::InList { .. } => {}
82+
Expr::Wildcard => {}
11183
}
112-
Expr::Wildcard => Err(DataFusionError::Internal(
113-
"Wildcard expressions are not valid in a logical query plan".to_owned(),
114-
)),
84+
Ok(Recursion::Continue(self))
11585
}
11686
}
11787

88+
/// Recursively walk an expression tree, collecting the unique set of column names
89+
/// referenced in the expression
90+
pub fn expr_to_column_names(expr: &Expr, accum: &mut HashSet<String>) -> Result<()> {
91+
expr.accept(ColumnNameVisitor { accum })?;
92+
Ok(())
93+
}
94+
11895
/// Create a `LogicalPlan::Explain` node by running `optimizer` on the
11996
/// input plan and capturing the resulting plan string
12097
pub fn optimize_explain(

0 commit comments

Comments
 (0)