Skip to content

Commit 4e60434

Browse files
committed
feat: support normalized expr in CSE
1 parent f190fc6 commit 4e60434

3 files changed

Lines changed: 288 additions & 30 deletions

File tree

datafusion/common/src/cse.rs

Lines changed: 109 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -50,33 +50,53 @@ impl<T: HashNode + ?Sized> HashNode for Arc<T> {
5050
}
5151
}
5252

53+
/// A trait that defines how to normalize a node.
54+
///
55+
/// This trait is used to normalize nodes before comparing them for CSE. Normalization
56+
/// can be used to ensure that two nodes that are semantically equivalent are considered
57+
/// equal for CSE.
58+
/// For example:`a + b` and `b + a` are semantically equivalent.
59+
pub trait NormalizeNode: Eq {
60+
fn normalize(&self) -> Self;
61+
fn enable_normalized(&self) -> bool;
62+
}
63+
5364
/// Identifier that represents a [`TreeNode`] tree.
5465
///
5566
/// This identifier is designed to be efficient and "hash", "accumulate", "equal" and
5667
/// "have no collision (as low as possible)"
57-
#[derive(Debug, Eq, PartialEq)]
58-
struct Identifier<'n, N> {
68+
#[derive(Debug, Eq)]
69+
struct Identifier<'n, N: NormalizeNode> {
5970
// Hash of `node` built up incrementally during the first, visiting traversal.
6071
// Its value is not necessarily equal to default hash of the node. E.g. it is not
6172
// equal to `expr.hash()` if the node is `Expr`.
6273
hash: u64,
6374
node: &'n N,
6475
}
6576

66-
impl<N> Clone for Identifier<'_, N> {
77+
impl<N: NormalizeNode> Clone for Identifier<'_, N> {
6778
fn clone(&self) -> Self {
6879
*self
6980
}
7081
}
71-
impl<N> Copy for Identifier<'_, N> {}
82+
impl<N: NormalizeNode> Copy for Identifier<'_, N> {}
7283

73-
impl<N> Hash for Identifier<'_, N> {
84+
impl<N: NormalizeNode> Hash for Identifier<'_, N> {
7485
fn hash<H: Hasher>(&self, state: &mut H) {
7586
state.write_u64(self.hash);
7687
}
7788
}
7889

79-
impl<'n, N: HashNode> Identifier<'n, N> {
90+
impl<N: NormalizeNode> PartialEq for Identifier<'_, N> {
91+
fn eq(&self, other: &Self) -> bool {
92+
self.node.normalize() == other.node.normalize()
93+
}
94+
}
95+
96+
impl<'n, N> Identifier<'n, N>
97+
where
98+
N: HashNode + NormalizeNode,
99+
{
80100
fn new(node: &'n N, random_state: &RandomState) -> Self {
81101
let mut hasher = random_state.build_hasher();
82102
node.hash_node(&mut hasher);
@@ -213,7 +233,11 @@ pub enum FoundCommonNodes<N> {
213233
///
214234
/// A [`TreeNode`] without any children (column, literal etc.) will not have identifier
215235
/// because they should not be recognized as common subtree.
216-
struct CSEVisitor<'a, 'n, N, C: CSEController<Node = N>> {
236+
struct CSEVisitor<'a, 'n, N, C>
237+
where
238+
N: NormalizeNode,
239+
C: CSEController<Node = N>,
240+
{
217241
/// statistics of [`TreeNode`]s
218242
node_stats: &'a mut NodeStats<'n, N>,
219243

@@ -244,7 +268,10 @@ struct CSEVisitor<'a, 'n, N, C: CSEController<Node = N>> {
244268
}
245269

246270
/// Record item that used when traversing a [`TreeNode`] tree.
247-
enum VisitRecord<'n, N> {
271+
enum VisitRecord<'n, N>
272+
where
273+
N: NormalizeNode,
274+
{
248275
/// Marks the beginning of [`TreeNode`]. It contains:
249276
/// - The post-order index assigned during the first, visiting traversal.
250277
EnterMark(usize),
@@ -258,7 +285,11 @@ enum VisitRecord<'n, N> {
258285
NodeItem(Identifier<'n, N>, bool),
259286
}
260287

261-
impl<'n, N: TreeNode + HashNode, C: CSEController<Node = N>> CSEVisitor<'_, 'n, N, C> {
288+
impl<'n, N, C> CSEVisitor<'_, 'n, N, C>
289+
where
290+
N: TreeNode + HashNode + NormalizeNode,
291+
C: CSEController<Node = N>,
292+
{
262293
/// Find the first `EnterMark` in the stack, and accumulates every `NodeItem` before
263294
/// it. Returns a tuple that contains:
264295
/// - The pre-order index of the [`TreeNode`] we marked.
@@ -271,17 +302,26 @@ impl<'n, N: TreeNode + HashNode, C: CSEController<Node = N>> CSEVisitor<'_, 'n,
271302
/// information up from children to parents via `visit_stack` during the first,
272303
/// visiting traversal and no need to test the expression's validity beforehand with
273304
/// an extra traversal).
274-
fn pop_enter_mark(&mut self) -> (usize, Option<Identifier<'n, N>>, bool) {
275-
let mut node_id = None;
305+
fn pop_enter_mark(
306+
&mut self,
307+
enable_normalize: bool,
308+
) -> (usize, Option<Identifier<'n, N>>, bool) {
309+
let mut node_ids: Vec<Identifier<'n, N>> = vec![];
276310
let mut is_valid = true;
277311

278312
while let Some(item) = self.visit_stack.pop() {
279313
match item {
280314
VisitRecord::EnterMark(down_index) => {
315+
if enable_normalize {
316+
node_ids.sort_by_key(|i| i.hash);
317+
}
318+
let node_id = node_ids
319+
.into_iter()
320+
.fold(None, |accum, item| Some(item.combine(accum)));
281321
return (down_index, node_id, is_valid);
282322
}
283323
VisitRecord::NodeItem(sub_node_id, sub_node_is_valid) => {
284-
node_id = Some(sub_node_id.combine(node_id));
324+
node_ids.push(sub_node_id);
285325
is_valid &= sub_node_is_valid;
286326
}
287327
}
@@ -290,8 +330,10 @@ impl<'n, N: TreeNode + HashNode, C: CSEController<Node = N>> CSEVisitor<'_, 'n,
290330
}
291331
}
292332

293-
impl<'n, N: TreeNode + HashNode + Eq, C: CSEController<Node = N>> TreeNodeVisitor<'n>
294-
for CSEVisitor<'_, 'n, N, C>
333+
impl<'n, N, C> TreeNodeVisitor<'n> for CSEVisitor<'_, 'n, N, C>
334+
where
335+
N: TreeNode + HashNode + NormalizeNode,
336+
C: CSEController<Node = N>,
295337
{
296338
type Node = N;
297339

@@ -331,7 +373,8 @@ impl<'n, N: TreeNode + HashNode + Eq, C: CSEController<Node = N>> TreeNodeVisito
331373
}
332374

333375
fn f_up(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
334-
let (down_index, sub_node_id, sub_node_is_valid) = self.pop_enter_mark();
376+
let (down_index, sub_node_id, sub_node_is_valid) =
377+
self.pop_enter_mark(node.enable_normalized());
335378

336379
let node_id = Identifier::new(node, self.random_state).combine(sub_node_id);
337380
let is_valid = C::is_valid(node) && sub_node_is_valid;
@@ -369,7 +412,11 @@ impl<'n, N: TreeNode + HashNode + Eq, C: CSEController<Node = N>> TreeNodeVisito
369412
/// Rewrite a [`TreeNode`] tree by replacing detected common subtrees with the
370413
/// corresponding temporary [`TreeNode`], that column contains the evaluate result of
371414
/// replaced [`TreeNode`] tree.
372-
struct CSERewriter<'a, 'n, N, C: CSEController<Node = N>> {
415+
struct CSERewriter<'a, 'n, N, C>
416+
where
417+
N: NormalizeNode,
418+
C: CSEController<Node = N>,
419+
{
373420
/// statistics of [`TreeNode`]s
374421
node_stats: &'a NodeStats<'n, N>,
375422

@@ -386,8 +433,10 @@ struct CSERewriter<'a, 'n, N, C: CSEController<Node = N>> {
386433
controller: &'a mut C,
387434
}
388435

389-
impl<N: TreeNode + Eq, C: CSEController<Node = N>> TreeNodeRewriter
390-
for CSERewriter<'_, '_, N, C>
436+
impl<N, C> TreeNodeRewriter for CSERewriter<'_, '_, N, C>
437+
where
438+
N: TreeNode + NormalizeNode,
439+
C: CSEController<Node = N>,
391440
{
392441
type Node = N;
393442

@@ -408,13 +457,30 @@ impl<N: TreeNode + Eq, C: CSEController<Node = N>> TreeNodeRewriter
408457
self.down_index += 1;
409458
}
410459

411-
let (node, alias) =
412-
self.common_nodes.entry(node_id).or_insert_with(|| {
413-
let node_alias = self.controller.generate_alias();
414-
(node, node_alias)
415-
});
416-
417-
let rewritten = self.controller.rewrite(node, alias);
460+
// We *must* replace all original nodes with same `node_id`, not just the first
461+
// node which is inserted into the common_nodes. This is because nodes with the same
462+
// `node_id` are semantically equivalent, but not exactly the same.
463+
//
464+
// For example, `a + 1` and `1 + a` are semantically equivalent but not identical.
465+
// In this case, we should replace the common expression `1 + a` with a new variable
466+
// (e.g., `__common_cse_1`). So, `a + 1` and `1 + a` would both be replaced by
467+
// `__common_cse_1`.
468+
//
469+
// The final result would be:
470+
// - `__common_cse_1 as a + 1`
471+
// - `__common_cse_1 as 1 + a`
472+
//
473+
// This way, we can efficiently handle semantically equivalent expressions without
474+
// incorrectly treating them as identical.
475+
let rewritten = if let Some((_, alias)) = self.common_nodes.get(&node_id)
476+
{
477+
self.controller.rewrite(&node, alias)
478+
} else {
479+
let node_alias = self.controller.generate_alias();
480+
let rewritten = self.controller.rewrite(&node, &node_alias);
481+
self.common_nodes.insert(node_id, (node, node_alias));
482+
rewritten
483+
};
418484

419485
return Ok(Transformed::new(rewritten, true, TreeNodeRecursion::Jump));
420486
}
@@ -441,7 +507,11 @@ pub struct CSE<N, C: CSEController<Node = N>> {
441507
controller: C,
442508
}
443509

444-
impl<N: TreeNode + HashNode + Clone + Eq, C: CSEController<Node = N>> CSE<N, C> {
510+
impl<N, C> CSE<N, C>
511+
where
512+
N: TreeNode + HashNode + Clone + NormalizeNode,
513+
C: CSEController<Node = N>,
514+
{
445515
pub fn new(controller: C) -> Self {
446516
Self {
447517
random_state: RandomState::new(),
@@ -557,6 +627,7 @@ impl<N: TreeNode + HashNode + Clone + Eq, C: CSEController<Node = N>> CSE<N, C>
557627
) -> Result<FoundCommonNodes<N>> {
558628
let mut found_common = false;
559629
let mut node_stats = NodeStats::new();
630+
560631
let id_arrays_list = nodes_list
561632
.iter()
562633
.map(|nodes| {
@@ -596,7 +667,9 @@ impl<N: TreeNode + HashNode + Clone + Eq, C: CSEController<Node = N>> CSE<N, C>
596667
#[cfg(test)]
597668
mod test {
598669
use crate::alias::AliasGenerator;
599-
use crate::cse::{CSEController, HashNode, IdArray, Identifier, NodeStats, CSE};
670+
use crate::cse::{
671+
CSEController, HashNode, IdArray, Identifier, NodeStats, NormalizeNode, CSE,
672+
};
600673
use crate::tree_node::tests::TestTreeNode;
601674
use crate::Result;
602675
use std::collections::HashSet;
@@ -662,6 +735,15 @@ mod test {
662735
}
663736
}
664737

738+
impl NormalizeNode for TestTreeNode<String> {
739+
fn normalize(&self) -> Self {
740+
self.clone()
741+
}
742+
fn enable_normalized(&self) -> bool {
743+
false
744+
}
745+
}
746+
665747
#[test]
666748
fn id_array_visitor() -> Result<()> {
667749
let alias_generator = AliasGenerator::new();

datafusion/expr/src/expr.rs

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ use crate::{
3434
};
3535

3636
use arrow::datatypes::{DataType, FieldRef};
37-
use datafusion_common::cse::HashNode;
37+
use datafusion_common::cse::{HashNode, NormalizeNode};
3838
use datafusion_common::tree_node::{
3939
Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
4040
};
@@ -1666,6 +1666,69 @@ impl Expr {
16661666
}
16671667
}
16681668

1669+
impl NormalizeNode for Expr {
1670+
fn enable_normalized(&self) -> bool {
1671+
#[allow(clippy::match_like_matches_macro)]
1672+
match self {
1673+
Expr::BinaryExpr(BinaryExpr {
1674+
op:
1675+
_op @ (Operator::Plus
1676+
| Operator::Multiply
1677+
| Operator::BitwiseAnd
1678+
| Operator::BitwiseOr
1679+
| Operator::BitwiseXor
1680+
| Operator::Eq
1681+
| Operator::NotEq),
1682+
..
1683+
}) => true,
1684+
_ => false,
1685+
}
1686+
}
1687+
1688+
fn normalize(&self) -> Expr {
1689+
match self {
1690+
Expr::BinaryExpr(BinaryExpr {
1691+
ref left,
1692+
ref op,
1693+
ref right,
1694+
}) => {
1695+
let normalized_left = left.normalize();
1696+
let normalized_right = right.normalize();
1697+
let new_binary = if matches!(
1698+
op,
1699+
Operator::Plus
1700+
| Operator::Multiply
1701+
| Operator::BitwiseAnd
1702+
| Operator::BitwiseOr
1703+
| Operator::BitwiseXor
1704+
| Operator::Eq
1705+
| Operator::NotEq
1706+
) {
1707+
let (l_expr, r_expr) =
1708+
if format!("{normalized_left}") < format!("{normalized_right}") {
1709+
(normalized_left, normalized_right)
1710+
} else {
1711+
(normalized_right, normalized_left)
1712+
};
1713+
BinaryExpr {
1714+
left: Box::new(l_expr),
1715+
op: *op,
1716+
right: Box::new(r_expr),
1717+
}
1718+
} else {
1719+
BinaryExpr {
1720+
left: Box::new(normalized_left),
1721+
op: *op,
1722+
right: Box::new(normalized_right),
1723+
}
1724+
};
1725+
Expr::BinaryExpr(new_binary)
1726+
}
1727+
other => other.clone(),
1728+
}
1729+
}
1730+
}
1731+
16691732
impl HashNode for Expr {
16701733
/// As it is pretty easy to forget changing this method when `Expr` changes the
16711734
/// implementation doesn't use wildcard patterns (`..`, `_`) to catch changes

0 commit comments

Comments
 (0)