From 619a686b39e42820df429d66bf74de3545606593 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Thu, 14 May 2026 21:50:28 +0000 Subject: [PATCH] [SPARK-56869][SQL] Speed up TreeNode transforms when rule doesn't match MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? Gate `CurrentOrigin.withOrigin(origin) { rule.applyOrElse(...) }` in four `TreeNode` rule-driven transform methods behind `rule.isDefinedAt(...)`, so the ThreadLocal wrap only runs when the rule actually fires: - `TreeNode.transformDownWithPruning` - `TreeNode.transformUpWithPruning` - `TreeNode.transformUpWithBeforeAndAfterRuleOnChildren` - `TreeNode.multiTransformDownWithPruning` (also drops a side-effecting default closure that became unnecessary) ### Why are the changes needed? `CurrentOrigin.withOrigin` is only observable when the rule constructs new nodes — they pick up `CurrentOrigin.get` in their `override val origin` field. On nodes the rule doesn't match, the wrap is pure overhead: two `ThreadLocal` writes plus a `try`/`finally` per node visit. JFR profiling (60s sample, 1.1M iterations of `transformDown` over a 1024-leaf balanced `Add` tree with a non-matching rule) shows: - 66% of CPU samples in `ThreadLocalMap.set` (line 486) - 13% in `ThreadLocalMap.getEntryAfterMiss` - 9% more in `ThreadLocalMap.set` (line 493) Total: ~88% of transform CPU spent inside `CurrentOrigin.withOrigin` for nodes the rule never matched. Microbenchmark (best time per N iterations, JDK 17, Xeon 8175M @ 2.50GHz, baseline = `upstream/master`): | case | baseline | optimized | speedup | |--------------------------------------------|---------:|----------:|--------:| | transformDown deep chain(5000) no-op | 12 ms | 6 ms | 2.0x | | transformDown deep chain(5000) rewrite leaf| 20109 ms| 15850 ms | 1.27x | | transformDown wide(100) no-op | 5 ms | 3 ms | 1.7x | | transformDown balanced(1024) no-op | 7 ms | 2 ms | 3.5x | | transformDown balanced(4096) no-op | 34 ms | 15 ms | 2.3x | | transformUp deep chain(1000) no-op | 3 ms | 1 ms | 3.0x | | transformUp deep chain(5000) no-op | 15 ms | 6 ms | 2.5x | | transformUp balanced(1024) no-op | 8 ms | 3 ms | 2.7x | | transformUp balanced(4096) no-op | 39 ms | 25 ms | 1.6x | Rewrite-heavy cases are unchanged because `withOrigin` still runs when the rule fires. Real Spark workloads (analyzer/optimizer batches running many rules across many nodes, each rule matching a small subset) are dominated by the no-match case, so the savings compound. ### Does this PR introduce _any_ user-facing change? No. The change is internal to `TreeNode` rule machinery and preserves all observable semantics: - `CurrentOrigin` is still set before any rule body that constructs new nodes runs. - `markRuleAsIneffective` / `isRuleIneffective` bookkeeping unchanged. - `copyTagsFrom` ordering on rule-replacement unchanged. - `fastEquals` short-circuit unchanged. - Result identity (`this eq result` on a no-op transform) preserved. ### How was this patch tested? - `build/sbt 'catalyst/testOnly *TreeNodeSuite'` — 36/36 pass. - `build/sbt 'catalyst/test'` — 9272/9272 pass across 352 suites (~7 min). ### Was this patch authored or co-authored using generative AI tooling? Generated-by: Claude Code (Anthropic Claude Opus 4.7) Co-authored-by: Isaac --- .../spark/sql/catalyst/trees/TreeNode.scala | 70 +++++++++++++------ 1 file changed, 47 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index e82e6a30b9bba..362c7d592295f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -491,12 +491,29 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] if (!cond.apply(this) || isRuleIneffective(ruleId)) { return this } - val afterRule = CurrentOrigin.withOrigin(origin) { - rule.applyOrElse(this, identity[BaseType]) - } - - // Check if unchanged and then possibly return old copy to avoid gc churn. - if (this fastEquals afterRule) { + // CurrentOrigin.withOrigin is only observable when the rule constructs new nodes + // (they pick up origin via `override val origin = CurrentOrigin.get`). When the rule + // doesn't fire, the wrapping is pure ThreadLocal thrash — profiling shows ~88% of + // transform CPU goes to ThreadLocal.set / get on no-match nodes. Skip it. + if (rule.isDefinedAt(this)) { + val afterRule = CurrentOrigin.withOrigin(origin) { + rule.apply(this) + } + // Check if unchanged and then possibly return old copy to avoid gc churn. + if (this fastEquals afterRule) { + val rewritten_plan = mapChildren(_.transformDownWithPruning(cond, ruleId)(rule)) + if (this eq rewritten_plan) { + markRuleAsIneffective(ruleId) + this + } else { + rewritten_plan + } + } else { + // If the transform function replaces this node with a new one, carry over the tags. + afterRule.copyTagsFrom(this) + afterRule.mapChildren(_.transformDownWithPruning(cond, ruleId)(rule)) + } + } else { val rewritten_plan = mapChildren(_.transformDownWithPruning(cond, ruleId)(rule)) if (this eq rewritten_plan) { markRuleAsIneffective(ruleId) @@ -504,10 +521,6 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] } else { rewritten_plan } - } else { - // If the transform function replaces this node with a new one, carry over the tags. - afterRule.copyTagsFrom(this) - afterRule.mapChildren(_.transformDownWithPruning(cond, ruleId)(rule)) } } @@ -544,14 +557,15 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] return this } val afterRuleOnChildren = mapChildren(_.transformUpWithPruning(cond, ruleId)(rule)) - val newNode = if (this fastEquals afterRuleOnChildren) { + // Skip the CurrentOrigin.withOrigin wrap when the rule doesn't match — see + // transformDownWithPruning above for the same optimization. + val target = if (this fastEquals afterRuleOnChildren) this else afterRuleOnChildren + val newNode = if (rule.isDefinedAt(target)) { CurrentOrigin.withOrigin(origin) { - rule.applyOrElse(this, identity[BaseType]) + rule.apply(target) } } else { - CurrentOrigin.withOrigin(origin) { - rule.applyOrElse(afterRuleOnChildren, identity[BaseType]) - } + target } if (this eq newNode) { markRuleAsIneffective(ruleId) @@ -588,8 +602,15 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] } val afterRuleOnChildren = mapChildren(_.transformUpWithBeforeAndAfterRuleOnChildren(cond, ruleId)(rule)) - val newNode = CurrentOrigin.withOrigin(origin) { - rule.applyOrElse((this, afterRuleOnChildren), { t: (BaseType, BaseType) => t._2 }) + // Skip the CurrentOrigin.withOrigin wrap when the rule doesn't match — see + // transformDownWithPruning above. The default would return afterRuleOnChildren anyway. + val key = (this, afterRuleOnChildren) + val newNode = if (rule.isDefinedAt(key)) { + CurrentOrigin.withOrigin(origin) { + rule.apply(key) + } + } else { + afterRuleOnChildren } if (this eq newNode) { this.markRuleAsIneffective(ruleId) @@ -685,12 +706,15 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] // alternatives. I.e. the "multiTransformDown is lazy" test case in `TreeNodeSuite` would fail. // Please note that this behaviour has a downside as well that we can only mark the rule on the // original node ineffective if the rule didn't match. - var ruleApplied = true - val afterRules = CurrentOrigin.withOrigin(origin) { - rule.applyOrElse(this, (_: BaseType) => { - ruleApplied = false - Seq.empty - }) + // Skip the CurrentOrigin.withOrigin wrap when the rule doesn't match — see + // transformDownWithPruning above. Also lets us drop the side-effecting default. + val ruleApplied = rule.isDefinedAt(this) + val afterRules: Seq[BaseType] = if (ruleApplied) { + CurrentOrigin.withOrigin(origin) { + rule.apply(this) + } + } else { + Seq.empty } val afterRulesLazyList = if (afterRules.isEmpty) {