From 359a1a0ce2f787e9e3674e26359b7694a130b7c6 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Mon, 25 May 2026 12:04:04 +0200 Subject: [PATCH] [SPARK-57038][SQL] Use `Expression.references` in SPJ planning MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? - Document `AttributeSet`'s iteration-order contract on the class scaladoc: iteration via `iterator` / `foreach` / `flatMap` returns elements in insertion order (driven by the underlying `LinkedHashSet`). `toSeq` is called out as the explicit exception — it sorts by `(name, exprId.id)` for codegen stability (SPARK-18394). - Migrate seven SPJ-related uses of `_.collectLeaves()` in `partitioning.scala` and `EnsureRequirements.scala` to `_.references` / `AttributeSet.fromAttributeSets(...)`. Drops the now-redundant `.map(_.asInstanceOf[Attribute])` cast at the `EnsureRequirements:89` site. - Update `EnsureRequirementsSuite` synthetic fixtures (`exprA..D`) from `Literal(1..4)` to `AttributeReference`s. The literals were stand-ins for columns; under the migration's `_.references`-based attribute extraction, literal children produce empty `AttributeSet`s and trip the planner's "exactly one attribute per partition expression" assertions. Real partitionings can't reach those assertions with literal-only transforms because `KeyedPartitioning.supportsExpressions`'s `isReference` check filters them out — so this is a fixture-only update that preserves test intent. ### Why are the changes needed? `TreeNode.collectLeaves()` returns every node in the tree where `children.isEmpty`, including `Literal`s. SPJ planning has always wanted attributes only, but with the existing partition expression layout (`TransformExpression.children = [col]`, parameters carried in a sidecar `numBucketsOpt: Option[Int]` field), the difference didn't surface. Follow-up work (e.g. SPARK-50593 / #55885) that puts literal parameters directly into `TransformExpression.children` (`bucket(Literal(numBuckets), col)`, `truncate(col, Literal(width))`) would otherwise force `TransformExpression` to override `collectLeaves` to filter literals, breaking the universal `TreeNode.collectLeaves` contract for one expression type. `Expression.references` already returns attributes only (filtering literals and other non-attribute leaves), and its insertion-ordered iteration is exactly what positional binding (`RowOrdering.create`, `reorder`, `attributes.zip(clustering)`) requires. The per-partition-expression single-column rule (enforced by `KeyedPartitioning.supportsExpressions`) ensures within-expression dedup never matters here. Documenting the iteration-order contract lets these call sites rely on the order without implicit dependency on implementation detail. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Covered by existing test suites that exercise the migrated call sites: `KeyGroupedPartitioningSuite`, `EnsureRequirementsSuite`, `ProjectedOrderingAndPartitioningSuite`. ### Was this patch authored or co-authored using generative AI tooling? Generated-by: Claude Opus 4.7 --- .../sql/catalyst/expressions/AttributeSet.scala | 4 ++++ .../catalyst/plans/physical/partitioning.scala | 10 +++++----- .../execution/exchange/EnsureRequirements.scala | 17 ++++++++++++----- .../exchange/EnsureRequirementsSuite.scala | 8 ++++---- 4 files changed, 25 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala index 236380b2c030b..d958cba27933e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala @@ -65,6 +65,10 @@ object AttributeSet { * `AttributeReference("a"...) == AttributeReference("b", ...)`. This tactic leads to broken tests, * and also makes doing transformations hard (we always try keep older trees instead of new ones * when the transformation was a no-op). + * + * Iteration via [[iterator]], [[foreach]], or [[Iterable]]-derived combinators (`flatMap`, etc.) + * visits elements in insertion order. Note: [[toSeq]] is an explicit exception -- it sorts by + * `(name, exprId.id)` for stable codegen output, see SPARK-18394. */ class AttributeSet private (private val baseSet: mutable.LinkedHashSet[AttributeEquals]) extends Iterable[Attribute] with Serializable { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index e92bb0f7c0d69..aeacdaec7a8de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -578,13 +578,13 @@ case class KeyedPartitioning( c.areAllClusterKeysMatched(expressions) } else { // We'll need to find leaf attributes from the partition expressions first. - lazy val attributes = expressions.flatMap(_.collectLeaves()) + lazy val attributes = AttributeSet.fromAttributeSets(expressions.map(_.references)) if (SQLConf.get.v2BucketingAllowKeysSubsetOfPartitionKeys) { // check that operation keys (required clustering keys) // overlap with partition keys (KeyedPartitioning attributes) requiredClustering.exists(x => attributes.exists(_.semanticEquals(x))) && - expressions.forall(_.collectLeaves().size == 1) + expressions.forall(_.references.size == 1) } else if (isNarrowed && !isGrouped) { // A narrowed, non-grouped partitioning carries the same skew risk as using a subset of // partition keys for a join: GroupPartitionsExec will merge partitions that held @@ -1218,9 +1218,9 @@ case class KeyedShuffleSpec( distKeyToPos.getOrElseUpdate(distKey.canonicalized, mutable.BitSet.empty).add(distKeyPos) } partitioning.expressions.map { e => - val leaves = e.collectLeaves() - assert(leaves.size == 1, s"Expected exactly one child from $e, but found ${leaves.size}") - distKeyToPos.getOrElse(leaves.head.canonicalized, mutable.BitSet.empty) + val refs = e.references + assert(refs.size == 1, s"Expected exactly one child from $e, but found ${refs.size}") + distKeyToPos.getOrElse(refs.head.canonicalized, mutable.BitSet.empty) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index c5a08e983e610..c632b3d841e61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -86,8 +86,9 @@ case class EnsureRequirements( // Find any KeyedPartitioning that satisfies via groupedSatisfies. val satisfyingKeyedPartitioning = groupedSatisfies.orElse(nonGroupedSatisfiesWhenGrouped).get - val attrs = satisfyingKeyedPartitioning.expressions.flatMap(_.collectLeaves()) - .map(_.asInstanceOf[Attribute]) + // The single-column invariant in KeyedPartitioning.supportsExpressions guarantees + // one attribute per partition expression. + val attrs = satisfyingKeyedPartitioning.expressions.flatMap(_.references) val keyRowOrdering = RowOrdering.create(o.ordering, attrs) val keyOrdering = keyRowOrdering.on((t: InternalRowComparableWrapper) => t.row) if (satisfyingKeyedPartitioning.partitionKeys.sliding(2).forall { @@ -409,12 +410,16 @@ case class EnsureRequirements( .orElse(reorderJoinKeysRecursively( leftKeys, rightKeys, leftPartitioning, None)) case (Some(KeyedPartitioning(clustering, _, _, _)), _) => - val leafExprs = clustering.flatMap(_.collectLeaves()) + // The single-column invariant in KeyedPartitioning.supportsExpressions guarantees one + // attribute per partition expression. + val leafExprs = clustering.flatMap(_.references) reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, leftKeys) .orElse(reorderJoinKeysRecursively( leftKeys, rightKeys, None, rightPartitioning)) case (_, Some(KeyedPartitioning(clustering, _, _, _))) => - val leafExprs = clustering.flatMap(_.collectLeaves()) + // The single-column invariant in KeyedPartitioning.supportsExpressions guarantees one + // attribute per partition expression. + val leafExprs = clustering.flatMap(_.references) reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, rightKeys) .orElse(reorderJoinKeysRecursively( leftKeys, rightKeys, leftPartitioning, None)) @@ -777,7 +782,9 @@ case class EnsureRequirements( partitioning: Partitioning, distribution: ClusteredDistribution): Option[KeyedShuffleSpec] = { def tryCreate(partitioning: KeyedPartitioning): Option[KeyedShuffleSpec] = { - val attributes = partitioning.expressions.flatMap(_.collectLeaves()) + // The single-column invariant in KeyedPartitioning.supportsExpressions guarantees one + // attribute per partition expression. + val attributes = partitioning.expressions.flatMap(_.references) val clustering = distribution.clustering val satisfies = if (SQLConf.get.getConf(SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION)) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala index 74b706bce34f1..17d00ec055e07 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala @@ -40,10 +40,10 @@ import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, StructField, StructType} class EnsureRequirementsSuite extends SharedSparkSession { - private val exprA = Literal(1) - private val exprB = Literal(2) - private val exprC = Literal(3) - private val exprD = Literal(4) + private val exprA = AttributeReference("a", IntegerType)() + private val exprB = AttributeReference("b", IntegerType)() + private val exprC = AttributeReference("c", IntegerType)() + private val exprD = AttributeReference("d", IntegerType)() private val EnsureRequirements = new EnsureRequirements()