diff --git a/datafusion/core/tests/expr_api/simplification.rs b/datafusion/core/tests/expr_api/simplification.rs index 9ce47153ba4a8..254166695271c 100644 --- a/datafusion/core/tests/expr_api/simplification.rs +++ b/datafusion/core/tests/expr_api/simplification.rs @@ -512,27 +512,34 @@ fn test_simplify(input_expr: Expr, expected_expr: Expr) { "Mismatch evaluating {input_expr}\n Expected:{expected_expr}\n Got:{simplified_expr}" ); } -fn test_simplify_with_cycle_count( +fn test_simplify_with_cycle_info( input_expr: Expr, expected_expr: Expr, - expected_count: u32, + expected_cycle_count: usize, + expected_iteration_count: usize, ) { let info: MyInfo = MyInfo { schema: expr_test_schema(), execution_props: ExecutionProps::new(), }; let simplifier = ExprSimplifier::new(info); - let (simplified_expr, count) = simplifier - .simplify_with_cycle_count(input_expr.clone()) + let (simplified_expr, info) = simplifier + .simplify_with_cycle_info(input_expr.clone()) .expect("successfully evaluated"); + let total_iterations = info.total_iterations(); + let completed_cycles = info.completed_cycles(); assert_eq!( simplified_expr, expected_expr, "Mismatch evaluating {input_expr}\n Expected:{expected_expr}\n Got:{simplified_expr}" ); assert_eq!( - count, expected_count, - "Mismatch simplifier cycle count\n Expected: {expected_count}\n Got:{count}" + completed_cycles, expected_cycle_count, + "Mismatch simplifier cycle count\n Expected: {expected_cycle_count}\n Got:{completed_cycles}" + ); + assert_eq!( + total_iterations, expected_iteration_count, + "Mismatch simplifier cycle count\n Expected: {expected_iteration_count}\n Got:{total_iterations}" ); } @@ -691,5 +698,5 @@ fn test_simplify_cycles() { let expr = cast(now(), DataType::Int64) .lt(cast(to_timestamp(vec![lit(0)]), DataType::Int64) + lit(i64::MAX)); let expected = lit(true); - test_simplify_with_cycle_count(expr, expected, 3); + test_simplify_with_cycle_info(expr, expected, 2, 7); } diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index c172d59797569..56e5054012245 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -49,6 +49,7 @@ pub mod propagate_empty_relation; pub mod push_down_filter; pub mod push_down_limit; pub mod replace_distinct_aggregate; +pub mod rewrite_cycle; pub mod rewrite_disjunctive_predicate; pub mod scalar_subquery_to_join; pub mod simplify_expressions; diff --git a/datafusion/optimizer/src/rewrite_cycle.rs b/datafusion/optimizer/src/rewrite_cycle.rs new file mode 100644 index 0000000000000..90e05358022d1 --- /dev/null +++ b/datafusion/optimizer/src/rewrite_cycle.rs @@ -0,0 +1,413 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// [`RewriteCycle`] API for executing a sequence of [TreeNodeRewriter]s in multiple passes. +use std::ops::ControlFlow; + +use datafusion_common::{ + tree_node::{Transformed, TreeNode, TreeNodeRewriter}, + Result, +}; + +/// A builder with methods for executing a "rewrite cycle". +/// +/// Often the results of one optimization rule can uncover more optimizations in other optimization +/// rules. A sequence of optimization rules can be ran in multiple "passes" until there are no +/// more optmizations to make. +/// +/// The [RewriteCycle] handles logic for running these multi-pass loops. +/// It applies a sequence of [TreeNodeRewriter]s to a [TreeNode] by calling +/// [TreeNode::rewrite] in a loop - passing the output of one rewrite as the input to the next +/// rewrite - until [RewriteCycle::max_cycles] is reached or until every [TreeNode::rewrite] +/// returns a [Transformed::no] result in a consecutive sequence. +#[derive(Debug)] +pub struct RewriteCycle { + max_cycles: usize, +} + +impl Default for RewriteCycle { + fn default() -> Self { + Self::new() + } +} + +impl RewriteCycle { + /// The default maximum number of completed cycles to run before terminating the rewrite loop. + /// You can override this default with [Self::with_max_cycles] + pub const DEFAULT_MAX_CYCLES: usize = 3; + + /// Creates a new [RewriteCycle] with default options. + pub fn new() -> Self { + Self { + max_cycles: Self::DEFAULT_MAX_CYCLES, + } + } + /// Sets the [Self::max_cycles] to run before terminating the rewrite loop. + pub fn with_max_cycles(mut self, max_cycles: usize) -> Self { + self.max_cycles = max_cycles; + self + } + + /// The maximum number of completed cycles to run before terminating the rewrite loop. + /// Defaults to [Self::DEFAULT_MAX_CYCLES]. + pub fn max_cycles(&self) -> usize { + self.max_cycles + } + + /// Runs a rewrite cycle on the given [TreeNode] using the given callback function to + /// explicitly handle the cycle iterations. + /// + /// The callback function is given a [RewriteCycleState], which manages the short-circuiting + /// logic of the loop. The function is expected to call [RewriteCycleState::rewrite] for each + /// individual [TreeNodeRewriter] in the cycle. [RewriteCycleState::rewrite] returns a [RewriteCycleControlFlow] + /// result, indicating whether the loop should break or continue. + /// + /// ```rust + /// use datafusion_common::{ + /// tree_node::{Transformed, TreeNodeRewriter}, + /// Result, ScalarValue + /// }; + /// use datafusion_expr::{lit, BinaryExpr, Expr, Operator}; + /// + /// use datafusion_optimizer::rewrite_cycle::RewriteCycle; + /// + /// ///Rewrites a BinaryExpr with operator `op` using a function `f` + /// struct ConstBinaryExprRewriter { + /// op: Operator, + /// f: Box Result>>, + /// } + /// impl TreeNodeRewriter for ConstBinaryExprRewriter { + /// type Node = Expr; + /// /// Rewrites BinaryExpr using the function + /// fn f_up(&mut self, node: Self::Node) -> Result> { + /// match node { + /// Expr::BinaryExpr(BinaryExpr { + /// ref left, + /// ref right, + /// op, + /// }) if op == self.op => match (left.as_ref(), right.as_ref()) { + /// (Expr::Literal(left), Expr::Literal(right)) => { + /// Ok((self.f)(left, right)?) + /// } + /// _ => Ok(Transformed::no(node)), + /// }, + /// _ => Ok(Transformed::no(node)), + /// } + /// } + /// } + /// // create two rewriters for evaluating literals + /// // first rewriter evaluates addition expressions + /// let mut addition_rewriter = ConstBinaryExprRewriter { + /// op: Operator::Plus, + /// f: Box::new(|left, right| { + /// Ok(Transformed::yes(Expr::Literal(left.add(right)?))) + /// }), + /// }; + /// // second rewriter evaluates multiplication expression + /// let mut multiplication_rewriter = ConstBinaryExprRewriter { + /// op: Operator::Multiply, + /// f: Box::new(|left, right| { + /// Ok(Transformed::yes(Expr::Literal(left.mul(right)?))) + /// }), + /// }; + /// // Create an expression from constant literals + /// let expr = lit(6) + (lit(4) * (lit(2) + (lit(3) * lit(5)))); + /// // Run rewriters in a loop until constant expression is fully evaluated + /// let (evaluated_expr, info) = RewriteCycle::new() + /// .with_max_cycles(4) + /// .each_cycle(expr, |cycle_state| { + /// cycle_state + /// .rewrite(&mut addition_rewriter)? + /// .rewrite(&mut multiplication_rewriter) + /// }) + /// .unwrap(); + /// assert_eq!(evaluated_expr, lit(74)); + /// assert_eq!(info.completed_cycles(), 3); + /// assert_eq!(info.total_iterations(), 7); + /// ``` + pub fn each_cycle< + Node: TreeNode, + F: FnMut( + RewriteCycleState, + ) -> RewriteCycleControlFlow>, + >( + &self, + node: Node, + mut f: F, + ) -> Result<(Node, RewriteCycleInfo)> { + let mut state = RewriteCycleState::new(node); + if self.max_cycles == 0 { + return state.finish(); + } + // run first cycle then record number of rewriters + state = match f(state) { + ControlFlow::Break(result) => return result?.finish(), + ControlFlow::Continue(node) => node, + }; + state.record_cycle_length(); + if state.is_done() { + return state.finish(); + } + // run remaining cycles + match (1..self.max_cycles).try_fold(state, |state, _| f(state)) { + ControlFlow::Break(result) => result?.finish(), + ControlFlow::Continue(state) => state.finish(), + } + } +} + +/// Iteration state of a rewrite cycle. See [RewriteCycle::each_cycle] for usage examples and information. +#[derive(Debug)] +pub struct RewriteCycleState { + node: Node, + consecutive_unchanged_count: usize, + rewrite_count: usize, + cycle_length: Option, +} + +impl RewriteCycleState { + fn new(node: Node) -> Self { + Self { + node, + cycle_length: None, + consecutive_unchanged_count: 0, + rewrite_count: 0, + } + } + + /// Records the rewrite cycle length based on the current iteration count + /// + /// When the total number of writers is not known upfront - such as when using + /// [RewriteCycle::each_cycle] we need to keep count of the number of [Self::rewrite] + /// calls and then record the number at the end of the first cycle. + fn record_cycle_length(&mut self) { + self.cycle_length = Some(self.rewrite_count); + } + + /// Returns true when the loop has reached the maximum cycle length or when we've received + /// consecutive unchanged tree nodes equal to the total number of rewriters. + fn is_done(&self) -> bool { + // default value indicates we have not completed a cycle + let Some(cycle_length) = self.cycle_length else { + return false; + }; + self.consecutive_unchanged_count >= cycle_length + } + + /// Finishes the iteration by consuming the state and returning a [TreeNode] and + /// [RewriteCycleInfo] + fn finish(self) -> Result<(Node, RewriteCycleInfo)> { + Ok(( + self.node, + RewriteCycleInfo { + cycle_length: self.cycle_length.unwrap_or(self.rewrite_count), + total_iterations: self.rewrite_count, + }, + )) + } + + /// Calls [TreeNode::rewrite] and determines if the rewrite cycle should break or continue + /// based on the current [RewriteCycleState]. + pub fn rewrite + ?Sized>( + mut self, + rewriter: &mut R, + ) -> RewriteCycleControlFlow { + match self.node.rewrite(rewriter) { + Err(e) => ControlFlow::Break(Err(e)), + Ok(Transformed { + data: node, + transformed, + .. + }) => { + self.node = node; + self.rewrite_count += 1; + if transformed { + self.consecutive_unchanged_count = 0; + } else { + self.consecutive_unchanged_count += 1; + } + if self.is_done() { + ControlFlow::Break(Ok(self)) + } else { + ControlFlow::Continue(self) + } + } + } + } +} + +/// Information about a rewrite cycle, such as total number of iterations and number of fully +/// completed cycles. This is useful for testing purposes to ensure that optimzation passes are +/// working as expected. +#[derive(Debug, Clone, Copy)] +pub struct RewriteCycleInfo { + total_iterations: usize, + cycle_length: usize, +} + +impl RewriteCycleInfo { + /// The total number of **fully completed** cycles. + pub fn completed_cycles(&self) -> usize { + self.total_iterations / self.cycle_length + } + + /// The total number of [TreeNode::rewrite] calls. + pub fn total_iterations(&self) -> usize { + self.total_iterations + } + + /// The number of [TreeNode::rewrite] calls within a single cycle. + pub fn cycle_length(&self) -> usize { + self.cycle_length + } +} + +pub type RewriteCycleControlFlow = ControlFlow, T>; +#[cfg(test)] +mod test { + use datafusion_common::{ + tree_node::{Transformed, TreeNodeRewriter}, + Result, ScalarValue, + }; + use datafusion_expr::{lit, BinaryExpr, Expr, Operator}; + + use crate::rewrite_cycle::RewriteCycle; + + /// Rewriter that does not make any change + struct IdentityRewriter {} + impl TreeNodeRewriter for IdentityRewriter { + type Node = Expr; + fn f_up(&mut self, node: Self::Node) -> Result> { + Ok(Transformed::no(node)) + } + } + + /// Rewriter that always sets transformed=true + struct AlwaysTransformedRewriter {} + impl TreeNodeRewriter for AlwaysTransformedRewriter { + type Node = Expr; + fn f_up(&mut self, node: Self::Node) -> Result> { + Ok(Transformed::yes(node)) + } + } + + ///Rewrites a BinaryExpr with operator `op` using a function `f` + #[allow(clippy::type_complexity)] + struct ConstBinaryExprRewriter { + op: Operator, + f: Box Result>>, + } + impl TreeNodeRewriter for ConstBinaryExprRewriter { + type Node = Expr; + fn f_up(&mut self, node: Self::Node) -> Result> { + match node { + Expr::BinaryExpr(BinaryExpr { + ref left, + ref right, + op, + }) if op == self.op => match (left.as_ref(), right.as_ref()) { + (Expr::Literal(left), Expr::Literal(right)) => { + Ok((self.f)(left, right)?) + } + _ => Ok(Transformed::no(node)), + }, + _ => Ok(Transformed::no(node)), + } + } + } + + #[test] + // cycle that makes no changes should complete exactly one cycle + fn rewrite_cycle_identity() { + let expr = lit(true); + let (expr, info) = RewriteCycle::new() + .with_max_cycles(50) + .each_cycle(expr, |cycle_state| { + cycle_state + .rewrite(&mut IdentityRewriter {})? + .rewrite(&mut IdentityRewriter {})? + .rewrite(&mut IdentityRewriter {}) + }) + .unwrap(); + assert_eq!(expr, lit(true)); + assert_eq!(info.completed_cycles(), 1); + assert_eq!(info.total_iterations(), 3); + } + + // rewriter that always transforms should complete all cycles + #[test] + fn rewrite_cycle_always_transforms() { + let expr = lit(true); + let (expr, info) = RewriteCycle::new() + .with_max_cycles(10) + .each_cycle(expr, |cycle_state| { + cycle_state + .rewrite(&mut IdentityRewriter {})? + .rewrite(&mut AlwaysTransformedRewriter {}) + }) + .unwrap(); + assert_eq!(expr, lit(true)); + assert_eq!(info.completed_cycles(), 10); + assert_eq!(info.total_iterations(), 20); + } + + #[test] + // test an example of const evaluation with two rewriters that depend on each other + fn rewrite_cycle_const_evaluation() { + let mut addition_rewriter = ConstBinaryExprRewriter { + op: Operator::Plus, + f: Box::new(|left, right| { + Ok(Transformed::yes(Expr::Literal(left.add(right)?))) + }), + }; + let mut multiplication_rewriter = ConstBinaryExprRewriter { + op: Operator::Multiply, + f: Box::new(|left, right| { + Ok(Transformed::yes(Expr::Literal(left.mul(right)?))) + }), + }; + // Create an expression from constant literals + let expr = lit(6) + (lit(4) * (lit(2) + (lit(3) * lit(5)))); + // Run rewriters in a loop until constant expression is fully evaluated + let (evaluated_expr, info) = RewriteCycle::new() + .with_max_cycles(4) + .each_cycle(expr, |cycle_state| { + cycle_state + .rewrite(&mut addition_rewriter)? + .rewrite(&mut multiplication_rewriter) + }) + .unwrap(); + assert_eq!(evaluated_expr, lit(74)); + assert_eq!(info.completed_cycles(), 3); + assert_eq!(info.total_iterations(), 7); + + // Same expression as before + let expr = lit(6) + (lit(4) * (lit(2) + (lit(3) * lit(5)))); + // Use `with_max_cycles` to end rewriting earlier + let (evaluated_expr, info) = RewriteCycle::new() + .with_max_cycles(2) + .each_cycle(expr, |cycle_state| { + cycle_state + .rewrite(&mut addition_rewriter)? + .rewrite(&mut multiplication_rewriter) + }) + .unwrap(); + assert_eq!(evaluated_expr, lit(6) + lit(68)); + assert_eq!(info.completed_cycles(), 2); + assert_eq!(info.total_iterations(), 4); + } +} diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 024cb74403881..446f7dfe30c3b 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -43,10 +43,13 @@ use datafusion_expr::{ use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval}; use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps}; -use crate::analyzer::type_coercion::TypeCoercionRewriter; use crate::simplify_expressions::guarantees::GuaranteeRewriter; use crate::simplify_expressions::regex::simplify_regex_expr; use crate::simplify_expressions::SimplifyInfo; +use crate::{ + analyzer::type_coercion::TypeCoercionRewriter, + rewrite_cycle::{RewriteCycle, RewriteCycleInfo}, +}; use super::inlist_simplifier::ShortenInListSimplifier; use super::utils::*; @@ -95,11 +98,11 @@ pub struct ExprSimplifier { /// true canonicalize: bool, /// Maximum number of simplifier cycles - max_simplifier_cycles: u32, + max_simplifier_cycles: usize, } pub const THRESHOLD_INLINE_INLIST: usize = 3; -pub const DEFAULT_MAX_SIMPLIFIER_CYCLES: u32 = 3; +pub const DEFAULT_MAX_SIMPLIFIER_CYCLES: usize = 3; impl ExprSimplifier { /// Create a new `ExprSimplifier` with the given `info` such as an @@ -178,7 +181,7 @@ impl ExprSimplifier { /// assert_eq!(expr, b_lt_2); /// ``` pub fn simplify(&self, expr: Expr) -> Result { - Ok(self.simplify_with_cycle_count(expr)?.0) + Ok(self.simplify_with_cycle_info(expr)?.0) } /// Like [Self::simplify], simplifies this [`Expr`] as much as possible, evaluating @@ -188,36 +191,27 @@ impl ExprSimplifier { /// /// See [Self::simplify] for details and usage examples. /// - pub fn simplify_with_cycle_count(&self, mut expr: Expr) -> Result<(Expr, u32)> { + pub fn simplify_with_cycle_info( + &self, + mut expr: Expr, + ) -> Result<(Expr, RewriteCycleInfo)> { let mut simplifier = Simplifier::new(&self.info); let mut const_evaluator = ConstEvaluator::try_new(self.info.execution_props())?; - let mut shorten_in_list_simplifier = ShortenInListSimplifier::new(); let mut guarantee_rewriter = GuaranteeRewriter::new(&self.guarantees); if self.canonicalize { expr = expr.rewrite(&mut Canonicalizer::new()).data()? } - - // Evaluating constants can enable new simplifications and - // simplifications can enable new constant evaluation - // see `Self::with_max_cycles` - let mut num_cycles = 0; - loop { - let Transformed { - data, transformed, .. - } = expr - .rewrite(&mut const_evaluator)? - .transform_data(|expr| expr.rewrite(&mut simplifier))? - .transform_data(|expr| expr.rewrite(&mut guarantee_rewriter))?; - expr = data; - num_cycles += 1; - if !transformed || num_cycles >= self.max_simplifier_cycles { - break; - } - } - // shorten inlist should be started after other inlist rules are applied - expr = expr.rewrite(&mut shorten_in_list_simplifier).data()?; - Ok((expr, num_cycles)) + let (mut expr, info) = RewriteCycle::new() + .with_max_cycles(self.max_simplifier_cycles) + .each_cycle(expr, |cycle_state| { + cycle_state + .rewrite(&mut const_evaluator)? + .rewrite(&mut simplifier)? + .rewrite(&mut guarantee_rewriter) + })?; + expr = expr.rewrite(&mut ShortenInListSimplifier::new()).data()?; + Ok((expr, info)) } /// Apply type coercion to an [`Expr`] so that it can be @@ -381,21 +375,15 @@ impl ExprSimplifier { /// // Expression: a IS NOT NULL /// let expr = col("a").is_not_null(); /// - /// // When using default maximum cycles, 2 cycles will be performed. - /// let (simplified_expr, count) = simplifier.simplify_with_cycle_count(expr.clone()).unwrap(); - /// assert_eq!(simplified_expr, lit(true)); - /// // 2 cycles were executed, but only 1 was needed - /// assert_eq!(count, 2); - /// /// // Only 1 simplification pass is necessary here, so we can set the maximum cycles to 1. - /// let (simplified_expr, count) = simplifier.with_max_cycles(1).simplify_with_cycle_count(expr.clone()).unwrap(); + /// let (simplified_expr, info) = simplifier.with_max_cycles(1).simplify_with_cycle_info(expr.clone()).unwrap(); /// // Expression has been rewritten to: (c = a AND b = 1) /// assert_eq!(simplified_expr, lit(true)); /// // Only 1 cycle was executed - /// assert_eq!(count, 1); + /// assert_eq!(info.completed_cycles(), 1); /// /// ``` - pub fn with_max_cycles(mut self, max_simplifier_cycles: u32) -> Self { + pub fn with_max_cycles(mut self, max_simplifier_cycles: usize) -> Self { self.max_simplifier_cycles = max_simplifier_cycles; self } @@ -450,7 +438,6 @@ impl TreeNodeRewriter for Canonicalizer { } } -#[allow(rustdoc::private_intra_doc_links)] /// Partially evaluate `Expr`s so constant subtrees are evaluated at plan time. /// /// Note it does not handle algebraic rewrites such as `(a or false)` @@ -475,8 +462,8 @@ struct ConstEvaluator<'a> { input_batch: RecordBatch, } -#[allow(dead_code)] /// The simplify result of ConstEvaluator +#[allow(dead_code)] enum ConstSimplifyResult { // Expr was simplifed and contains the new expression Simplified(ScalarValue), @@ -2971,17 +2958,17 @@ mod tests { try_simplify(expr).unwrap() } - fn try_simplify_with_cycle_count(expr: Expr) -> Result<(Expr, u32)> { + fn try_simplify_with_cycle_info(expr: Expr) -> Result<(Expr, RewriteCycleInfo)> { let schema = expr_test_schema(); let execution_props = ExecutionProps::new(); let simplifier = ExprSimplifier::new( SimplifyContext::new(&execution_props).with_schema(schema), ); - simplifier.simplify_with_cycle_count(expr) + simplifier.simplify_with_cycle_info(expr) } - fn simplify_with_cycle_count(expr: Expr) -> (Expr, u32) { - try_simplify_with_cycle_count(expr).unwrap() + fn simplify_with_cycle_info(expr: Expr) -> (Expr, RewriteCycleInfo) { + try_simplify_with_cycle_info(expr).unwrap() } fn simplify_with_guarantee( @@ -3697,24 +3684,27 @@ mod tests { // TRUE let expr = lit(true); let expected = lit(true); - let (expr, num_iter) = simplify_with_cycle_count(expr); + let (expr, info) = simplify_with_cycle_info(expr); assert_eq!(expr, expected); - assert_eq!(num_iter, 1); + assert_eq!(info.completed_cycles(), 1); + assert_eq!(info.total_iterations(), 3); // (true != NULL) OR (5 > 10) let expr = lit(true).not_eq(lit_bool_null()).or(lit(5).gt(lit(10))); let expected = lit_bool_null(); - let (expr, num_iter) = simplify_with_cycle_count(expr); + let (expr, info) = simplify_with_cycle_info(expr); assert_eq!(expr, expected); - assert_eq!(num_iter, 2); + assert_eq!(info.completed_cycles(), 1); + assert_eq!(info.total_iterations(), 4); // NOTE: this currently does not simplify // (((c4 - 10) + 10) *100) / 100 let expr = (((col("c4") - lit(10)) + lit(10)) * lit(100)) / lit(100); let expected = expr.clone(); - let (expr, num_iter) = simplify_with_cycle_count(expr); + let (expr, info) = simplify_with_cycle_info(expr); assert_eq!(expr, expected); - assert_eq!(num_iter, 1); + assert_eq!(info.completed_cycles(), 1); + assert_eq!(info.total_iterations(), 3); // ((c4<1 or c3<2) and c3_non_null<3) and false let expr = col("c4") @@ -3723,10 +3713,12 @@ mod tests { .and(col("c3_non_null").lt(lit(3))) .and(lit(false)); let expected = lit(false); - let (expr, num_iter) = simplify_with_cycle_count(expr); + let (expr, info) = simplify_with_cycle_info(expr); assert_eq!(expr, expected); - assert_eq!(num_iter, 2); + assert_eq!(info.completed_cycles(), 1); + assert_eq!(info.total_iterations(), 5); } + #[test] fn test_simplify_udaf() { let udaf = AggregateUDF::new_from_impl(SimplifyMockUdaf::new_with_simplify());