@@ -50,33 +50,53 @@ impl<T: HashNode + ?Sized> HashNode for Arc<T> {
5050 }
5151}
5252
53+ /// A trait that defines how to normalize a node.
54+ ///
55+ /// This trait is used to normalize nodes before comparing them for CSE. Normalization
56+ /// can be used to ensure that two nodes that are semantically equivalent are considered
57+ /// equal for CSE.
58+ /// For example:`a + b` and `b + a` are semantically equivalent.
59+ pub trait NormalizeNode : Eq {
60+ fn normalize ( & self ) -> Self ;
61+ fn enable_normalized ( & self ) -> bool ;
62+ }
63+
5364/// Identifier that represents a [`TreeNode`] tree.
5465///
5566/// This identifier is designed to be efficient and "hash", "accumulate", "equal" and
5667/// "have no collision (as low as possible)"
57- #[ derive( Debug , Eq , PartialEq ) ]
58- struct Identifier < ' n , N > {
68+ #[ derive( Debug , Eq ) ]
69+ struct Identifier < ' n , N : NormalizeNode > {
5970 // Hash of `node` built up incrementally during the first, visiting traversal.
6071 // Its value is not necessarily equal to default hash of the node. E.g. it is not
6172 // equal to `expr.hash()` if the node is `Expr`.
6273 hash : u64 ,
6374 node : & ' n N ,
6475}
6576
66- impl < N > Clone for Identifier < ' _ , N > {
77+ impl < N : NormalizeNode > Clone for Identifier < ' _ , N > {
6778 fn clone ( & self ) -> Self {
6879 * self
6980 }
7081}
71- impl < N > Copy for Identifier < ' _ , N > { }
82+ impl < N : NormalizeNode > Copy for Identifier < ' _ , N > { }
7283
73- impl < N > Hash for Identifier < ' _ , N > {
84+ impl < N : NormalizeNode > Hash for Identifier < ' _ , N > {
7485 fn hash < H : Hasher > ( & self , state : & mut H ) {
7586 state. write_u64 ( self . hash ) ;
7687 }
7788}
7889
79- impl < ' n , N : HashNode > Identifier < ' n , N > {
90+ impl < N : NormalizeNode > PartialEq for Identifier < ' _ , N > {
91+ fn eq ( & self , other : & Self ) -> bool {
92+ self . node . normalize ( ) == other. node . normalize ( )
93+ }
94+ }
95+
96+ impl < ' n , N > Identifier < ' n , N >
97+ where
98+ N : HashNode + NormalizeNode ,
99+ {
80100 fn new ( node : & ' n N , random_state : & RandomState ) -> Self {
81101 let mut hasher = random_state. build_hasher ( ) ;
82102 node. hash_node ( & mut hasher) ;
@@ -213,7 +233,11 @@ pub enum FoundCommonNodes<N> {
213233///
214234/// A [`TreeNode`] without any children (column, literal etc.) will not have identifier
215235/// because they should not be recognized as common subtree.
216- struct CSEVisitor < ' a , ' n , N , C : CSEController < Node = N > > {
236+ struct CSEVisitor < ' a , ' n , N , C >
237+ where
238+ N : NormalizeNode ,
239+ C : CSEController < Node = N > ,
240+ {
217241 /// statistics of [`TreeNode`]s
218242 node_stats : & ' a mut NodeStats < ' n , N > ,
219243
@@ -244,7 +268,10 @@ struct CSEVisitor<'a, 'n, N, C: CSEController<Node = N>> {
244268}
245269
246270/// Record item that used when traversing a [`TreeNode`] tree.
247- enum VisitRecord < ' n , N > {
271+ enum VisitRecord < ' n , N >
272+ where
273+ N : NormalizeNode ,
274+ {
248275 /// Marks the beginning of [`TreeNode`]. It contains:
249276 /// - The post-order index assigned during the first, visiting traversal.
250277 EnterMark ( usize ) ,
@@ -258,7 +285,11 @@ enum VisitRecord<'n, N> {
258285 NodeItem ( Identifier < ' n , N > , bool ) ,
259286}
260287
261- impl < ' n , N : TreeNode + HashNode , C : CSEController < Node = N > > CSEVisitor < ' _ , ' n , N , C > {
288+ impl < ' n , N , C > CSEVisitor < ' _ , ' n , N , C >
289+ where
290+ N : TreeNode + HashNode + NormalizeNode ,
291+ C : CSEController < Node = N > ,
292+ {
262293 /// Find the first `EnterMark` in the stack, and accumulates every `NodeItem` before
263294 /// it. Returns a tuple that contains:
264295 /// - The pre-order index of the [`TreeNode`] we marked.
@@ -271,17 +302,26 @@ impl<'n, N: TreeNode + HashNode, C: CSEController<Node = N>> CSEVisitor<'_, 'n,
271302 /// information up from children to parents via `visit_stack` during the first,
272303 /// visiting traversal and no need to test the expression's validity beforehand with
273304 /// an extra traversal).
274- fn pop_enter_mark ( & mut self ) -> ( usize , Option < Identifier < ' n , N > > , bool ) {
275- let mut node_id = None ;
305+ fn pop_enter_mark (
306+ & mut self ,
307+ enable_normalize : bool ,
308+ ) -> ( usize , Option < Identifier < ' n , N > > , bool ) {
309+ let mut node_ids: Vec < Identifier < ' n , N > > = vec ! [ ] ;
276310 let mut is_valid = true ;
277311
278312 while let Some ( item) = self . visit_stack . pop ( ) {
279313 match item {
280314 VisitRecord :: EnterMark ( down_index) => {
315+ if enable_normalize {
316+ node_ids. sort_by_key ( |i| i. hash ) ;
317+ }
318+ let node_id = node_ids
319+ . into_iter ( )
320+ . fold ( None , |accum, item| Some ( item. combine ( accum) ) ) ;
281321 return ( down_index, node_id, is_valid) ;
282322 }
283323 VisitRecord :: NodeItem ( sub_node_id, sub_node_is_valid) => {
284- node_id = Some ( sub_node_id . combine ( node_id ) ) ;
324+ node_ids . push ( sub_node_id ) ;
285325 is_valid &= sub_node_is_valid;
286326 }
287327 }
@@ -290,8 +330,10 @@ impl<'n, N: TreeNode + HashNode, C: CSEController<Node = N>> CSEVisitor<'_, 'n,
290330 }
291331}
292332
293- impl < ' n , N : TreeNode + HashNode + Eq , C : CSEController < Node = N > > TreeNodeVisitor < ' n >
294- for CSEVisitor < ' _ , ' n , N , C >
333+ impl < ' n , N , C > TreeNodeVisitor < ' n > for CSEVisitor < ' _ , ' n , N , C >
334+ where
335+ N : TreeNode + HashNode + NormalizeNode ,
336+ C : CSEController < Node = N > ,
295337{
296338 type Node = N ;
297339
@@ -331,7 +373,8 @@ impl<'n, N: TreeNode + HashNode + Eq, C: CSEController<Node = N>> TreeNodeVisito
331373 }
332374
333375 fn f_up ( & mut self , node : & ' n Self :: Node ) -> Result < TreeNodeRecursion > {
334- let ( down_index, sub_node_id, sub_node_is_valid) = self . pop_enter_mark ( ) ;
376+ let ( down_index, sub_node_id, sub_node_is_valid) =
377+ self . pop_enter_mark ( node. enable_normalized ( ) ) ;
335378
336379 let node_id = Identifier :: new ( node, self . random_state ) . combine ( sub_node_id) ;
337380 let is_valid = C :: is_valid ( node) && sub_node_is_valid;
@@ -369,7 +412,11 @@ impl<'n, N: TreeNode + HashNode + Eq, C: CSEController<Node = N>> TreeNodeVisito
369412/// Rewrite a [`TreeNode`] tree by replacing detected common subtrees with the
370413/// corresponding temporary [`TreeNode`], that column contains the evaluate result of
371414/// replaced [`TreeNode`] tree.
372- struct CSERewriter < ' a , ' n , N , C : CSEController < Node = N > > {
415+ struct CSERewriter < ' a , ' n , N , C >
416+ where
417+ N : NormalizeNode ,
418+ C : CSEController < Node = N > ,
419+ {
373420 /// statistics of [`TreeNode`]s
374421 node_stats : & ' a NodeStats < ' n , N > ,
375422
@@ -386,8 +433,10 @@ struct CSERewriter<'a, 'n, N, C: CSEController<Node = N>> {
386433 controller : & ' a mut C ,
387434}
388435
389- impl < N : TreeNode + Eq , C : CSEController < Node = N > > TreeNodeRewriter
390- for CSERewriter < ' _ , ' _ , N , C >
436+ impl < N , C > TreeNodeRewriter for CSERewriter < ' _ , ' _ , N , C >
437+ where
438+ N : TreeNode + NormalizeNode ,
439+ C : CSEController < Node = N > ,
391440{
392441 type Node = N ;
393442
@@ -408,13 +457,30 @@ impl<N: TreeNode + Eq, C: CSEController<Node = N>> TreeNodeRewriter
408457 self . down_index += 1 ;
409458 }
410459
411- let ( node, alias) =
412- self . common_nodes . entry ( node_id) . or_insert_with ( || {
413- let node_alias = self . controller . generate_alias ( ) ;
414- ( node, node_alias)
415- } ) ;
416-
417- let rewritten = self . controller . rewrite ( node, alias) ;
460+ // We *must* replace all original nodes with same `node_id`, not just the first
461+ // node which is inserted into the common_nodes. This is because nodes with the same
462+ // `node_id` are semantically equivalent, but not exactly the same.
463+ //
464+ // For example, `a + 1` and `1 + a` are semantically equivalent but not identical.
465+ // In this case, we should replace the common expression `1 + a` with a new variable
466+ // (e.g., `__common_cse_1`). So, `a + 1` and `1 + a` would both be replaced by
467+ // `__common_cse_1`.
468+ //
469+ // The final result would be:
470+ // - `__common_cse_1 as a + 1`
471+ // - `__common_cse_1 as 1 + a`
472+ //
473+ // This way, we can efficiently handle semantically equivalent expressions without
474+ // incorrectly treating them as identical.
475+ let rewritten = if let Some ( ( _, alias) ) = self . common_nodes . get ( & node_id)
476+ {
477+ self . controller . rewrite ( & node, alias)
478+ } else {
479+ let node_alias = self . controller . generate_alias ( ) ;
480+ let rewritten = self . controller . rewrite ( & node, & node_alias) ;
481+ self . common_nodes . insert ( node_id, ( node, node_alias) ) ;
482+ rewritten
483+ } ;
418484
419485 return Ok ( Transformed :: new ( rewritten, true , TreeNodeRecursion :: Jump ) ) ;
420486 }
@@ -441,7 +507,11 @@ pub struct CSE<N, C: CSEController<Node = N>> {
441507 controller : C ,
442508}
443509
444- impl < N : TreeNode + HashNode + Clone + Eq , C : CSEController < Node = N > > CSE < N , C > {
510+ impl < N , C > CSE < N , C >
511+ where
512+ N : TreeNode + HashNode + Clone + NormalizeNode ,
513+ C : CSEController < Node = N > ,
514+ {
445515 pub fn new ( controller : C ) -> Self {
446516 Self {
447517 random_state : RandomState :: new ( ) ,
@@ -557,6 +627,7 @@ impl<N: TreeNode + HashNode + Clone + Eq, C: CSEController<Node = N>> CSE<N, C>
557627 ) -> Result < FoundCommonNodes < N > > {
558628 let mut found_common = false ;
559629 let mut node_stats = NodeStats :: new ( ) ;
630+
560631 let id_arrays_list = nodes_list
561632 . iter ( )
562633 . map ( |nodes| {
@@ -596,7 +667,9 @@ impl<N: TreeNode + HashNode + Clone + Eq, C: CSEController<Node = N>> CSE<N, C>
596667#[ cfg( test) ]
597668mod test {
598669 use crate :: alias:: AliasGenerator ;
599- use crate :: cse:: { CSEController , HashNode , IdArray , Identifier , NodeStats , CSE } ;
670+ use crate :: cse:: {
671+ CSEController , HashNode , IdArray , Identifier , NodeStats , NormalizeNode , CSE ,
672+ } ;
600673 use crate :: tree_node:: tests:: TestTreeNode ;
601674 use crate :: Result ;
602675 use std:: collections:: HashSet ;
@@ -662,6 +735,15 @@ mod test {
662735 }
663736 }
664737
738+ impl NormalizeNode for TestTreeNode < String > {
739+ fn normalize ( & self ) -> Self {
740+ self . clone ( )
741+ }
742+ fn enable_normalized ( & self ) -> bool {
743+ false
744+ }
745+ }
746+
665747 #[ test]
666748 fn id_array_visitor ( ) -> Result < ( ) > {
667749 let alias_generator = AliasGenerator :: new ( ) ;
0 commit comments