From 7a9469083e4044306943a36a5c9cbb0a922db108 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Wed, 11 Feb 2026 19:42:25 +0100 Subject: [PATCH 01/29] Replace `KeyGroupedPartitioning` with `KeyedPartitioning`, add new `GroupPartitionsExec` operator, remove old code --- .../scala/org/apache/spark/Partitioner.scala | 2 +- .../plans/physical/partitioning.scala | 235 ++++++++---- ...nternalRowComparableWrapperBenchmark.scala | 8 +- .../execution/KeyGroupedPartitionedScan.scala | 184 --------- .../datasources/v2/BatchScanExec.scala | 79 ++-- .../datasources/v2/ContinuousScanExec.scala | 4 +- .../datasources/v2/DataSourceRDD.scala | 92 ++--- .../v2/DataSourceV2ScanExecBase.scala | 99 +---- .../datasources/v2/DataSourceV2Strategy.scala | 4 +- .../datasources/v2/GroupPartitionsExec.scala | 221 +++++++++++ .../exchange/EnsureRequirements.scala | 234 ++++++------ .../exchange/ShuffleExchangeExec.scala | 12 +- .../joins/StoragePartitionJoinParams.scala | 48 --- .../DistributionAndOrderingSuiteBase.scala | 6 +- .../KeyGroupedPartitioningSuite.scala | 356 +++++++++++------- .../exchange/EnsureRequirementsSuite.scala | 215 ++++++----- 16 files changed, 949 insertions(+), 850 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/StoragePartitionJoinParams.scala diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 7d086f34f6983..7dd6015b2f87e 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -142,7 +142,7 @@ private[spark] class PartitionIdPassthrough(override val numPartitions: Int) ext /** * A [[org.apache.spark.Partitioner]] that partitions all records using partition value map. * The `valueMap` is a map that contains tuples of (partition value, partition id). It is generated - * by [[org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning]], used to partition + * by [[org.apache.spark.sql.catalyst.plans.physical.KeyedPartitioning]], used to partition * the other side of a join to make sure records with same partition value are in the same * partition. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index b0fa4f889cda1..c999a5cb7b938 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.plans.physical +import java.util.Objects + import scala.annotation.tailrec import scala.collection.mutable @@ -346,43 +348,113 @@ case class CoalescedHashPartitioning(from: HashPartitioning, partitions: Seq[Coa } /** - * Represents a partitioning where rows are split across partitions based on transforms defined - * by `expressions`. `partitionValues`, if defined, should contain value of partition key(s) in - * ascending order, after evaluated by the transforms in `expressions`, for each input partition. - * In addition, its length must be the same as the number of Spark partitions (and thus is a 1-1 - * mapping), and each row in `partitionValues` must be unique. + * Represents a partitioning where rows are split across partitions based on transforms defined by + * `expressions`. + * + * == Partition Keys == + * This partitioning has two sets of partition keys: + * + * - `partitionKeys`: The current partition key for each partition, in ascending order. May contain + * duplicates when first created from a data source, but becomes unique after grouping. + * + * - `originalPartitionKeys`: The original partition keys from the data source, in ascending order. + * Always preserves the original values, even after grouping. Used to track the original + * distribution for optimization purposes. + * + * == Grouping State == + * A KeyedPartitioning can be in two states: + * + * - '''Ungrouped''' (when `isGrouped == false`): `partitionKeys` contains duplicates. Multiple + * input partitions share the same key. This is the initial state when created from a data source. + * + * - '''Grouped''' (when `isGrouped == true`): `partitionKeys` contains only unique values. Each + * partition has a distinct key. This state is achieved by applying `GroupPartitionsExec`, which + * coalesces partitions with the same key. * - * The `originalPartitionValues`, on the other hand, are partition values from the original input - * splits returned by data sources. It may contain duplicated values. + * == Example == + * Consider a data source with partition transform `[years(ts_col)]` and 4 input splits: * - * For example, if a data source reports partition transform expressions `[years(ts_col)]` with 4 - * input splits whose corresponding partition values are `[0, 1, 2, 2]`, then the `expressions` - * in this case is `[years(ts_col)]`, while `partitionValues` is `[0, 1, 2]`, which - * represents 3 input partitions with distinct partition values. All rows in each partition have - * the same value for column `ts_col` (which is of timestamp type), after being applied by the - * `years` transform. This is generated after combining the two splits with partition value `2` - * into a single Spark partition. + * '''Before GroupPartitionsExec''' (ungrouped): + * {{{ + * expressions: [years(ts_col)] + * partitionKeys: [0, 1, 2, 2] // partition 2 and 3 have the same key + * originalPartitionKeys: [0, 1, 2, 2] + * numPartitions: 4 + * isGrouped: false + * }}} * - * On the other hand, in this example `[0, 1, 2, 2]` is the value of `originalPartitionValues` - * which is calculated from the original input splits. + * '''After GroupPartitionsExec''' (grouped): + * {{{ + * expressions: [years(ts_col)] + * partitionKeys: [0, 1, 2] // duplicates removed, partitions coalesced + * originalPartitionKeys: [0, 1, 2, 2] // unchanged, preserves original distribution + * numPartitions: 3 + * isGrouped: true + * }}} * - * @param expressions partition expressions for the partitioning. - * @param numPartitions the number of partitions - * @param partitionValues the values for the final cluster keys (that is, after applying grouping - * on the input splits according to `expressions`) of the distribution, - * must be in ascending order, and must NOT contain duplicated values. - * @param originalPartitionValues the original input partition values before any grouping has been - * applied, must be in ascending order, and may contain duplicated - * values + * @param expressions Partition transform expressions (e.g., `years(col)`, `bucket(10, col)`). + * @param partitionKeys Current partition keys, one per partition, in ascending order. + * May contain duplicates before grouping. + * @param originalPartitionKeys Original partition keys from the data source, in ascending order. + * Preserves the initial distribution even after grouping. */ -case class KeyGroupedPartitioning( +case class KeyedPartitioning( expressions: Seq[Expression], - numPartitions: Int, - partitionValues: Seq[InternalRow] = Seq.empty, - originalPartitionValues: Seq[InternalRow] = Seq.empty) extends HashPartitioningLike { + partitionKeys: Seq[InternalRow], + originalPartitionKeys: Seq[InternalRow]) extends Expression with Partitioning with Unevaluable { + override val numPartitions = partitionKeys.length + + override def children: Seq[Expression] = expressions + override def nullable: Boolean = false + override def dataType: DataType = IntegerType + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): KeyedPartitioning = + copy(expressions = newChildren) + + @transient private lazy val dataTypes: Seq[DataType] = expressions.map(_.dataType) + + @transient private lazy val comparableWrapperFactory = + InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(dataTypes) + + @transient private lazy val rowOrdering = RowOrdering.createNaturalAscendingOrdering(dataTypes) + + @transient lazy val isGrouped: Boolean = { + partitionKeys.map(comparableWrapperFactory).distinct.size == partitionKeys.size + } + + def toGrouped: KeyedPartitioning = { + val groupedPartitions = partitionKeys + .map(comparableWrapperFactory) + .distinct + .map(_.row) + .sorted(rowOrdering) + + KeyedPartitioning(expressions, groupedPartitions, originalPartitionKeys) + } + + def projectAndGroup(positions: Seq[Int]): KeyedPartitioning = { + val projectedExpressions = positions.map(expressions) + val projectedDataTypes = projectedExpressions.map(_.dataType) + val projectedPartitionKeys = partitionKeys.map( + KeyedPartitioning.projectKey(_, positions, projectedDataTypes) + ) + val projectedOriginalPartitionKeys = originalPartitionKeys.map( + KeyedPartitioning.projectKey(_, positions, projectedDataTypes) + ) + val internalRowComparableWrapperFactory = + InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(projectedDataTypes) + val distinctPartitionKeys = projectedPartitionKeys + .map(internalRowComparableWrapperFactory) + .distinct + .map(_.row) + + copy(expressions = projectedExpressions, partitionKeys = distinctPartitionKeys, + originalPartitionKeys = projectedOriginalPartitionKeys) + } override def satisfies0(required: Distribution): Boolean = { - super.satisfies0(required) || { + super.satisfies0(required) || isGrouped && { required match { case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) => if (requireAllClusterKeys) { @@ -395,9 +467,9 @@ case class KeyGroupedPartitioning( if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) { // check that join keys (required clustering keys) - // overlap with partition keys (KeyGroupedPartitioning attributes) + // overlap with partition keys (KeyedPartitioning attributes) requiredClustering.exists(x => attributes.exists(_.semanticEquals(x))) && - expressions.forall(_.collectLeaves().size == 1) + expressions.forall(_.collectLeaves().size == 1) } else { attributes.forall(x => requiredClustering.exists(_.semanticEquals(x))) } @@ -416,63 +488,37 @@ case class KeyGroupedPartitioning( val result = KeyGroupedShuffleSpec(this, distribution) if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) { // If allowing join keys to be subset of clustering keys, we should create a new - // `KeyGroupedPartitioning` here that is grouped on the join keys instead, and use that as + // `KeyedPartitioning` here that is grouped on the join keys instead, and use that as // the returned shuffle spec. val joinKeyPositions = result.keyPositions.map(_.nonEmpty).zipWithIndex.filter(_._1).map(_._2) - val projectedPartitioning = KeyGroupedPartitioning(expressions, joinKeyPositions, - partitionValues, originalPartitionValues) + val projectedPartitioning = projectAndGroup(joinKeyPositions) result.copy(partitioning = projectedPartitioning, joinKeyPositions = Some(joinKeyPositions)) } else { result } } - lazy val uniquePartitionValues: Seq[InternalRow] = { - val internalRowComparableFactory = - InternalRowComparableWrapper.getInternalRowComparableWrapperFactory( - expressions.map(_.dataType)) - partitionValues - .map(internalRowComparableFactory) - .distinct - .map(_.row) - } - - override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = - copy(expressions = newChildren) -} + override def equals(that: Any): Boolean = that match { + case k: KeyedPartitioning if this.expressions == k.expressions => + def keysEqual(keys1: Seq[InternalRow], keys2: Seq[InternalRow]): Boolean = { + keys1.size == keys2.size && keys1.zip(keys2).forall { case (l, r) => + comparableWrapperFactory(l).equals(comparableWrapperFactory(r)) + } + } -object KeyGroupedPartitioning { - def apply( - expressions: Seq[Expression], - projectionPositions: Seq[Int], - partitionValues: Seq[InternalRow], - originalPartitionValues: Seq[InternalRow]): KeyGroupedPartitioning = { - val projectedExpressions = projectionPositions.map(expressions(_)) - val projectedPartitionValues = partitionValues.map(project(expressions, projectionPositions, _)) - val projectedOriginalPartitionValues = - originalPartitionValues.map(project(expressions, projectionPositions, _)) - val internalRowComparableFactory = - InternalRowComparableWrapper.getInternalRowComparableWrapperFactory( - projectedExpressions.map(_.dataType)) - - val finalPartitionValues = projectedPartitionValues - .map(internalRowComparableFactory) - .distinct - .map(_.row) + keysEqual(partitionKeys, k.partitionKeys) && + keysEqual(originalPartitionKeys, k.originalPartitionKeys) - KeyGroupedPartitioning(projectedExpressions, finalPartitionValues.length, - finalPartitionValues, projectedOriginalPartitionValues) + case _ => false } - def project( - expressions: Seq[Expression], - positions: Seq[Int], - input: InternalRow): InternalRow = { - val projectedValues: Array[Any] = positions.map(i => input.get(i, expressions(i).dataType)) - .toArray - new GenericInternalRow(projectedValues) + override def hashCode(): Int = { + Objects.hash(expressions, partitionKeys.map(comparableWrapperFactory), + originalPartitionKeys.map(comparableWrapperFactory)) } +} +object KeyedPartitioning { def supportsExpressions(expressions: Seq[Expression]): Boolean = { def isSupportedTransform(transform: TransformExpression): Boolean = { transform.children.size == 1 && isReference(transform.children.head) @@ -491,6 +537,28 @@ object KeyGroupedPartitioning { case _ => false } } + + def projectKey( + key: InternalRow, + positions: Seq[Int], + dataTypes: Seq[DataType]): InternalRow = { + val projectedKey = positions.zip(dataTypes).map { + case (position, dataType) => key.get(position, dataType) + }.toArray[Any] + new GenericInternalRow(projectedKey) + } + + def reduceKey( + key: InternalRow, + reducers: Seq[Option[Reducer[_, _]]], + dataTypes: Seq[DataType]): InternalRow = { + val keyValues = key.toSeq(dataTypes) + val reducedKey = keyValues.zip(reducers).map{ + case (v, Some(reducer: Reducer[Any, Any])) => reducer.reduce(v) + case (v, _) => v + }.toArray + new GenericInternalRow(reducedKey) + } } /** @@ -827,7 +895,7 @@ case class CoalescedHashShuffleSpec( } /** - * [[ShuffleSpec]] created by [[KeyGroupedPartitioning]]. + * [[ShuffleSpec]] created by [[KeyedPartitioning]]. * * @param partitioning key grouped partitioning * @param distribution distribution @@ -835,10 +903,12 @@ case class CoalescedHashShuffleSpec( * This is set if joining on a subset of cluster keys is allowed. */ case class KeyGroupedShuffleSpec( - partitioning: KeyGroupedPartitioning, + partitioning: KeyedPartitioning, distribution: ClusteredDistribution, joinKeyPositions: Option[Seq[Int]] = None) extends ShuffleSpec { + assert(partitioning.isGrouped) + /** * A sequence where each element is a set of positions of the partition expression to the cluster * keys. For instance, if cluster keys are [a, b, b] and partition expressions are @@ -878,7 +948,7 @@ case class KeyGroupedShuffleSpec( partitioning.expressions.map(_.dataType)) distribution.clustering.length == otherDistribution.clustering.length && numPartitions == other.numPartitions && areKeysCompatible(otherSpec) && - partitioning.partitionValues.zip(otherPartitioning.partitionValues).forall { + partitioning.partitionKeys.zip(otherPartitioning.partitionKeys).forall { case (left, right) => internalRowComparableFactory(left).equals(internalRowComparableFactory(right)) } @@ -959,21 +1029,20 @@ case class KeyGroupedShuffleSpec( te.copy(children = te.children.map(_ => clustering(positionSet.head))) case (_, positionSet) => clustering(positionSet.head) } - KeyGroupedPartitioning(newExpressions, - partitioning.numPartitions, - partitioning.partitionValues) + KeyedPartitioning(newExpressions, partitioning.partitionKeys, + partitioning.originalPartitionKeys) } } object KeyGroupedShuffleSpec { - def reducePartitionValue( + def reducePartitionKey( row: InternalRow, reducers: Seq[Option[Reducer[_, _]]], dataTypes: Seq[DataType], internalRowComparableWrapperFactory: InternalRow => InternalRowComparableWrapper ): InternalRowComparableWrapper = { - val partitionVals = row.toSeq(dataTypes) - val reducedRow = partitionVals.zip(reducers).map{ + val partitionKeys = row.toSeq(dataTypes) + val reducedRow = partitionKeys.zip(reducers).map{ case (v, Some(reducer: Reducer[Any, Any])) => reducer.reduce(v) case (v, _) => v }.toArray diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapperBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapperBenchmark.scala index 764dac35f6736..87cb212253f32 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapperBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapperBenchmark.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.util import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} import org.apache.spark.sql.catalyst.expressions.Literal -import org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning +import org.apache.spark.sql.catalyst.plans.physical.KeyedPartitioning import org.apache.spark.sql.connector.catalog.PartitionInternalRow import org.apache.spark.sql.types.IntegerType @@ -61,10 +61,10 @@ object InternalRowComparableWrapperBenchmark extends BenchmarkBase { // just to mock the data types val expressions = (Seq(Literal(day, IntegerType), Literal(0, IntegerType))) - val leftPartitioning = KeyGroupedPartitioning(expressions, bucketNum, partitions) - val rightPartitioning = KeyGroupedPartitioning(expressions, bucketNum, partitions) + val leftPartitioning = KeyedPartitioning(expressions, partitions, partitions) + val rightPartitioning = KeyedPartitioning(expressions, partitions, partitions) val merged = InternalRowComparableWrapper.mergePartitions( - leftPartitioning.partitionValues, rightPartitioning.partitionValues, expressions) + leftPartitioning.partitionKeys, rightPartitioning.partitionKeys, expressions) assert(merged.size == bucketNum) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala deleted file mode 100644 index cac4a9bc852f6..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala +++ /dev/null @@ -1,184 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.RowOrdering -import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, KeyGroupedShuffleSpec} -import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper -import org.apache.spark.sql.execution.joins.StoragePartitionJoinParams - -/** Base trait for a data source scan capable of producing a key-grouped output. */ -trait KeyGroupedPartitionedScan[T] { - /** - * The output partitioning of this scan after applying any pushed-down SPJ parameters. - * - * @param basePartitioning The original key-grouped partitioning of the scan. - * @param spjParams SPJ parameters for the scan. - */ - def getOutputKeyGroupedPartitioning( - basePartitioning: KeyGroupedPartitioning, - spjParams: StoragePartitionJoinParams): KeyGroupedPartitioning = { - val projectedExpressions = spjParams.joinKeyPositions match { - case Some(projectionPositions) => - projectionPositions.map(i => basePartitioning.expressions(i)) - case _ => basePartitioning.expressions - } - - val newPartValues = spjParams.commonPartitionValues match { - case Some(commonPartValues) => - // We allow duplicated partition values if - // `spark.sql.sources.v2.bucketing.partiallyClusteredDistribution.enabled` is true - commonPartValues.flatMap { - case (partValue, numSplits) => Seq.fill(numSplits)(partValue) - } - case None => - spjParams.joinKeyPositions match { - case Some(projectionPositions) => - val internalRowComparableWrapperFactory = - InternalRowComparableWrapper.getInternalRowComparableWrapperFactory( - projectedExpressions.map(_.dataType)) - basePartitioning.partitionValues.map { r => - val projectedRow = KeyGroupedPartitioning.project(basePartitioning.expressions, - projectionPositions, r) - internalRowComparableWrapperFactory(projectedRow) - }.distinct.map(_.row) - case _ => basePartitioning.partitionValues - } - } - basePartitioning.copy(expressions = projectedExpressions, numPartitions = newPartValues.length, - partitionValues = newPartValues) - } - - /** - * Re-groups the input partitions for this scan based on the provided SPJ params, returning a list - * of partitions to be scanned by each scan task. - * - * @param p The output KeyGroupedPartitioning of this scan. - * @param spjParams SPJ parameters for the scan. - * @param filteredPartitions The input partitions (after applying filtering) to be - * re-grouped for this scan, initially grouped by partition value. - * @param partitionValueAccessor Accessor for the partition values (as an [[InternalRow]]) - */ - def getInputPartitionGrouping( - p: KeyGroupedPartitioning, - spjParams: StoragePartitionJoinParams, - filteredPartitions: Seq[Seq[T]], - partitionValueAccessor: T => InternalRow): Seq[Seq[T]] = { - assert(spjParams.keyGroupedPartitioning.isDefined) - val expressions = spjParams.keyGroupedPartitioning.get - - // Re-group the input partitions if we are projecting on a subset of join keys - val (groupedPartitions, partExpressions) = spjParams.joinKeyPositions match { - case Some(projectPositions) => - val projectedExpressions = projectPositions.map(i => expressions(i)) - val internalRowComparableWrapperFactory = - InternalRowComparableWrapper.getInternalRowComparableWrapperFactory( - projectedExpressions.map(_.dataType)) - val parts = filteredPartitions.flatten.groupBy(part => { - val row = partitionValueAccessor(part) - val projectedRow = KeyGroupedPartitioning.project( - expressions, projectPositions, row) - internalRowComparableWrapperFactory(projectedRow) - }).map { case (wrapper, splits) => (wrapper.row, splits) }.toSeq - (parts, projectedExpressions) - case _ => - val groupedParts = filteredPartitions.map(splits => { - assert(splits.nonEmpty) - (partitionValueAccessor(splits.head), splits) - }) - (groupedParts, expressions) - } - - // Also re-group the partitions if we are reducing compatible partition expressions - val partitionDataTypes = partExpressions.map(_.dataType) - val internalRowComparableWrapperFactory = - InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(partitionDataTypes) - val finalGroupedPartitions = spjParams.reducers match { - case Some(reducers) => - val result = groupedPartitions.groupBy { case (row, _) => - KeyGroupedShuffleSpec.reducePartitionValue( - row, reducers, partitionDataTypes, internalRowComparableWrapperFactory) - }.map { case (wrapper, splits) => (wrapper.row, splits.flatMap(_._2)) }.toSeq - val rowOrdering = RowOrdering.createNaturalAscendingOrdering( - partExpressions.map(_.dataType)) - result.sorted(rowOrdering.on((t: (InternalRow, _)) => t._1)) - case _ => groupedPartitions - } - - // When partially clustered, the input partitions are not grouped by partition - // values. Here we'll need to check `commonPartitionValues` and decide how to group - // and replicate splits within a partition. - if (spjParams.commonPartitionValues.isDefined && spjParams.applyPartialClustering) { - // A mapping from the common partition values to how many splits the partition - // should contain. - val commonPartValuesMap = spjParams.commonPartitionValues - .get - .map(t => (internalRowComparableWrapperFactory(t._1), t._2)) - .toMap - val filteredGroupedPartitions = finalGroupedPartitions.filter { - case (partValues, _) => - commonPartValuesMap.keySet.contains(internalRowComparableWrapperFactory(partValues)) - } - val nestGroupedPartitions = filteredGroupedPartitions.map { case (partValue, splits) => - // `commonPartValuesMap` should contain the part value since it's the super set. - val numSplits = commonPartValuesMap.get(internalRowComparableWrapperFactory(partValue)) - assert(numSplits.isDefined, s"Partition value $partValue does not exist in " + - "common partition values from Spark plan") - - val newSplits = if (spjParams.replicatePartitions) { - // We need to also replicate partitions according to the other side of join - Seq.fill(numSplits.get)(splits) - } else { - // Not grouping by partition values: this could be the side with partially - // clustered distribution. Because of dynamic filtering, we'll need to check if - // the final number of splits of a partition is smaller than the original - // number, and fill with empty splits if so. This is necessary so that both - // sides of a join will have the same number of partitions & splits. - splits.map(Seq(_)).padTo(numSplits.get, Seq.empty) - } - (internalRowComparableWrapperFactory(partValue), newSplits) - } - - // Now fill missing partition keys with empty partitions - val partitionMapping = nestGroupedPartitions.toMap - spjParams.commonPartitionValues.get.flatMap { - case (partValue, numSplits) => - // Use empty partition for those partition values that are not present. - partitionMapping.getOrElse( - internalRowComparableWrapperFactory(partValue), - Seq.fill(numSplits)(Seq.empty)) - } - } else { - // either `commonPartitionValues` is not defined, or it is defined but - // `applyPartialClustering` is false. - val partitionMapping = finalGroupedPartitions.map { case (partValue, splits) => - internalRowComparableWrapperFactory(partValue) -> splits - }.toMap - - // In case `commonPartitionValues` is not defined (e.g., SPJ is not used), there - // could exist duplicated partition values, as partition grouping is not done - // at the beginning and postponed to this method. It is important to use unique - // partition values here so that grouped partitions won't get duplicated. - p.uniquePartitionValues.map { partValue => - // Use empty partition for those partition values that are not present - partitionMapping.getOrElse(internalRowComparableWrapperFactory(partValue), Seq.empty) - } - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index 55866cc858405..a24c90c7a8502 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -24,12 +24,10 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, Partitioning, SinglePartition} +import org.apache.spark.sql.catalyst.plans.physical.{KeyedPartitioning, SinglePartition} import org.apache.spark.sql.catalyst.util.{truncatedString, InternalRowComparableWrapper} import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.read._ -import org.apache.spark.sql.execution.KeyGroupedPartitionedScan -import org.apache.spark.sql.execution.joins.StoragePartitionJoinParams import org.apache.spark.util.ArrayImplicits._ /** @@ -41,8 +39,8 @@ case class BatchScanExec( runtimeFilters: Seq[Expression], ordering: Option[Seq[SortOrder]] = None, @transient table: Table, - spjParams: StoragePartitionJoinParams = StoragePartitionJoinParams() - ) extends DataSourceV2ScanExecBase with KeyGroupedPartitionedScan[InputPartition] { + keyGroupedPartitioning: Option[Seq[Expression]] = None + ) extends DataSourceV2ScanExecBase { @transient lazy val batch: Batch = if (scan == null) null else scan.toBatch @@ -50,8 +48,7 @@ case class BatchScanExec( override def equals(other: Any): Boolean = other match { case other: BatchScanExec => this.batch != null && this.batch == other.batch && - this.runtimeFilters == other.runtimeFilters && - this.spjParams == other.spjParams + this.runtimeFilters == other.runtimeFilters case _ => false } @@ -61,15 +58,14 @@ case class BatchScanExec( @transient override lazy val inputPartitions: Seq[InputPartition] = batch.planInputPartitions().toImmutableArraySeq - @transient private lazy val filteredPartitions: Seq[Seq[InputPartition]] = { + @transient private lazy val filteredPartitions: Seq[Option[InputPartition]] = { val dataSourceFilters = runtimeFilters.flatMap { case DynamicPruningExpression(e) => DataSourceV2Strategy.translateRuntimeFilterV2(e) case _ => None } + val originalPartitioning = outputPartitioning if (dataSourceFilters.nonEmpty) { - val originalPartitioning = outputPartitioning - // the cast is safe as runtime filters are only assigned if the scan can be filtered val filterableScan = scan.asInstanceOf[SupportsRuntimeV2Filtering] filterableScan.filter(dataSourceFilters.toArray) @@ -78,49 +74,65 @@ case class BatchScanExec( val newPartitions = scan.toBatch.planInputPartitions() originalPartitioning match { - case p: KeyGroupedPartitioning => + case p: KeyedPartitioning => if (newPartitions.exists(!_.isInstanceOf[HasPartitionKey])) { throw new SparkException("Data source must have preserved the original partitioning " + "during runtime filtering: not all partitions implement HasPartitionKey after " + "filtering") } - val newPartitionValues = newPartitions.map(partition => + val newPartitionKeys = newPartitions.map(partition => InternalRowComparableWrapper(partition.asInstanceOf[HasPartitionKey], p.expressions)) .toSet - val oldPartitionValues = p.partitionValues + val oldPartitionKeys = p.partitionKeys .map(partition => InternalRowComparableWrapper(partition, p.expressions)).toSet // We require the new number of partition values to be equal or less than the old number // of partition values here. In the case of less than, empty partitions will be added for // those missing values that are not present in the new input partitions. - if (oldPartitionValues.size < newPartitionValues.size) { + if (oldPartitionKeys.size < newPartitionKeys.size) { throw new SparkException("During runtime filtering, data source must either report " + "the same number of partition values, or a subset of partition values from the " + - s"original. Before: ${oldPartitionValues.size} partition values. " + - s"After: ${newPartitionValues.size} partition values") + s"original. Before: ${oldPartitionKeys.size} partition values. " + + s"After: ${newPartitionKeys.size} partition values") } - if (!newPartitionValues.forall(oldPartitionValues.contains)) { + if (!newPartitionKeys.forall(oldPartitionKeys.contains)) { throw new SparkException("During runtime filtering, data source must not report new " + "partition values that are not present in the original partitioning.") } - groupPartitions(newPartitions.toImmutableArraySeq) - .map(_.groupedParts.map(_.parts)).getOrElse(Seq.empty) + val dataTypes = p.expressions.map(_.dataType) + val rowOrdering = RowOrdering.createNaturalAscendingOrdering(dataTypes) + val internalRowComparableWrapperFactory = + InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(dataTypes) + + val inputMap = inputPartitions.groupBy(p => + internalRowComparableWrapperFactory(p.asInstanceOf[HasPartitionKey].partitionKey()) + ).view.mapValues(_.size) + val filteredMap = newPartitions.groupBy(p => + internalRowComparableWrapperFactory(p.asInstanceOf[HasPartitionKey].partitionKey()) + ) + inputMap.toSeq + .sortBy { case (keyWrapper, _) => keyWrapper.row }(rowOrdering) + .flatMap { case (keyWrapper, size) => + val fps = filteredMap.getOrElse(keyWrapper, Array.empty) + assert(fps.size <= size) + fps.map(Some).padTo(size, None) + } case _ => // no validation is needed as the data source did not report any specific partitioning - newPartitions.map(Seq(_)).toImmutableArraySeq + newPartitions.toSeq.map(Some) } } else { - partitions - } - } - - override def outputPartitioning: Partitioning = { - super.outputPartitioning match { - case k: KeyGroupedPartitioning => getOutputKeyGroupedPartitioning(k, spjParams) - case p => p + (originalPartitioning match { + case p: KeyedPartitioning => + val dataTypes = p.expressions.map(_.dataType) + val rowOrdering = RowOrdering.createNaturalAscendingOrdering(dataTypes) + inputPartitions.sortBy(_.asInstanceOf[HasPartitionKey].partitionKey())(rowOrdering) + + case _ => inputPartitions + }).map(Some) } } @@ -131,22 +143,13 @@ case class BatchScanExec( // return an empty RDD with 1 partition if dynamic filtering removed the only split sparkContext.parallelize(Array.empty[InternalRow].toImmutableArraySeq, 1) } else { - val finalPartitions = outputPartitioning match { - case p: KeyGroupedPartitioning => getInputPartitionGrouping( - p, spjParams, filteredPartitions, p => p.asInstanceOf[HasPartitionKey].partitionKey()) - case _ => filteredPartitions - } - new DataSourceRDD( - sparkContext, finalPartitions, readerFactory, supportsColumnar, customMetrics) + sparkContext, filteredPartitions, readerFactory, supportsColumnar, customMetrics) } postDriverMetrics() rdd } - override def keyGroupedPartitioning: Option[Seq[Expression]] = - spjParams.keyGroupedPartitioning - override def doCanonicalize(): BatchScanExec = { this.copy( output = output.map(QueryPlan.normalizeExpressions(_, output)), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ContinuousScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ContinuousScanExec.scala index 288233e691453..e9e5f0f3175cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ContinuousScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ContinuousScanExec.scala @@ -52,7 +52,7 @@ case class ContinuousScanExec( } override lazy val inputRDD: RDD[InternalRow] = { - assert(partitions.forall(_.length == 1), "should only contain a single partition") + assert(partitions.forall(_.isDefined), "should contain a partition") EpochCoordinatorRef.get( sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), sparkContext.env) @@ -61,7 +61,7 @@ case class ContinuousScanExec( sparkContext, conf.continuousStreamingExecutorQueueSize, conf.continuousStreamingExecutorPollIntervalMs, - partitions.map(_.head), + partitions.map(_.get), schema, readerFactory, customMetrics) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index 19a057c72506b..2fedb97e8461e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric} import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.ArrayImplicits._ -class DataSourceRDDPartition(val index: Int, val inputPartitions: Seq[InputPartition]) +class DataSourceRDDPartition(val index: Int, val inputPartition: Option[InputPartition]) extends Partition with Serializable /** @@ -50,9 +50,22 @@ private case class ReaderState(reader: PartitionReader[_], iterator: MetricsIter // TODO: we should have 2 RDDs: an RDD[InternalRow] for row-based scan, an `RDD[ColumnarBatch]` for // columnar scan. +/** + * An RDD that reads data from a V2 data source. + * + * This RDD handles both row-based and columnar reads, tracks custom metrics from the data source, + * and ensures that task completion listeners are added only once per thread to avoid duplicate + * metric updates and resource cleanup. + * + * @param sc The Spark context + * @param inputPartitions The input partitions to read from + * @param partitionReaderFactory Factory for creating partition readers + * @param columnarReads Whether to use columnar reads + * @param customMetrics Custom metrics defined by the data source + */ class DataSourceRDD( sc: SparkContext, - @transient private val inputPartitions: Seq[Seq[InputPartition]], + @transient private val inputPartitions: Seq[Option[InputPartition]], partitionReaderFactory: PartitionReaderFactory, columnarReads: Boolean, customMetrics: Map[String, SQLMetric]) @@ -65,7 +78,7 @@ class DataSourceRDD( override protected def getPartitions: Array[Partition] = { inputPartitions.zipWithIndex.map { - case (inputPartitions, index) => new DataSourceRDDPartition(index, inputPartitions) + case (inputPartition, index) => new DataSourceRDDPartition(index, inputPartition) }.toArray } @@ -98,62 +111,39 @@ class DataSourceRDD( } } - val iterator = new Iterator[Object] { - private val inputPartitions = castPartition(split).inputPartitions - private var currentIter: Option[Iterator[Object]] = None - private var currentIndex: Int = 0 - - override def hasNext: Boolean = currentIter.exists(_.hasNext) || advanceToNextIter() - - override def next(): Object = { - if (!hasNext) throw new NoSuchElementException("No more elements") - currentIter.get.next() + castPartition(split).inputPartition.iterator.flatMap { inputPartition => + val (iter, reader) = if (columnarReads) { + val batchReader = partitionReaderFactory.createColumnarReader(inputPartition) + val iter = new MetricsBatchIterator( + new PartitionIterator[ColumnarBatch](batchReader, customMetrics)) + (iter, batchReader) + } else { + val rowReader = partitionReaderFactory.createReader(inputPartition) + val iter = new MetricsRowIterator( + new PartitionIterator[InternalRow](rowReader, customMetrics)) + (iter, rowReader) } - private def advanceToNextIter(): Boolean = { - if (currentIndex >= inputPartitions.length) { - false - } else { - val inputPartition = inputPartitions(currentIndex) - currentIndex += 1 - - // TODO: SPARK-25083 remove the type erasure hack in data source scan - val (iter, reader) = if (columnarReads) { - val batchReader = partitionReaderFactory.createColumnarReader(inputPartition) - val iter = new MetricsBatchIterator( - new PartitionIterator[ColumnarBatch](batchReader, customMetrics)) - (iter, batchReader) - } else { - val rowReader = partitionReaderFactory.createReader(inputPartition) - val iter = new MetricsRowIterator( - new PartitionIterator[InternalRow](rowReader, customMetrics)) - (iter, rowReader) - } - - // Flush metrics and close the previous reader before advancing to the next one. - // Pass the accumulated metrics to the new reader so they carry forward correctly. - val prevState = taskReaderStates.get(taskAttemptId) - if (prevState != null) { - val metrics = prevState.reader.currentMetricsValues - CustomMetrics.updateMetrics(metrics.toImmutableArraySeq, customMetrics) - reader.initMetricsValues(metrics) - prevState.reader.close() - } + // Flush metrics and close the previous reader before advancing to the next one. + // Pass the accumulated metrics to the new reader so they carry forward correctly. + val prevState = taskReaderStates.get(taskAttemptId) + if (prevState != null) { + val metrics = prevState.reader.currentMetricsValues + CustomMetrics.updateMetrics(metrics.toImmutableArraySeq, customMetrics) + reader.initMetricsValues(metrics) + prevState.reader.close() + } - // Update the map so the completion listener always references the latest reader. - taskReaderStates.put(taskAttemptId, ReaderState(reader, iter)) + // Update the map so the completion listener always references the latest reader. + taskReaderStates.put(taskAttemptId, ReaderState(reader, iter)) - currentIter = Some(iter) - hasNext - } - } + // TODO: SPARK-25083 remove the type erasure hack in data source scan + new InterruptibleIterator(context, iter.asInstanceOf[Iterator[InternalRow]]) } - - new InterruptibleIterator(context, iterator).asInstanceOf[Iterator[InternalRow]] } override def getPreferredLocations(split: Partition): Seq[String] = { - castPartition(split).inputPartitions.flatMap(_.preferredLocations()) + castPartition(split).inputPartition.toSeq.flatMap(_.preferredLocations()) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala index 95d85dab5cedc..ac993e18876a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala @@ -21,12 +21,11 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Expression, RowOrdering, SortOrder} import org.apache.spark.sql.catalyst.plans.physical -import org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning -import org.apache.spark.sql.catalyst.util.{truncatedString, InternalRowComparableWrapper} +import org.apache.spark.sql.catalyst.plans.physical.KeyedPartitioning +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition, PartitionReaderFactory, Scan} import org.apache.spark.sql.execution.{ExplainUtils, LeafExecNode, SQLExecution} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.connector.SupportsMetadata import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.ArrayImplicits._ @@ -63,9 +62,7 @@ trait DataSourceV2ScanExecBase extends LeafExecNode { redact(result) } - def partitions: Seq[Seq[InputPartition]] = { - groupedPartitions.map(_.groupedParts.map(_.parts)).getOrElse(inputPartitions.map(Seq(_))) - } + def partitions: Seq[Option[InputPartition]] = inputPartitions.map(Some) /** * Shorthand for calling redact() without specifying redacting rules @@ -94,76 +91,24 @@ trait DataSourceV2ScanExecBase extends LeafExecNode { override def outputPartitioning: physical.Partitioning = { keyGroupedPartitioning match { - case Some(exprs) if KeyGroupedPartitioning.supportsExpressions(exprs) => - groupedPartitions - .map { keyGroupedPartsInfo => - val keyGroupedParts = keyGroupedPartsInfo.groupedParts - KeyGroupedPartitioning(exprs, keyGroupedParts.size, keyGroupedParts.map(_.value), - keyGroupedPartsInfo.originalParts.map(_.partitionKey())) - } - .getOrElse(super.outputPartitioning) + case Some(exprs) if conf.v2BucketingEnabled && KeyedPartitioning.supportsExpressions(exprs) && + inputPartitions.length > 0 && inputPartitions.forall(_.isInstanceOf[HasPartitionKey]) => + val dataTypes = exprs.map(_.dataType) + val rowOrdering = RowOrdering.createNaturalAscendingOrdering(dataTypes) + val partitionKeys = + inputPartitions.map(_.asInstanceOf[HasPartitionKey].partitionKey()).sorted(rowOrdering) + KeyedPartitioning(exprs, partitionKeys, partitionKeys) case _ => super.outputPartitioning } } - @transient lazy val groupedPartitions: Option[KeyGroupedPartitionInfo] = { - // Early check if we actually need to materialize the input partitions. - keyGroupedPartitioning match { - case Some(_) => groupPartitions(inputPartitions) - case _ => None - } - } - /** - * Group partition values for all the input partitions. This returns `Some` iff: - * - [[SQLConf.V2_BUCKETING_ENABLED]] is turned on - * - all input partitions implement [[HasPartitionKey]] - * - `keyGroupedPartitioning` is set - * - * The result, if defined, is a [[KeyGroupedPartitionInfo]] which contains a list of - * [[KeyGroupedPartition]], as well as a list of partition values from the original input splits, - * sorted according to the partition keys in ascending order. - * - * A non-empty result means each partition is clustered on a single key and therefore eligible - * for further optimizations to eliminate shuffling in some operations such as join and aggregate. + * Returns the output ordering from the data source if available, otherwise falls back + * to the default (no ordering). This allows data sources to report their natural ordering + * through `SupportsReportOrdering`. */ - def groupPartitions(inputPartitions: Seq[InputPartition]): Option[KeyGroupedPartitionInfo] = { - if (!SQLConf.get.v2BucketingEnabled) return None - - keyGroupedPartitioning.flatMap { expressions => - val results = inputPartitions.takeWhile { - case _: HasPartitionKey => true - case _ => false - }.map(p => (p.asInstanceOf[HasPartitionKey].partitionKey(), p.asInstanceOf[HasPartitionKey])) - - if (results.length != inputPartitions.length || inputPartitions.isEmpty) { - // Not all of the `InputPartitions` implements `HasPartitionKey`, therefore skip here. - None - } else { - // also sort the input partitions according to their partition key order. This ensures - // a canonical order from both sides of a bucketed join, for example. - val partitionDataTypes = expressions.map(_.dataType) - val rowOrdering = RowOrdering.createNaturalAscendingOrdering(partitionDataTypes) - val sortedKeyToPartitions = results.sorted(rowOrdering.on((t: (InternalRow, _)) => t._1)) - val sortedGroupedPartitions = sortedKeyToPartitions - .map(t => (InternalRowComparableWrapper(t._1, expressions), t._2)) - .groupBy(_._1) - .toSeq - .map { case (key, s) => KeyGroupedPartition(key.row, s.map(_._2)) } - .sorted(rowOrdering.on((k: KeyGroupedPartition) => k.value)) - - Some(KeyGroupedPartitionInfo(sortedGroupedPartitions, sortedKeyToPartitions.map(_._2))) - } - } - } - - override def outputOrdering: Seq[SortOrder] = { - // when multiple partitions are grouped together, ordering inside partitions is not preserved - val partitioningPreservesOrdering = groupedPartitions - .forall(_.groupedParts.forall(_.parts.length <= 1)) - ordering.filter(_ => partitioningPreservesOrdering).getOrElse(super.outputOrdering) - } + override def outputOrdering: Seq[SortOrder] = ordering.getOrElse(super.outputOrdering) override def supportsColumnar: Boolean = { scan.columnarSupportMode() match { @@ -210,19 +155,3 @@ trait DataSourceV2ScanExecBase extends LeafExecNode { } } } - -/** - * A key-grouped Spark partition, which could consist of multiple input splits - * - * @param value the partition value shared by all the input splits - * @param parts the input splits that are grouped into a single Spark partition - */ -private[v2] case class KeyGroupedPartition(value: InternalRow, parts: Seq[InputPartition]) - -/** - * Information about key-grouped partitions, which contains a list of grouped partitions as well - * as the original input partitions before the grouping. - */ -private[v2] case class KeyGroupedPartitionInfo( - groupedParts: Seq[KeyGroupedPartition], - originalParts: Seq[HasPartitionKey]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 81bc1990404a9..cf6811523d4da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -45,7 +45,6 @@ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors import org.apache.spark.sql.execution.{FilterExec, InSubqueryExec, LeafExecNode, LocalTableScanExec, ProjectExec, RowDataSourceScanExec, SparkPlan, SparkStrategy => Strategy} import org.apache.spark.sql.execution.command.CommandUtils import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, LogicalRelationWithTable, PushableColumnAndNestedColumn} -import org.apache.spark.sql.execution.joins.StoragePartitionJoinParams import org.apache.spark.sql.execution.streaming.continuous.{WriteToContinuousDataSource, WriteToContinuousDataSourceExec} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.WAREHOUSE_PATH @@ -161,8 +160,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat case _ => false } val batchExec = BatchScanExec(relation.output, relation.scan, runtimeFilters, - relation.ordering, relation.relation.table, - StoragePartitionJoinParams(relation.keyGroupedPartitioning)) + relation.ordering, relation.relation.table, relation.keyGroupedPartitioning) DataSourceV2Strategy.withProjectAndFilter( project, postScanFilters, batchExec, !batchExec.supportsColumnar) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala new file mode 100644 index 0000000000000..498c4b7a0f42a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala @@ -0,0 +1,221 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.spark.sql.execution.datasources.v2 + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{Partition, SparkException} +import org.apache.spark.rdd.{CoalescedRDD, PartitionCoalescer, PartitionGroup, RDD} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{KeyedPartitioning, Partitioning} +import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper +import org.apache.spark.sql.connector.catalog.functions.Reducer +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.types.DataType + +/** + * Physical operator that groups input partitions by their partition keys. + * + * This operator is used to coalesce partitions from bucketed/partitioned data sources + * where multiple input partitions share the same partition key. It's commonly used in + * storage-partitioned joins to align partitions from different sides of the join. + * + * @param child The child plan providing bucketed/partitioned input + * @param joinKeyPositions Optional projection to select a subset of the partitioning key + * for join compatibility (e.g., when join keys are a subset of + * partition keys) + * @param commonPartitionKeys Optional sequence of expected partition key values and their + * split counts, used for partially clustered data + * @param reducers Optional reducers to apply to partition keys for grouping compatibility + * @param applyPartialClustering Whether to apply partial clustering for skewed data + * @param replicatePartitions Whether to replicate partitions across multiple keys + */ +case class GroupPartitionsExec( + child: SparkPlan, + joinKeyPositions: Option[Seq[Int]] = None, + commonPartitionKeys: Option[Seq[(InternalRow, Int)]] = None, + reducers: Option[Seq[Option[Reducer[_, _]]]] = None, + applyPartialClustering: Boolean = false, + replicatePartitions: Boolean = false + ) extends UnaryExecNode { + + override def outputPartitioning: Partitioning = { + child.outputPartitioning match { + case p: Partitioning with Expression => + p.transform { + case k: KeyedPartitioning => + val projectedExpressions = projectExpressions(k.expressions) + val projectedDataTypes = projectedExpressions.map(_.dataType) + k.copy(expressions = projectedExpressions, + partitionKeys = groupedPartitions.map(_._1), + originalPartitionKeys = projectKeys(k.originalPartitionKeys, projectedDataTypes)) + }.asInstanceOf[Partitioning] + case o => o + } + } + + private def projectExpressions(expressions: Seq[Expression]) = { + joinKeyPositions match { + case Some(projectionPositions) => + projectionPositions.map(expressions) + case _ => expressions + } + } + + private def projectKeys(keys: Seq[InternalRow], dataTypes: Seq[DataType]) = { + joinKeyPositions match { + case Some(projectionPositions) => + keys.map(KeyedPartitioning.projectKey(_, projectionPositions, dataTypes)) + case _ => keys + } + } + + /** + * Extracts the first KeyedPartitioning from the child's output partitioning. + * The child must have a KeyedPartitioning in its partitioning scheme. + */ + lazy val firstKeyedPartitioning = { + child.outputPartitioning.asInstanceOf[Partitioning with Expression].collectFirst { + case k: KeyedPartitioning => k + }.getOrElse( + throw new SparkException("GroupPartitionsExec requires a child with KeyedPartitioning")) + } + + /** + * Computes the grouped partitions by: + * 1. Projecting partition keys if joinKeyPositions is specified + * 2. Reducing keys if reducers are specified + * 3. Grouping input partition indices by their (possibly projected/reduced) keys + * 4. Sorting or distributing based on whether partial clustering is enabled + * + * Returns a sequence of (partitionKey, inputPartitionIndices) pairs representing + * how input partitions should be grouped together. + */ + lazy val groupedPartitions: Seq[(InternalRow, Seq[Int])] = { + // Also sort the input partitions according to their partition key order. This ensures + // a canonical order from both sides of a bucketed join, for example. + + val (projectedDataTypes, projectedKeys) = + joinKeyPositions match { + case Some(projectionPositions) => + val projectedDataTypes = + projectExpressions(firstKeyedPartitioning.expressions).map(_.dataType) + val projectedKeys = firstKeyedPartitioning.partitionKeys + .map(KeyedPartitioning.projectKey(_, projectionPositions, projectedDataTypes)) + (projectedDataTypes, projectedKeys) + + case _ => + val dataTypes = firstKeyedPartitioning.expressions.map(_.dataType) + (dataTypes, firstKeyedPartitioning.partitionKeys) + } + + val reducedKeys = reducers match { + case Some(reducers) => + projectedKeys.map(KeyedPartitioning.reduceKey(_, reducers, projectedDataTypes)) + case _ => projectedKeys + } + + val internalRowComparableWrapperFactory = + InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(projectedDataTypes) + + val map = reducedKeys.zipWithIndex.groupMap { + case (key, _) => internalRowComparableWrapperFactory(key) + }(_._2) + + // When partially clustered, the input partitions are not grouped by partition + // values. Here we'll need to check `commonPartitionKeys` and decide how to group + // and replicate splits within a partition. + if (commonPartitionKeys.isDefined) { + commonPartitionKeys.get.flatMap { case (key, numSplits) => + val splits = map.getOrElse(internalRowComparableWrapperFactory(key), Seq.empty) + if (applyPartialClustering && !replicatePartitions) { + val paddedSplits = splits.map(Seq(_)).padTo(numSplits, Seq.empty) + paddedSplits.map((key, _)) + } else { + Seq.fill(numSplits)((key, splits)) + } + } + } else { + val rowOrdering = RowOrdering.createNaturalAscendingOrdering(projectedDataTypes) + map.toSeq + .map { case (keyWrapper, v) => (keyWrapper.row, v) } + .sorted(rowOrdering.on((t: (InternalRow, _)) => t._1)) + } + } + + override protected def doExecute(): RDD[InternalRow] = { + val partitionCoalescer = new GroupedPartitionCoalescer(groupedPartitions.map(_._2)) + if (groupedPartitions.isEmpty) { + sparkContext.emptyRDD + } else { + new CoalescedRDD(child.execute(), groupedPartitions.size, Some(partitionCoalescer)) + } + } + + override def output: Seq[Attribute] = child.output + + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) + + override def outputOrdering: Seq[SortOrder] = { + // when multiple partitions are grouped together, ordering inside partitions is not preserved + if (groupedPartitions.forall(_._2.size <= 1)) { + child.outputOrdering + } else { + super.outputOrdering + } + } +} + +/** + * A PartitionCoalescer that groups partitions according to a pre-computed grouping plan. + * + * Unlike Spark's default coalescer which tries to minimize data movement, this coalescer + * groups partitions based on their partition keys to maintain the grouping semantics + * required for storage-partitioned operations. + * + * @param groupedPartitions Sequence where each element is a sequence of input partition + * indices that should be grouped together + */ +class GroupedPartitionCoalescer( + val groupedPartitions: Seq[Seq[Int]] + ) extends PartitionCoalescer with Serializable { + + override def coalesce(maxPartitions: Int, parent: RDD[_]): Array[PartitionGroup] = { + groupedPartitions.map { partitionIndices => + val partitions = new ArrayBuffer[Partition](partitionIndices.size) + val preferredLocations = new ArrayBuffer[String](partitionIndices.size) + partitionIndices.foreach { partitionIndex => + val partition = parent.partitions(partitionIndex) + partitions += partition + preferredLocations ++= parent.preferredLocations(partition) + } + // Select the most common location as the preferred location + val preferredLocation = preferredLocations + .groupBy(identity) + .view.mapValues(_.size) + .maxByOption(_._2) + .map(_._1) + val partitionGroup = new PartitionGroup(preferredLocation) + partitionGroup.partitions ++= partitions + partitionGroup + }.toArray + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index e239174e40ad4..216a1e7b63523 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper import org.apache.spark.sql.connector.catalog.functions.Reducer import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.datasources.v2.BatchScanExec +import org.apache.spark.sql.execution.datasources.v2.GroupPartitionsExec import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.internal.SQLConf @@ -64,7 +64,17 @@ case class EnsureRequirements( // Ensure that the operator's children satisfy their output distribution requirements. var children = originalChildren.zip(requiredChildDistributions).map { case (child, distribution) if child.outputPartitioning.satisfies(distribution) => - ensureOrdering(child, distribution) + distribution match { + case o: OrderedDistribution => + ensureOrdering(child, child.outputPartitioning, o) + case _ => child + } + case (c @ GroupedPartitions(p), distribution) if p.satisfies(distribution) => + distribution match { + case o: OrderedDistribution => + ensureOrdering(c, p, o) + case _ => GroupPartitionsExec(c) + } case (child, BroadcastDistribution(mode)) => BroadcastExchangeExec(mode, child) case (child, distribution) => @@ -138,17 +148,11 @@ case class EnsureRequirements( !p._2.canCreatePartitioning || children(p._1).isInstanceOf[ShuffleExchangeLike] ) // Choose all the specs that can be used to shuffle other children - val candidateSpecs = specs - .filter(_._2.canCreatePartitioning) - .filter { - // To choose a KeyGroupedShuffleSpec, we must be able to push down SPJ parameters into - // the scan (for join key positions). If these parameters can't be pushed down, this - // spec can't be used to shuffle other children. - case (idx, _: KeyGroupedShuffleSpec) => canPushDownSPJParamsToScan(children(idx)) - case _ => true - } - .filter(p => !shouldConsiderMinParallelism || - children(p._1).outputPartitioning.numPartitions >= conf.defaultNumShufflePartitions) + val candidateSpecs = specs.filter { case (index, spec) => + spec.canCreatePartitioning && + (!shouldConsiderMinParallelism || + children(index).outputPartitioning.numPartitions >= conf.defaultNumShufflePartitions) + } val bestSpecOpt = if (candidateSpecs.isEmpty) { None } else { @@ -206,7 +210,7 @@ case class EnsureRequirements( // partitioned side's BatchScanExec is grouped by join keys to match, // and we do that by pushing down the join keys case Some(KeyGroupedShuffleSpec(_, _, Some(joinKeyPositions))) => - populateJoinKeyPositions(child, Some(joinKeyPositions)) + withJoinKeyPositions(child, joinKeyPositions) case _ => child } } else { @@ -225,6 +229,7 @@ case class EnsureRequirements( child match { case ShuffleExchangeExec(_, c, so, ps) => ShuffleExchangeExec(newPartitioning, c, so, ps) + case GroupPartitionsExec(c, _, _, _, _, _) => ShuffleExchangeExec(newPartitioning, c) case _ => ShuffleExchangeExec(newPartitioning, child) } } @@ -305,21 +310,30 @@ case class EnsureRequirements( } } - private def ensureOrdering(plan: SparkPlan, distribution: Distribution) = { - (plan.outputPartitioning, distribution) match { - case (p @ KeyGroupedPartitioning(expressions, _, partitionValues, _), - d @ OrderedDistribution(ordering)) if p.satisfies(d) => - val attrs = expressions.flatMap(_.collectLeaves()).map(_.asInstanceOf[Attribute]) - val partitionOrdering: Ordering[InternalRow] = { - RowOrdering.create(ordering, attrs) + private def ensureOrdering( + plan: SparkPlan, + partitioning: Partitioning, + distribution: OrderedDistribution) = { + partitioning match { + case p: Partitioning with Expression => + val satisfyingKeyedPartitioning = + p.collectFirst { case k: KeyedPartitioning if k.satisfies(distribution) => k } + satisfyingKeyedPartitioning match { + case Some(k) => + val attrs = k.expressions.flatMap(_.collectLeaves()).map(_.asInstanceOf[Attribute]) + val partitionOrdering: Ordering[InternalRow] = { + RowOrdering.create(distribution.ordering, attrs) + } + // Sort 'commonPartitionKeys' and use this mechanism to ensure BatchScan's output + // partitions are ordered + val sorted = k.partitionKeys.sorted(partitionOrdering) + GroupPartitionsExec(plan, commonPartitionKeys = Some(sorted.map((_, 1)))) + + case _ => plan } - // Sort 'commonPartitionValues' and use this mechanism to ensure BatchScan's - // output partitions are ordered - val sorted = partitionValues.sorted(partitionOrdering) - populateCommonPartitionInfo(plan, sorted.map((_, 1)), - None, None, applyPartialClustering = false, replicatePartitions = false) + case _ => plan - } + } } /** @@ -340,12 +354,12 @@ case class EnsureRequirements( reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, rightExpressions, rightKeys) .orElse(reorderJoinKeysRecursively( leftKeys, rightKeys, leftPartitioning, None)) - case (Some(KeyGroupedPartitioning(clustering, _, _, _)), _) => + case (Some(KeyedPartitioning(clustering, _, _)), _) => val leafExprs = clustering.flatMap(_.collectLeaves()) reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, leftKeys) .orElse(reorderJoinKeysRecursively( leftKeys, rightKeys, None, rightPartitioning)) - case (_, Some(KeyGroupedPartitioning(clustering, _, _, _))) => + case (_, Some(KeyedPartitioning(clustering, _, _))) => val leafExprs = clustering.flatMap(_.collectLeaves()) reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, rightKeys) .orElse(reorderJoinKeysRecursively( @@ -409,27 +423,9 @@ case class EnsureRequirements( } } - /** - * Whether SPJ params can be pushed down to the leaf nodes of a physical plan. For a plan to be - * eligible for SPJ parameter pushdown, all leaf nodes must be a KeyGroupedPartitioning-aware - * scan. - * - * Notably, if the leaf of `plan` is an [[RDDScanExec]] created by checkpointing a DSv2 scan, the - * reported partitioning will be a [[KeyGroupedPartitioning]], but this plan will _not_ be - * eligible for SPJ parameter pushdown (as the partitioning is static and can't be easily - * re-grouped or padded with empty partitions according to the partition values on the other side - * of the join). - */ - private def canPushDownSPJParamsToScan(plan: SparkPlan): Boolean = { - plan.collectLeaves().forall { - case _: KeyGroupedPartitionedScan[_] => true - case _ => false - } - } - /** * Checks whether two children, `left` and `right`, of a join operator have compatible - * `KeyGroupedPartitioning`, and can benefit from storage-partitioned join. + * `KeyedPartitioning`, and can benefit from storage-partitioned join. * * Returns the updated new children if the check is successful, otherwise `None`. */ @@ -438,12 +434,6 @@ case class EnsureRequirements( left: SparkPlan, right: SparkPlan, requiredChildDistribution: Seq[Distribution]): Option[Seq[SparkPlan]] = { - // If SPJ params can't be pushed down to either the left or right side, it's unsafe to do an - // SPJ. - if (!canPushDownSPJParamsToScan(left) || !canPushDownSPJParamsToScan(right)) { - return None - } - parent match { case smj: SortMergeJoinExec => checkKeyGroupCompatible(left, right, smj.joinType, requiredChildDistribution) @@ -475,11 +465,21 @@ case class EnsureRequirements( val leftSpec = specs.head val rightSpec = specs(1) - var isCompatible = false - if (!conf.v2BucketingPushPartValuesEnabled && - !conf.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) { - isCompatible = leftSpec.isCompatibleWith(rightSpec) - } else { + def contains(partitioning: Partitioning, keyedPartitioning: KeyedPartitioning): Boolean = { + partitioning match { + case k: KeyedPartitioning => k == keyedPartitioning + case PartitioningCollection(partitionings) => + partitionings.exists(contains(_, keyedPartitioning)) + case _ => false + } + } + + var isCompatible = contains(left.outputPartitioning, leftSpec.partitioning) && + contains(right.outputPartitioning, rightSpec.partitioning) && + leftSpec.isCompatibleWith(rightSpec) + if ((!isCompatible || conf.v2BucketingPartiallyClusteredDistributionEnabled) && + (conf.v2BucketingPushPartValuesEnabled || + conf.v2BucketingAllowJoinKeysSubsetOfPartitionKeys)) { logInfo("Pushing common partition values for storage-partitioned join") isCompatible = leftSpec.areKeysCompatible(rightSpec) @@ -499,8 +499,8 @@ case class EnsureRequirements( // just push the common set of partition values: `[0, 1, 2, 3]` down to the two data // sources. if (isCompatible) { - val leftPartValues = leftSpec.partitioning.partitionValues - val rightPartValues = rightSpec.partitioning.partitionValues + val leftPartValues = leftSpec.partitioning.partitionKeys + val rightPartValues = rightSpec.partitioning.partitionKeys val numLeftPartValues = MDC(LogKeys.NUM_LEFT_PARTITION_VALUES, leftPartValues.size) val numRightPartValues = MDC(LogKeys.NUM_RIGHT_PARTITION_VALUES, rightPartValues.size) @@ -517,11 +517,11 @@ case class EnsureRequirements( // in case of compatible but not identical partition expressions, we apply 'reduce' // transforms to group one side's partitions as well as the common partition values val leftReducers = leftSpec.reducers(rightSpec) - val leftParts = reducePartValues(leftSpec.partitioning.partitionValues, + val leftParts = reducePartValues(leftSpec.partitioning.partitionKeys, partitionExprs, leftReducers) val rightReducers = rightSpec.reducers(leftSpec) - val rightParts = reducePartValues(rightSpec.partitioning.partitionValues, + val rightParts = reducePartValues(rightSpec.partitioning.partitionKeys, partitionExprs, rightReducers) @@ -564,8 +564,8 @@ case class EnsureRequirements( logInfo(log"Skipping partially clustered distribution as it cannot be applied for " + log"join type '${MDC(LogKeys.JOIN_TYPE, joinType)}'") } else { - val leftLink = left.logicalLink - val rightLink = right.logicalLink + val leftLink = unwrapGroupPartitions(left).logicalLink + val rightLink = unwrapGroupPartitions(right).logicalLink replicateLeftSide = if ( leftLink.isDefined && rightLink.isDefined && @@ -609,14 +609,13 @@ case class EnsureRequirements( } else { // In partially clustered distribution, we should use un-grouped partition values val spec = if (replicateLeftSide) rightSpec else leftSpec - val partValues = spec.partitioning.originalPartitionValues + val originalPartitionKeys = spec.partitioning.originalPartitionKeys val internalRowComparableWrapperFactory = InternalRowComparableWrapper.getInternalRowComparableWrapperFactory( partitionExprs.map(_.dataType)) - val numExpectedPartitions = partValues - .map(internalRowComparableWrapperFactory) - .groupBy(identity) + val numExpectedPartitions = originalPartitionKeys + .groupBy(internalRowComparableWrapperFactory) .transform((_, v) => v.size) mergedPartValues = mergedPartValues.map { case (partVal, numParts) => @@ -632,9 +631,9 @@ case class EnsureRequirements( } // Now we need to push-down the common partition information to the scan in each child - newLeft = populateCommonPartitionInfo(left, mergedPartValues, leftSpec.joinKeyPositions, + newLeft = applyGroupPartitions(left, leftSpec.joinKeyPositions, mergedPartValues, leftReducers, applyPartialClustering, replicateLeftSide) - newRight = populateCommonPartitionInfo(right, mergedPartValues, rightSpec.joinKeyPositions, + newRight = applyGroupPartitions(right, rightSpec.joinKeyPositions, mergedPartValues, rightReducers, applyPartialClustering, replicateRightSide) } } @@ -670,46 +669,46 @@ case class EnsureRequirements( joinType == LeftAnti || joinType == LeftOuter } - // Populate the common partition information down to the scan nodes - private def populateCommonPartitionInfo( + /** + * Unwraps a GroupPartitionsExec to get the underlying child plan. + */ + private def unwrapGroupPartitions(plan: SparkPlan): SparkPlan = plan match { + case g: GroupPartitionsExec => g.child + case other => other + } + + /** + * Applies or updates GroupPartitionsExec with the given parameters. + */ + private def applyGroupPartitions( plan: SparkPlan, - values: Seq[(InternalRow, Int)], joinKeyPositions: Option[Seq[Int]], + mergedPartValues: Seq[(InternalRow, Int)], reducers: Option[Seq[Option[Reducer[_, _]]]], applyPartialClustering: Boolean, - replicatePartitions: Boolean): SparkPlan = plan match { - case scan: BatchScanExec => - val newScan = scan.copy( - spjParams = scan.spjParams.copy( - commonPartitionValues = Some(values), + replicatePartitions: Boolean): SparkPlan = { + plan match { + case g: GroupPartitionsExec => + g.copy( joinKeyPositions = joinKeyPositions, + commonPartitionKeys = Some(mergedPartValues), reducers = reducers, applyPartialClustering = applyPartialClustering, - replicatePartitions = replicatePartitions - ) - ) - newScan.copyTagsFrom(scan) - newScan - case node => - node.mapChildren(child => populateCommonPartitionInfo( - child, values, joinKeyPositions, reducers, applyPartialClustering, replicatePartitions)) + replicatePartitions = replicatePartitions) + case _ => + GroupPartitionsExec(plan, joinKeyPositions, Some(mergedPartValues), reducers, + applyPartialClustering, replicatePartitions) + } } - - private def populateJoinKeyPositions( - plan: SparkPlan, - joinKeyPositions: Option[Seq[Int]]): SparkPlan = plan match { - case scan: BatchScanExec => - val newScan = scan.copy( - spjParams = scan.spjParams.copy( - joinKeyPositions = joinKeyPositions - ) - ) - newScan.copyTagsFrom(scan) - newScan - case node => - node.mapChildren(child => populateJoinKeyPositions( - child, joinKeyPositions)) + /** + * Applies join key positions to a plan by wrapping or updating GroupPartitionsExec. + */ + private def withJoinKeyPositions(plan: SparkPlan, positions: Seq[Int]): SparkPlan = { + plan match { + case g: GroupPartitionsExec => g.copy(joinKeyPositions = Some(positions)) + case _ => GroupPartitionsExec(plan, joinKeyPositions = Some(positions)) + } } private def reducePartValues( @@ -723,7 +722,7 @@ case class EnsureRequirements( InternalRowComparableWrapper.getInternalRowComparableWrapperFactory( partitionDataTypes) partValues.map { row => - KeyGroupedShuffleSpec.reducePartitionValue( + KeyGroupedShuffleSpec.reducePartitionKey( row, reducers, partitionDataTypes, internalRowComparableWrapperFactory) }.distinct.map(_.row) case _ => partValues @@ -732,13 +731,13 @@ case class EnsureRequirements( /** * Tries to create a [[KeyGroupedShuffleSpec]] from the input partitioning and distribution, if - * the partitioning is a [[KeyGroupedPartitioning]] (either directly or indirectly), and + * the partitioning is a [[KeyedPartitioning]] (either directly or indirectly), and * satisfies the given distribution. */ private def createKeyGroupedShuffleSpec( partitioning: Partitioning, distribution: ClusteredDistribution): Option[KeyGroupedShuffleSpec] = { - def tryCreate(partitioning: KeyGroupedPartitioning): Option[KeyGroupedShuffleSpec] = { + def tryCreate(partitioning: KeyedPartitioning): Option[KeyGroupedShuffleSpec] = { val attributes = partitioning.expressions.flatMap(_.collectLeaves()) val clustering = distribution.clustering @@ -758,10 +757,9 @@ case class EnsureRequirements( } partitioning match { - case p: KeyGroupedPartitioning => tryCreate(p) + case p: KeyedPartitioning => tryCreate(p) case PartitioningCollection(partitionings) => - val specs = partitionings.map(p => createKeyGroupedShuffleSpec(p, distribution)) - specs.filter(_.isDefined).map(_.get).headOption + partitionings.collectFirst(Function.unlift(createKeyGroupedShuffleSpec(_, distribution))) case _ => None } } @@ -856,3 +854,23 @@ case class EnsureRequirements( } } } + +object GroupedPartitions { + def unapply(plan: SparkPlan): Option[Partitioning with Expression] = { + groupPartitions(plan.outputPartitioning) + } + + private def groupPartitions(p: Partitioning): Option[Partitioning with Expression] = { + p match { + case c: PartitioningCollection => + c.partitionings.flatMap(groupPartitions) match { + case Nil => None + case p :: Nil => Some(p) + case ps => Some(PartitioningCollection(ps)) + } + case k: KeyedPartitioning => Some(k.toGrouped) + case _ => None + } + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 95120039a6f94..f50f6f484ac50 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -370,12 +370,14 @@ object ShuffleExchangeExec { ascending = true, samplePointsPerPartitionHint = SQLConf.get.rangeExchangeSampleSizePerPartition) case SinglePartition => new ConstantPartitioner - case k @ KeyGroupedPartitioning(expressions, n, _, _) => - val valueMap = k.uniquePartitionValues.zipWithIndex.map { + case k @ KeyedPartitioning(expressions, _, _) => + val keyGroupedPartitioning = k.toGrouped + val valueMap = keyGroupedPartitioning.partitionKeys.zipWithIndex.map { case (partition, index) => (partition.toSeq(expressions.map(_.dataType)), index) }.toMap - new KeyGroupedPartitioner(mutable.Map(valueMap.toSeq: _*), n) - case _ => throw SparkException.internalError(s"Exchange not implemented for $newPartitioning") + new KeyGroupedPartitioner(mutable.Map(valueMap.toSeq: _*), + keyGroupedPartitioning.numPartitions) + case p => throw SparkException.internalError(s"Exchange not implemented for $newPartitioning") // TODO: Handle BroadcastPartitioning. } def getPartitionKeyExtractor(): InternalRow => Any = newPartitioning match { @@ -401,7 +403,7 @@ object ShuffleExchangeExec { val projection = UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes) row => projection(row) case SinglePartition => identity - case KeyGroupedPartitioning(expressions, _, _, _) => + case KeyedPartitioning(expressions, _, _) => row => bindReferences(expressions, outputAttributes).map(_.eval(row)) case s: ShufflePartitionIdPassThrough => // For ShufflePartitionIdPassThrough, the expression directly evaluates to the partition ID diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/StoragePartitionJoinParams.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/StoragePartitionJoinParams.scala deleted file mode 100644 index a28eafc5cae5b..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/StoragePartitionJoinParams.scala +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins - -import java.util.Objects - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.connector.catalog.functions.Reducer - -case class StoragePartitionJoinParams( - keyGroupedPartitioning: Option[Seq[Expression]] = None, - joinKeyPositions: Option[Seq[Int]] = None, - commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None, - reducers: Option[Seq[Option[Reducer[_, _]]]] = None, - applyPartialClustering: Boolean = false, - replicatePartitions: Boolean = false) { - override def equals(other: Any): Boolean = other match { - case other: StoragePartitionJoinParams => - this.commonPartitionValues == other.commonPartitionValues && - this.replicatePartitions == other.replicatePartitions && - this.applyPartialClustering == other.applyPartialClustering && - this.joinKeyPositions == other.joinKeyPositions - case _ => - false - } - - override def hashCode(): Int = Objects.hash( - joinKeyPositions: Option[Seq[Int]], - commonPartitionValues: Option[Seq[(InternalRow, Int)]], - applyPartialClustering: java.lang.Boolean, - replicatePartitions: java.lang.Boolean) -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala index 1a0efa7c4aafb..cf9133b0835d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala @@ -51,9 +51,9 @@ abstract class DistributionAndOrderingSuiteBase plan: QueryPlan[T]): Partitioning = partitioning match { case HashPartitioning(exprs, numPartitions) => HashPartitioning(exprs.map(resolveAttrs(_, plan)), numPartitions) - case KeyGroupedPartitioning(clustering, numPartitions, partValues, originalPartValues) => - KeyGroupedPartitioning(clustering.map(resolveAttrs(_, plan)), numPartitions, partValues, - originalPartValues) + case KeyedPartitioning(expressions, partitionKeys, originalPartitionKeys) => + KeyedPartitioning(expressions.map(resolveAttrs(_, plan)), partitionKeys, + originalPartitionKeys) case PartitioningCollection(partitionings) => PartitioningCollection(partitionings.map(resolvePartitioning(_, plan))) case RangePartitioning(ordering, numPartitions) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 56bd028464e54..869fc6a0cca6e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -30,8 +30,7 @@ import org.apache.spark.sql.connector.distributions.Distributions import org.apache.spark.sql.connector.expressions._ import org.apache.spark.sql.connector.expressions.Expressions._ import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.datasources.v2.BatchScanExec -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation +import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2ScanRelation, GroupPartitionsExec} import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike} import org.apache.spark.sql.execution.joins.SortMergeJoinExec import org.apache.spark.sql.functions.{col, max} @@ -76,7 +75,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { Column.create("dept_id", IntegerType), Column.create("data", StringType)) - test("clustered distribution: output partitioning should be KeyGroupedPartitioning") { + test("clustered distribution: output partitioning should be KeyedPartitioning") { val partitions: Array[Transform] = Array(Expressions.years("ts")) // create a table with 3 partitions, partitioned by `years` transform @@ -89,18 +88,15 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { var df = sql(s"SELECT count(*) FROM testcat.ns.$table GROUP BY ts") val catalystDistribution = physical.ClusteredDistribution( Seq(TransformExpression(YearsFunction, Seq(attr("ts"))))) - val partitionValues = Seq(50, 51, 52).map(v => InternalRow.fromSeq(Seq(v))) - val projectedPositions = catalystDistribution.clustering.indices + val partitionKeys = Seq(50L, 51L, 52L).map(v => InternalRow.fromSeq(Seq(v))) checkQueryPlan(df, catalystDistribution, - physical.KeyGroupedPartitioning(catalystDistribution.clustering, projectedPositions, - partitionValues, partitionValues)) + physical.KeyedPartitioning(catalystDistribution.clustering, partitionKeys, partitionKeys)) // multiple group keys should work too as long as partition keys are subset of them df = sql(s"SELECT count(*) FROM testcat.ns.$table GROUP BY id, ts") checkQueryPlan(df, catalystDistribution, - physical.KeyGroupedPartitioning(catalystDistribution.clustering, projectedPositions, - partitionValues, partitionValues)) + physical.KeyedPartitioning(catalystDistribution.clustering, partitionKeys, partitionKeys)) } test("non-clustered distribution: no partition") { @@ -124,9 +120,9 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { Seq(TransformExpression(BucketFunction, Seq(attr("ts")), Some(32)))) // Has exactly one partition. - val partitionValues = Seq(0).map(v => InternalRow.fromSeq(Seq(v))) + val partitionKeys = Seq(0).map(v => InternalRow.fromSeq(Seq(v))) checkQueryPlan(df, distribution, - physical.KeyGroupedPartitioning(distribution.clustering, 1, partitionValues, partitionValues)) + physical.KeyedPartitioning(distribution.clustering, partitionKeys, partitionKeys)) } test("non-clustered distribution: no V2 catalog") { @@ -275,7 +271,8 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { private def testWithCustomersAndOrders( customers_partitions: Array[Transform], orders_partitions: Array[Transform], - expectedNumOfShuffleExecs: Int): Unit = { + expectedNumOfShuffleExecs: Int, + expectedGroupPartitionsExecs: Int): Unit = { createTable(customers, customersColumns, customers_partitions) sql(s"INSERT INTO testcat.ns.$customers VALUES " + s"('aaa', 10, 1), ('bbb', 20, 2), ('ccc', 30, 3)") @@ -295,6 +292,9 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val shuffles = collectShuffles(df.queryExecution.executedPlan) assert(shuffles.length == expectedNumOfShuffleExecs) + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.length == expectedGroupPartitionsExecs) + checkAnswer(df, Seq(Row("aaa", 10, 100.0), Row("aaa", 10, 200.0), Row("bbb", 20, 150.0), Row("bbb", 20, 250.0), Row("bbb", 20, 350.0), Row("ccc", 30, 400.50))) @@ -306,6 +306,12 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } } + protected def collectAllGroupPartitions(plan: SparkPlan): Seq[GroupPartitionsExec] = { + collect(plan) { + case g: GroupPartitionsExec => g + } + } + protected def collectShuffles(plan: SparkPlan): Seq[ShuffleExchangeLike] = { // here we skip collecting shuffle operators that are not associated with SMJ collect(plan) { @@ -314,7 +320,17 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { collect(smj) { case s: ShuffleExchangeExec => s }) - } + }.toSet.toSeq + + protected def collectGroupPartitions(plan: SparkPlan): Seq[GroupPartitionsExec] = { + // here we skip collecting shuffle operators that are not associated with SMJ + collect(plan) { + case s: SortMergeJoinExec => s + }.flatMap(smj => + collect(smj) { + case g: GroupPartitionsExec => g + }) + }.toSet.toSeq private def collectScans(plan: SparkPlan): Seq[BatchScanExec] = { collect(plan) { case s: BatchScanExec => s } @@ -324,7 +340,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val customers_partitions = Array(bucket(4, "customer_id")) val orders_partitions = Array(bucket(4, "customer_id")) - testWithCustomersAndOrders(customers_partitions, orders_partitions, 0) + testWithCustomersAndOrders(customers_partitions, orders_partitions, 0, 1) } test("partitioned join: number of buckets mismatch should trigger shuffle") { @@ -332,13 +348,13 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val orders_partitions = Array(bucket(2, "customer_id")) // should shuffle both sides when number of buckets are not the same - testWithCustomersAndOrders(customers_partitions, orders_partitions, 2) + testWithCustomersAndOrders(customers_partitions, orders_partitions, 2, 0) } test("partitioned join: only one side reports partitioning") { val customers_partitions = Array(bucket(4, "customer_id")) - testWithCustomersAndOrders(customers_partitions, Array.empty, 2) + testWithCustomersAndOrders(customers_partitions, Array.empty, 2, 0) } private val items: String = "items" @@ -366,7 +382,10 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val df = sql(s"SELECT MAX(price) AS res FROM testcat.ns.$items GROUP BY id") val shuffles = collectAllShuffles(df.queryExecution.executedPlan) assert(shuffles.isEmpty, - "should contain shuffle when not grouping by partition values") + "should not contain shuffle when grouping by partition values") + val groupPartitions = collectAllGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.size == 1, + "should contain group partitions when grouping by partition values") checkAnswer(df.sort("res"), Seq(Row(10.0), Row(15.5), Row(41.0))) } @@ -390,9 +409,13 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { if (sortingEnabled) { assert(collectAllShuffles(df.queryExecution.executedPlan).isEmpty, "should contain no shuffle when sorting by partition values") + assert(collectAllGroupPartitions(df.queryExecution.executedPlan).size == 1, + "should contain partition grouping when sorting by partition values") } else { assert(collectAllShuffles(df.queryExecution.executedPlan).size == 1, "should contain one shuffle when optimization is disabled") + assert(collectAllGroupPartitions(df.queryExecution.executedPlan).isEmpty, + "should contain no partition grouping when optimization is disabled") } checkAnswer(df, answer) }: Unit @@ -446,6 +469,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { |""".stripMargin) checkAnswer(df, Seq(Row(1, 1, "aa"), Row(2, 2, "bb"), Row(3, 3, "cc"))) assert(collectShuffles(df.queryExecution.executedPlan).isEmpty) + assert(collectGroupPartitions(df.queryExecution.executedPlan).isEmpty) } } @@ -473,6 +497,9 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val df = createJoinTestDF(Seq("id" -> "item_id", "arrive_time" -> "time")) val shuffles = collectShuffles(df.queryExecution.executedPlan) assert(shuffles.isEmpty, "should not add shuffle for both sides of the join") + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.size === 2, + "should contain group partitions on both sides of the join") checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0), Row(1, "aa", 41.0, 44.0), Row(1, "aa", 41.0, 45.0), Row(2, "bb", 10.0, 11.0), Row(2, "bb", 10.5, 11.0), Row(3, "cc", 15.5, 19.5)) @@ -505,6 +532,9 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val df = createJoinTestDF(Seq("id" -> "item_id", "arrive_time" -> "time")) val shuffles = collectShuffles(df.queryExecution.executedPlan) assert(shuffles.isEmpty, "should not add shuffle for both sides of the join") + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.size === 2, + "should contain group partitions on both sides of the join") checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0), Row(1, "aa", 41.0, 44.0), Row(1, "aa", 41.0, 45.0), Row(2, "bb", 10.0, 11.0), Row(2, "bb", 10.5, 11.0), Row(3, "cc", 15.5, 19.5)) @@ -532,11 +562,16 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { withSQLConf(SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString) { val df = createJoinTestDF(Seq("id" -> "item_id", "arrive_time" -> "time")) val shuffles = collectShuffles(df.queryExecution.executedPlan) + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) if (pushDownValues) { assert(shuffles.isEmpty, "should not add shuffle when partition values mismatch") + assert(groupPartitions.size === 2, + "should add group partitions when partition values mismatch") } else { assert(shuffles.nonEmpty, "should add shuffle when partition values mismatch, and " + "pushing down partition values is not enabled") + assert(groupPartitions.isEmpty, "should not add group partition when partition values " + + "mismatch, and pushing down partition values is not enabled") } checkAnswer(df, @@ -566,11 +601,16 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { withSQLConf(SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString) { val df = createJoinTestDF(Seq("id" -> "item_id")) val shuffles = collectShuffles(df.queryExecution.executedPlan) + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) if (pushDownValues) { assert(shuffles.isEmpty, "should not add shuffle when partition values mismatch") + assert(groupPartitions.size === 2, + "should add group partitions when partition values mismatch") } else { assert(shuffles.nonEmpty, "should add shuffle when partition values mismatch, and " + "pushing down partition values is not enabled") + assert(groupPartitions.isEmpty, "should not add group partition when partition values " + + "mismatch, and pushing down partition values is not enabled") } checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0), Row(3, "bb", 10.0, 19.5))) @@ -598,11 +638,16 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { withSQLConf(SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString) { val df = createJoinTestDF(Seq("id" -> "item_id")) val shuffles = collectShuffles(df.queryExecution.executedPlan) + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) if (pushDownValues) { assert(shuffles.isEmpty, "should not add shuffle when partition values mismatch") + assert(groupPartitions.size === 2, + "should add group partitions when partition values mismatch") } else { assert(shuffles.nonEmpty, "should add shuffle when partition values mismatch, and " + "pushing down partition values is not enabled") + assert(groupPartitions.isEmpty, "should not add group partition when partition values " + + "mismatch, and pushing down partition values is not enabled") } checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0), Row(2, "bb", 10.0, 19.5))) @@ -629,11 +674,16 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { withSQLConf(SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString) { val df = createJoinTestDF(Seq("id" -> "item_id")) val shuffles = collectShuffles(df.queryExecution.executedPlan) + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) if (pushDownValues) { assert(shuffles.isEmpty, "should not add shuffle when partition values mismatch") + assert(groupPartitions.size === 2, + "should add group partitions when partition values mismatch") } else { assert(shuffles.nonEmpty, "should add shuffle when partition values mismatch, and " + "pushing down partition values is not enabled") + assert(groupPartitions.isEmpty, "should not add group partition when partition values " + + "mismatch, and pushing down partition values is not enabled") } checkAnswer(df, Seq.empty) @@ -641,7 +691,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } } - test("SPARK-49205: KeyGroupedPartitioning should inherit HashPartitioningLike") { + test("SPARK-49205: KeyedPartitioning should be an Expression") { val items_partitions = Array(days("arrive_time")) createTable(items, itemsColumns, items_partitions) sql(s"INSERT INTO testcat.ns.$items VALUES " + @@ -717,8 +767,8 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val shuffles = collectShuffles(df.queryExecution.executedPlan) assert(shuffles.isEmpty, "should not contain any shuffle") if (pushDownValues) { - val scans = collectScans(df.queryExecution.executedPlan) - assert(scans.forall(_.inputRDD.partitions.length == expected)) + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.forall(_.outputPartitioning.numPartitions == expected)) } checkAnswer(df, Seq(Row(1, "aa", 40.0, 45.0), Row(1, "aa", 40.0, 50.0), Row(2, "bb", 10.0, 15.0), Row(2, "bb", 10.0, 20.0), Row(3, "cc", 15.5, 20.0))) @@ -758,8 +808,8 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val shuffles = collectShuffles(df.queryExecution.executedPlan) assert(shuffles.isEmpty, "should not contain any shuffle") if (pushDownValues) { - val scans = collectScans(df.queryExecution.executedPlan) - assert(scans.forall(_.inputRDD.partitions.length == expected)) + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.forall(_.outputPartitioning.numPartitions === expected)) } checkAnswer(df, Seq( Row(1, "aa", 40.0, 45.0), Row(1, "aa", 40.0, 50.0), Row(1, "aa", 40.0, 55.0), @@ -806,8 +856,8 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val shuffles = collectShuffles(df.queryExecution.executedPlan) if (pushDownValues) { assert(shuffles.isEmpty, "should not contain any shuffle") - val scans = collectScans(df.queryExecution.executedPlan) - assert(scans.forall(_.inputRDD.partitions.length == expected)) + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.forall(_.outputPartitioning.numPartitions === expected)) } else { assert(shuffles.nonEmpty, "should contain shuffle when not pushing down partition values") @@ -857,8 +907,8 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val shuffles = collectShuffles(df.queryExecution.executedPlan) if (pushDownValues) { assert(shuffles.isEmpty, "should not contain any shuffle") - val scans = collectScans(df.queryExecution.executedPlan) - assert(scans.forall(_.inputRDD.partitions.length == expected)) + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.forall(_.outputPartitioning.numPartitions === expected)) } else { assert(shuffles.nonEmpty, "should contain shuffle when not pushing down partition values") @@ -903,8 +953,8 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val shuffles = collectShuffles(df.queryExecution.executedPlan) if (pushDownValues) { assert(shuffles.isEmpty, "should not contain any shuffle") - val scans = collectScans(df.queryExecution.executedPlan) - assert(scans.forall(_.inputRDD.partitions.length == expected)) + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.forall(_.outputPartitioning.numPartitions === expected)) } else { assert(shuffles.nonEmpty, "should contain shuffle when not pushing down partition values") @@ -950,9 +1000,8 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val shuffles = collectShuffles(df.queryExecution.executedPlan) if (pushDownValues) { assert(shuffles.isEmpty, "should not contain any shuffle") - val scans = collectScans(df.queryExecution.executedPlan) - assert(scans.forall(_.inputRDD.partitions.length == expected), - s"Expected $expected but got ${scans.head.inputRDD.partitions.length}") + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.forall(_.outputPartitioning.numPartitions === expected)) } else { assert(shuffles.nonEmpty, "should contain shuffle when not pushing down partition values") @@ -999,10 +1048,8 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val shuffles = collectShuffles(df.queryExecution.executedPlan) if (pushDownValues) { assert(shuffles.isEmpty, "should not contain any shuffle") - val scans = collectScans(df.queryExecution.executedPlan) - assert(scans.map(_.inputRDD.partitions.length).toSet.size == 1) - assert(scans.forall(_.inputRDD.partitions.length == expected), - s"Expected $expected but got ${scans.head.inputRDD.partitions.length}") + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.forall(_.outputPartitioning.numPartitions === expected)) } else { assert(shuffles.nonEmpty, "should contain shuffle when not pushing down partition values") @@ -1047,10 +1094,8 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val shuffles = collectShuffles(df.queryExecution.executedPlan) if (pushDownValues) { assert(shuffles.isEmpty, "should not contain any shuffle") - val scans = collectScans(df.queryExecution.executedPlan) - assert(scans.map(_.inputRDD.partitions.length).toSet.size == 1) - assert(scans.forall(_.inputRDD.partitions.length == expected), - s"Expected $expected but got ${scans.head.inputRDD.partitions.length}") + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.forall(_.outputPartitioning.numPartitions === expected)) } else { assert(shuffles.nonEmpty, "should contain shuffle when not pushing down partition values") @@ -1123,8 +1168,8 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val shuffles = collectShuffles(df.queryExecution.executedPlan) assert(shuffles.isEmpty, "should not contain any shuffle") if (pushDownValues) { - val scans = collectScans(df.queryExecution.executedPlan) - assert(scans.forall(_.inputRDD.partitions.length === 3)) + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.forall(_.outputPartitioning.numPartitions === 3)) } } } @@ -1226,8 +1271,8 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val shuffles = collectShuffles(df.queryExecution.executedPlan) if (pushDownValues) { assert(shuffles.isEmpty, "should not add shuffle for both sides of the join") - val scans = collectScans(df.queryExecution.executedPlan) - assert(scans.forall(_.inputRDD.partitions.length == expected)) + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.forall(_.outputPartitioning.numPartitions === expected)) } else { assert(shuffles.nonEmpty, "should contain shuffle when not pushing down partition values") @@ -1495,12 +1540,12 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val shuffles = collectShuffles(df.queryExecution.executedPlan) assert(shuffles.isEmpty, "SPJ should be triggered") - val scans = collectScans(df.queryExecution.executedPlan) - .map(_.inputRDD.partitions.length) + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + .map(_.outputPartitioning.numPartitions) if (partiallyClustered) { - assert(scans == Seq(8, 8)) + assert(groupPartitions == Seq(8, 8)) } else { - assert(scans == Seq(4, 4)) + assert(groupPartitions == Seq(4, 4)) } checkAnswer(df, Seq( Row(3, "dd", "dd"), @@ -1564,23 +1609,23 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { assert(shuffles.nonEmpty, "SPJ should not be triggered") } - val scannedPartitions = collectScans(df.queryExecution.executedPlan) - .map(_.inputRDD.partitions.length) + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + .map(_.outputPartitioning.numPartitions) (allowJoinKeysSubsetOfPartitionKeys, partiallyClustered, filter) match { // SPJ, partially-clustered, with filter - case (true, true, true) => assert(scannedPartitions == Seq(6, 6)) + case (true, true, true) => assert(groupPartitions == Seq(6, 6)) // SPJ, partially-clustered, no filter - case (true, true, false) => assert(scannedPartitions == Seq(8, 8)) + case (true, true, false) => assert(groupPartitions == Seq(8, 8)) // SPJ and not partially-clustered, with filter - case (true, false, true) => assert(scannedPartitions == Seq(2, 2)) + case (true, false, true) => assert(groupPartitions == Seq(2, 2)) // SPJ and not partially-clustered, no filter - case (true, false, false) => assert(scannedPartitions == Seq(4, 4)) + case (true, false, false) => assert(groupPartitions == Seq(4, 4)) // No SPJ - case _ => assert(scannedPartitions == Seq(5, 4)) + case _ => assert(groupPartitions == Seq.empty) } checkAnswer(df, Seq( @@ -1703,8 +1748,8 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val shuffles = collectShuffles(df.queryExecution.executedPlan) assert(shuffles.isEmpty, "SPJ should be triggered") - val partions = collectScans(df.queryExecution.executedPlan).map(_.inputRDD. - partitions.length) + val partions = collectGroupPartitions(df.queryExecution.executedPlan) + .map(_.outputPartitioning.numPartitions) val expectedBuckets = Math.min(table1buckets1, table2buckets1) * Math.min(table1buckets2, table2buckets2) assert(partions == Seq(expectedBuckets, expectedBuckets)) @@ -1863,13 +1908,12 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val shuffles = collectShuffles(df.queryExecution.executedPlan) assert(shuffles.isEmpty, "SPJ should be triggered") - val scans = collectScans(df.queryExecution.executedPlan).map(_.inputRDD. - partitions.length) - + val partitions = collectGroupPartitions(df.queryExecution.executedPlan) + .map(_.outputPartitioning.numPartitions) def gcd(a: Int, b: Int): Int = BigInt(a).gcd(BigInt(b)).toInt val expectedPartitions = gcd(table1buckets1, table2buckets1) * gcd(table1buckets2, table2buckets2) - assert(scans == Seq(expectedPartitions, expectedPartitions)) + assert(partitions == Seq(expectedPartitions, expectedPartitions)) checkAnswer(df, Seq( Row(0, 0, "aa", "aa"), @@ -2041,12 +2085,12 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val shuffles = collectShuffles(df.queryExecution.executedPlan) assert(shuffles.isEmpty, "SPJ should be triggered") - val scans = collectScans(df.queryExecution.executedPlan).map(_.inputRDD. - partitions.length) + val partitions = collectGroupPartitions(df.queryExecution.executedPlan) + .map(_.outputPartitioning.numPartitions) val expectedBuckets = Math.min(table1buckets, table2buckets) - assert(scans == Seq(expectedBuckets, expectedBuckets)) + assert(partitions == Seq(expectedBuckets, expectedBuckets)) checkAnswer(df, Seq( Row(0, 6, 0, 0, "aa", "01"), @@ -2105,16 +2149,16 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { |""".stripMargin) val shuffles = collectShuffles(df.queryExecution.executedPlan) - val scans = collectScans(df.queryExecution.executedPlan).map(_.inputRDD. - partitions.length) + val partitions = collectGroupPartitions(df.queryExecution.executedPlan) + .map(_.outputPartitioning.numPartitions) (allowPushDown, partiallyClustered) match { case (true, false) => assert(shuffles.isEmpty, "SPJ should be triggered") - assert(scans == Seq(2, 2)) + assert(partitions == Seq(2, 2)) case (_, _) => assert(shuffles.nonEmpty, "SPJ should not be triggered") - assert(scans == Seq(3, 2)) + assert(partitions.isEmpty) } checkAnswer(df, Seq( @@ -2172,13 +2216,13 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { assert(shuffles.nonEmpty, "SPJ should not be triggered") } - val scans = collectScans(df.queryExecution.executedPlan) - .map(_.inputRDD.partitions.length) + val partitions = collectGroupPartitions(df.queryExecution.executedPlan) + .map(_.outputPartitioning.numPartitions) (pushDownValues, allowJoinKeysSubsetOfPartitionKeys, partiallyClustered) match { // SPJ and partially-clustered - case (true, true, true) => assert(scans == Seq(3, 3)) + case (_, true, _) => assert(partitions == Seq(3, 3)) // non-SPJ or SPJ/partially-clustered - case _ => assert(scans == Seq(3, 3)) + case _ => assert(partitions.isEmpty) } } } @@ -2226,15 +2270,15 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { assert(shuffles.nonEmpty, "SPJ should not be triggered") } - val scans = collectScans(df.queryExecution.executedPlan) - .map(_.inputRDD.partitions.length) + val partitions = collectGroupPartitions(df.queryExecution.executedPlan) + .map(_.outputPartitioning.numPartitions) (allowJoinKeysSubsetOfPartitionKeys, partiallyClustered) match { // SPJ and partially-clustered - case (true, true) => assert(scans == Seq(5, 5)) + case (true, true) => assert(partitions == Seq(5, 5)) // SPJ and not partially-clustered - case (true, false) => assert(scans == Seq(3, 3)) + case (true, false) => assert(partitions == Seq(3, 3)) // No SPJ - case _ => assert(scans == Seq(4, 4)) + case _ => assert(partitions.isEmpty) } checkAnswer(df, @@ -2466,8 +2510,8 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0), Row(5, "cc", 44.5, 44.0)) ) - val scans = collectScans(df.queryExecution.executedPlan) - assert(scans.forall(_.inputRDD.partitions.length == 2)) + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.forall(_.outputPartitioning.numPartitions == 2)) } } @@ -2491,8 +2535,8 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val shuffles = collectShuffles(df.queryExecution.executedPlan) assert(shuffles.isEmpty, "should not add shuffle for both sides of the join") assert(df.collect().isEmpty, "should return no results") - val scans = collectScans(df.queryExecution.executedPlan) - assert(scans.forall(_.inputRDD.partitions.length == 0)) + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.forall(_.outputPartitioning.numPartitions == 0)) } } @@ -2523,8 +2567,8 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { Row(1, "aa", 40.0, 40.0)) ) - val scans = collectScans(df.queryExecution.executedPlan) - assert(scans.forall(_.inputRDD.partitions.length == 3)) + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.forall(_.outputPartitioning.numPartitions == 3)) } } @@ -2556,8 +2600,8 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { Row(1, "aa", 40.0, 40.0)) ) - val scans = collectScans(df.queryExecution.executedPlan) - assert(scans.forall(_.inputRDD.partitions.length == 4)) + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.forall(_.outputPartitioning.numPartitions == 4)) } } @@ -2588,8 +2632,8 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { Row(4, "aa", 40.0, 42.0)) ) - val scans = collectScans(df.queryExecution.executedPlan) - assert(scans.forall(_.inputRDD.partitions.length == 3)) + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.forall(_.outputPartitioning.numPartitions == 3)) } } @@ -2623,8 +2667,8 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0), Row(5, "cc", 44.5, 44.0)) ) - val scans = collectScans(df.queryExecution.executedPlan) - assert(scans.forall(_.inputRDD.partitions.length == 2)) + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.forall(_.outputPartitioning.numPartitions == 2)) } } @@ -2646,10 +2690,13 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val shuffles = collectAllShuffles(df.queryExecution.executedPlan) assert(shuffles.isEmpty, "should not contain shuffle when not grouping by partition values") + val groupPartitions = collectAllGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.size === 1) + assert(groupPartitions.head.outputPartitioning.numPartitions == 3) } } - test("SPARK-53322: checkpointed scans aren't used for SPJ") { + test("SPARK-53322: checkpointed scans are used for SPJ") { withTempDir { dir => spark.sparkContext.setCheckpointDir(dir.getPath) val itemsPartitions = Array(identity("id")) @@ -2688,14 +2735,21 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { df, Seq(Row(1, "aa", 41.0, 40.0), Row(3, "cc", 15.5, 25.5)) ) - // 1 shuffle for SORT and 2 shuffles for JOIN are expected. - assert(collectAllShuffles(df.queryExecution.executedPlan).length === 3) + if (pushdownValues) { + // 1 shuffle for SORT and 2 group partitions for JOIN are expected. + assert(collectAllShuffles(df.queryExecution.executedPlan).length === 1) + assert(collectAllGroupPartitions(df.queryExecution.executedPlan).length === 2) + } else { + // 1 shuffle for SORT and 2 shuffles for JOIN are expected. + assert(collectAllShuffles(df.queryExecution.executedPlan).length === 3) + assert(collectAllGroupPartitions(df.queryExecution.executedPlan).length === 0) + } } } } } - test("SPARK-53322: checkpointed scans can't shuffle other children on SPJ") { + test("SPARK-53322: checkpointed scans can shuffle other children on SPJ") { withTempDir { dir => spark.sparkContext.setCheckpointDir(dir.getPath) val itemsPartitions = Array(identity("id")) @@ -2727,52 +2781,16 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { df, Seq(Row(1, "aa", 41.0, 40.0), Row(3, "cc", 15.5, 25.5)) ) - // 1 shuffle for SORT and 2 shuffles for JOIN are expected. - assert(collectAllShuffles(df.queryExecution.executedPlan).length === 3) + // 1 shuffle for SORT and 1 shuffle for JOIN are expected. + assert(collectAllShuffles(df.queryExecution.executedPlan).length === 2) + // 0 group partitions are expected because both sides of the join are clustered from scans + assert(collectAllGroupPartitions(df.queryExecution.executedPlan).length === 0) } } } } - test("SPARK-53322: checkpointed scans can be shuffled by children on SPJ") { - withTempDir { dir => - spark.sparkContext.setCheckpointDir(dir.getPath) - val itemsPartitions = Array(identity("id")) - createTable(items, itemsColumns, itemsPartitions) - sql(s"INSERT INTO testcat.ns.$items VALUES " + - s"(1, 'aa', 41.0, cast('2020-01-01' as timestamp)), " + - s"(2, 'bb', 10.0, cast('2020-01-02' as timestamp)), " + - s"(3, 'cc', 15.5, cast('2020-01-03' as timestamp))") - - createTable(purchases, purchasesColumns, Array(identity("item_id"))) - sql(s"INSERT INTO testcat.ns.$purchases VALUES " + - s"(1, 40.0, cast('2020-01-01' as timestamp)), " + - s"(3, 25.5, cast('2020-01-03' as timestamp)), " + - s"(4, 20.0, cast('2020-01-04' as timestamp))") - - withSQLConf( - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", - SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true", - SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true") { - val scanDF1 = spark.read.table(s"testcat.ns.$items").checkpoint().as("i") - val scanDF2 = spark.read.table(s"testcat.ns.$purchases").as("p") - - val df = scanDF1 - .join(scanDF2, col("id") === col("item_id")) - .selectExpr("id", "name", "i.price AS purchase_price", "p.price AS sale_price") - .orderBy("id", "purchase_price", "sale_price") - checkAnswer( - df, - Seq(Row(1, "aa", 41.0, 40.0), Row(3, "cc", 15.5, 25.5)) - ) - - // One shuffle for the sort and one shuffle for one side of the JOIN are expected. - assert(collectAllShuffles(df.queryExecution.executedPlan).length === 2) - } - } - } - - test("SPARK-54439: KeyGroupedPartitioning and join key size mismatch") { + test("SPARK-54439: KeyedPartitioning and join key size mismatch") { val items_partitions = Array(identity("id")) createTable(items, itemsColumns, items_partitions) @@ -2797,7 +2815,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } } - test("SPARK-54439: KeyGroupedPartitioning with transform and join key size mismatch") { + test("SPARK-54439: KeyedPartitioning with transform and join key size mismatch") { val items_partitions = Array(years("arrive_time")) createTable(items, itemsColumns, items_partitions) @@ -2832,10 +2850,13 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { "(4, 'cc', 15.5, cast('2021-02-01' as timestamp))") val metrics = runAndFetchMetrics { - val df = sql(s"SELECT * FROM testcat.ns.$items") - val scans = collectScans(df.queryExecution.executedPlan) - assert(scans(0).inputRDD.partitions.length === 2, "items scan should have 2 partition groups") + val df = sql(s"SELECT id, count(*) FROM testcat.ns.$items GROUP BY id") df.collect() + val scans = collectScans(df.queryExecution.executedPlan) + assert(scans(0).inputRDD.partitions.length === 3, "items scan should have 3 partitions") + val groupPartitions = collectAllGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions(0).outputPartitioning.numPartitions === 2, + "group partitions should have 2 partition groups") } assert(metrics("number of rows read") == "3") } @@ -2892,4 +2913,65 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { Row("ccc", 30, 400.50))) } } + + test("SPARK-55092: Don't group partitions for join when not needed") { + val items_partitions = Array(identity("id")) + createTable(items, itemsColumns, items_partitions) + + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(4, 'bb', 10.0, cast('2021-01-01' as timestamp)), " + + "(4, 'cc', 15.5, cast('2021-02-01' as timestamp))") + + val purchases_partitions = Array(years("time")) + createTable(purchases, purchasesColumns, purchases_partitions) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(3, 19.5, cast('2020-02-01' as timestamp))") + + withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true") { + val df = createJoinTestDF(Seq("id" -> "item_id"), extraColumns = Seq("year(p.time)")) + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.size == 1, "only shuffle one side not report partitioning") + + val scans = collectScans(df.queryExecution.executedPlan) + assert(scans(0).inputRDD.partitions.length === 3, + "items scan should not group") + assert(scans(1).inputRDD.partitions.length === 2, + "purchases scan should not group") + + checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0, 2020))) + } + + withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "false") { + val df = createJoinTestDF(Seq("id" -> "item_id"), extraColumns = Seq("year(p.time)")) + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.size == 2, "only shuffle one side not report partitioning") + + val scans = collectScans(df.queryExecution.executedPlan) + assert(scans(0).inputRDD.partitions.length === 3, + "items scan should not group as it is shuffled") + assert(scans(1).inputRDD.partitions.length === 2, + "purchases scan should not group as it is shuffled") + + checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0, 2020))) + } + } + + test("SPARK-55092: Don't group partitions for aggregate when not needed") { + val items_partitions = Array(identity("id")) + createTable(items, itemsColumns, items_partitions) + + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(4, 'bb', 10.0, cast('2021-01-01' as timestamp)), " + + "(4, 'cc', 15.5, cast('2021-02-01' as timestamp))") + + val df = sql(s"SELECT * FROM testcat.ns.$items") + val scans = collectScans(df.queryExecution.executedPlan) + assert(scans(0).inputRDD.partitions.length === 3, + "items scan should not group") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala index 1cc0d795d74f8..3ad97be1622e2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.statsEstimation.StatsTestPlan import org.apache.spark.sql.connector.catalog.functions._ import org.apache.spark.sql.execution.{DummySparkPlan, SortExec} import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.datasources.v2.BatchScanExec +import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, GroupPartitionsExec} import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.execution.python.FlatMapCoGroupsInPandasExec import org.apache.spark.sql.execution.window.WindowExec @@ -91,15 +91,15 @@ class EnsureRequirementsSuite extends SharedSparkSession { } } - test("reorder should handle KeyGroupedPartitioning") { + test("reorder should handle KeyedPartitioning") { // partitioning on the left val plan1 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning(Seq( - years(exprA), bucket(4, exprB), days(exprC)), 4) + outputPartitioning = KeyedPartitioning( + Seq(years(exprA), bucket(4, exprB), days(exprC)), Seq.empty, Seq.empty) ) val plan2 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning(Seq( - years(exprB), bucket(4, exprA), days(exprD)), 4) + outputPartitioning = KeyedPartitioning(Seq( + years(exprB), bucket(4, exprA), days(exprD)), Seq.empty, Seq.empty) ) val smjExec = SortMergeJoinExec( exprB :: exprC :: exprA :: Nil, exprA :: exprD :: exprB :: Nil, @@ -107,8 +107,8 @@ class EnsureRequirementsSuite extends SharedSparkSession { ) EnsureRequirements.apply(smjExec) match { case SortMergeJoinExec(leftKeys, rightKeys, _, _, - SortExec(_, _, DummySparkPlan(_, _, _: KeyGroupedPartitioning, _, _), _), - SortExec(_, _, DummySparkPlan(_, _, _: KeyGroupedPartitioning, _, _), _), _) => + SortExec(_, _, DummySparkPlan(_, _, _: KeyedPartitioning, _, _), _), + SortExec(_, _, DummySparkPlan(_, _, _: KeyedPartitioning, _, _), _), _) => assert(leftKeys === Seq(exprA, exprB, exprC)) assert(rightKeys === Seq(exprB, exprA, exprD)) case other => fail(other.toString) @@ -116,8 +116,8 @@ class EnsureRequirementsSuite extends SharedSparkSession { // partitioning on the right val plan3 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning(Seq( - bucket(4, exprD), days(exprA), years(exprC)), 4) + outputPartitioning = KeyedPartitioning(Seq( + bucket(4, exprD), days(exprA), years(exprC)), Seq.empty, Seq.empty) ) val smjExec2 = SortMergeJoinExec( exprB :: exprD :: exprC :: Nil, exprA :: exprC :: exprD :: Nil, @@ -777,18 +777,18 @@ class EnsureRequirementsSuite extends SharedSparkSession { } } - test("Check with KeyGroupedPartitioning") { + test("Check with KeyedPartitioning") { // simplest case: identity transforms var plan1 = new DummySparkPlanWithBatchScanChild( - KeyGroupedPartitioning(exprA :: exprB :: Nil, 5)) + KeyedPartitioning(exprA :: exprB :: Nil, Seq.empty, Seq.empty)) var plan2 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning(exprA :: exprC :: Nil, 5)) + outputPartitioning = KeyedPartitioning(exprA :: exprC :: Nil, Seq.empty, Seq.empty)) var smjExec = SortMergeJoinExec( exprA :: exprB :: Nil, exprA :: exprC :: Nil, Inner, None, plan1, plan2) EnsureRequirements.apply(smjExec) match { case SortMergeJoinExec(_, _, _, _, - SortExec(_, _, DummySparkPlan(_, _, left: KeyGroupedPartitioning, _, _), _), - SortExec(_, _, DummySparkPlan(_, _, right: KeyGroupedPartitioning, _, _), _), _) => + SortExec(_, _, DummySparkPlan(_, _, left: KeyedPartitioning, _, _), _), + SortExec(_, _, DummySparkPlan(_, _, right: KeyedPartitioning, _, _), _), _) => assert(left.expressions === Seq(exprA, exprB)) assert(right.expressions === Seq(exprA, exprC)) case other => fail(other.toString) @@ -796,19 +796,19 @@ class EnsureRequirementsSuite extends SharedSparkSession { // matching bucket transforms from both sides plan1 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning( - bucket(4, exprA) :: bucket(16, exprB) :: Nil, 4) + outputPartitioning = KeyedPartitioning( + bucket(4, exprA) :: bucket(16, exprB) :: Nil, Seq.empty, Seq.empty) ) plan2 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning( - bucket(4, exprA) :: bucket(16, exprC) :: Nil, 4) + outputPartitioning = KeyedPartitioning( + bucket(4, exprA) :: bucket(16, exprC) :: Nil, Seq.empty, Seq.empty) ) smjExec = SortMergeJoinExec( exprA :: exprB :: Nil, exprA :: exprC :: Nil, Inner, None, plan1, plan2) EnsureRequirements.apply(smjExec) match { case SortMergeJoinExec(_, _, _, _, - SortExec(_, _, DummySparkPlan(_, _, left: KeyGroupedPartitioning, _, _), _), - SortExec(_, _, DummySparkPlan(_, _, right: KeyGroupedPartitioning, _, _), _), _) => + SortExec(_, _, DummySparkPlan(_, _, left: KeyedPartitioning, _, _), _), + SortExec(_, _, DummySparkPlan(_, _, right: KeyedPartitioning, _, _), _), _) => assert(left.expressions === Seq(bucket(4, exprA), bucket(16, exprB))) assert(right.expressions === Seq(bucket(4, exprA), bucket(16, exprC))) case other => fail(other.toString) @@ -816,20 +816,20 @@ class EnsureRequirementsSuite extends SharedSparkSession { // partition collections plan1 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning( - bucket(4, exprA) :: bucket(16, exprB) :: Nil, 4) + outputPartitioning = KeyedPartitioning( + bucket(4, exprA) :: bucket(16, exprB) :: Nil, Seq.empty, Seq.empty) ) plan2 = new DummySparkPlanWithBatchScanChild( outputPartitioning = PartitioningCollection(Seq( - KeyGroupedPartitioning(bucket(4, exprA) :: bucket(16, exprC) :: Nil, 4), - KeyGroupedPartitioning(bucket(4, exprA) :: bucket(16, exprC) :: Nil, 4)) + KeyedPartitioning(bucket(4, exprA) :: bucket(16, exprC) :: Nil, Seq.empty, Seq.empty), + KeyedPartitioning(bucket(4, exprA) :: bucket(16, exprC) :: Nil, Seq.empty, Seq.empty)) ) ) smjExec = SortMergeJoinExec( exprA :: exprB :: Nil, exprA :: exprC :: Nil, Inner, None, plan1, plan2) EnsureRequirements.apply(smjExec) match { case SortMergeJoinExec(_, _, _, _, - SortExec(_, _, DummySparkPlan(_, _, left: KeyGroupedPartitioning, _, _), _), + SortExec(_, _, DummySparkPlan(_, _, left: KeyedPartitioning, _, _), _), SortExec(_, _, DummySparkPlan(_, _, _: PartitioningCollection, _, _), _), _) => assert(left.expressions === Seq(bucket(4, exprA), bucket(16, exprB))) case other => fail(other.toString) @@ -839,24 +839,26 @@ class EnsureRequirementsSuite extends SharedSparkSession { EnsureRequirements.apply(smjExec) match { case SortMergeJoinExec(_, _, _, _, SortExec(_, _, DummySparkPlan(_, _, _: PartitioningCollection, _, _), _), - SortExec(_, _, DummySparkPlan(_, _, right: KeyGroupedPartitioning, _, _), _), _) => + SortExec(_, _, DummySparkPlan(_, _, right: KeyedPartitioning, _, _), _), _) => assert(right.expressions === Seq(bucket(4, exprA), bucket(16, exprB))) case other => fail(other.toString) } // bucket + years transforms from both sides plan1 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning(bucket(4, exprA) :: years(exprB) :: Nil, 4) + outputPartitioning = KeyedPartitioning(bucket(4, exprA) :: years(exprB) :: Nil, + Seq.empty, Seq.empty) ) plan2 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning(bucket(4, exprA) :: years(exprC) :: Nil, 4) + outputPartitioning = KeyedPartitioning(bucket(4, exprA) :: years(exprC) :: Nil, + Seq.empty, Seq.empty) ) smjExec = SortMergeJoinExec( exprA :: exprB :: Nil, exprA :: exprC :: Nil, Inner, None, plan1, plan2) EnsureRequirements.apply(smjExec) match { case SortMergeJoinExec(_, _, _, _, - SortExec(_, _, DummySparkPlan(_, _, left: KeyGroupedPartitioning, _, _), _), - SortExec(_, _, DummySparkPlan(_, _, right: KeyGroupedPartitioning, _, _), _), _) => + SortExec(_, _, DummySparkPlan(_, _, left: KeyedPartitioning, _, _), _), + SortExec(_, _, DummySparkPlan(_, _, right: KeyedPartitioning, _, _), _), _) => assert(left.expressions === Seq(bucket(4, exprA), years(exprB))) assert(right.expressions === Seq(bucket(4, exprA), years(exprC))) case other => fail(other.toString) @@ -865,12 +867,12 @@ class EnsureRequirementsSuite extends SharedSparkSession { // by default spark.sql.requireAllClusterKeysForCoPartition is true, so when there isn't // exact match on all partition keys, Spark will fallback to shuffle. plan1 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning( - bucket(4, exprA) :: bucket(4, exprB) :: Nil, 4) + outputPartitioning = KeyedPartitioning( + bucket(4, exprA) :: bucket(4, exprB) :: Nil, Seq.empty, Seq.empty) ) plan2 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning( - bucket(4, exprA) :: bucket(4, exprC) :: Nil, 4) + outputPartitioning = KeyedPartitioning( + bucket(4, exprA) :: bucket(4, exprC) :: Nil, Seq.empty, Seq.empty) ) smjExec = SortMergeJoinExec( exprA :: exprB :: exprB :: Nil, exprA :: exprC :: exprC :: Nil, Inner, None, plan1, plan2) @@ -884,14 +886,14 @@ class EnsureRequirementsSuite extends SharedSparkSession { } } - test(s"KeyGroupedPartitioning with ${REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key} = false") { + test(s"KeyedPartitioning with ${REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key} = false") { var plan1 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning( - bucket(4, exprB) :: years(exprC) :: Nil, 4) + outputPartitioning = KeyedPartitioning( + bucket(4, exprB) :: years(exprC) :: Nil, Seq.empty, Seq.empty) ) var plan2 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning( - bucket(4, exprC) :: years(exprB) :: Nil, 4) + outputPartitioning = KeyedPartitioning( + bucket(4, exprC) :: years(exprB) :: Nil, Seq.empty, Seq.empty) ) // simple case @@ -899,8 +901,8 @@ class EnsureRequirementsSuite extends SharedSparkSession { exprA :: exprB :: exprC :: Nil, exprA :: exprC :: exprB :: Nil, Inner, None, plan1, plan2) applyEnsureRequirementsWithSubsetKeys(smjExec) match { case SortMergeJoinExec(_, _, _, _, - SortExec(_, _, DummySparkPlan(_, _, left: KeyGroupedPartitioning, _, _), _), - SortExec(_, _, DummySparkPlan(_, _, right: KeyGroupedPartitioning, _, _), _), _) => + SortExec(_, _, DummySparkPlan(_, _, left: KeyedPartitioning, _, _), _), + SortExec(_, _, DummySparkPlan(_, _, right: KeyedPartitioning, _, _), _), _) => assert(left.expressions === Seq(bucket(4, exprB), years(exprC))) assert(right.expressions === Seq(bucket(4, exprC), years(exprB))) case other => fail(other.toString) @@ -908,19 +910,19 @@ class EnsureRequirementsSuite extends SharedSparkSession { // should also work with distributions with duplicated keys plan1 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning( - bucket(4, exprA) :: years(exprB) :: Nil, 4) + outputPartitioning = KeyedPartitioning( + bucket(4, exprA) :: years(exprB) :: Nil, Seq.empty, Seq.empty) ) plan2 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning( - bucket(4, exprA) :: years(exprC) :: Nil, 4) + outputPartitioning = KeyedPartitioning( + bucket(4, exprA) :: years(exprC) :: Nil, Seq.empty, Seq.empty) ) smjExec = SortMergeJoinExec( exprA :: exprB :: exprB :: Nil, exprA :: exprC :: exprC :: Nil, Inner, None, plan1, plan2) applyEnsureRequirementsWithSubsetKeys(smjExec) match { case SortMergeJoinExec(_, _, _, _, - SortExec(_, _, DummySparkPlan(_, _, left: KeyGroupedPartitioning, _, _), _), - SortExec(_, _, DummySparkPlan(_, _, right: KeyGroupedPartitioning, _, _), _), _) => + SortExec(_, _, DummySparkPlan(_, _, left: KeyedPartitioning, _, _), _), + SortExec(_, _, DummySparkPlan(_, _, right: KeyedPartitioning, _, _), _), _) => assert(left.expressions === Seq(bucket(4, exprA), years(exprB))) assert(right.expressions === Seq(bucket(4, exprA), years(exprC))) case other => fail(other.toString) @@ -928,17 +930,17 @@ class EnsureRequirementsSuite extends SharedSparkSession { // both partitioning and distribution have duplicated keys plan1 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning( - years(exprA) :: bucket(4, exprB) :: days(exprA) :: Nil, 5)) + outputPartitioning = KeyedPartitioning( + years(exprA) :: bucket(4, exprB) :: days(exprA) :: Nil, Seq.empty, Seq.empty)) plan2 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning( - years(exprA) :: bucket(4, exprC) :: days(exprA) :: Nil, 5)) + outputPartitioning = KeyedPartitioning( + years(exprA) :: bucket(4, exprC) :: days(exprA) :: Nil, Seq.empty, Seq.empty)) smjExec = SortMergeJoinExec( exprA :: exprB :: exprB :: Nil, exprA :: exprC :: exprC :: Nil, Inner, None, plan1, plan2) applyEnsureRequirementsWithSubsetKeys(smjExec) match { case SortMergeJoinExec(_, _, _, _, - SortExec(_, _, DummySparkPlan(_, _, left: KeyGroupedPartitioning, _, _), _), - SortExec(_, _, DummySparkPlan(_, _, right: KeyGroupedPartitioning, _, _), _), _) => + SortExec(_, _, DummySparkPlan(_, _, left: KeyedPartitioning, _, _), _), + SortExec(_, _, DummySparkPlan(_, _, right: KeyedPartitioning, _, _), _), _) => assert(left.expressions === Seq(years(exprA), bucket(4, exprB), days(exprA))) assert(right.expressions === Seq(years(exprA), bucket(4, exprC), days(exprA))) case other => fail(other.toString) @@ -946,12 +948,12 @@ class EnsureRequirementsSuite extends SharedSparkSession { // invalid case: partitioning key positions don't match plan1 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning( - bucket(4, exprA) :: bucket(4, exprB) :: Nil, 4) + outputPartitioning = KeyedPartitioning( + bucket(4, exprA) :: bucket(4, exprB) :: Nil, Seq.empty, Seq.empty) ) plan2 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning( - bucket(4, exprB) :: bucket(4, exprC) :: Nil, 4) + outputPartitioning = KeyedPartitioning( + bucket(4, exprB) :: bucket(4, exprC) :: Nil, Seq.empty, Seq.empty) ) smjExec = SortMergeJoinExec( @@ -967,12 +969,12 @@ class EnsureRequirementsSuite extends SharedSparkSession { // invalid case: different number of buckets (we don't support coalescing/repartitioning yet) plan1 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning( - bucket(4, exprA) :: bucket(4, exprB) :: Nil, 4) + outputPartitioning = KeyedPartitioning( + bucket(4, exprA) :: bucket(4, exprB) :: Nil, Seq.empty, Seq.empty) ) plan2 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning( - bucket(4, exprA) :: bucket(8, exprC) :: Nil, 4) + outputPartitioning = KeyedPartitioning( + bucket(4, exprA) :: bucket(8, exprC) :: Nil, Seq.empty, Seq.empty) ) smjExec = SortMergeJoinExec( exprA :: exprB :: exprB :: Nil, exprA :: exprC :: exprC :: Nil, Inner, None, plan1, plan2) @@ -987,10 +989,12 @@ class EnsureRequirementsSuite extends SharedSparkSession { // invalid case: partition key positions match but with different transforms plan1 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning(years(exprA) :: bucket(4, exprB) :: Nil, 4) + outputPartitioning = KeyedPartitioning(years(exprA) :: bucket(4, exprB) :: Nil, + Seq.empty, Seq.empty) ) plan2 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning(days(exprA) :: bucket(4, exprC) :: Nil, 4) + outputPartitioning = KeyedPartitioning(days(exprA) :: bucket(4, exprC) :: Nil, + Seq.empty, Seq.empty) ) smjExec = SortMergeJoinExec( exprA :: exprB :: exprB :: Nil, exprA :: exprC :: exprC :: Nil, Inner, None, plan1, plan2) @@ -1006,12 +1010,12 @@ class EnsureRequirementsSuite extends SharedSparkSession { // invalid case: multiple references in transform plan1 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning( - years(exprA) :: buckets(4, Seq(exprB, exprC)) :: Nil, 4) + outputPartitioning = KeyedPartitioning( + years(exprA) :: buckets(4, Seq(exprB, exprC)) :: Nil, Seq.empty, Seq.empty) ) plan2 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning( - years(exprA) :: buckets(4, Seq(exprB, exprC)) :: Nil, 4) + outputPartitioning = KeyedPartitioning( + years(exprA) :: buckets(4, Seq(exprB, exprC)) :: Nil, Seq.empty, Seq.empty) ) smjExec = SortMergeJoinExec( exprA :: exprB :: exprB :: Nil, exprA :: exprC :: exprC :: Nil, Inner, None, plan1, plan2) @@ -1032,12 +1036,12 @@ class EnsureRequirementsSuite extends SharedSparkSession { .map(new GenericInternalRow(_)) var plan1 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning(bucket(4, exprB) :: bucket(8, exprC) :: Nil, - leftPartValues.length, leftPartValues) + outputPartitioning = KeyedPartitioning(bucket(4, exprB) :: bucket(8, exprC) :: Nil, + leftPartValues, leftPartValues) ) var plan2 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyGroupedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: Nil, - rightPartValues.length, rightPartValues) + outputPartitioning = KeyedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: Nil, + rightPartValues, rightPartValues) ) // simple case @@ -1045,8 +1049,13 @@ class EnsureRequirementsSuite extends SharedSparkSession { exprA :: exprB :: exprC :: Nil, exprA :: exprC :: exprB :: Nil, Inner, None, plan1, plan2) applyEnsureRequirementsWithSubsetKeys(smjExec) match { case SortMergeJoinExec(_, _, _, _, - SortExec(_, _, DummySparkPlan(_, _, left: KeyGroupedPartitioning, _, _), _), - SortExec(_, _, DummySparkPlan(_, _, right: KeyGroupedPartitioning, _, _), _), _) => + SortExec(_, _, + GroupPartitionsExec(DummySparkPlan(_, _, left: KeyedPartitioning, _, _), + _, _, _, _, _), _), + SortExec(_, _, + GroupPartitionsExec(DummySparkPlan(_, _, right: KeyedPartitioning, _, _), + _, _, _, _, _), _), + _) => assert(left.expressions === Seq(bucket(4, exprB), bucket(8, exprC))) assert(right.expressions === Seq(bucket(4, exprC), bucket(8, exprB))) case other => fail(other.toString) @@ -1055,10 +1064,10 @@ class EnsureRequirementsSuite extends SharedSparkSession { // With partition collections plan1 = new DummySparkPlanWithBatchScanChild(outputPartitioning = PartitioningCollection( - Seq(KeyGroupedPartitioning(bucket(4, exprB) :: bucket(8, exprC) :: Nil, - leftPartValues.length, leftPartValues), - KeyGroupedPartitioning(bucket(4, exprB) :: bucket(8, exprC) :: Nil, - leftPartValues.length, leftPartValues)) + Seq(KeyedPartitioning(bucket(4, exprB) :: bucket(8, exprC) :: Nil, + leftPartValues, leftPartValues), + KeyedPartitioning(bucket(4, exprB) :: bucket(8, exprC) :: Nil, + leftPartValues, leftPartValues)) ) ) @@ -1066,11 +1075,16 @@ class EnsureRequirementsSuite extends SharedSparkSession { exprA :: exprB :: exprC :: Nil, exprA :: exprC :: exprB :: Nil, Inner, None, plan1, plan2) applyEnsureRequirementsWithSubsetKeys(smjExec) match { case SortMergeJoinExec(_, _, _, _, - SortExec(_, _, DummySparkPlan(_, _, left: PartitioningCollection, _, _), _), - SortExec(_, _, DummySparkPlan(_, _, right: KeyGroupedPartitioning, _, _), _), _) => + SortExec(_, _, + GroupPartitionsExec(DummySparkPlan(_, _, left: PartitioningCollection, _, _), + _, _, _, _, _), _), + SortExec(_, _, + GroupPartitionsExec(DummySparkPlan(_, _, right: KeyedPartitioning, _, _), + _, _, _, _, _), _), + _) => assert(left.partitionings.length == 2) - assert(left.partitionings.head.isInstanceOf[KeyGroupedPartitioning]) - assert(left.partitionings.head.asInstanceOf[KeyGroupedPartitioning].expressions == + assert(left.partitionings.head.isInstanceOf[KeyedPartitioning]) + assert(left.partitionings.head.asInstanceOf[KeyedPartitioning].expressions == Seq(bucket(4, exprB), bucket(8, exprC))) assert(right.expressions === Seq(bucket(4, exprC), bucket(8, exprB))) case other => fail(other.toString) @@ -1082,16 +1096,16 @@ class EnsureRequirementsSuite extends SharedSparkSession { Seq( PartitioningCollection( Seq( - KeyGroupedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: Nil, - rightPartValues.length, rightPartValues), - KeyGroupedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: Nil, - rightPartValues.length, rightPartValues))), + KeyedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: Nil, + rightPartValues, rightPartValues), + KeyedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: Nil, + rightPartValues, rightPartValues))), PartitioningCollection( Seq( - KeyGroupedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: Nil, - rightPartValues.length, rightPartValues), - KeyGroupedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: Nil, - rightPartValues.length, rightPartValues))) + KeyedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: Nil, + rightPartValues, rightPartValues), + KeyedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: Nil, + rightPartValues, rightPartValues))) ) ) ) @@ -1100,11 +1114,16 @@ class EnsureRequirementsSuite extends SharedSparkSession { exprA :: exprB :: exprC :: Nil, exprA :: exprC :: exprB :: Nil, Inner, None, plan1, plan2) applyEnsureRequirementsWithSubsetKeys(smjExec) match { case SortMergeJoinExec(_, _, _, _, - SortExec(_, _, DummySparkPlan(_, _, left: PartitioningCollection, _, _), _), - SortExec(_, _, DummySparkPlan(_, _, right: PartitioningCollection, _, _), _), _) => + SortExec(_, _, + GroupPartitionsExec(DummySparkPlan(_, _, left: PartitioningCollection, _, _), + _, _, _, _, _), _), + SortExec(_, _, + GroupPartitionsExec(DummySparkPlan(_, _, right: PartitioningCollection, _, _), + _, _, _, _, _), _), + _) => assert(left.partitionings.length == 2) - assert(left.partitionings.head.isInstanceOf[KeyGroupedPartitioning]) - assert(left.partitionings.head.asInstanceOf[KeyGroupedPartitioning].expressions == + assert(left.partitionings.head.isInstanceOf[KeyedPartitioning]) + assert(left.partitionings.head.asInstanceOf[KeyedPartitioning].expressions == Seq(bucket(4, exprB), bucket(8, exprC))) assert(right.partitionings.length == 2) assert(right.partitionings.head.isInstanceOf[PartitioningCollection]) @@ -1120,16 +1139,16 @@ class EnsureRequirementsSuite extends SharedSparkSession { val a1 = AttributeReference("a1", IntegerType)() val partitionValue = Seq(50, 51, 52).map(v => InternalRow.fromSeq(Seq(v))) - val plan1 = new DummySparkPlanWithBatchScanChild(outputPartitioning = KeyGroupedPartitioning( - identity(a1) :: Nil, 4, partitionValue)) + val plan1 = new DummySparkPlanWithBatchScanChild(outputPartitioning = KeyedPartitioning( + identity(a1) :: Nil, partitionValue, partitionValue)) val plan2 = DummySparkPlan(outputPartitioning = SinglePartition) val smjExec = ShuffledHashJoinExec( a1 :: Nil, a1 :: Nil, Inner, BuildRight, None, plan1, plan2) EnsureRequirements.apply(smjExec) match { case ShuffledHashJoinExec(_, _, _, _, _, - DummySparkPlan(_, _, left: KeyGroupedPartitioning, _, _), - ShuffleExchangeExec(KeyGroupedPartitioning(attrs, 4, pv, _), + DummySparkPlan(_, _, left: KeyedPartitioning, _, _), + ShuffleExchangeExec(KeyedPartitioning(attrs, pv, _), DummySparkPlan(_, _, SinglePartition, _, _), _, _), _) => assert(left.expressions == a1 :: Nil) assert(attrs == a1 :: Nil) From 2762976e967a895e66c3c8ce7ef3350752a1bb81 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Sat, 21 Feb 2026 11:15:42 +0100 Subject: [PATCH 02/29] code cleanup --- .../plans/physical/partitioning.scala | 20 +++ .../datasources/v2/GroupPartitionsExec.scala | 125 +++++++++--------- .../exchange/EnsureRequirements.scala | 59 ++++----- 3 files changed, 114 insertions(+), 90 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index c999a5cb7b938..061d4eee55564 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -548,6 +548,16 @@ object KeyedPartitioning { new GenericInternalRow(projectedKey) } + /** + * Projects a sequence of partition keys by selecting only the specified positions. + */ + def projectKeys( + keys: Seq[InternalRow], + positions: Seq[Int], + dataTypes: Seq[DataType]): Seq[InternalRow] = { + keys.map(projectKey(_, positions, dataTypes)) + } + def reduceKey( key: InternalRow, reducers: Seq[Option[Reducer[_, _]]], @@ -559,6 +569,16 @@ object KeyedPartitioning { }.toArray new GenericInternalRow(reducedKey) } + + /** + * Reduces a sequence of partition keys by applying reducers to each position. + */ + def reduceKeys( + keys: Seq[InternalRow], + reducers: Seq[Option[Reducer[_, _]]], + dataTypes: Seq[DataType]): Seq[InternalRow] = { + keys.map(reduceKey(_, reducers, dataTypes)) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala index 498c4b7a0f42a..098540d0717fa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala @@ -61,41 +61,56 @@ case class GroupPartitionsExec( case p: Partitioning with Expression => p.transform { case k: KeyedPartitioning => - val projectedExpressions = projectExpressions(k.expressions) - val projectedDataTypes = projectedExpressions.map(_.dataType) - k.copy(expressions = projectedExpressions, + val (projectedExpressions, projectedDataTypes) = projectExpressions(k.expressions) + val projectedOriginalKeys = joinKeyPositions.fold(k.originalPartitionKeys)( + KeyedPartitioning.projectKeys(k.originalPartitionKeys, _, projectedDataTypes)) + k.copy( + expressions = projectedExpressions, partitionKeys = groupedPartitions.map(_._1), - originalPartitionKeys = projectKeys(k.originalPartitionKeys, projectedDataTypes)) + originalPartitionKeys = projectedOriginalKeys) }.asInstanceOf[Partitioning] case o => o } } - private def projectExpressions(expressions: Seq[Expression]) = { - joinKeyPositions match { - case Some(projectionPositions) => - projectionPositions.map(expressions) - case _ => expressions - } + private def projectExpressions(expressions: Seq[Expression]): (Seq[Expression], Seq[DataType]) = { + val projectedExpressions = joinKeyPositions.fold(expressions)(_.map(expressions)) + val projectedDataTypes = projectedExpressions.map(_.dataType) + + (projectedExpressions, projectedDataTypes) } - private def projectKeys(keys: Seq[InternalRow], dataTypes: Seq[DataType]) = { - joinKeyPositions match { - case Some(projectionPositions) => - keys.map(KeyedPartitioning.projectKey(_, projectionPositions, dataTypes)) - case _ => keys + /** + * Distributes partitions based on `commonPartitionKeys` and clustering mode. + */ + private def distributeByCommonKeys( + keyWrapperMap: Map[InternalRowComparableWrapper, Seq[Int]], + comparableWrapperFactory: InternalRow => InternalRowComparableWrapper + ): Seq[(InternalRow, Seq[Int])] = { + commonPartitionKeys.get.flatMap { case (key, numSplits) => + val splits = keyWrapperMap.getOrElse(comparableWrapperFactory(key), Seq.empty) + if (applyPartialClustering && !replicatePartitions) { + // Distribute splits across expected partitions, padding with empty sequences + val paddedSplits = splits.map(Seq(_)).padTo(numSplits, Seq.empty) + paddedSplits.map((key, _)) + } else { + // Replicate all splits to each expected partition + Seq.fill(numSplits)((key, splits)) + } } } /** - * Extracts the first KeyedPartitioning from the child's output partitioning. - * The child must have a KeyedPartitioning in its partitioning scheme. + * Groups and sorts partitions by their keys in ascending order. */ - lazy val firstKeyedPartitioning = { - child.outputPartitioning.asInstanceOf[Partitioning with Expression].collectFirst { - case k: KeyedPartitioning => k - }.getOrElse( - throw new SparkException("GroupPartitionsExec requires a child with KeyedPartitioning")) + private def groupAndSortByKeys( + keyWrapperMap: Map[InternalRowComparableWrapper, Seq[Int]], + dataTypes: Seq[DataType] + ): Seq[(InternalRow, Seq[Int])] = { + val rowOrdering = RowOrdering.createNaturalAscendingOrdering(dataTypes) + keyWrapperMap.toSeq + .map { case (keyWrapper, indices) => (keyWrapper.row, indices) } + .sorted(rowOrdering.on((t: (InternalRow, _)) => t._1)) } /** @@ -109,54 +124,46 @@ case class GroupPartitionsExec( * how input partitions should be grouped together. */ lazy val groupedPartitions: Seq[(InternalRow, Seq[Int])] = { - // Also sort the input partitions according to their partition key order. This ensures - // a canonical order from both sides of a bucketed join, for example. - - val (projectedDataTypes, projectedKeys) = - joinKeyPositions match { - case Some(projectionPositions) => - val projectedDataTypes = - projectExpressions(firstKeyedPartitioning.expressions).map(_.dataType) - val projectedKeys = firstKeyedPartitioning.partitionKeys - .map(KeyedPartitioning.projectKey(_, projectionPositions, projectedDataTypes)) - (projectedDataTypes, projectedKeys) - - case _ => - val dataTypes = firstKeyedPartitioning.expressions.map(_.dataType) - (dataTypes, firstKeyedPartitioning.partitionKeys) - } + // Extract the KeyedPartitioning from child's output partitioning + val keyedPartitioning = child.outputPartitioning + .asInstanceOf[Partitioning with Expression] + .collectFirst { case k: KeyedPartitioning => k } + .getOrElse( + throw new SparkException("GroupPartitionsExec requires a child with KeyedPartitioning")) + + // Project partition keys if join key positions are specified + val (projectedDataTypes, projectedKeys) = joinKeyPositions match { + case Some(positions) => + val (projectedExpressions, projectedDataTypes) = + projectExpressions(keyedPartitioning.expressions) + val projectedKeys = KeyedPartitioning.projectKeys( + keyedPartitioning.partitionKeys, positions, projectedDataTypes) + (projectedDataTypes, projectedKeys) + + case None => + val dataTypes = keyedPartitioning.expressions.map(_.dataType) + (dataTypes, keyedPartitioning.partitionKeys) + } + // Reduce keys if reducers are specified val reducedKeys = reducers match { case Some(reducers) => - projectedKeys.map(KeyedPartitioning.reduceKey(_, reducers, projectedDataTypes)) - case _ => projectedKeys + KeyedPartitioning.reduceKeys(projectedKeys, reducers, projectedDataTypes) + case None => projectedKeys } - val internalRowComparableWrapperFactory = + // Create map from partition keys to their indices + val comparableWrapperFactory = InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(projectedDataTypes) - val map = reducedKeys.zipWithIndex.groupMap { - case (key, _) => internalRowComparableWrapperFactory(key) + val keyWrapperToPartitionIndices = reducedKeys.zipWithIndex.groupMap { + case (key, _) => comparableWrapperFactory(key) }(_._2) - // When partially clustered, the input partitions are not grouped by partition - // values. Here we'll need to check `commonPartitionKeys` and decide how to group - // and replicate splits within a partition. if (commonPartitionKeys.isDefined) { - commonPartitionKeys.get.flatMap { case (key, numSplits) => - val splits = map.getOrElse(internalRowComparableWrapperFactory(key), Seq.empty) - if (applyPartialClustering && !replicatePartitions) { - val paddedSplits = splits.map(Seq(_)).padTo(numSplits, Seq.empty) - paddedSplits.map((key, _)) - } else { - Seq.fill(numSplits)((key, splits)) - } - } + distributeByCommonKeys(keyWrapperToPartitionIndices, comparableWrapperFactory) } else { - val rowOrdering = RowOrdering.createNaturalAscendingOrdering(projectedDataTypes) - map.toSeq - .map { case (keyWrapper, v) => (keyWrapper.row, v) } - .sorted(rowOrdering.on((t: (InternalRow, _)) => t._1)) + groupAndSortByKeys(keyWrapperToPartitionIndices, projectedDataTypes) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 216a1e7b63523..8b6379ebbf662 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -499,15 +499,15 @@ case class EnsureRequirements( // just push the common set of partition values: `[0, 1, 2, 3]` down to the two data // sources. if (isCompatible) { - val leftPartValues = leftSpec.partitioning.partitionKeys - val rightPartValues = rightSpec.partitioning.partitionKeys + val leftPartKeys = leftSpec.partitioning.partitionKeys + val rightPartKeys = rightSpec.partitioning.partitionKeys - val numLeftPartValues = MDC(LogKeys.NUM_LEFT_PARTITION_VALUES, leftPartValues.size) - val numRightPartValues = MDC(LogKeys.NUM_RIGHT_PARTITION_VALUES, rightPartValues.size) + val numLeftPartKeys = MDC(LogKeys.NUM_LEFT_PARTITION_VALUES, leftPartKeys.size) + val numRightPartKeys = MDC(LogKeys.NUM_RIGHT_PARTITION_VALUES, rightPartKeys.size) logInfo( log""" - |Left side # of partitions: $numLeftPartValues - |Right side # of partitions: $numRightPartValues + |Left side # of partitions: $numLeftPartKeys + |Right side # of partitions: $numRightPartKeys |""".stripMargin) // As partition keys are compatible, we can pick either left or right as partition @@ -517,20 +517,20 @@ case class EnsureRequirements( // in case of compatible but not identical partition expressions, we apply 'reduce' // transforms to group one side's partitions as well as the common partition values val leftReducers = leftSpec.reducers(rightSpec) - val leftParts = reducePartValues(leftSpec.partitioning.partitionKeys, + val leftParts = reducePartitionKeys(leftSpec.partitioning.partitionKeys, partitionExprs, leftReducers) val rightReducers = rightSpec.reducers(leftSpec) - val rightParts = reducePartValues(rightSpec.partitioning.partitionKeys, + val rightParts = reducePartitionKeys(rightSpec.partitioning.partitionKeys, partitionExprs, rightReducers) // merge values on both sides - var mergedPartValues = mergePartitions(leftParts, rightParts, partitionExprs, joinType) + var mergedPartitionKeys = mergePartitions(leftParts, rightParts, partitionExprs, joinType) .map(v => (v, 1)) logInfo(log"After merging, there are " + - log"${MDC(LogKeys.NUM_PARTITIONS, mergedPartValues.size)} partitions") + log"${MDC(LogKeys.NUM_PARTITIONS, mergedPartitionKeys.size)} partitions") var replicateLeftSide = false var replicateRightSide = false @@ -588,7 +588,7 @@ case class EnsureRequirements( // to apply the grouping & replication of partitions logInfo("Using number of partitions to determine which side of join " + "to fully cluster partition values") - leftPartValues.size < rightPartValues.size + leftPartKeys.size < rightPartKeys.size } replicateRightSide = !replicateLeftSide @@ -616,24 +616,24 @@ case class EnsureRequirements( val numExpectedPartitions = originalPartitionKeys .groupBy(internalRowComparableWrapperFactory) - .transform((_, v) => v.size) + .view.mapValues(_.size) - mergedPartValues = mergedPartValues.map { case (partVal, numParts) => - (partVal, numExpectedPartitions.getOrElse( - internalRowComparableWrapperFactory(partVal), numParts)) + mergedPartitionKeys = mergedPartitionKeys.map { case (key, numParts) => + (key, numExpectedPartitions.getOrElse( + internalRowComparableWrapperFactory(key), numParts)) } logInfo(log"After applying partially clustered distribution, there are " + - log"${MDC(LogKeys.NUM_PARTITIONS, mergedPartValues.map(_._2).sum)} partitions.") + log"${MDC(LogKeys.NUM_PARTITIONS, mergedPartitionKeys.map(_._2).sum)} partitions.") applyPartialClustering = true } } } // Now we need to push-down the common partition information to the scan in each child - newLeft = applyGroupPartitions(left, leftSpec.joinKeyPositions, mergedPartValues, + newLeft = applyGroupPartitions(left, leftSpec.joinKeyPositions, mergedPartitionKeys, leftReducers, applyPartialClustering, replicateLeftSide) - newRight = applyGroupPartitions(right, rightSpec.joinKeyPositions, mergedPartValues, + newRight = applyGroupPartitions(right, rightSpec.joinKeyPositions, mergedPartitionKeys, rightReducers, applyPartialClustering, replicateRightSide) } } @@ -683,7 +683,7 @@ case class EnsureRequirements( private def applyGroupPartitions( plan: SparkPlan, joinKeyPositions: Option[Seq[Int]], - mergedPartValues: Seq[(InternalRow, Int)], + mergedPartitionKeys: Seq[(InternalRow, Int)], reducers: Option[Seq[Option[Reducer[_, _]]]], applyPartialClustering: Boolean, replicatePartitions: Boolean): SparkPlan = { @@ -691,12 +691,12 @@ case class EnsureRequirements( case g: GroupPartitionsExec => g.copy( joinKeyPositions = joinKeyPositions, - commonPartitionKeys = Some(mergedPartValues), + commonPartitionKeys = Some(mergedPartitionKeys), reducers = reducers, applyPartialClustering = applyPartialClustering, replicatePartitions = replicatePartitions) case _ => - GroupPartitionsExec(plan, joinKeyPositions, Some(mergedPartValues), reducers, + GroupPartitionsExec(plan, joinKeyPositions, Some(mergedPartitionKeys), reducers, applyPartialClustering, replicatePartitions) } } @@ -711,21 +711,18 @@ case class EnsureRequirements( } } - private def reducePartValues( - partValues: Seq[InternalRow], + private def reducePartitionKeys( + partitionKeys: Seq[InternalRow], expressions: Seq[Expression], reducers: Option[Seq[Option[Reducer[_, _]]]]) = { reducers match { case Some(reducers) => - val partitionDataTypes = expressions.map(_.dataType) + val dataTypes = expressions.map(_.dataType) val internalRowComparableWrapperFactory = - InternalRowComparableWrapper.getInternalRowComparableWrapperFactory( - partitionDataTypes) - partValues.map { row => - KeyGroupedShuffleSpec.reducePartitionKey( - row, reducers, partitionDataTypes, internalRowComparableWrapperFactory) - }.distinct.map(_.row) - case _ => partValues + InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(dataTypes) + KeyedPartitioning.reduceKeys(partitionKeys, reducers, dataTypes) + .distinctBy(internalRowComparableWrapperFactory) + case _ => partitionKeys } } From 0fb5be69bc4fbae0120a3ac03086e0f404584770 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Sat, 21 Feb 2026 19:25:35 +0100 Subject: [PATCH 03/29] more code cleanup --- .../plans/physical/partitioning.scala | 60 +++++++------------ .../datasources/v2/GroupPartitionsExec.scala | 39 +++++------- 2 files changed, 36 insertions(+), 63 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 061d4eee55564..728f43e072545 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -424,33 +424,25 @@ case class KeyedPartitioning( } def toGrouped: KeyedPartitioning = { - val groupedPartitions = partitionKeys - .map(comparableWrapperFactory) - .distinct - .map(_.row) - .sorted(rowOrdering) + val distinctSortedPartitionKeys = + partitionKeys.distinctBy(comparableWrapperFactory).sorted(rowOrdering) - KeyedPartitioning(expressions, groupedPartitions, originalPartitionKeys) + KeyedPartitioning(expressions, distinctSortedPartitionKeys, originalPartitionKeys) } - def projectAndGroup(positions: Seq[Int]): KeyedPartitioning = { + /** + * Projects this partitioning's expressions by selecting only the specified positions. + * Returns both the projected expressions and their data types. + */ + def projectKeys(positions: Seq[Int]): + (Seq[Expression], Seq[DataType], Seq[InternalRow], Seq[InternalRow]) = { val projectedExpressions = positions.map(expressions) val projectedDataTypes = projectedExpressions.map(_.dataType) - val projectedPartitionKeys = partitionKeys.map( - KeyedPartitioning.projectKey(_, positions, projectedDataTypes) - ) - val projectedOriginalPartitionKeys = originalPartitionKeys.map( - KeyedPartitioning.projectKey(_, positions, projectedDataTypes) - ) - val internalRowComparableWrapperFactory = - InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(projectedDataTypes) - val distinctPartitionKeys = projectedPartitionKeys - .map(internalRowComparableWrapperFactory) - .distinct - .map(_.row) - - copy(expressions = projectedExpressions, partitionKeys = distinctPartitionKeys, - originalPartitionKeys = projectedOriginalPartitionKeys) + val projectedKeys = KeyedPartitioning.projectKeys(partitionKeys, positions, projectedDataTypes) + val projectedOriginalKeys = + KeyedPartitioning.projectKeys(originalPartitionKeys, positions, projectedDataTypes) + + (projectedExpressions, projectedDataTypes, projectedKeys, projectedOriginalKeys) } override def satisfies0(required: Distribution): Boolean = { @@ -491,7 +483,13 @@ case class KeyedPartitioning( // `KeyedPartitioning` here that is grouped on the join keys instead, and use that as // the returned shuffle spec. val joinKeyPositions = result.keyPositions.map(_.nonEmpty).zipWithIndex.filter(_._1).map(_._2) - val projectedPartitioning = projectAndGroup(joinKeyPositions) + val (projectedExpressions, projectedDataTypes, projectedKeys, projectedOriginalKeys) = + projectKeys(joinKeyPositions) + val projectedComparableWrapperFactory = + InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(projectedDataTypes) + val distinctPartitionKeys = projectedKeys.distinctBy(projectedComparableWrapperFactory) + val projectedPartitioning = copy(expressions = projectedExpressions, + partitionKeys = distinctPartitionKeys, originalPartitionKeys = projectedOriginalKeys) result.copy(partitioning = projectedPartitioning, joinKeyPositions = Some(joinKeyPositions)) } else { result @@ -1054,22 +1052,6 @@ case class KeyGroupedShuffleSpec( } } -object KeyGroupedShuffleSpec { - def reducePartitionKey( - row: InternalRow, - reducers: Seq[Option[Reducer[_, _]]], - dataTypes: Seq[DataType], - internalRowComparableWrapperFactory: InternalRow => InternalRowComparableWrapper - ): InternalRowComparableWrapper = { - val partitionKeys = row.toSeq(dataTypes) - val reducedRow = partitionKeys.zip(reducers).map{ - case (v, Some(reducer: Reducer[Any, Any])) => reducer.reduce(v) - case (v, _) => v - }.toArray - internalRowComparableWrapperFactory(new GenericInternalRow(reducedRow)) - } -} - case class ShufflePartitionIdPassThroughSpec( partitioning: ShufflePartitionIdPassThrough, distribution: ClusteredDistribution) extends ShuffleSpec { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala index 098540d0717fa..80dfc4c4cd38d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala @@ -61,9 +61,7 @@ case class GroupPartitionsExec( case p: Partitioning with Expression => p.transform { case k: KeyedPartitioning => - val (projectedExpressions, projectedDataTypes) = projectExpressions(k.expressions) - val projectedOriginalKeys = joinKeyPositions.fold(k.originalPartitionKeys)( - KeyedPartitioning.projectKeys(k.originalPartitionKeys, _, projectedDataTypes)) + val projectedExpressions = joinKeyPositions.fold(k.expressions)(_.map(k.expressions)) k.copy( expressions = projectedExpressions, partitionKeys = groupedPartitions.map(_._1), @@ -73,13 +71,6 @@ case class GroupPartitionsExec( } } - private def projectExpressions(expressions: Seq[Expression]): (Seq[Expression], Seq[DataType]) = { - val projectedExpressions = joinKeyPositions.fold(expressions)(_.map(expressions)) - val projectedDataTypes = projectedExpressions.map(_.dataType) - - (projectedExpressions, projectedDataTypes) - } - /** * Distributes partitions based on `commonPartitionKeys` and clustering mode. */ @@ -123,7 +114,7 @@ case class GroupPartitionsExec( * Returns a sequence of (partitionKey, inputPartitionIndices) pairs representing * how input partitions should be grouped together. */ - lazy val groupedPartitions: Seq[(InternalRow, Seq[Int])] = { + lazy val (groupedPartitions, projectedOriginalKeys) = { // Extract the KeyedPartitioning from child's output partitioning val keyedPartitioning = child.outputPartitioning .asInstanceOf[Partitioning with Expression] @@ -132,18 +123,16 @@ case class GroupPartitionsExec( throw new SparkException("GroupPartitionsExec requires a child with KeyedPartitioning")) // Project partition keys if join key positions are specified - val (projectedDataTypes, projectedKeys) = joinKeyPositions match { - case Some(positions) => - val (projectedExpressions, projectedDataTypes) = - projectExpressions(keyedPartitioning.expressions) - val projectedKeys = KeyedPartitioning.projectKeys( - keyedPartitioning.partitionKeys, positions, projectedDataTypes) - (projectedDataTypes, projectedKeys) - - case None => - val dataTypes = keyedPartitioning.expressions.map(_.dataType) - (dataTypes, keyedPartitioning.partitionKeys) - } + val (projectedDataTypes, projectedKeys, projectedOriginalKeys) = + joinKeyPositions match { + case Some(positions) => + val (_, projectedDataTypes, projectedKeys, projectedOriginalKeys) = + keyedPartitioning.projectKeys(positions) + (projectedDataTypes, projectedKeys, projectedOriginalKeys) + case None => + val dataTypes = keyedPartitioning.expressions.map(_.dataType) + (dataTypes, keyedPartitioning.partitionKeys, keyedPartitioning.originalPartitionKeys) + } // Reduce keys if reducers are specified val reducedKeys = reducers match { @@ -160,11 +149,13 @@ case class GroupPartitionsExec( case (key, _) => comparableWrapperFactory(key) }(_._2) - if (commonPartitionKeys.isDefined) { + val groupedPartitions = if (commonPartitionKeys.isDefined) { distributeByCommonKeys(keyWrapperToPartitionIndices, comparableWrapperFactory) } else { groupAndSortByKeys(keyWrapperToPartitionIndices, projectedDataTypes) } + + (groupedPartitions, projectedOriginalKeys) } override protected def doExecute(): RDD[InternalRow] = { From b8b2faa86dc8e80f374dfface97c097120e11ab4 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Tue, 24 Feb 2026 13:48:28 +0100 Subject: [PATCH 04/29] no need to sort partition keys when building grouped `KeyedPartitioning` --- .../sql/catalyst/plans/physical/partitioning.scala | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 728f43e072545..1042e3acf1de1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -419,15 +419,13 @@ case class KeyedPartitioning( @transient private lazy val rowOrdering = RowOrdering.createNaturalAscendingOrdering(dataTypes) - @transient lazy val isGrouped: Boolean = { - partitionKeys.map(comparableWrapperFactory).distinct.size == partitionKeys.size - } + @transient lazy val isGrouped: Boolean = + partitionKeys.distinctBy(comparableWrapperFactory).size == partitionKeys.size def toGrouped: KeyedPartitioning = { - val distinctSortedPartitionKeys = - partitionKeys.distinctBy(comparableWrapperFactory).sorted(rowOrdering) + val groupedPartitionKeys = partitionKeys.distinctBy(comparableWrapperFactory) - KeyedPartitioning(expressions, distinctSortedPartitionKeys, originalPartitionKeys) + KeyedPartitioning(expressions, groupedPartitionKeys) } /** From 48735135a255720e7a14403f0167c2743dd37cd9 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Tue, 24 Feb 2026 16:24:46 +0100 Subject: [PATCH 05/29] fix `BatchScanExec.equals()` --- .../spark/sql/execution/datasources/v2/BatchScanExec.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index a24c90c7a8502..e9b60efb9e8ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -48,7 +48,8 @@ case class BatchScanExec( override def equals(other: Any): Boolean = other match { case other: BatchScanExec => this.batch != null && this.batch == other.batch && - this.runtimeFilters == other.runtimeFilters + this.runtimeFilters == other.runtimeFilters && + this.keyGroupedPartitioning == other.keyGroupedPartitioning case _ => false } From 6f7b980bab76bb07a9a7132097a3d2b52e5167e9 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Tue, 24 Feb 2026 18:40:57 +0100 Subject: [PATCH 06/29] remove `originalPartitionKeys` from `KeyedPartitioning` because it is always available as `outputPartitioning` of children below `GroupPartitionsExec`s (on the contrary to the `BatchScanExec` based solution where we had to save it explicitly); add more documentation on when keys are sorted, minor refactorings --- .../plans/physical/partitioning.scala | 106 ++++++++------- ...nternalRowComparableWrapperBenchmark.scala | 8 +- .../v2/DataSourceV2ScanExecBase.scala | 2 +- .../datasources/v2/GroupPartitionsExec.scala | 26 ++-- .../exchange/EnsureRequirements.scala | 67 +++++----- .../exchange/ShuffleExchangeExec.scala | 4 +- .../DistributionAndOrderingSuiteBase.scala | 5 +- .../KeyGroupedPartitioningSuite.scala | 6 +- .../exchange/EnsureRequirementsSuite.scala | 126 ++++++++---------- 9 files changed, 165 insertions(+), 185 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 1042e3acf1de1..919d4cae849cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -351,25 +351,37 @@ case class CoalescedHashPartitioning(from: HashPartitioning, partitions: Seq[Coa * Represents a partitioning where rows are split across partitions based on transforms defined by * `expressions`. * - * == Partition Keys == - * This partitioning has two sets of partition keys: + * == Usage Forms == + * `KeyedPartitioning` is used in two distinct forms: + * + * 1. '''As outputPartitioning''': When used as a node's output partitioning (e.g., in + * `BatchScanExec` or `GroupPartitionsExec`), the `partitionKeys` are always in sorted order. + * This is how leaf data source nodes produce partition keys originally, and this ordering is + * preserved through `GroupPartitionsExec`. The sorted order is critical for storage-partitioned + * join compatibility. * - * - `partitionKeys`: The current partition key for each partition, in ascending order. May contain - * duplicates when first created from a data source, but becomes unique after grouping. + * 2. '''In KeyGroupedShuffleSpec''': When used within `KeyGroupedShuffleSpec`, the `partitionKeys` + * may not be in sorted order. This occurs because `KeyGroupedShuffleSpec` can project the + * partition keys by join key positions (see `projectKeys` method), reordering them to match the + * join key order rather than the original sorted partition key order. The `EnsureRequirements` + * rule ensures that either the unordered keys from both sides of a join match exactly, or it + * builds a common ordered set of keys and pushes them down to `GroupPartitionsExec` on both + * sides to establish a compatible ordering. * - * - `originalPartitionKeys`: The original partition keys from the data source, in ascending order. - * Always preserves the original values, even after grouping. Used to track the original - * distribution for optimization purposes. + * == Partition Keys == + * - `partitionKeys`: The partition keys, one per partition. May contain duplicates initially + * (ungrouped state), but becomes unique after `GroupPartitionsExec` applies grouping. * * == Grouping State == * A KeyedPartitioning can be in two states: * - * - '''Ungrouped''' (when `isGrouped == false`): `partitionKeys` contains duplicates. Multiple - * input partitions share the same key. This is the initial state when created from a data source. + * - '''Ungrouped''' (when `isGrouped == false`): `partitionKeys` contains duplicates, meaning + * multiple input partitions share the same key. This occurs when a data source has multiple + * splits for the same partition value. * - * - '''Grouped''' (when `isGrouped == true`): `partitionKeys` contains only unique values. Each - * partition has a distinct key. This state is achieved by applying `GroupPartitionsExec`, which - * coalesces partitions with the same key. + * - '''Grouped''' (when `isGrouped == true`): `partitionKeys` contains only unique values, with + * each partition having a distinct key. This occurs when: (1) a data source natively produces + * unique partition keys, or (2) `GroupPartitionsExec` coalesces partitions with duplicate keys. * * == Example == * Consider a data source with partition transform `[years(ts_col)]` and 4 input splits: @@ -377,8 +389,7 @@ case class CoalescedHashPartitioning(from: HashPartitioning, partitions: Seq[Coa * '''Before GroupPartitionsExec''' (ungrouped): * {{{ * expressions: [years(ts_col)] - * partitionKeys: [0, 1, 2, 2] // partition 2 and 3 have the same key - * originalPartitionKeys: [0, 1, 2, 2] + * partitionKeys: [0, 1, 2, 2] // partitions 2 and 3 have the same key * numPartitions: 4 * isGrouped: false * }}} @@ -387,21 +398,18 @@ case class CoalescedHashPartitioning(from: HashPartitioning, partitions: Seq[Coa * {{{ * expressions: [years(ts_col)] * partitionKeys: [0, 1, 2] // duplicates removed, partitions coalesced - * originalPartitionKeys: [0, 1, 2, 2] // unchanged, preserves original distribution * numPartitions: 3 * isGrouped: true * }}} * * @param expressions Partition transform expressions (e.g., `years(col)`, `bucket(10, col)`). - * @param partitionKeys Current partition keys, one per partition, in ascending order. - * May contain duplicates before grouping. - * @param originalPartitionKeys Original partition keys from the data source, in ascending order. - * Preserves the initial distribution even after grouping. + * @param partitionKeys Partition keys, one per partition. When used as outputPartitioning, + * always in sorted order. When used in KeyGroupedShuffleSpec, may be + * unsorted after projection. May contain duplicates when ungrouped. */ case class KeyedPartitioning( expressions: Seq[Expression], - partitionKeys: Seq[InternalRow], - originalPartitionKeys: Seq[InternalRow]) extends Expression with Partitioning with Unevaluable { + partitionKeys: Seq[InternalRow]) extends Expression with Partitioning with Unevaluable { override val numPartitions = partitionKeys.length override def children: Seq[Expression] = expressions @@ -412,18 +420,16 @@ case class KeyedPartitioning( newChildren: IndexedSeq[Expression]): KeyedPartitioning = copy(expressions = newChildren) - @transient private lazy val dataTypes: Seq[DataType] = expressions.map(_.dataType) + @transient lazy val expressionDataTypes: Seq[DataType] = expressions.map(_.dataType) - @transient private lazy val comparableWrapperFactory = - InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(dataTypes) - - @transient private lazy val rowOrdering = RowOrdering.createNaturalAscendingOrdering(dataTypes) + @transient lazy val keysComparableWrapperFactory = + InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(expressionDataTypes) @transient lazy val isGrouped: Boolean = - partitionKeys.distinctBy(comparableWrapperFactory).size == partitionKeys.size + partitionKeys.distinctBy(keysComparableWrapperFactory).size == partitionKeys.size def toGrouped: KeyedPartitioning = { - val groupedPartitionKeys = partitionKeys.distinctBy(comparableWrapperFactory) + val groupedPartitionKeys = partitionKeys.distinctBy(keysComparableWrapperFactory) KeyedPartitioning(expressions, groupedPartitionKeys) } @@ -432,15 +438,21 @@ case class KeyedPartitioning( * Projects this partitioning's expressions by selecting only the specified positions. * Returns both the projected expressions and their data types. */ - def projectKeys(positions: Seq[Int]): - (Seq[Expression], Seq[DataType], Seq[InternalRow], Seq[InternalRow]) = { + def projectKeys(positions: Seq[Int]): (Seq[Expression], Seq[DataType], Seq[InternalRow]) = { val projectedExpressions = positions.map(expressions) val projectedDataTypes = projectedExpressions.map(_.dataType) val projectedKeys = KeyedPartitioning.projectKeys(partitionKeys, positions, projectedDataTypes) - val projectedOriginalKeys = - KeyedPartitioning.projectKeys(originalPartitionKeys, positions, projectedDataTypes) - (projectedExpressions, projectedDataTypes, projectedKeys, projectedOriginalKeys) + (projectedExpressions, projectedDataTypes, projectedKeys) + } + + /** + * Reduces this partitioning's partition keys by applying the given reducers. + * Returns the distinct reduced keys. + */ + def reduceKeys(reducers: Seq[Option[Reducer[_, _]]]): Seq[InternalRow] = { + KeyedPartitioning.reduceKeys(partitionKeys, reducers, expressionDataTypes) + .distinctBy(keysComparableWrapperFactory) } override def satisfies0(required: Distribution): Boolean = { @@ -481,13 +493,12 @@ case class KeyedPartitioning( // `KeyedPartitioning` here that is grouped on the join keys instead, and use that as // the returned shuffle spec. val joinKeyPositions = result.keyPositions.map(_.nonEmpty).zipWithIndex.filter(_._1).map(_._2) - val (projectedExpressions, projectedDataTypes, projectedKeys, projectedOriginalKeys) = - projectKeys(joinKeyPositions) + val (projectedExpressions, projectedDataTypes, projectedKeys) = projectKeys(joinKeyPositions) val projectedComparableWrapperFactory = InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(projectedDataTypes) - val distinctPartitionKeys = projectedKeys.distinctBy(projectedComparableWrapperFactory) - val projectedPartitioning = copy(expressions = projectedExpressions, - partitionKeys = distinctPartitionKeys, originalPartitionKeys = projectedOriginalKeys) + val distinctProjectedKeys = projectedKeys.distinctBy(projectedComparableWrapperFactory) + val projectedPartitioning = + copy(expressions = projectedExpressions, partitionKeys = distinctProjectedKeys) result.copy(partitioning = projectedPartitioning, joinKeyPositions = Some(joinKeyPositions)) } else { result @@ -496,22 +507,16 @@ case class KeyedPartitioning( override def equals(that: Any): Boolean = that match { case k: KeyedPartitioning if this.expressions == k.expressions => - def keysEqual(keys1: Seq[InternalRow], keys2: Seq[InternalRow]): Boolean = { - keys1.size == keys2.size && keys1.zip(keys2).forall { case (l, r) => - comparableWrapperFactory(l).equals(comparableWrapperFactory(r)) + partitionKeys.size == k.partitionKeys.size && + partitionKeys.zip(k.partitionKeys).forall { case (l, r) => + keysComparableWrapperFactory(l).equals(keysComparableWrapperFactory(r)) } - } - - keysEqual(partitionKeys, k.partitionKeys) && - keysEqual(originalPartitionKeys, k.originalPartitionKeys) case _ => false } - override def hashCode(): Int = { - Objects.hash(expressions, partitionKeys.map(comparableWrapperFactory), - originalPartitionKeys.map(comparableWrapperFactory)) - } + override def hashCode(): Int = + Objects.hash(expressions, partitionKeys.map(keysComparableWrapperFactory)) } object KeyedPartitioning { @@ -1045,8 +1050,7 @@ case class KeyGroupedShuffleSpec( te.copy(children = te.children.map(_ => clustering(positionSet.head))) case (_, positionSet) => clustering(positionSet.head) } - KeyedPartitioning(newExpressions, partitioning.partitionKeys, - partitioning.originalPartitionKeys) + KeyedPartitioning(newExpressions, partitioning.partitionKeys) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapperBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapperBenchmark.scala index 87cb212253f32..a96e58727bd15 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapperBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapperBenchmark.scala @@ -41,7 +41,7 @@ object InternalRowComparableWrapperBenchmark extends BenchmarkBase { val partitionNum = 200_000 val bucketNum = 4096 val day = 20240401 - val partitions = (0 until partitionNum).map { i => + val partitionKeys = (0 until partitionNum).map { i => val bucketId = i % bucketNum PartitionInternalRow.apply(Array(day, bucketId)); } @@ -51,7 +51,7 @@ object InternalRowComparableWrapperBenchmark extends BenchmarkBase { val internalRowComparableWrapperFactory = InternalRowComparableWrapper.getInternalRowComparableWrapperFactory( Seq(IntegerType, IntegerType)) - val distinct = partitions + val distinct = partitionKeys .map(internalRowComparableWrapperFactory) .toSet assert(distinct.size == bucketNum) @@ -61,8 +61,8 @@ object InternalRowComparableWrapperBenchmark extends BenchmarkBase { // just to mock the data types val expressions = (Seq(Literal(day, IntegerType), Literal(0, IntegerType))) - val leftPartitioning = KeyedPartitioning(expressions, partitions, partitions) - val rightPartitioning = KeyedPartitioning(expressions, partitions, partitions) + val leftPartitioning = KeyedPartitioning(expressions, partitionKeys) + val rightPartitioning = KeyedPartitioning(expressions, partitionKeys) val merged = InternalRowComparableWrapper.mergePartitions( leftPartitioning.partitionKeys, rightPartitioning.partitionKeys, expressions) assert(merged.size == bucketNum) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala index ac993e18876a9..c4a59df5e1cb9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala @@ -97,7 +97,7 @@ trait DataSourceV2ScanExecBase extends LeafExecNode { val rowOrdering = RowOrdering.createNaturalAscendingOrdering(dataTypes) val partitionKeys = inputPartitions.map(_.asInstanceOf[HasPartitionKey].partitionKey()).sorted(rowOrdering) - KeyedPartitioning(exprs, partitionKeys, partitionKeys) + KeyedPartitioning(exprs, partitionKeys) case _ => super.outputPartitioning } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala index 80dfc4c4cd38d..07e72b41ae2a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala @@ -61,11 +61,11 @@ case class GroupPartitionsExec( case p: Partitioning with Expression => p.transform { case k: KeyedPartitioning => + // There can be multiple `KeyedPartitioning` in an output partitioning of a join, but + // they can only differ in `expressions`. `partitionKeys` must match so we can calculate + // it only once via `groupedPartitions`. val projectedExpressions = joinKeyPositions.fold(k.expressions)(_.map(k.expressions)) - k.copy( - expressions = projectedExpressions, - partitionKeys = groupedPartitions.map(_._1), - originalPartitionKeys = projectedOriginalKeys) + k.copy(expressions = projectedExpressions, partitionKeys = groupedPartitions.map(_._1)) }.asInstanceOf[Partitioning] case o => o } @@ -114,8 +114,9 @@ case class GroupPartitionsExec( * Returns a sequence of (partitionKey, inputPartitionIndices) pairs representing * how input partitions should be grouped together. */ - lazy val (groupedPartitions, projectedOriginalKeys) = { - // Extract the KeyedPartitioning from child's output partitioning + lazy val groupedPartitions = { + // There must be a `KeyedPartitioning` in child's output partitioning as a + // `GroupPartitionsExec` node is added to a plan only in that case. val keyedPartitioning = child.outputPartitioning .asInstanceOf[Partitioning with Expression] .collectFirst { case k: KeyedPartitioning => k } @@ -123,15 +124,14 @@ case class GroupPartitionsExec( throw new SparkException("GroupPartitionsExec requires a child with KeyedPartitioning")) // Project partition keys if join key positions are specified - val (projectedDataTypes, projectedKeys, projectedOriginalKeys) = + val (projectedDataTypes, projectedKeys) = joinKeyPositions match { case Some(positions) => - val (_, projectedDataTypes, projectedKeys, projectedOriginalKeys) = - keyedPartitioning.projectKeys(positions) - (projectedDataTypes, projectedKeys, projectedOriginalKeys) + val (_, projectedDataTypes, projectedKeys) = keyedPartitioning.projectKeys(positions) + (projectedDataTypes, projectedKeys) case None => val dataTypes = keyedPartitioning.expressions.map(_.dataType) - (dataTypes, keyedPartitioning.partitionKeys, keyedPartitioning.originalPartitionKeys) + (dataTypes, keyedPartitioning.partitionKeys) } // Reduce keys if reducers are specified @@ -149,13 +149,11 @@ case class GroupPartitionsExec( case (key, _) => comparableWrapperFactory(key) }(_._2) - val groupedPartitions = if (commonPartitionKeys.isDefined) { + if (commonPartitionKeys.isDefined) { distributeByCommonKeys(keyWrapperToPartitionIndices, comparableWrapperFactory) } else { groupAndSortByKeys(keyWrapperToPartitionIndices, projectedDataTypes) } - - (groupedPartitions, projectedOriginalKeys) } override protected def doExecute(): RDD[InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 8b6379ebbf662..12c25dda64736 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -204,7 +204,7 @@ case class EnsureRequirements( case ((child, dist), idx) => if (bestSpecOpt.isDefined && bestSpecOpt.get.isCompatibleWith(specs(idx))) { bestSpecOpt match { - // If keyGroupCompatible = false, we can still perform SPJ + // If `areChildrenCompatible` is false, we can still perform SPJ // by shuffling the other side based on join keys (see the else case below). // Hence we need to ensure that after this call, the outputPartitioning of the // partitioned side's BatchScanExec is grouped by join keys to match, @@ -354,12 +354,12 @@ case class EnsureRequirements( reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, rightExpressions, rightKeys) .orElse(reorderJoinKeysRecursively( leftKeys, rightKeys, leftPartitioning, None)) - case (Some(KeyedPartitioning(clustering, _, _)), _) => + case (Some(KeyedPartitioning(clustering, _)), _) => val leafExprs = clustering.flatMap(_.collectLeaves()) reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, leftKeys) .orElse(reorderJoinKeysRecursively( leftKeys, rightKeys, None, rightPartitioning)) - case (_, Some(KeyedPartitioning(clustering, _, _))) => + case (_, Some(KeyedPartitioning(clustering, _))) => val leafExprs = clustering.flatMap(_.collectLeaves()) reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, rightKeys) .orElse(reorderJoinKeysRecursively( @@ -465,17 +465,22 @@ case class EnsureRequirements( val leftSpec = specs.head val rightSpec = specs(1) - def contains(partitioning: Partitioning, keyedPartitioning: KeyedPartitioning): Boolean = { + def containsPartitioning( + partitioning: Partitioning, + keyedPartitioning: KeyedPartitioning): Boolean = { partitioning match { case k: KeyedPartitioning => k == keyedPartitioning case PartitioningCollection(partitionings) => - partitionings.exists(contains(_, keyedPartitioning)) + partitionings.exists(containsPartitioning(_, keyedPartitioning)) case _ => false } } - var isCompatible = contains(left.outputPartitioning, leftSpec.partitioning) && - contains(right.outputPartitioning, rightSpec.partitioning) && + // We don't need to add alter or add any `GroupPartitionsExec` when the child partitionings are + // not modified (projected) in specs and left and right side partitionings are compatible with + // each other. + var isCompatible = containsPartitioning(left.outputPartitioning, leftSpec.partitioning) && + containsPartitioning(right.outputPartitioning, rightSpec.partitioning) && leftSpec.isCompatibleWith(rightSpec) if ((!isCompatible || conf.v2BucketingPartiallyClusteredDistributionEnabled) && (conf.v2BucketingPushPartValuesEnabled || @@ -517,16 +522,15 @@ case class EnsureRequirements( // in case of compatible but not identical partition expressions, we apply 'reduce' // transforms to group one side's partitions as well as the common partition values val leftReducers = leftSpec.reducers(rightSpec) - val leftParts = reducePartitionKeys(leftSpec.partitioning.partitionKeys, - partitionExprs, - leftReducers) + val leftReducedKeys = leftReducers.fold(leftSpec.partitioning.partitionKeys)( + leftSpec.partitioning.reduceKeys) val rightReducers = rightSpec.reducers(leftSpec) - val rightParts = reducePartitionKeys(rightSpec.partitioning.partitionKeys, - partitionExprs, - rightReducers) + val rightReducedKeys = rightReducers.fold(rightSpec.partitioning.partitionKeys)( + rightSpec.partitioning.reduceKeys) // merge values on both sides - var mergedPartitionKeys = mergePartitions(leftParts, rightParts, partitionExprs, joinType) + var mergedPartitionKeys = + mergePartitions(leftReducedKeys, rightReducedKeys, partitionExprs, joinType) .map(v => (v, 1)) logInfo(log"After merging, there are " + @@ -608,13 +612,23 @@ case class EnsureRequirements( replicateRightSide = false } else { // In partially clustered distribution, we should use un-grouped partition values - val spec = if (replicateLeftSide) rightSpec else leftSpec - val originalPartitionKeys = spec.partitioning.originalPartitionKeys + val (replicatedChild, replicatedSpec) = if (replicateLeftSide) { + (right, rightSpec) + } else { + (left, leftSpec) + } + val originalPartitioning = (replicatedChild match { + case g: GroupPartitionsExec => g.child + case o => o + }).outputPartitioning.asInstanceOf[KeyedPartitioning] + val dataTypes = partitionExprs.map(_.dataType) + val projectedOriginalPartitionKeys = + replicatedSpec.joinKeyPositions.fold(originalPartitioning.partitionKeys)( + KeyedPartitioning.projectKeys(originalPartitioning.partitionKeys, _, dataTypes)) val internalRowComparableWrapperFactory = - InternalRowComparableWrapper.getInternalRowComparableWrapperFactory( - partitionExprs.map(_.dataType)) + InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(dataTypes) - val numExpectedPartitions = originalPartitionKeys + val numExpectedPartitions = projectedOriginalPartitionKeys .groupBy(internalRowComparableWrapperFactory) .view.mapValues(_.size) @@ -711,21 +725,6 @@ case class EnsureRequirements( } } - private def reducePartitionKeys( - partitionKeys: Seq[InternalRow], - expressions: Seq[Expression], - reducers: Option[Seq[Option[Reducer[_, _]]]]) = { - reducers match { - case Some(reducers) => - val dataTypes = expressions.map(_.dataType) - val internalRowComparableWrapperFactory = - InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(dataTypes) - KeyedPartitioning.reduceKeys(partitionKeys, reducers, dataTypes) - .distinctBy(internalRowComparableWrapperFactory) - case _ => partitionKeys - } - } - /** * Tries to create a [[KeyGroupedShuffleSpec]] from the input partitioning and distribution, if * the partitioning is a [[KeyedPartitioning]] (either directly or indirectly), and diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index f50f6f484ac50..eef3f87f3fe03 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -370,7 +370,7 @@ object ShuffleExchangeExec { ascending = true, samplePointsPerPartitionHint = SQLConf.get.rangeExchangeSampleSizePerPartition) case SinglePartition => new ConstantPartitioner - case k @ KeyedPartitioning(expressions, _, _) => + case k @ KeyedPartitioning(expressions, _) => val keyGroupedPartitioning = k.toGrouped val valueMap = keyGroupedPartitioning.partitionKeys.zipWithIndex.map { case (partition, index) => (partition.toSeq(expressions.map(_.dataType)), index) @@ -403,7 +403,7 @@ object ShuffleExchangeExec { val projection = UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes) row => projection(row) case SinglePartition => identity - case KeyedPartitioning(expressions, _, _) => + case KeyedPartitioning(expressions, _) => row => bindReferences(expressions, outputAttributes).map(_.eval(row)) case s: ShufflePartitionIdPassThrough => // For ShufflePartitionIdPassThrough, the expression directly evaluates to the partition ID diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala index cf9133b0835d3..37394c3b071c1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala @@ -51,9 +51,8 @@ abstract class DistributionAndOrderingSuiteBase plan: QueryPlan[T]): Partitioning = partitioning match { case HashPartitioning(exprs, numPartitions) => HashPartitioning(exprs.map(resolveAttrs(_, plan)), numPartitions) - case KeyedPartitioning(expressions, partitionKeys, originalPartitionKeys) => - KeyedPartitioning(expressions.map(resolveAttrs(_, plan)), partitionKeys, - originalPartitionKeys) + case KeyedPartitioning(expressions, partitionKeys) => + KeyedPartitioning(expressions.map(resolveAttrs(_, plan)), partitionKeys) case PartitioningCollection(partitionings) => PartitioningCollection(partitionings.map(resolvePartitioning(_, plan))) case RangePartitioning(ordering, numPartitions) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 869fc6a0cca6e..ebb4598150c01 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -91,12 +91,12 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val partitionKeys = Seq(50L, 51L, 52L).map(v => InternalRow.fromSeq(Seq(v))) checkQueryPlan(df, catalystDistribution, - physical.KeyedPartitioning(catalystDistribution.clustering, partitionKeys, partitionKeys)) + physical.KeyedPartitioning(catalystDistribution.clustering, partitionKeys)) // multiple group keys should work too as long as partition keys are subset of them df = sql(s"SELECT count(*) FROM testcat.ns.$table GROUP BY id, ts") checkQueryPlan(df, catalystDistribution, - physical.KeyedPartitioning(catalystDistribution.clustering, partitionKeys, partitionKeys)) + physical.KeyedPartitioning(catalystDistribution.clustering, partitionKeys)) } test("non-clustered distribution: no partition") { @@ -122,7 +122,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { // Has exactly one partition. val partitionKeys = Seq(0).map(v => InternalRow.fromSeq(Seq(v))) checkQueryPlan(df, distribution, - physical.KeyedPartitioning(distribution.clustering, partitionKeys, partitionKeys)) + physical.KeyedPartitioning(distribution.clustering, partitionKeys)) } test("non-clustered distribution: no V2 catalog") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala index 3ad97be1622e2..0e5581d82a099 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala @@ -94,12 +94,12 @@ class EnsureRequirementsSuite extends SharedSparkSession { test("reorder should handle KeyedPartitioning") { // partitioning on the left val plan1 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyedPartitioning( - Seq(years(exprA), bucket(4, exprB), days(exprC)), Seq.empty, Seq.empty) + outputPartitioning = + KeyedPartitioning(Seq(years(exprA), bucket(4, exprB), days(exprC)), Seq.empty) ) val plan2 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyedPartitioning(Seq( - years(exprB), bucket(4, exprA), days(exprD)), Seq.empty, Seq.empty) + outputPartitioning = + KeyedPartitioning(Seq(years(exprB), bucket(4, exprA), days(exprD)), Seq.empty) ) val smjExec = SortMergeJoinExec( exprB :: exprC :: exprA :: Nil, exprA :: exprD :: exprB :: Nil, @@ -116,8 +116,8 @@ class EnsureRequirementsSuite extends SharedSparkSession { // partitioning on the right val plan3 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyedPartitioning(Seq( - bucket(4, exprD), days(exprA), years(exprC)), Seq.empty, Seq.empty) + outputPartitioning = + KeyedPartitioning(Seq(bucket(4, exprD), days(exprA), years(exprC)), Seq.empty) ) val smjExec2 = SortMergeJoinExec( exprB :: exprD :: exprC :: Nil, exprA :: exprC :: exprD :: Nil, @@ -780,9 +780,9 @@ class EnsureRequirementsSuite extends SharedSparkSession { test("Check with KeyedPartitioning") { // simplest case: identity transforms var plan1 = new DummySparkPlanWithBatchScanChild( - KeyedPartitioning(exprA :: exprB :: Nil, Seq.empty, Seq.empty)) + KeyedPartitioning(exprA :: exprB :: Nil, Seq.empty)) var plan2 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyedPartitioning(exprA :: exprC :: Nil, Seq.empty, Seq.empty)) + outputPartitioning = KeyedPartitioning(exprA :: exprC :: Nil, Seq.empty)) var smjExec = SortMergeJoinExec( exprA :: exprB :: Nil, exprA :: exprC :: Nil, Inner, None, plan1, plan2) EnsureRequirements.apply(smjExec) match { @@ -796,12 +796,12 @@ class EnsureRequirementsSuite extends SharedSparkSession { // matching bucket transforms from both sides plan1 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyedPartitioning( - bucket(4, exprA) :: bucket(16, exprB) :: Nil, Seq.empty, Seq.empty) + outputPartitioning = + KeyedPartitioning(bucket(4, exprA) :: bucket(16, exprB) :: Nil, Seq.empty) ) plan2 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyedPartitioning( - bucket(4, exprA) :: bucket(16, exprC) :: Nil, Seq.empty, Seq.empty) + outputPartitioning = + KeyedPartitioning(bucket(4, exprA) :: bucket(16, exprC) :: Nil, Seq.empty) ) smjExec = SortMergeJoinExec( exprA :: exprB :: Nil, exprA :: exprC :: Nil, Inner, None, plan1, plan2) @@ -816,13 +816,13 @@ class EnsureRequirementsSuite extends SharedSparkSession { // partition collections plan1 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyedPartitioning( - bucket(4, exprA) :: bucket(16, exprB) :: Nil, Seq.empty, Seq.empty) + outputPartitioning = + KeyedPartitioning(bucket(4, exprA) :: bucket(16, exprB) :: Nil, Seq.empty) ) plan2 = new DummySparkPlanWithBatchScanChild( outputPartitioning = PartitioningCollection(Seq( - KeyedPartitioning(bucket(4, exprA) :: bucket(16, exprC) :: Nil, Seq.empty, Seq.empty), - KeyedPartitioning(bucket(4, exprA) :: bucket(16, exprC) :: Nil, Seq.empty, Seq.empty)) + KeyedPartitioning(bucket(4, exprA) :: bucket(16, exprC) :: Nil, Seq.empty), + KeyedPartitioning(bucket(4, exprA) :: bucket(16, exprC) :: Nil, Seq.empty)) ) ) smjExec = SortMergeJoinExec( @@ -846,12 +846,10 @@ class EnsureRequirementsSuite extends SharedSparkSession { // bucket + years transforms from both sides plan1 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyedPartitioning(bucket(4, exprA) :: years(exprB) :: Nil, - Seq.empty, Seq.empty) + outputPartitioning = KeyedPartitioning(bucket(4, exprA) :: years(exprB) :: Nil, Seq.empty) ) plan2 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyedPartitioning(bucket(4, exprA) :: years(exprC) :: Nil, - Seq.empty, Seq.empty) + outputPartitioning = KeyedPartitioning(bucket(4, exprA) :: years(exprC) :: Nil, Seq.empty) ) smjExec = SortMergeJoinExec( exprA :: exprB :: Nil, exprA :: exprC :: Nil, Inner, None, plan1, plan2) @@ -867,12 +865,10 @@ class EnsureRequirementsSuite extends SharedSparkSession { // by default spark.sql.requireAllClusterKeysForCoPartition is true, so when there isn't // exact match on all partition keys, Spark will fallback to shuffle. plan1 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyedPartitioning( - bucket(4, exprA) :: bucket(4, exprB) :: Nil, Seq.empty, Seq.empty) + outputPartitioning = KeyedPartitioning(bucket(4, exprA) :: bucket(4, exprB) :: Nil, Seq.empty) ) plan2 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyedPartitioning( - bucket(4, exprA) :: bucket(4, exprC) :: Nil, Seq.empty, Seq.empty) + outputPartitioning = KeyedPartitioning(bucket(4, exprA) :: bucket(4, exprC) :: Nil, Seq.empty) ) smjExec = SortMergeJoinExec( exprA :: exprB :: exprB :: Nil, exprA :: exprC :: exprC :: Nil, Inner, None, plan1, plan2) @@ -888,12 +884,10 @@ class EnsureRequirementsSuite extends SharedSparkSession { test(s"KeyedPartitioning with ${REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key} = false") { var plan1 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyedPartitioning( - bucket(4, exprB) :: years(exprC) :: Nil, Seq.empty, Seq.empty) + outputPartitioning = KeyedPartitioning(bucket(4, exprB) :: years(exprC) :: Nil, Seq.empty) ) var plan2 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyedPartitioning( - bucket(4, exprC) :: years(exprB) :: Nil, Seq.empty, Seq.empty) + outputPartitioning = KeyedPartitioning(bucket(4, exprC) :: years(exprB) :: Nil, Seq.empty) ) // simple case @@ -910,12 +904,10 @@ class EnsureRequirementsSuite extends SharedSparkSession { // should also work with distributions with duplicated keys plan1 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyedPartitioning( - bucket(4, exprA) :: years(exprB) :: Nil, Seq.empty, Seq.empty) + outputPartitioning = KeyedPartitioning(bucket(4, exprA) :: years(exprB) :: Nil, Seq.empty) ) plan2 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyedPartitioning( - bucket(4, exprA) :: years(exprC) :: Nil, Seq.empty, Seq.empty) + outputPartitioning = KeyedPartitioning(bucket(4, exprA) :: years(exprC) :: Nil, Seq.empty) ) smjExec = SortMergeJoinExec( exprA :: exprB :: exprB :: Nil, exprA :: exprC :: exprC :: Nil, Inner, None, plan1, plan2) @@ -930,11 +922,11 @@ class EnsureRequirementsSuite extends SharedSparkSession { // both partitioning and distribution have duplicated keys plan1 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyedPartitioning( - years(exprA) :: bucket(4, exprB) :: days(exprA) :: Nil, Seq.empty, Seq.empty)) + outputPartitioning = + KeyedPartitioning(years(exprA) :: bucket(4, exprB) :: days(exprA) :: Nil, Seq.empty)) plan2 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyedPartitioning( - years(exprA) :: bucket(4, exprC) :: days(exprA) :: Nil, Seq.empty, Seq.empty)) + outputPartitioning = + KeyedPartitioning(years(exprA) :: bucket(4, exprC) :: days(exprA) :: Nil, Seq.empty)) smjExec = SortMergeJoinExec( exprA :: exprB :: exprB :: Nil, exprA :: exprC :: exprC :: Nil, Inner, None, plan1, plan2) applyEnsureRequirementsWithSubsetKeys(smjExec) match { @@ -948,12 +940,10 @@ class EnsureRequirementsSuite extends SharedSparkSession { // invalid case: partitioning key positions don't match plan1 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyedPartitioning( - bucket(4, exprA) :: bucket(4, exprB) :: Nil, Seq.empty, Seq.empty) + outputPartitioning = KeyedPartitioning(bucket(4, exprA) :: bucket(4, exprB) :: Nil, Seq.empty) ) plan2 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyedPartitioning( - bucket(4, exprB) :: bucket(4, exprC) :: Nil, Seq.empty, Seq.empty) + outputPartitioning = KeyedPartitioning(bucket(4, exprB) :: bucket(4, exprC) :: Nil, Seq.empty) ) smjExec = SortMergeJoinExec( @@ -969,12 +959,10 @@ class EnsureRequirementsSuite extends SharedSparkSession { // invalid case: different number of buckets (we don't support coalescing/repartitioning yet) plan1 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyedPartitioning( - bucket(4, exprA) :: bucket(4, exprB) :: Nil, Seq.empty, Seq.empty) + outputPartitioning = KeyedPartitioning(bucket(4, exprA) :: bucket(4, exprB) :: Nil, Seq.empty) ) plan2 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyedPartitioning( - bucket(4, exprA) :: bucket(8, exprC) :: Nil, Seq.empty, Seq.empty) + outputPartitioning = KeyedPartitioning(bucket(4, exprA) :: bucket(8, exprC) :: Nil, Seq.empty) ) smjExec = SortMergeJoinExec( exprA :: exprB :: exprB :: Nil, exprA :: exprC :: exprC :: Nil, Inner, None, plan1, plan2) @@ -989,12 +977,10 @@ class EnsureRequirementsSuite extends SharedSparkSession { // invalid case: partition key positions match but with different transforms plan1 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyedPartitioning(years(exprA) :: bucket(4, exprB) :: Nil, - Seq.empty, Seq.empty) + outputPartitioning = KeyedPartitioning(years(exprA) :: bucket(4, exprB) :: Nil, Seq.empty) ) plan2 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyedPartitioning(days(exprA) :: bucket(4, exprC) :: Nil, - Seq.empty, Seq.empty) + outputPartitioning = KeyedPartitioning(days(exprA) :: bucket(4, exprC) :: Nil, Seq.empty) ) smjExec = SortMergeJoinExec( exprA :: exprB :: exprB :: Nil, exprA :: exprC :: exprC :: Nil, Inner, None, plan1, plan2) @@ -1010,12 +996,12 @@ class EnsureRequirementsSuite extends SharedSparkSession { // invalid case: multiple references in transform plan1 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyedPartitioning( - years(exprA) :: buckets(4, Seq(exprB, exprC)) :: Nil, Seq.empty, Seq.empty) + outputPartitioning = + KeyedPartitioning(years(exprA) :: buckets(4, Seq(exprB, exprC)) :: Nil, Seq.empty) ) plan2 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyedPartitioning( - years(exprA) :: buckets(4, Seq(exprB, exprC)) :: Nil, Seq.empty, Seq.empty) + outputPartitioning = + KeyedPartitioning(years(exprA) :: buckets(4, Seq(exprB, exprC)) :: Nil, Seq.empty) ) smjExec = SortMergeJoinExec( exprA :: exprB :: exprB :: Nil, exprA :: exprC :: exprC :: Nil, Inner, None, plan1, plan2) @@ -1036,12 +1022,12 @@ class EnsureRequirementsSuite extends SharedSparkSession { .map(new GenericInternalRow(_)) var plan1 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyedPartitioning(bucket(4, exprB) :: bucket(8, exprC) :: Nil, - leftPartValues, leftPartValues) + outputPartitioning = + KeyedPartitioning(bucket(4, exprB) :: bucket(8, exprC) :: Nil, leftPartValues) ) var plan2 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = KeyedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: Nil, - rightPartValues, rightPartValues) + outputPartitioning = + KeyedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: Nil, rightPartValues) ) // simple case @@ -1064,10 +1050,8 @@ class EnsureRequirementsSuite extends SharedSparkSession { // With partition collections plan1 = new DummySparkPlanWithBatchScanChild(outputPartitioning = PartitioningCollection( - Seq(KeyedPartitioning(bucket(4, exprB) :: bucket(8, exprC) :: Nil, - leftPartValues, leftPartValues), - KeyedPartitioning(bucket(4, exprB) :: bucket(8, exprC) :: Nil, - leftPartValues, leftPartValues)) + Seq(KeyedPartitioning(bucket(4, exprB) :: bucket(8, exprC) :: Nil, leftPartValues), + KeyedPartitioning(bucket(4, exprB) :: bucket(8, exprC) :: Nil, leftPartValues)) ) ) @@ -1096,16 +1080,12 @@ class EnsureRequirementsSuite extends SharedSparkSession { Seq( PartitioningCollection( Seq( - KeyedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: Nil, - rightPartValues, rightPartValues), - KeyedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: Nil, - rightPartValues, rightPartValues))), + KeyedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: Nil, rightPartValues), + KeyedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: Nil, rightPartValues))), PartitioningCollection( Seq( - KeyedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: Nil, - rightPartValues, rightPartValues), - KeyedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: Nil, - rightPartValues, rightPartValues))) + KeyedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: Nil, rightPartValues), + KeyedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: Nil, rightPartValues))) ) ) ) @@ -1138,9 +1118,9 @@ class EnsureRequirementsSuite extends SharedSparkSession { val a1 = AttributeReference("a1", IntegerType)() - val partitionValue = Seq(50, 51, 52).map(v => InternalRow.fromSeq(Seq(v))) - val plan1 = new DummySparkPlanWithBatchScanChild(outputPartitioning = KeyedPartitioning( - identity(a1) :: Nil, partitionValue, partitionValue)) + val partitionKeys = Seq(50, 51, 52).map(v => InternalRow.fromSeq(Seq(v))) + val plan1 = new DummySparkPlanWithBatchScanChild( + outputPartitioning = KeyedPartitioning(identity(a1) :: Nil, partitionKeys)) val plan2 = DummySparkPlan(outputPartitioning = SinglePartition) val smjExec = ShuffledHashJoinExec( @@ -1148,11 +1128,11 @@ class EnsureRequirementsSuite extends SharedSparkSession { EnsureRequirements.apply(smjExec) match { case ShuffledHashJoinExec(_, _, _, _, _, DummySparkPlan(_, _, left: KeyedPartitioning, _, _), - ShuffleExchangeExec(KeyedPartitioning(attrs, pv, _), + ShuffleExchangeExec(KeyedPartitioning(attrs, pks), DummySparkPlan(_, _, SinglePartition, _, _), _, _), _) => assert(left.expressions == a1 :: Nil) assert(attrs == a1 :: Nil) - assert(partitionValue == pv) + assert(partitionKeys == pks) case other => fail(other.toString) } } From dc72b0d69cd827a7f94fedcc46576f280bb00ce9 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Tue, 24 Feb 2026 20:57:32 +0100 Subject: [PATCH 07/29] minor name and docs fix --- .../spark/sql/catalyst/plans/physical/partitioning.scala | 9 ++++----- .../sql/execution/exchange/EnsureRequirements.scala | 6 +++--- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 919d4cae849cd..a29b306323a0b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -362,11 +362,10 @@ case class CoalescedHashPartitioning(from: HashPartitioning, partitions: Seq[Coa * * 2. '''In KeyGroupedShuffleSpec''': When used within `KeyGroupedShuffleSpec`, the `partitionKeys` * may not be in sorted order. This occurs because `KeyGroupedShuffleSpec` can project the - * partition keys by join key positions (see `projectKeys` method), reordering them to match the - * join key order rather than the original sorted partition key order. The `EnsureRequirements` - * rule ensures that either the unordered keys from both sides of a join match exactly, or it - * builds a common ordered set of keys and pushes them down to `GroupPartitionsExec` on both - * sides to establish a compatible ordering. + * partition keys by join key positions. The `EnsureRequirements` rule ensures that either the + * unordered keys from both sides of a join match exactly, or it builds a common ordered set of + * keys and pushes them down to `GroupPartitionsExec` on both sides to establish a compatible + * ordering. * * == Partition Keys == * - `partitionKeys`: The partition keys, one per partition. May contain duplicates initially diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 12c25dda64736..7cd2c7d207ad2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -612,18 +612,18 @@ case class EnsureRequirements( replicateRightSide = false } else { // In partially clustered distribution, we should use un-grouped partition values - val (replicatedChild, replicatedSpec) = if (replicateLeftSide) { + val (partiallyClusteredChild, partiallyClusteredSpec) = if (replicateLeftSide) { (right, rightSpec) } else { (left, leftSpec) } - val originalPartitioning = (replicatedChild match { + val originalPartitioning = (partiallyClusteredChild match { case g: GroupPartitionsExec => g.child case o => o }).outputPartitioning.asInstanceOf[KeyedPartitioning] val dataTypes = partitionExprs.map(_.dataType) val projectedOriginalPartitionKeys = - replicatedSpec.joinKeyPositions.fold(originalPartitioning.partitionKeys)( + partiallyClusteredSpec.joinKeyPositions.fold(originalPartitioning.partitionKeys)( KeyedPartitioning.projectKeys(originalPartitioning.partitionKeys, _, dataTypes)) val internalRowComparableWrapperFactory = InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(dataTypes) From 84c3afae4f328a04a37c656f8576e85b81a06603 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Wed, 25 Feb 2026 10:36:20 +0100 Subject: [PATCH 08/29] fix `BatchScanExec` canonicalization --- .../spark/sql/execution/datasources/v2/BatchScanExec.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index e9b60efb9e8ca..810b88a7d47f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -156,7 +156,8 @@ case class BatchScanExec( output = output.map(QueryPlan.normalizeExpressions(_, output)), runtimeFilters = QueryPlan.normalizePredicates( runtimeFilters.filterNot(_ == DynamicPruningExpression(Literal.TrueLiteral)), - output)) + output), + keyGroupedPartitioning = keyGroupedPartitioning.map(QueryPlan.normalizePredicates(_, output))) } override def simpleString(maxFields: Int): String = { From 5b7677c9395d37bcc03a6c899cb5f8701d05afd9 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Wed, 25 Feb 2026 14:40:58 +0100 Subject: [PATCH 09/29] more code cleanup and docs fixes --- .../plans/physical/partitioning.scala | 87 +++++++++---------- .../datasources/v2/BatchScanExec.scala | 16 ++-- .../datasources/v2/GroupPartitionsExec.scala | 28 +++--- .../exchange/EnsureRequirements.scala | 27 +++--- 4 files changed, 67 insertions(+), 91 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index a29b306323a0b..f70d743790911 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -421,28 +421,26 @@ case class KeyedPartitioning( @transient lazy val expressionDataTypes: Seq[DataType] = expressions.map(_.dataType) - @transient lazy val keysComparableWrapperFactory = + @transient lazy val comparableKeyWrapperFactory = InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(expressionDataTypes) + @transient lazy val keyOrdering = RowOrdering.createNaturalAscendingOrdering(expressionDataTypes) + @transient lazy val isGrouped: Boolean = - partitionKeys.distinctBy(keysComparableWrapperFactory).size == partitionKeys.size + partitionKeys.distinctBy(comparableKeyWrapperFactory).size == partitionKeys.size def toGrouped: KeyedPartitioning = { - val groupedPartitionKeys = partitionKeys.distinctBy(keysComparableWrapperFactory) + val groupedPartitionKeys = partitionKeys.distinctBy(comparableKeyWrapperFactory) KeyedPartitioning(expressions, groupedPartitionKeys) } /** * Projects this partitioning's expressions by selecting only the specified positions. - * Returns both the projected expressions and their data types. + * Returns the projected expressions and their data types together with the projected keys. */ - def projectKeys(positions: Seq[Int]): (Seq[Expression], Seq[DataType], Seq[InternalRow]) = { - val projectedExpressions = positions.map(expressions) - val projectedDataTypes = projectedExpressions.map(_.dataType) - val projectedKeys = KeyedPartitioning.projectKeys(partitionKeys, positions, projectedDataTypes) - - (projectedExpressions, projectedDataTypes, projectedKeys) + def projectKeys(positions: Seq[Int]): (Seq[DataType], Seq[InternalRow]) = { + KeyedPartitioning.projectKeys(partitionKeys, expressionDataTypes, positions) } /** @@ -450,8 +448,8 @@ case class KeyedPartitioning( * Returns the distinct reduced keys. */ def reduceKeys(reducers: Seq[Option[Reducer[_, _]]]): Seq[InternalRow] = { - KeyedPartitioning.reduceKeys(partitionKeys, reducers, expressionDataTypes) - .distinctBy(keysComparableWrapperFactory) + KeyedPartitioning.reduceKeys(partitionKeys, expressionDataTypes, reducers) + .distinctBy(comparableKeyWrapperFactory) } override def satisfies0(required: Distribution): Boolean = { @@ -492,10 +490,11 @@ case class KeyedPartitioning( // `KeyedPartitioning` here that is grouped on the join keys instead, and use that as // the returned shuffle spec. val joinKeyPositions = result.keyPositions.map(_.nonEmpty).zipWithIndex.filter(_._1).map(_._2) - val (projectedExpressions, projectedDataTypes, projectedKeys) = projectKeys(joinKeyPositions) - val projectedComparableWrapperFactory = + val projectedExpressions = joinKeyPositions.map(expressions) + val (projectedDataTypes, projectedKeys) = projectKeys(joinKeyPositions) + val comparableKeyWrapperFactory = InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(projectedDataTypes) - val distinctProjectedKeys = projectedKeys.distinctBy(projectedComparableWrapperFactory) + val distinctProjectedKeys = projectedKeys.distinctBy(comparableKeyWrapperFactory) val projectedPartitioning = copy(expressions = projectedExpressions, partitionKeys = distinctProjectedKeys) result.copy(partitioning = projectedPartitioning, joinKeyPositions = Some(joinKeyPositions)) @@ -508,14 +507,14 @@ case class KeyedPartitioning( case k: KeyedPartitioning if this.expressions == k.expressions => partitionKeys.size == k.partitionKeys.size && partitionKeys.zip(k.partitionKeys).forall { case (l, r) => - keysComparableWrapperFactory(l).equals(keysComparableWrapperFactory(r)) + comparableKeyWrapperFactory(l).equals(comparableKeyWrapperFactory(r)) } case _ => false } override def hashCode(): Int = - Objects.hash(expressions, partitionKeys.map(keysComparableWrapperFactory)) + Objects.hash(expressions, partitionKeys.map(comparableKeyWrapperFactory)) } object KeyedPartitioning { @@ -538,36 +537,23 @@ object KeyedPartitioning { } } - def projectKey( - key: InternalRow, - positions: Seq[Int], - dataTypes: Seq[DataType]): InternalRow = { - val projectedKey = positions.zip(dataTypes).map { - case (position, dataType) => key.get(position, dataType) - }.toArray[Any] - new GenericInternalRow(projectedKey) - } - /** * Projects a sequence of partition keys by selecting only the specified positions. */ def projectKeys( keys: Seq[InternalRow], - positions: Seq[Int], - dataTypes: Seq[DataType]): Seq[InternalRow] = { - keys.map(projectKey(_, positions, dataTypes)) - } + dataTypes: Seq[DataType], + positions: Seq[Int]): (Seq[DataType], Seq[InternalRow]) = { + val projectedDataTypes = positions.map(dataTypes) + val positionsWithTypes = positions.zip(projectedDataTypes) + val projectedKeys = keys.map { key => + val projectedKey = positionsWithTypes.map { + case (position, dataType) => key.get(position, dataType) + }.toArray[Any] + new GenericInternalRow(projectedKey) + } - def reduceKey( - key: InternalRow, - reducers: Seq[Option[Reducer[_, _]]], - dataTypes: Seq[DataType]): InternalRow = { - val keyValues = key.toSeq(dataTypes) - val reducedKey = keyValues.zip(reducers).map{ - case (v, Some(reducer: Reducer[Any, Any])) => reducer.reduce(v) - case (v, _) => v - }.toArray - new GenericInternalRow(reducedKey) + (projectedDataTypes, projectedKeys) } /** @@ -575,9 +561,16 @@ object KeyedPartitioning { */ def reduceKeys( keys: Seq[InternalRow], - reducers: Seq[Option[Reducer[_, _]]], - dataTypes: Seq[DataType]): Seq[InternalRow] = { - keys.map(reduceKey(_, reducers, dataTypes)) + dataTypes: Seq[DataType], + reducers: Seq[Option[Reducer[_, _]]]): Seq[InternalRow] = { + keys.map { key => + val keyValues = key.toSeq(dataTypes) + val reducedKey = keyValues.zip(reducers).map { + case (v, Some(reducer: Reducer[Any, Any])) => reducer.reduce(v) + case (v, _) => v + }.toArray + new GenericInternalRow(reducedKey) + } } } @@ -963,14 +956,12 @@ case class KeyGroupedShuffleSpec( // transform functions. // 4. the partition values from both sides are following the same order. case otherSpec @ KeyGroupedShuffleSpec(otherPartitioning, otherDistribution, _) => - lazy val internalRowComparableFactory = - InternalRowComparableWrapper.getInternalRowComparableWrapperFactory( - partitioning.expressions.map(_.dataType)) distribution.clustering.length == otherDistribution.clustering.length && numPartitions == other.numPartitions && areKeysCompatible(otherSpec) && partitioning.partitionKeys.zip(otherPartitioning.partitionKeys).forall { case (left, right) => - internalRowComparableFactory(left).equals(internalRowComparableFactory(right)) + partitioning.comparableKeyWrapperFactory(left) + .equals(partitioning.comparableKeyWrapperFactory(right)) } case ShuffleSpecCollection(specs) => specs.exists(isCompatibleWith) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index 810b88a7d47f6..a459917237fc1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -101,19 +101,15 @@ case class BatchScanExec( "partition values that are not present in the original partitioning.") } - val dataTypes = p.expressions.map(_.dataType) - val rowOrdering = RowOrdering.createNaturalAscendingOrdering(dataTypes) - val internalRowComparableWrapperFactory = - InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(dataTypes) - + val comparableKeyWrapperFactory = p.comparableKeyWrapperFactory val inputMap = inputPartitions.groupBy(p => - internalRowComparableWrapperFactory(p.asInstanceOf[HasPartitionKey].partitionKey()) + comparableKeyWrapperFactory(p.asInstanceOf[HasPartitionKey].partitionKey()) ).view.mapValues(_.size) val filteredMap = newPartitions.groupBy(p => - internalRowComparableWrapperFactory(p.asInstanceOf[HasPartitionKey].partitionKey()) + comparableKeyWrapperFactory(p.asInstanceOf[HasPartitionKey].partitionKey()) ) inputMap.toSeq - .sortBy { case (keyWrapper, _) => keyWrapper.row }(rowOrdering) + .sortBy { case (keyWrapper, _) => keyWrapper.row }(p.keyOrdering) .flatMap { case (keyWrapper, size) => val fps = filteredMap.getOrElse(keyWrapper, Array.empty) assert(fps.size <= size) @@ -128,9 +124,7 @@ case class BatchScanExec( } else { (originalPartitioning match { case p: KeyedPartitioning => - val dataTypes = p.expressions.map(_.dataType) - val rowOrdering = RowOrdering.createNaturalAscendingOrdering(dataTypes) - inputPartitions.sortBy(_.asInstanceOf[HasPartitionKey].partitionKey())(rowOrdering) + inputPartitions.sortBy(_.asInstanceOf[HasPartitionKey].partitionKey())(p.keyOrdering) case _ => inputPartitions }).map(Some) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala index 07e72b41ae2a8..58800a9188e03 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala @@ -76,10 +76,10 @@ case class GroupPartitionsExec( */ private def distributeByCommonKeys( keyWrapperMap: Map[InternalRowComparableWrapper, Seq[Int]], - comparableWrapperFactory: InternalRow => InternalRowComparableWrapper + comparableKeyWrapperFactory: InternalRow => InternalRowComparableWrapper ): Seq[(InternalRow, Seq[Int])] = { commonPartitionKeys.get.flatMap { case (key, numSplits) => - val splits = keyWrapperMap.getOrElse(comparableWrapperFactory(key), Seq.empty) + val splits = keyWrapperMap.getOrElse(comparableKeyWrapperFactory(key), Seq.empty) if (applyPartialClustering && !replicatePartitions) { // Distribute splits across expected partitions, padding with empty sequences val paddedSplits = splits.map(Seq(_)).padTo(numSplits, Seq.empty) @@ -125,32 +125,24 @@ case class GroupPartitionsExec( // Project partition keys if join key positions are specified val (projectedDataTypes, projectedKeys) = - joinKeyPositions match { - case Some(positions) => - val (_, projectedDataTypes, projectedKeys) = keyedPartitioning.projectKeys(positions) - (projectedDataTypes, projectedKeys) - case None => - val dataTypes = keyedPartitioning.expressions.map(_.dataType) - (dataTypes, keyedPartitioning.partitionKeys) - } + joinKeyPositions.fold( + (keyedPartitioning.expressionDataTypes, keyedPartitioning.partitionKeys) + )(keyedPartitioning.projectKeys) // Reduce keys if reducers are specified - val reducedKeys = reducers match { - case Some(reducers) => - KeyedPartitioning.reduceKeys(projectedKeys, reducers, projectedDataTypes) - case None => projectedKeys - } + val reducedKeys = reducers.fold(projectedKeys)( + KeyedPartitioning.reduceKeys(projectedKeys, projectedDataTypes, _)) // Create map from partition keys to their indices - val comparableWrapperFactory = + val comparableKeyWrapperFactory = InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(projectedDataTypes) val keyWrapperToPartitionIndices = reducedKeys.zipWithIndex.groupMap { - case (key, _) => comparableWrapperFactory(key) + case (key, _) => comparableKeyWrapperFactory(key) }(_._2) if (commonPartitionKeys.isDefined) { - distributeByCommonKeys(keyWrapperToPartitionIndices, comparableWrapperFactory) + distributeByCommonKeys(keyWrapperToPartitionIndices, comparableKeyWrapperFactory) } else { groupAndSortByKeys(keyWrapperToPartitionIndices, projectedDataTypes) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 7cd2c7d207ad2..8b0ff8db6dafd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -476,9 +476,9 @@ case class EnsureRequirements( } } - // We don't need to add alter or add any `GroupPartitionsExec` when the child partitionings are - // not modified (projected) in specs and left and right side partitionings are compatible with - // each other. + // We don't need to alter the existing or add new `GroupPartitionsExec` when the child + // partitionings are not modified (projected) in specs and left and right side partitionings are + // compatible with each other. var isCompatible = containsPartitioning(left.outputPartitioning, leftSpec.partitioning) && containsPartitioning(right.outputPartitioning, rightSpec.partitioning) && leftSpec.isCompatibleWith(rightSpec) @@ -522,8 +522,8 @@ case class EnsureRequirements( // in case of compatible but not identical partition expressions, we apply 'reduce' // transforms to group one side's partitions as well as the common partition values val leftReducers = leftSpec.reducers(rightSpec) - val leftReducedKeys = leftReducers.fold(leftSpec.partitioning.partitionKeys)( - leftSpec.partitioning.reduceKeys) + val leftReducedKeys = + leftReducers.fold(leftSpec.partitioning.partitionKeys)(leftSpec.partitioning.reduceKeys) val rightReducers = rightSpec.reducers(leftSpec) val rightReducedKeys = rightReducers.fold(rightSpec.partitioning.partitionKeys)( rightSpec.partitioning.reduceKeys) @@ -621,20 +621,19 @@ case class EnsureRequirements( case g: GroupPartitionsExec => g.child case o => o }).outputPartitioning.asInstanceOf[KeyedPartitioning] - val dataTypes = partitionExprs.map(_.dataType) - val projectedOriginalPartitionKeys = - partiallyClusteredSpec.joinKeyPositions.fold(originalPartitioning.partitionKeys)( - KeyedPartitioning.projectKeys(originalPartitioning.partitionKeys, _, dataTypes)) - val internalRowComparableWrapperFactory = - InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(dataTypes) + val (projectedDataTypes, projectedOriginalPartitionKeys) = + partiallyClusteredSpec.joinKeyPositions.fold( + (originalPartitioning.expressionDataTypes, originalPartitioning.partitionKeys) + )(originalPartitioning.projectKeys) + val comparableKeyWrapperFactory = InternalRowComparableWrapper + .getInternalRowComparableWrapperFactory(projectedDataTypes) val numExpectedPartitions = projectedOriginalPartitionKeys - .groupBy(internalRowComparableWrapperFactory) + .groupBy(comparableKeyWrapperFactory) .view.mapValues(_.size) mergedPartitionKeys = mergedPartitionKeys.map { case (key, numParts) => - (key, numExpectedPartitions.getOrElse( - internalRowComparableWrapperFactory(key), numParts)) + (key, numExpectedPartitions.getOrElse(comparableKeyWrapperFactory(key), numParts)) } logInfo(log"After applying partially clustered distribution, there are " + From 25866203ac5736d5cb21f8706b80893f13109266 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Wed, 25 Feb 2026 18:16:29 +0100 Subject: [PATCH 10/29] partially clustered distribution no longer requires `canApplyPartialClusteredDistribution()` check as partition grouping happens right under the join, add granular grouping test --- .../exchange/EnsureRequirements.scala | 46 +++----- .../KeyGroupedPartitioningSuite.scala | 103 ++++++++++++++++++ 2 files changed, 117 insertions(+), 32 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 8b0ff8db6dafd..438a6d6e0dfd0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -403,26 +403,6 @@ case class EnsureRequirements( } } - /** - * Whether partial clustering can be applied to a given child query plan. This is true if the plan - * consists only of a sequence of unary nodes where each node does not use the scan's key-grouped - * partitioning to satisfy its required distribution. Otherwise, partially clustering could be - * applied to a key-grouped partitioning unrelated to this join. - */ - private def canApplyPartialClusteredDistribution(plan: SparkPlan): Boolean = { - !plan.exists { - // Unary nodes are safe as long as they don't have a required distribution (for example, a - // project or filter). If they have a required distribution, then we should assume that this - // plan can't be partially clustered (since the key-grouped partitioning may be needed to - // satisfy this distribution unrelated to this JOIN). - case u if u.children.length == 1 => - u.requiredChildDistribution.head != UnspecifiedDistribution - // Only allow a non-unary node if it's a leaf node - key-grouped partitionings other binary - // nodes (like another JOIN) aren't safe to partially cluster. - case other => other.children.nonEmpty - } - } - /** * Checks whether two children, `left` and `right`, of a join operator have compatible * `KeyedPartitioning`, and can benefit from storage-partitioned join. @@ -553,16 +533,9 @@ case class EnsureRequirements( // whether partially clustered distribution can be applied. For instance, the // optimization cannot be applied to a left outer join, where the left hand // side is chosen as the side to replicate partitions according to stats. - // Similarly, the partially clustered distribution cannot be applied if the - // partially clustered side must use the scan's key-grouped partitioning to - // satisfy some unrelated required distribution in its plan (for example, for an aggregate - // or window function), as this will give incorrect results (for example, duplicate - // row_number() values). // Otherwise, query result could be incorrect. - val canReplicateLeft = canReplicateLeftSide(joinType) && - canApplyPartialClusteredDistribution(right) - val canReplicateRight = canReplicateRightSide(joinType) && - canApplyPartialClusteredDistribution(left) + val canReplicateLeft = canReplicateLeftSide(joinType) + val canReplicateRight = canReplicateRightSide(joinType) if (!canReplicateLeft && !canReplicateRight) { logInfo(log"Skipping partially clustered distribution as it cannot be applied for " + @@ -617,14 +590,23 @@ case class EnsureRequirements( } else { (left, leftSpec) } + // Original `KeyedPartitioning` can be obtained from the child directly if the child + // satisfied the distribution requirement; or from the child's child if it didn't as + // the child must be a `GroupPartitionsExec` inserted by `EnsureRequirement` + // to satisfy the distribution requirement. val originalPartitioning = (partiallyClusteredChild match { case g: GroupPartitionsExec => g.child case o => o - }).outputPartitioning.asInstanceOf[KeyedPartitioning] + }).outputPartitioning.asInstanceOf[Partitioning with Expression] + // `originalPartitioning` can be a collection, but there must be `KeyedPartitioning` + // in it. + val originalKeyedPartitioning = + originalPartitioning.collectFirst { case k: KeyedPartitioning => k }.get val (projectedDataTypes, projectedOriginalPartitionKeys) = partiallyClusteredSpec.joinKeyPositions.fold( - (originalPartitioning.expressionDataTypes, originalPartitioning.partitionKeys) - )(originalPartitioning.projectKeys) + (originalKeyedPartitioning.expressionDataTypes, + originalKeyedPartitioning.partitionKeys) + )(originalKeyedPartitioning.projectKeys) val comparableKeyWrapperFactory = InternalRowComparableWrapper .getInternalRowComparableWrapperFactory(projectedDataTypes) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index ebb4598150c01..b9e6fc145b8e2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -370,6 +370,12 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { Column.create("price", FloatType), Column.create("time", TimestampType)) + private val details: String = "details" + private val detailsColumns: Array[Column] = Array( + Column.create("item_id", LongType), + Column.create("description", StringType), + Column.create("updated", TimestampType)) + test("SPARK-48655: group by on partition keys should not introduce additional shuffle") { val items_partitions = Array(identity("id")) createTable(items, itemsColumns, items_partitions) @@ -2974,4 +2980,101 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { assert(scans(0).inputRDD.partitions.length === 3, "items scan should not group") } + + test("SPARK-55092: Multi table join granular partition grouping") { + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { + val items_partitions = Array(identity("id"), years("arrive_time")) + createTable(items, itemsColumns, items_partitions) + + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 10.0, cast('2021-01-01' as timestamp)), " + + "(1, 'aa', 20.0, cast('2022-01-01' as timestamp)), " + + "(2, 'aa', 30.0, cast('2021-01-01' as timestamp)), " + + "(2, 'aa', 40.0, cast('2022-01-01' as timestamp))") + + val purchases_partitions = Array(identity("item_id"), years("time")) + createTable(purchases, purchasesColumns, purchases_partitions) + + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(2, 10.0, cast('2021-01-01' as timestamp)), " + + "(2, 20.0, cast('2022-01-01' as timestamp)), " + + "(3, 30.0, cast('2021-01-01' as timestamp)), " + + "(3, 40.0, cast('2022-01-01' as timestamp))") + + val details_partitions = Array(identity("item_id")) + createTable(details, detailsColumns, details_partitions) + + sql(s"INSERT INTO testcat.ns.$details VALUES " + + "(2, 'cc', cast('2021-01-01' as timestamp)), " + + "(3, 'cc', cast('2022-01-01' as timestamp))") + + val df = sql( + s""" + |SELECT i.id, i.arrive_time, p.item_id, d.item_id + |FROM testcat.ns.$items i + |JOIN testcat.ns.$purchases p ON p.item_id = i.id AND p.time = i.arrive_time + |JOIN testcat.ns.$details d ON d.item_id = i.id + |""".stripMargin) + + checkAnswer(df, Seq( + Row(2, Timestamp.valueOf("2021-01-01 00:00:00"), 2, 2), + Row(2, Timestamp.valueOf("2022-01-01 00:00:00"), 2, 2))) + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, "should not contain any shuffle") + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + // Expect 6 partitions in the inner join node legs because partitioning uses 2 attributes. + // Expect 3 partitions in the outer join node legs because partitioning uses 1 attributes. + assert(groupPartitions.map(_.outputPartitioning.numPartitions) === Seq(3, 6, 6, 3)) + } + } + + test("SPARK-55092: Multi table join partial clustering") { + withSQLConf(SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> "true") { + val items_partitions = Array(identity("id")) + createTable(items, itemsColumns, items_partitions) + + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 10.0, cast('2021-01-01' as timestamp)), " + + "(1, 'aa', 20.0, cast('2022-01-01' as timestamp)), " + + "(2, 'aa', 30.0, cast('2021-01-01' as timestamp)), " + + "(2, 'aa', 40.0, cast('2022-01-01' as timestamp))") + + val purchases_partitions = Array(identity("item_id")) + createTable(purchases, purchasesColumns, purchases_partitions) + + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(2, 10.0, cast('2021-01-01' as timestamp)), " + + "(3, 20.0, cast('2022-01-01' as timestamp))") + + val details_partitions = Array(identity("item_id")) + createTable(details, detailsColumns, details_partitions) + + sql(s"INSERT INTO testcat.ns.$details VALUES " + + "(2, 'cc', cast('2021-01-01' as timestamp)), " + + "(4, 'cc', cast('2022-01-01' as timestamp))") + + val df = sql( + s""" + |SELECT i.id, i.price, p.price, d.description + |FROM testcat.ns.$items i + |JOIN testcat.ns.$purchases p ON p.item_id = i.id + |JOIN testcat.ns.$details d ON d.item_id = i.id + |""".stripMargin) + + checkAnswer(df, Seq( + Row(2, 30.0, 10.0, "cc"), + Row(2, 40.0, 10.0, "cc"))) + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, "should not contain any shuffle") + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + // Expect 5 partitions in the inner join node legs because 4 from the partially clustered + // items table and 1 new from clustered purchases table. + // Expect 6 partitions in the outer join node legs because 5 from the partially clustered + // inner join result and 1 new from clustered details table. + assert(groupPartitions.map(_.outputPartitioning.numPartitions) === Seq(6, 5, 5, 6)) + } + } } From 90a49a3014f4df71f40f58e25e07f5b36e884fc1 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Thu, 26 Feb 2026 11:03:13 +0100 Subject: [PATCH 11/29] more code cleanup and comments --- .../exchange/EnsureRequirements.scala | 39 ++++++++----------- 1 file changed, 16 insertions(+), 23 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 438a6d6e0dfd0..4c4a922b96012 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -445,22 +445,14 @@ case class EnsureRequirements( val leftSpec = specs.head val rightSpec = specs(1) - def containsPartitioning( - partitioning: Partitioning, - keyedPartitioning: KeyedPartitioning): Boolean = { - partitioning match { - case k: KeyedPartitioning => k == keyedPartitioning - case PartitioningCollection(partitionings) => - partitionings.exists(containsPartitioning(_, keyedPartitioning)) - case _ => false - } - } - // We don't need to alter the existing or add new `GroupPartitionsExec` when the child // partitionings are not modified (projected) in specs and left and right side partitionings are // compatible with each other. - var isCompatible = containsPartitioning(left.outputPartitioning, leftSpec.partitioning) && - containsPartitioning(right.outputPartitioning, rightSpec.partitioning) && + // Left and right `outputPartitioning` is a `PartitioningCollection` or a `KeyedPartitioning` + // otherwise `createKeyGroupedShuffleSpec()` would have returned `None`. + var isCompatible = + left.outputPartitioning.asInstanceOf[Expression].exists(_ == leftSpec.partitioning) && + right.outputPartitioning.asInstanceOf[Expression].exists(_ == rightSpec.partitioning) && leftSpec.isCompatibleWith(rightSpec) if ((!isCompatible || conf.v2BucketingPartiallyClusteredDistributionEnabled) && (conf.v2BucketingPushPartValuesEnabled || @@ -541,8 +533,11 @@ case class EnsureRequirements( logInfo(log"Skipping partially clustered distribution as it cannot be applied for " + log"join type '${MDC(LogKeys.JOIN_TYPE, joinType)}'") } else { - val leftLink = unwrapGroupPartitions(left).logicalLink - val rightLink = unwrapGroupPartitions(right).logicalLink + val unwrappedLeft = unwrapGroupPartitions(left) + val unwrappedRight = unwrapGroupPartitions(right) + + val leftLink = unwrappedLeft.logicalLink + val rightLink = unwrappedRight.logicalLink replicateLeftSide = if ( leftLink.isDefined && rightLink.isDefined && @@ -586,20 +581,18 @@ case class EnsureRequirements( } else { // In partially clustered distribution, we should use un-grouped partition values val (partiallyClusteredChild, partiallyClusteredSpec) = if (replicateLeftSide) { - (right, rightSpec) + (unwrappedRight, rightSpec) } else { - (left, leftSpec) + (unwrappedLeft, leftSpec) } // Original `KeyedPartitioning` can be obtained from the child directly if the child // satisfied the distribution requirement; or from the child's child if it didn't as // the child must be a `GroupPartitionsExec` inserted by `EnsureRequirement` // to satisfy the distribution requirement. - val originalPartitioning = (partiallyClusteredChild match { - case g: GroupPartitionsExec => g.child - case o => o - }).outputPartitioning.asInstanceOf[Partitioning with Expression] - // `originalPartitioning` can be a collection, but there must be `KeyedPartitioning` - // in it. + val originalPartitioning = + partiallyClusteredChild.outputPartitioning.asInstanceOf[Expression] + // `outputPartitioning` is either a `PartitioningCollection` or a `KeyedPartitioning` + // otherwise `createKeyGroupedShuffleSpec()` would have returned `None`. val originalKeyedPartitioning = originalPartitioning.collectFirst { case k: KeyedPartitioning => k }.get val (projectedDataTypes, projectedOriginalPartitionKeys) = From a0a8a3d3e00984f6ce25b27131ba94f17a7adab7 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Fri, 27 Feb 2026 10:55:43 +0100 Subject: [PATCH 12/29] BatchScanExec code cleanup --- .../datasources/v2/BatchScanExec.scala | 52 ++++++++----------- 1 file changed, 23 insertions(+), 29 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index a459917237fc1..16f27a10666e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical.{KeyedPartitioning, SinglePartition} -import org.apache.spark.sql.catalyst.util.{truncatedString, InternalRowComparableWrapper} +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.read._ import org.apache.spark.util.ArrayImplicits._ @@ -75,44 +75,38 @@ case class BatchScanExec( val newPartitions = scan.toBatch.planInputPartitions() originalPartitioning match { - case p: KeyedPartitioning => + case k: KeyedPartitioning => if (newPartitions.exists(!_.isInstanceOf[HasPartitionKey])) { throw new SparkException("Data source must have preserved the original partitioning " + "during runtime filtering: not all partitions implement HasPartitionKey after " + "filtering") } - val newPartitionKeys = newPartitions.map(partition => - InternalRowComparableWrapper(partition.asInstanceOf[HasPartitionKey], p.expressions)) - .toSet - val oldPartitionKeys = p.partitionKeys - .map(partition => InternalRowComparableWrapper(partition, p.expressions)).toSet - // We require the new number of partition values to be equal or less than the old number - // of partition values here. In the case of less than, empty partitions will be added for - // those missing values that are not present in the new input partitions. - if (oldPartitionKeys.size < newPartitionKeys.size) { - throw new SparkException("During runtime filtering, data source must either report " + - "the same number of partition values, or a subset of partition values from the " + - s"original. Before: ${oldPartitionKeys.size} partition values. " + - s"After: ${newPartitionKeys.size} partition values") - } - if (!newPartitionKeys.forall(oldPartitionKeys.contains)) { + val inputMap = inputPartitions.groupBy( + p => k.comparableKeyWrapperFactory(p.asInstanceOf[HasPartitionKey].partitionKey()) + ).view.mapValues(_.size) + val filteredMap = newPartitions.groupBy( + p => k.comparableKeyWrapperFactory(p.asInstanceOf[HasPartitionKey].partitionKey()) + ) + + if (!filteredMap.keySet.subsetOf(inputMap.keySet)) { throw new SparkException("During runtime filtering, data source must not report new " + - "partition values that are not present in the original partitioning.") + "partition keys that are not present in the original partitioning.") } - val comparableKeyWrapperFactory = p.comparableKeyWrapperFactory - val inputMap = inputPartitions.groupBy(p => - comparableKeyWrapperFactory(p.asInstanceOf[HasPartitionKey].partitionKey()) - ).view.mapValues(_.size) - val filteredMap = newPartitions.groupBy(p => - comparableKeyWrapperFactory(p.asInstanceOf[HasPartitionKey].partitionKey()) - ) inputMap.toSeq - .sortBy { case (keyWrapper, _) => keyWrapper.row }(p.keyOrdering) + .sortBy { case (keyWrapper, _) => keyWrapper.row }(k.keyOrdering) .flatMap { case (keyWrapper, size) => + // We require the new number of partitions to be equal or less than the old number of + // partitions for a given key. In the case of less than, empty partitions are added. val fps = filteredMap.getOrElse(keyWrapper, Array.empty) - assert(fps.size <= size) + + if (fps.size > size) { + throw new SparkException("During runtime filtering, data source must not report " + + s"new partitions for a given key. Before: $size partitions. " + + s"After: ${fps.size} partitions") + } + fps.map(Some).padTo(size, None) } @@ -123,8 +117,8 @@ case class BatchScanExec( } else { (originalPartitioning match { - case p: KeyedPartitioning => - inputPartitions.sortBy(_.asInstanceOf[HasPartitionKey].partitionKey())(p.keyOrdering) + case k: KeyedPartitioning => + inputPartitions.sortBy(_.asInstanceOf[HasPartitionKey].partitionKey())(k.keyOrdering) case _ => inputPartitions }).map(Some) From 853e6b72fcdb7e528f2a394c7d658091facbf9f1 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Tue, 3 Mar 2026 10:27:45 +0100 Subject: [PATCH 13/29] make `isGrouped` precomputed --- .../plans/physical/partitioning.scala | 26 ++++++++++++++----- 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index f70d743790911..3e331e6c5feed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -405,10 +405,13 @@ case class CoalescedHashPartitioning(from: HashPartitioning, partitions: Seq[Coa * @param partitionKeys Partition keys, one per partition. When used as outputPartitioning, * always in sorted order. When used in KeyGroupedShuffleSpec, may be * unsorted after projection. May contain duplicates when ungrouped. + * @param isGrouped Whether partition keys are unique (no duplicates). Computed on first + * creation, then preserved through copy operations to avoid recomputation. */ case class KeyedPartitioning( expressions: Seq[Expression], - partitionKeys: Seq[InternalRow]) extends Expression with Partitioning with Unevaluable { + partitionKeys: Seq[InternalRow], + isGrouped: Boolean = false) extends Expression with Partitioning with Unevaluable { override val numPartitions = partitionKeys.length override def children: Seq[Expression] = expressions @@ -426,13 +429,10 @@ case class KeyedPartitioning( @transient lazy val keyOrdering = RowOrdering.createNaturalAscendingOrdering(expressionDataTypes) - @transient lazy val isGrouped: Boolean = - partitionKeys.distinctBy(comparableKeyWrapperFactory).size == partitionKeys.size - def toGrouped: KeyedPartitioning = { val groupedPartitionKeys = partitionKeys.distinctBy(comparableKeyWrapperFactory) - KeyedPartitioning(expressions, groupedPartitionKeys) + KeyedPartitioning(expressions, groupedPartitionKeys, isGrouped = true) } /** @@ -518,6 +518,20 @@ case class KeyedPartitioning( } object KeyedPartitioning { + /** + * Creates a KeyedPartitioning with isGrouped computed from the partition keys. + * Use this when creating a new KeyedPartitioning from scratch (e.g., from a data source). + */ + def apply( + expressions: Seq[Expression], + partitionKeys: Seq[InternalRow]): KeyedPartitioning = { + val dataTypes = expressions.map(_.dataType) + val comparableKeyWrapperFactory = + InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(dataTypes) + val isGrouped = partitionKeys.distinctBy(comparableKeyWrapperFactory).size == partitionKeys.size + new KeyedPartitioning(expressions, partitionKeys, isGrouped) + } + def supportsExpressions(expressions: Seq[Expression]): Boolean = { def isSupportedTransform(transform: TransformExpression): Boolean = { transform.children.size == 1 && isReference(transform.children.head) @@ -1040,7 +1054,7 @@ case class KeyGroupedShuffleSpec( te.copy(children = te.children.map(_ => clustering(positionSet.head))) case (_, positionSet) => clustering(positionSet.head) } - KeyedPartitioning(newExpressions, partitioning.partitionKeys) + KeyedPartitioning(newExpressions, partitioning.partitionKeys, partitioning.isGrouped) } } From f5baf7670b6cba639b9b032586d161d7462f8efc Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Tue, 3 Mar 2026 10:31:00 +0100 Subject: [PATCH 14/29] change `toGrouped` to sort partition keys --- .../spark/sql/catalyst/plans/physical/partitioning.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 3e331e6c5feed..f51678b7dc28c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -430,7 +430,8 @@ case class KeyedPartitioning( @transient lazy val keyOrdering = RowOrdering.createNaturalAscendingOrdering(expressionDataTypes) def toGrouped: KeyedPartitioning = { - val groupedPartitionKeys = partitionKeys.distinctBy(comparableKeyWrapperFactory) + val groupedPartitionKeys = + partitionKeys.distinctBy(comparableKeyWrapperFactory).sorted(keyOrdering) KeyedPartitioning(expressions, groupedPartitionKeys, isGrouped = true) } From 174fe90dbf27cbff2523d804ba8f2087bfd936e4 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Tue, 3 Mar 2026 15:00:28 +0100 Subject: [PATCH 15/29] use `InternalRowComparableWrapper` partition keys in `KeyedPartitioning` --- .../plans/physical/partitioning.scala | 75 ++++++------ .../util/InternalRowComparableWrapper.scala | 27 ----- ...nternalRowComparableWrapperBenchmark.scala | 26 ++-- .../datasources/v2/BatchScanExec.scala | 18 +-- .../datasources/v2/GroupPartitionsExec.scala | 58 +++++---- .../exchange/EnsureRequirements.scala | 112 +++++++++--------- .../exchange/ShuffleExchangeExec.scala | 9 +- .../DistributionAndOrderingSuiteBase.scala | 4 +- .../exchange/EnsureRequirementsSuite.scala | 2 +- 9 files changed, 143 insertions(+), 188 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index f51678b7dc28c..475e6a52f3ce7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -402,16 +402,17 @@ case class CoalescedHashPartitioning(from: HashPartitioning, partitions: Seq[Coa * }}} * * @param expressions Partition transform expressions (e.g., `years(col)`, `bucket(10, col)`). - * @param partitionKeys Partition keys, one per partition. When used as outputPartitioning, - * always in sorted order. When used in KeyGroupedShuffleSpec, may be - * unsorted after projection. May contain duplicates when ungrouped. + * @param partitionKeys Partition keys wrapped in InternalRowComparableWrapper for efficient + * comparison and grouping. One per partition. When used as outputPartitioning, + * always in sorted order. When used in KeyGroupedShuffleSpec, may be unsorted + * after projection. May contain duplicates when ungrouped. * @param isGrouped Whether partition keys are unique (no duplicates). Computed on first * creation, then preserved through copy operations to avoid recomputation. */ case class KeyedPartitioning( expressions: Seq[Expression], - partitionKeys: Seq[InternalRow], - isGrouped: Boolean = false) extends Expression with Partitioning with Unevaluable { + @transient partitionKeys: Seq[InternalRowComparableWrapper], + isGrouped: Boolean) extends Expression with Partitioning with Unevaluable { override val numPartitions = partitionKeys.length override def children: Seq[Expression] = expressions @@ -424,34 +425,30 @@ case class KeyedPartitioning( @transient lazy val expressionDataTypes: Seq[DataType] = expressions.map(_.dataType) - @transient lazy val comparableKeyWrapperFactory = - InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(expressionDataTypes) + @transient lazy val keyRowOrdering = + RowOrdering.createNaturalAscendingOrdering(expressionDataTypes) - @transient lazy val keyOrdering = RowOrdering.createNaturalAscendingOrdering(expressionDataTypes) + @transient lazy val keyOrdering = keyRowOrdering.on((t: InternalRowComparableWrapper) => t.row) def toGrouped: KeyedPartitioning = { - val groupedPartitionKeys = - partitionKeys.distinctBy(comparableKeyWrapperFactory).sorted(keyOrdering) + val groupedPartitionKeys = partitionKeys.distinct.sorted(keyOrdering) - KeyedPartitioning(expressions, groupedPartitionKeys, isGrouped = true) + new KeyedPartitioning(expressions, groupedPartitionKeys, isGrouped = true) } /** * Projects this partitioning's expressions by selecting only the specified positions. * Returns the projected expressions and their data types together with the projected keys. */ - def projectKeys(positions: Seq[Int]): (Seq[DataType], Seq[InternalRow]) = { + def projectKeys(positions: Seq[Int]): (Seq[DataType], Seq[InternalRowComparableWrapper]) = KeyedPartitioning.projectKeys(partitionKeys, expressionDataTypes, positions) - } /** * Reduces this partitioning's partition keys by applying the given reducers. * Returns the distinct reduced keys. */ - def reduceKeys(reducers: Seq[Option[Reducer[_, _]]]): Seq[InternalRow] = { - KeyedPartitioning.reduceKeys(partitionKeys, expressionDataTypes, reducers) - .distinctBy(comparableKeyWrapperFactory) - } + def reduceKeys(reducers: Seq[Option[Reducer[_, _]]]): Seq[InternalRowComparableWrapper] = + KeyedPartitioning.reduceKeys(partitionKeys, expressionDataTypes, reducers).distinct override def satisfies0(required: Distribution): Boolean = { super.satisfies0(required) || isGrouped && { @@ -492,10 +489,8 @@ case class KeyedPartitioning( // the returned shuffle spec. val joinKeyPositions = result.keyPositions.map(_.nonEmpty).zipWithIndex.filter(_._1).map(_._2) val projectedExpressions = joinKeyPositions.map(expressions) - val (projectedDataTypes, projectedKeys) = projectKeys(joinKeyPositions) - val comparableKeyWrapperFactory = - InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(projectedDataTypes) - val distinctProjectedKeys = projectedKeys.distinctBy(comparableKeyWrapperFactory) + val projectedKeys = projectKeys(joinKeyPositions)._2 + val distinctProjectedKeys = projectedKeys.distinct val projectedPartitioning = copy(expressions = projectedExpressions, partitionKeys = distinctProjectedKeys) result.copy(partitioning = projectedPartitioning, joinKeyPositions = Some(joinKeyPositions)) @@ -506,16 +501,13 @@ case class KeyedPartitioning( override def equals(that: Any): Boolean = that match { case k: KeyedPartitioning if this.expressions == k.expressions => - partitionKeys.size == k.partitionKeys.size && - partitionKeys.zip(k.partitionKeys).forall { case (l, r) => - comparableKeyWrapperFactory(l).equals(comparableKeyWrapperFactory(r)) - } + this.partitionKeys == k.partitionKeys case _ => false } override def hashCode(): Int = - Objects.hash(expressions, partitionKeys.map(comparableKeyWrapperFactory)) + Objects.hash(expressions, partitionKeys) } object KeyedPartitioning { @@ -529,8 +521,9 @@ object KeyedPartitioning { val dataTypes = expressions.map(_.dataType) val comparableKeyWrapperFactory = InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(dataTypes) - val isGrouped = partitionKeys.distinctBy(comparableKeyWrapperFactory).size == partitionKeys.size - new KeyedPartitioning(expressions, partitionKeys, isGrouped) + val comparablePartitionKeys = partitionKeys.map(comparableKeyWrapperFactory) + val isGrouped = comparablePartitionKeys.distinct.size == comparablePartitionKeys.size + new KeyedPartitioning(expressions, comparablePartitionKeys, isGrouped) } def supportsExpressions(expressions: Seq[Expression]): Boolean = { @@ -556,16 +549,18 @@ object KeyedPartitioning { * Projects a sequence of partition keys by selecting only the specified positions. */ def projectKeys( - keys: Seq[InternalRow], + keys: Seq[InternalRowComparableWrapper], dataTypes: Seq[DataType], - positions: Seq[Int]): (Seq[DataType], Seq[InternalRow]) = { + positions: Seq[Int]): (Seq[DataType], Seq[InternalRowComparableWrapper]) = { val projectedDataTypes = positions.map(dataTypes) + val comparableKeyWrapperFactory = + InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(projectedDataTypes) val positionsWithTypes = positions.zip(projectedDataTypes) val projectedKeys = keys.map { key => val projectedKey = positionsWithTypes.map { - case (position, dataType) => key.get(position, dataType) + case (position, dataType) => key.row.get(position, dataType) }.toArray[Any] - new GenericInternalRow(projectedKey) + comparableKeyWrapperFactory(new GenericInternalRow(projectedKey)) } (projectedDataTypes, projectedKeys) @@ -575,16 +570,18 @@ object KeyedPartitioning { * Reduces a sequence of partition keys by applying reducers to each position. */ def reduceKeys( - keys: Seq[InternalRow], + keys: Seq[InternalRowComparableWrapper], dataTypes: Seq[DataType], - reducers: Seq[Option[Reducer[_, _]]]): Seq[InternalRow] = { + reducers: Seq[Option[Reducer[_, _]]]): Seq[InternalRowComparableWrapper] = { + val comparableKeyWrapperFactory = + InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(dataTypes) keys.map { key => - val keyValues = key.toSeq(dataTypes) + val keyValues = key.row.toSeq(dataTypes) val reducedKey = keyValues.zip(reducers).map { case (v, Some(reducer: Reducer[Any, Any])) => reducer.reduce(v) case (v, _) => v }.toArray - new GenericInternalRow(reducedKey) + comparableKeyWrapperFactory(new GenericInternalRow(reducedKey)) } } } @@ -973,11 +970,7 @@ case class KeyGroupedShuffleSpec( case otherSpec @ KeyGroupedShuffleSpec(otherPartitioning, otherDistribution, _) => distribution.clustering.length == otherDistribution.clustering.length && numPartitions == other.numPartitions && areKeysCompatible(otherSpec) && - partitioning.partitionKeys.zip(otherPartitioning.partitionKeys).forall { - case (left, right) => - partitioning.comparableKeyWrapperFactory(left) - .equals(partitioning.comparableKeyWrapperFactory(right)) - } + partitioning.partitionKeys == otherPartitioning.partitionKeys case ShuffleSpecCollection(specs) => specs.exists(isCompatibleWith) case _ => false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala index b9935d40ed985..217d12710a6a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.util -import scala.collection.mutable - import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{BaseOrdering, Expression, Murmur3HashFunction, RowOrdering} import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition} @@ -101,31 +99,6 @@ object InternalRowComparableWrapper { new InternalRowComparableWrapper(partitionRow, partitionExpression.map(_.dataType)) } - def mergePartitions( - leftPartitioning: Seq[InternalRow], - rightPartitioning: Seq[InternalRow], - partitionExpression: Seq[Expression], - intersect: Boolean = false): Seq[InternalRowComparableWrapper] = { - val partitionDataTypes = partitionExpression.map(_.dataType) - val leftPartitionSet = new mutable.HashSet[InternalRowComparableWrapper] - val internalRowComparableWrapperFactory = - getInternalRowComparableWrapperFactory(partitionDataTypes) - leftPartitioning - .map(internalRowComparableWrapperFactory) - .foreach(partition => leftPartitionSet.add(partition)) - val rightPartitionSet = new mutable.HashSet[InternalRowComparableWrapper] - rightPartitioning - .map(internalRowComparableWrapperFactory) - .foreach(partition => rightPartitionSet.add(partition)) - - val result = if (intersect) { - leftPartitionSet.intersect(rightPartitionSet) - } else { - leftPartitionSet.union(rightPartitionSet) - } - result.toSeq - } - /** Creates a shared factory method for a given row schema to avoid excessive cache lookups. */ def getInternalRowComparableWrapperFactory( dataTypes: Seq[DataType]): InternalRow => InternalRowComparableWrapper = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapperBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapperBenchmark.scala index a96e58727bd15..4f431e6171b28 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapperBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapperBenchmark.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.catalyst.util +import scala.collection.mutable + import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} -import org.apache.spark.sql.catalyst.expressions.Literal -import org.apache.spark.sql.catalyst.plans.physical.KeyedPartitioning import org.apache.spark.sql.connector.catalog.PartitionInternalRow import org.apache.spark.sql.types.IntegerType @@ -47,24 +47,22 @@ object InternalRowComparableWrapperBenchmark extends BenchmarkBase { } val benchmark = new Benchmark("internal row comparable wrapper", partitionNum, output = output) + val comparableKeyWrapperFactory = + InternalRowComparableWrapper.getInternalRowComparableWrapperFactory( + Seq(IntegerType, IntegerType)) + val comparablePartitionKeys = partitionKeys.map(comparableKeyWrapperFactory) + benchmark.addCase("toSet") { _ => - val internalRowComparableWrapperFactory = - InternalRowComparableWrapper.getInternalRowComparableWrapperFactory( - Seq(IntegerType, IntegerType)) - val distinct = partitionKeys - .map(internalRowComparableWrapperFactory) - .toSet + val distinct = comparablePartitionKeys.toSet + assert(distinct.size == bucketNum) } benchmark.addCase("mergePartitions") { _ => - // just to mock the data types - val expressions = (Seq(Literal(day, IntegerType), Literal(0, IntegerType))) + val leftKeySet = mutable.HashSet.from(comparablePartitionKeys) + val rightKeySet = mutable.HashSet.from(comparablePartitionKeys) + val merged = leftKeySet.union(rightKeySet) - val leftPartitioning = KeyedPartitioning(expressions, partitionKeys) - val rightPartitioning = KeyedPartitioning(expressions, partitionKeys) - val merged = InternalRowComparableWrapper.mergePartitions( - leftPartitioning.partitionKeys, rightPartitioning.partitionKeys, expressions) assert(merged.size == bucketNum) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index 16f27a10666e5..a2a85a1b50e55 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical.{KeyedPartitioning, SinglePartition} -import org.apache.spark.sql.catalyst.util.truncatedString +import org.apache.spark.sql.catalyst.util.{truncatedString, InternalRowComparableWrapper} import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.read._ import org.apache.spark.util.ArrayImplicits._ @@ -82,11 +82,11 @@ case class BatchScanExec( "filtering") } - val inputMap = inputPartitions.groupBy( - p => k.comparableKeyWrapperFactory(p.asInstanceOf[HasPartitionKey].partitionKey()) - ).view.mapValues(_.size) + val inputMap = k.partitionKeys.groupBy(identity).view.mapValues(_.size) + val comparableKeyWrapperFactory = InternalRowComparableWrapper + .getInternalRowComparableWrapperFactory(k.expressionDataTypes) val filteredMap = newPartitions.groupBy( - p => k.comparableKeyWrapperFactory(p.asInstanceOf[HasPartitionKey].partitionKey()) + p => comparableKeyWrapperFactory(p.asInstanceOf[HasPartitionKey].partitionKey()) ) if (!filteredMap.keySet.subsetOf(inputMap.keySet)) { @@ -95,11 +95,11 @@ case class BatchScanExec( } inputMap.toSeq - .sortBy { case (keyWrapper, _) => keyWrapper.row }(k.keyOrdering) - .flatMap { case (keyWrapper, size) => + .sortBy(_._1)(k.keyOrdering) + .flatMap { case (key, size) => // We require the new number of partitions to be equal or less than the old number of // partitions for a given key. In the case of less than, empty partitions are added. - val fps = filteredMap.getOrElse(keyWrapper, Array.empty) + val fps = filteredMap.getOrElse(key, Array.empty) if (fps.size > size) { throw new SparkException("During runtime filtering, data source must not report " + @@ -118,7 +118,7 @@ case class BatchScanExec( } else { (originalPartitioning match { case k: KeyedPartitioning => - inputPartitions.sortBy(_.asInstanceOf[HasPartitionKey].partitionKey())(k.keyOrdering) + inputPartitions.sortBy(_.asInstanceOf[HasPartitionKey].partitionKey())(k.keyRowOrdering) case _ => inputPartitions }).map(Some) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala index 58800a9188e03..d274f3fb567d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala @@ -49,21 +49,31 @@ import org.apache.spark.sql.types.DataType */ case class GroupPartitionsExec( child: SparkPlan, - joinKeyPositions: Option[Seq[Int]] = None, - commonPartitionKeys: Option[Seq[(InternalRow, Int)]] = None, - reducers: Option[Seq[Option[Reducer[_, _]]]] = None, - applyPartialClustering: Boolean = false, - replicatePartitions: Boolean = false + @transient joinKeyPositions: Option[Seq[Int]] = None, + @transient commonPartitionKeys: Option[Seq[(InternalRowComparableWrapper, Int)]] = None, + @transient reducers: Option[Seq[Option[Reducer[_, _]]]] = None, + @transient applyPartialClustering: Boolean = false, + @transient replicatePartitions: Boolean = false ) extends UnaryExecNode { override def outputPartitioning: Partitioning = { child.outputPartitioning match { case p: Partitioning with Expression => + // There can be multiple `KeyedPartitioning` in an output partitioning of a join, but they + // can only differ in `expressions`. `partitionKeys` must match so we can calculate it only + // once via `groupedPartitions`. + + val keyedPartitionings = p.collect { case k: KeyedPartitioning => k } + if (keyedPartitionings.size > 1) { + val first = keyedPartitionings.head + keyedPartitionings.tail.foreach { k => + assert(k.partitionKeys == first.partitionKeys, + "All KeyedPartitioning nodes must have identical partition keys") + } + } + p.transform { case k: KeyedPartitioning => - // There can be multiple `KeyedPartitioning` in an output partitioning of a join, but - // they can only differ in `expressions`. `partitionKeys` must match so we can calculate - // it only once via `groupedPartitions`. val projectedExpressions = joinKeyPositions.fold(k.expressions)(_.map(k.expressions)) k.copy(expressions = projectedExpressions, partitionKeys = groupedPartitions.map(_._1)) }.asInstanceOf[Partitioning] @@ -74,12 +84,9 @@ case class GroupPartitionsExec( /** * Distributes partitions based on `commonPartitionKeys` and clustering mode. */ - private def distributeByCommonKeys( - keyWrapperMap: Map[InternalRowComparableWrapper, Seq[Int]], - comparableKeyWrapperFactory: InternalRow => InternalRowComparableWrapper - ): Seq[(InternalRow, Seq[Int])] = { + private def distributeByCommonKeys(keyMap: Map[InternalRowComparableWrapper, Seq[Int]]) = { commonPartitionKeys.get.flatMap { case (key, numSplits) => - val splits = keyWrapperMap.getOrElse(comparableKeyWrapperFactory(key), Seq.empty) + val splits = keyMap.getOrElse(key, Seq.empty) if (applyPartialClustering && !replicatePartitions) { // Distribute splits across expected partitions, padding with empty sequences val paddedSplits = splits.map(Seq(_)).padTo(numSplits, Seq.empty) @@ -95,13 +102,10 @@ case class GroupPartitionsExec( * Groups and sorts partitions by their keys in ascending order. */ private def groupAndSortByKeys( - keyWrapperMap: Map[InternalRowComparableWrapper, Seq[Int]], - dataTypes: Seq[DataType] - ): Seq[(InternalRow, Seq[Int])] = { - val rowOrdering = RowOrdering.createNaturalAscendingOrdering(dataTypes) - keyWrapperMap.toSeq - .map { case (keyWrapper, indices) => (keyWrapper.row, indices) } - .sorted(rowOrdering.on((t: (InternalRow, _)) => t._1)) + keyMap: Map[InternalRowComparableWrapper, Seq[Int]], + dataTypes: Seq[DataType]) = { + val keyOrdering = RowOrdering.createNaturalAscendingOrdering(dataTypes) + keyMap.toSeq.sorted(keyOrdering.on((t: (InternalRowComparableWrapper, _)) => t._1.row)) } /** @@ -114,7 +118,7 @@ case class GroupPartitionsExec( * Returns a sequence of (partitionKey, inputPartitionIndices) pairs representing * how input partitions should be grouped together. */ - lazy val groupedPartitions = { + @transient lazy val groupedPartitions = { // There must be a `KeyedPartitioning` in child's output partitioning as a // `GroupPartitionsExec` node is added to a plan only in that case. val keyedPartitioning = child.outputPartitioning @@ -133,18 +137,12 @@ case class GroupPartitionsExec( val reducedKeys = reducers.fold(projectedKeys)( KeyedPartitioning.reduceKeys(projectedKeys, projectedDataTypes, _)) - // Create map from partition keys to their indices - val comparableKeyWrapperFactory = - InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(projectedDataTypes) - - val keyWrapperToPartitionIndices = reducedKeys.zipWithIndex.groupMap { - case (key, _) => comparableKeyWrapperFactory(key) - }(_._2) + val keyToPartitionIndices = reducedKeys.zipWithIndex.groupMap(_._1)(_._2) if (commonPartitionKeys.isDefined) { - distributeByCommonKeys(keyWrapperToPartitionIndices, comparableKeyWrapperFactory) + distributeByCommonKeys(keyToPartitionIndices) } else { - groupAndSortByKeys(keyWrapperToPartitionIndices, projectedDataTypes) + groupAndSortByKeys(keyToPartitionIndices, projectedDataTypes) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 4c4a922b96012..4cfed4f629c09 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -21,7 +21,6 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.internal.{LogKeys} -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ @@ -321,12 +320,11 @@ case class EnsureRequirements( satisfyingKeyedPartitioning match { case Some(k) => val attrs = k.expressions.flatMap(_.collectLeaves()).map(_.asInstanceOf[Attribute]) - val partitionOrdering: Ordering[InternalRow] = { - RowOrdering.create(distribution.ordering, attrs) - } + val keyRowOrdering = RowOrdering.create(distribution.ordering, attrs) + val keyOrdering = keyRowOrdering.on((t: InternalRowComparableWrapper) => t.row) // Sort 'commonPartitionKeys' and use this mechanism to ensure BatchScan's output // partitions are ordered - val sorted = k.partitionKeys.sorted(partitionOrdering) + val sorted = k.partitionKeys.sorted(keyOrdering) GroupPartitionsExec(plan, commonPartitionKeys = Some(sorted.map((_, 1)))) case _ => plan @@ -354,12 +352,12 @@ case class EnsureRequirements( reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, rightExpressions, rightKeys) .orElse(reorderJoinKeysRecursively( leftKeys, rightKeys, leftPartitioning, None)) - case (Some(KeyedPartitioning(clustering, _)), _) => + case (Some(KeyedPartitioning(clustering, _, _)), _) => val leafExprs = clustering.flatMap(_.collectLeaves()) reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, leftKeys) .orElse(reorderJoinKeysRecursively( leftKeys, rightKeys, None, rightPartitioning)) - case (_, Some(KeyedPartitioning(clustering, _))) => + case (_, Some(KeyedPartitioning(clustering, _, _))) => val leafExprs = clustering.flatMap(_.collectLeaves()) reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, rightKeys) .orElse(reorderJoinKeysRecursively( @@ -444,6 +442,8 @@ case class EnsureRequirements( val leftSpec = specs.head val rightSpec = specs(1) + val leftPartitioning = leftSpec.partitioning + val rightPartitioning = rightSpec.partitioning // We don't need to alter the existing or add new `GroupPartitionsExec` when the child // partitionings are not modified (projected) in specs and left and right side partitionings are @@ -451,8 +451,8 @@ case class EnsureRequirements( // Left and right `outputPartitioning` is a `PartitioningCollection` or a `KeyedPartitioning` // otherwise `createKeyGroupedShuffleSpec()` would have returned `None`. var isCompatible = - left.outputPartitioning.asInstanceOf[Expression].exists(_ == leftSpec.partitioning) && - right.outputPartitioning.asInstanceOf[Expression].exists(_ == rightSpec.partitioning) && + left.outputPartitioning.asInstanceOf[Expression].exists(_ == leftPartitioning) && + right.outputPartitioning.asInstanceOf[Expression].exists(_ == rightPartitioning) && leftSpec.isCompatibleWith(rightSpec) if ((!isCompatible || conf.v2BucketingPartiallyClusteredDistributionEnabled) && (conf.v2BucketingPushPartValuesEnabled || @@ -476,8 +476,8 @@ case class EnsureRequirements( // just push the common set of partition values: `[0, 1, 2, 3]` down to the two data // sources. if (isCompatible) { - val leftPartKeys = leftSpec.partitioning.partitionKeys - val rightPartKeys = rightSpec.partitioning.partitionKeys + val leftPartKeys = leftPartitioning.partitionKeys + val rightPartKeys = leftPartitioning.partitionKeys val numLeftPartKeys = MDC(LogKeys.NUM_LEFT_PARTITION_VALUES, leftPartKeys.size) val numRightPartKeys = MDC(LogKeys.NUM_RIGHT_PARTITION_VALUES, rightPartKeys.size) @@ -487,23 +487,19 @@ case class EnsureRequirements( |Right side # of partitions: $numRightPartKeys |""".stripMargin) - // As partition keys are compatible, we can pick either left or right as partition - // expressions - val partitionExprs = leftSpec.partitioning.expressions - // in case of compatible but not identical partition expressions, we apply 'reduce' // transforms to group one side's partitions as well as the common partition values val leftReducers = leftSpec.reducers(rightSpec) val leftReducedKeys = - leftReducers.fold(leftSpec.partitioning.partitionKeys)(leftSpec.partitioning.reduceKeys) + leftReducers.fold(leftPartitioning.partitionKeys)(leftPartitioning.reduceKeys) val rightReducers = rightSpec.reducers(leftSpec) - val rightReducedKeys = rightReducers.fold(rightSpec.partitioning.partitionKeys)( - rightSpec.partitioning.reduceKeys) + val rightReducedKeys = + rightReducers.fold(rightPartitioning.partitionKeys)(rightPartitioning.reduceKeys) // merge values on both sides var mergedPartitionKeys = - mergePartitions(leftReducedKeys, rightReducedKeys, partitionExprs, joinType) - .map(v => (v, 1)) + mergePartitions(leftReducedKeys, rightReducedKeys, joinType, leftPartitioning.keyOrdering) + .map((_, 1)) logInfo(log"After merging, there are " + log"${MDC(LogKeys.NUM_PARTITIONS, mergedPartitionKeys.size)} partitions") @@ -595,20 +591,15 @@ case class EnsureRequirements( // otherwise `createKeyGroupedShuffleSpec()` would have returned `None`. val originalKeyedPartitioning = originalPartitioning.collectFirst { case k: KeyedPartitioning => k }.get - val (projectedDataTypes, projectedOriginalPartitionKeys) = - partiallyClusteredSpec.joinKeyPositions.fold( - (originalKeyedPartitioning.expressionDataTypes, - originalKeyedPartitioning.partitionKeys) - )(originalKeyedPartitioning.projectKeys) - val comparableKeyWrapperFactory = InternalRowComparableWrapper - .getInternalRowComparableWrapperFactory(projectedDataTypes) - - val numExpectedPartitions = projectedOriginalPartitionKeys - .groupBy(comparableKeyWrapperFactory) - .view.mapValues(_.size) + val projectedOriginalPartitionKeys = partiallyClusteredSpec.joinKeyPositions + .fold(originalKeyedPartitioning.partitionKeys)( + originalKeyedPartitioning.projectKeys(_)._2) + + val numExpectedPartitions = + projectedOriginalPartitionKeys.groupBy(identity).view.mapValues(_.size) mergedPartitionKeys = mergedPartitionKeys.map { case (key, numParts) => - (key, numExpectedPartitions.getOrElse(comparableKeyWrapperFactory(key), numParts)) + (key, numExpectedPartitions.getOrElse(key, numParts)) } logInfo(log"After applying partially clustered distribution, there are " + @@ -671,7 +662,7 @@ case class EnsureRequirements( private def applyGroupPartitions( plan: SparkPlan, joinKeyPositions: Option[Seq[Int]], - mergedPartitionKeys: Seq[(InternalRow, Int)], + mergedPartitionKeys: Seq[(InternalRowComparableWrapper, Int)], reducers: Option[Seq[Option[Reducer[_, _]]]], applyPartialClustering: Boolean, replicatePartitions: Boolean): SparkPlan = { @@ -735,44 +726,47 @@ case class EnsureRequirements( } /** - * Merge and sort partitions values for SPJ and optionally enable partition filtering. - * Both sides must have - * matching partition expressions. - * @param leftPartitioning left side partition values - * @param rightPartitioning right side partition values - * @param partitionExpression partition expressions + * Merge and sort partitions keys for SPJ and optionally enable partition filtering. + * Both sides must have matching partition expressions. + * @param leftPartitionKeys left side partition keys + * @param rightPartitionKeys right side partition keys * @param joinType join type for optional partition filtering + * @keyOrdering ordering to sort partition keys * @return merged and sorted partition values */ - private def mergePartitions( - leftPartitioning: Seq[InternalRow], - rightPartitioning: Seq[InternalRow], - partitionExpression: Seq[Expression], - joinType: JoinType): Seq[InternalRow] = { - val internalRowComparableWrapperFactory = - InternalRowComparableWrapper.getInternalRowComparableWrapperFactory( - partitionExpression.map(_.dataType)) - + def mergePartitions( + leftPartitionKeys: Seq[InternalRowComparableWrapper], + rightPartitionKeys: Seq[InternalRowComparableWrapper], + joinType: JoinType, + keyOrdering: Ordering[InternalRowComparableWrapper]): Seq[InternalRowComparableWrapper] = { val merged = if (SQLConf.get.getConf(SQLConf.V2_BUCKETING_PARTITION_FILTER_ENABLED)) { joinType match { - case Inner => InternalRowComparableWrapper.mergePartitions( - leftPartitioning, rightPartitioning, partitionExpression, intersect = true) - case LeftOuter => leftPartitioning.map(internalRowComparableWrapperFactory) - case RightOuter => rightPartitioning.map(internalRowComparableWrapperFactory) - case _ => InternalRowComparableWrapper.mergePartitions(leftPartitioning, - rightPartitioning, partitionExpression) + case Inner => mergePartitionKeys(leftPartitionKeys, rightPartitionKeys, intersect = true) + case LeftOuter => leftPartitionKeys + case RightOuter => rightPartitionKeys + case _ => mergePartitionKeys(leftPartitionKeys, rightPartitionKeys) } } else { - InternalRowComparableWrapper.mergePartitions(leftPartitioning, rightPartitioning, - partitionExpression) + mergePartitionKeys(leftPartitionKeys, rightPartitionKeys) } // SPARK-41471: We keep to order of partitions to make sure the order of // partitions is deterministic in different case. - val partitionOrdering: Ordering[InternalRow] = { - RowOrdering.createNaturalAscendingOrdering(partitionExpression.map(_.dataType)) + merged.sorted(keyOrdering) + } + + private def mergePartitionKeys( + leftPartitionKeys: Seq[InternalRowComparableWrapper], + rightPartitionKeys: Seq[InternalRowComparableWrapper], + intersect: Boolean = false) = { + val leftKeySet = mutable.HashSet.from(leftPartitionKeys) + val rightKeySet = mutable.HashSet.from(rightPartitionKeys) + val result = if (intersect) { + leftKeySet.intersect(rightKeySet) + } else { + leftKeySet.union(rightKeySet) } - merged.map(_.row).sorted(partitionOrdering) + result.toSeq } def apply(plan: SparkPlan): SparkPlan = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index eef3f87f3fe03..4ff69ae73a874 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -370,13 +370,12 @@ object ShuffleExchangeExec { ascending = true, samplePointsPerPartitionHint = SQLConf.get.rangeExchangeSampleSizePerPartition) case SinglePartition => new ConstantPartitioner - case k @ KeyedPartitioning(expressions, _) => + case k @ KeyedPartitioning(expressions, _, _) => val keyGroupedPartitioning = k.toGrouped val valueMap = keyGroupedPartitioning.partitionKeys.zipWithIndex.map { - case (partition, index) => (partition.toSeq(expressions.map(_.dataType)), index) + case (key, index) => (key.row.toSeq(expressions.map(_.dataType)), index) }.toMap - new KeyGroupedPartitioner(mutable.Map(valueMap.toSeq: _*), - keyGroupedPartitioning.numPartitions) + new KeyGroupedPartitioner(mutable.Map.from(valueMap), keyGroupedPartitioning.numPartitions) case p => throw SparkException.internalError(s"Exchange not implemented for $newPartitioning") // TODO: Handle BroadcastPartitioning. } @@ -403,7 +402,7 @@ object ShuffleExchangeExec { val projection = UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes) row => projection(row) case SinglePartition => identity - case KeyedPartitioning(expressions, _) => + case KeyedPartitioning(expressions, _, _) => row => bindReferences(expressions, outputAttributes).map(_.eval(row)) case s: ShufflePartitionIdPassThrough => // For ShufflePartitionIdPassThrough, the expression directly evaluates to the partition ID diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala index 37394c3b071c1..d88a610f94b6f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala @@ -51,8 +51,8 @@ abstract class DistributionAndOrderingSuiteBase plan: QueryPlan[T]): Partitioning = partitioning match { case HashPartitioning(exprs, numPartitions) => HashPartitioning(exprs.map(resolveAttrs(_, plan)), numPartitions) - case KeyedPartitioning(expressions, partitionKeys) => - KeyedPartitioning(expressions.map(resolveAttrs(_, plan)), partitionKeys) + case KeyedPartitioning(expressions, partitionKeys, isGrouped) => + KeyedPartitioning(expressions.map(resolveAttrs(_, plan)), partitionKeys, isGrouped) case PartitioningCollection(partitionings) => PartitioningCollection(partitionings.map(resolvePartitioning(_, plan))) case RangePartitioning(ordering, numPartitions) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala index 0e5581d82a099..96469e7a35e7d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala @@ -1128,7 +1128,7 @@ class EnsureRequirementsSuite extends SharedSparkSession { EnsureRequirements.apply(smjExec) match { case ShuffledHashJoinExec(_, _, _, _, _, DummySparkPlan(_, _, left: KeyedPartitioning, _, _), - ShuffleExchangeExec(KeyedPartitioning(attrs, pks), + ShuffleExchangeExec(KeyedPartitioning(attrs, pks, _), DummySparkPlan(_, _, SinglePartition, _, _), _, _), _) => assert(left.expressions == a1 :: Nil) assert(attrs == a1 :: Nil) From fa432914e4981ea15fbe9d3463c2adcc16d05593 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Tue, 3 Mar 2026 15:09:54 +0100 Subject: [PATCH 16/29] fix dead variable --- .../spark/sql/execution/exchange/ShuffleExchangeExec.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 4ff69ae73a874..8849da4f70224 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -376,7 +376,7 @@ object ShuffleExchangeExec { case (key, index) => (key.row.toSeq(expressions.map(_.dataType)), index) }.toMap new KeyGroupedPartitioner(mutable.Map.from(valueMap), keyGroupedPartitioning.numPartitions) - case p => throw SparkException.internalError(s"Exchange not implemented for $newPartitioning") + case _ => throw SparkException.internalError(s"Exchange not implemented for $newPartitioning") // TODO: Handle BroadcastPartitioning. } def getPartitionKeyExtractor(): InternalRow => Any = newPartitioning match { From 1b6bb2988a4ae9fca5f3050bf999be43a7ffd0ca Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Tue, 3 Mar 2026 15:18:34 +0100 Subject: [PATCH 17/29] `GroupPartitionsExec` support columnar execution --- .../datasources/v2/GroupPartitionsExec.scala | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala index d274f3fb567d5..0d78acde2b33b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper import org.apache.spark.sql.connector.catalog.functions.Reducer import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.vectorized.ColumnarBatch /** * Physical operator that groups input partitions by their partition keys. @@ -155,6 +156,17 @@ case class GroupPartitionsExec( } } + override def supportsColumnar: Boolean = child.supportsColumnar + + override protected def doExecuteColumnar(): RDD[ColumnarBatch] = { + val partitionCoalescer = new GroupedPartitionCoalescer(groupedPartitions.map(_._2)) + if (groupedPartitions.isEmpty) { + sparkContext.emptyRDD + } else { + new CoalescedRDD(child.executeColumnar(), groupedPartitions.size, Some(partitionCoalescer)) + } + } + override def output: Seq[Attribute] = child.output override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = From 94054cb7c9029c0aa5120ffe5b7283aeebb725b2 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Tue, 3 Mar 2026 17:22:03 +0100 Subject: [PATCH 18/29] fix `KeyedPartitioning.isGrouped` calculations --- .../apache/spark/sql/catalyst/plans/physical/partitioning.scala | 2 +- .../sql/execution/datasources/v2/GroupPartitionsExec.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 475e6a52f3ce7..2b48bcdaafda8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -492,7 +492,7 @@ case class KeyedPartitioning( val projectedKeys = projectKeys(joinKeyPositions)._2 val distinctProjectedKeys = projectedKeys.distinct val projectedPartitioning = - copy(expressions = projectedExpressions, partitionKeys = distinctProjectedKeys) + KeyedPartitioning(projectedExpressions, distinctProjectedKeys, isGrouped = true) result.copy(partitioning = projectedPartitioning, joinKeyPositions = Some(joinKeyPositions)) } else { result diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala index 0d78acde2b33b..8fbdb1f9f7af9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala @@ -76,7 +76,7 @@ case class GroupPartitionsExec( p.transform { case k: KeyedPartitioning => val projectedExpressions = joinKeyPositions.fold(k.expressions)(_.map(k.expressions)) - k.copy(expressions = projectedExpressions, partitionKeys = groupedPartitions.map(_._1)) + KeyedPartitioning(projectedExpressions, groupedPartitions.map(_._1), isGrouped = true) }.asInstanceOf[Partitioning] case o => o } From b43017d1736b20f20c9d27135ec2259d1fedc2ab Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Tue, 3 Mar 2026 17:44:45 +0100 Subject: [PATCH 19/29] fix `applyGroupPartitions` documentation --- .../spark/sql/execution/exchange/EnsureRequirements.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 4cfed4f629c09..8c01e20f1f81c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -609,7 +609,7 @@ case class EnsureRequirements( } } - // Now we need to push-down the common partition information to the scan in each child + // Now we need to push-down the common partition information to the `GroupPartitionsExec`s. newLeft = applyGroupPartitions(left, leftSpec.joinKeyPositions, mergedPartitionKeys, leftReducers, applyPartialClustering, replicateLeftSide) newRight = applyGroupPartitions(right, rightSpec.joinKeyPositions, mergedPartitionKeys, @@ -657,7 +657,11 @@ case class EnsureRequirements( } /** - * Applies or updates GroupPartitionsExec with the given parameters. + * Applies or updates `GroupPartitionsExec` with the given parameters. + * + * `GroupPartitionsExec` can be either the given plan node (child of the join inserted by + * `EnsureRequirement`) if the original child didn't satisfy the distribution requirement; or we + * can create a new one specifically for this join. */ private def applyGroupPartitions( plan: SparkPlan, From 490f78279757187df8a83d392fc649801c5d02a2 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Tue, 3 Mar 2026 19:02:14 +0100 Subject: [PATCH 20/29] address review findings --- .../datasources/v2/GroupPartitionsExec.scala | 16 ++++++++-------- .../execution/exchange/EnsureRequirements.scala | 15 ++++++++++----- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala index 8fbdb1f9f7af9..cc40dfc49823e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala @@ -42,8 +42,8 @@ import org.apache.spark.sql.vectorized.ColumnarBatch * @param joinKeyPositions Optional projection to select a subset of the partitioning key * for join compatibility (e.g., when join keys are a subset of * partition keys) - * @param commonPartitionKeys Optional sequence of expected partition key values and their - * split counts, used for partially clustered data + * @param expectedPartitionKeys Optional sequence of expected partition key values and their + * split counts * @param reducers Optional reducers to apply to partition keys for grouping compatibility * @param applyPartialClustering Whether to apply partial clustering for skewed data * @param replicatePartitions Whether to replicate partitions across multiple keys @@ -51,7 +51,7 @@ import org.apache.spark.sql.vectorized.ColumnarBatch case class GroupPartitionsExec( child: SparkPlan, @transient joinKeyPositions: Option[Seq[Int]] = None, - @transient commonPartitionKeys: Option[Seq[(InternalRowComparableWrapper, Int)]] = None, + @transient expectedPartitionKeys: Option[Seq[(InternalRowComparableWrapper, Int)]] = None, @transient reducers: Option[Seq[Option[Reducer[_, _]]]] = None, @transient applyPartialClustering: Boolean = false, @transient replicatePartitions: Boolean = false @@ -83,10 +83,10 @@ case class GroupPartitionsExec( } /** - * Distributes partitions based on `commonPartitionKeys` and clustering mode. + * Aligns partitions based on `expectedPartitionKeys` and clustering mode. */ - private def distributeByCommonKeys(keyMap: Map[InternalRowComparableWrapper, Seq[Int]]) = { - commonPartitionKeys.get.flatMap { case (key, numSplits) => + private def alignToExpectedKeys(keyMap: Map[InternalRowComparableWrapper, Seq[Int]]) = { + expectedPartitionKeys.get.flatMap { case (key, numSplits) => val splits = keyMap.getOrElse(key, Seq.empty) if (applyPartialClustering && !replicatePartitions) { // Distribute splits across expected partitions, padding with empty sequences @@ -140,8 +140,8 @@ case class GroupPartitionsExec( val keyToPartitionIndices = reducedKeys.zipWithIndex.groupMap(_._1)(_._2) - if (commonPartitionKeys.isDefined) { - distributeByCommonKeys(keyToPartitionIndices) + if (expectedPartitionKeys.isDefined) { + alignToExpectedKeys(keyToPartitionIndices) } else { groupAndSortByKeys(keyToPartitionIndices, projectedDataTypes) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 8c01e20f1f81c..20f0fbbbd1553 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -322,10 +322,10 @@ case class EnsureRequirements( val attrs = k.expressions.flatMap(_.collectLeaves()).map(_.asInstanceOf[Attribute]) val keyRowOrdering = RowOrdering.create(distribution.ordering, attrs) val keyOrdering = keyRowOrdering.on((t: InternalRowComparableWrapper) => t.row) - // Sort 'commonPartitionKeys' and use this mechanism to ensure BatchScan's output + // Sort 'expectedPartitionKeys' and use this mechanism to ensure BatchScan's output // partitions are ordered val sorted = k.partitionKeys.sorted(keyOrdering) - GroupPartitionsExec(plan, commonPartitionKeys = Some(sorted.map((_, 1)))) + GroupPartitionsExec(plan, expectedPartitionKeys = Some(sorted.map((_, 1)))) case _ => plan } @@ -672,12 +672,14 @@ case class EnsureRequirements( replicatePartitions: Boolean): SparkPlan = { plan match { case g: GroupPartitionsExec => - g.copy( + val newGroupPartitions = g.copy( joinKeyPositions = joinKeyPositions, - commonPartitionKeys = Some(mergedPartitionKeys), + expectedPartitionKeys = Some(mergedPartitionKeys), reducers = reducers, applyPartialClustering = applyPartialClustering, replicatePartitions = replicatePartitions) + newGroupPartitions.copyTagsFrom(g) + newGroupPartitions case _ => GroupPartitionsExec(plan, joinKeyPositions, Some(mergedPartitionKeys), reducers, applyPartialClustering, replicatePartitions) @@ -689,7 +691,10 @@ case class EnsureRequirements( */ private def withJoinKeyPositions(plan: SparkPlan, positions: Seq[Int]): SparkPlan = { plan match { - case g: GroupPartitionsExec => g.copy(joinKeyPositions = Some(positions)) + case g: GroupPartitionsExec => + val newGroupPartitions = g.copy(joinKeyPositions = Some(positions)) + newGroupPartitions.copyTagsFrom(g) + newGroupPartitions case _ => GroupPartitionsExec(plan, joinKeyPositions = Some(positions)) } } From 2a7e0b347f58e967335fd1759cef36d110dadb41 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Tue, 3 Mar 2026 20:09:45 +0100 Subject: [PATCH 21/29] Update sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala Co-authored-by: Liang-Chi Hsieh --- .../spark/sql/execution/exchange/EnsureRequirements.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 20f0fbbbd1553..011cbaf24ab40 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -477,7 +477,7 @@ case class EnsureRequirements( // sources. if (isCompatible) { val leftPartKeys = leftPartitioning.partitionKeys - val rightPartKeys = leftPartitioning.partitionKeys + val rightPartKeys = rightPartitioning.partitionKeys val numLeftPartKeys = MDC(LogKeys.NUM_LEFT_PARTITION_VALUES, leftPartKeys.size) val numRightPartKeys = MDC(LogKeys.NUM_RIGHT_PARTITION_VALUES, rightPartKeys.size) From 200fdc076e20e0ff27e7ffe3072093b32b868ae8 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Wed, 4 Mar 2026 09:20:51 +0100 Subject: [PATCH 22/29] minor test fix --- .../spark/sql/execution/exchange/EnsureRequirementsSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala index 96469e7a35e7d..7512cbe7f90b1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala @@ -1132,7 +1132,7 @@ class EnsureRequirementsSuite extends SharedSparkSession { DummySparkPlan(_, _, SinglePartition, _, _), _, _), _) => assert(left.expressions == a1 :: Nil) assert(attrs == a1 :: Nil) - assert(partitionKeys == pks) + assert(partitionKeys == pks.map(_.row)) case other => fail(other.toString) } } From 4a904ad98a3870e99136cf88a797903feb6496b8 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Wed, 4 Mar 2026 10:22:27 +0100 Subject: [PATCH 23/29] empty partitioned table test --- .../v2/DataSourceV2ScanExecBase.scala | 2 +- .../KeyGroupedPartitioningSuite.scala | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala index c4a59df5e1cb9..877e65341c1c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala @@ -92,7 +92,7 @@ trait DataSourceV2ScanExecBase extends LeafExecNode { override def outputPartitioning: physical.Partitioning = { keyGroupedPartitioning match { case Some(exprs) if conf.v2BucketingEnabled && KeyedPartitioning.supportsExpressions(exprs) && - inputPartitions.length > 0 && inputPartitions.forall(_.isInstanceOf[HasPartitionKey]) => + inputPartitions.nonEmpty && inputPartitions.forall(_.isInstanceOf[HasPartitionKey]) => val dataTypes = exprs.map(_.dataType) val rowOrdering = RowOrdering.createNaturalAscendingOrdering(dataTypes) val partitionKeys = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index b9e6fc145b8e2..701128243fc77 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -3077,4 +3077,23 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { assert(groupPartitions.map(_.outputPartitioning.numPartitions) === Seq(6, 5, 5, 6)) } } + + test("SPARK-55092: Empty partitioned table") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { + val items_partitions = Array(identity("id")) + createTable(items, itemsColumns, items_partitions) + + val purchases_partitions = Array(identity("item_id")) + createTable(purchases, purchasesColumns, purchases_partitions) + + val df = createJoinTestDF(Seq("id" -> "item_id")) + checkAnswer(df, Seq.empty) + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.size === 2, "empty tables should not report KeyedPartitioning") + + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.isEmpty, "empty tables should not report KeyedPartitioning") + } + } } From 5a4ecd1815b2f5af5700d3a1d65107d2ac67b928 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Wed, 4 Mar 2026 11:18:46 +0100 Subject: [PATCH 24/29] additional checks for runtime filter tests --- .../datasources/v2/BatchScanExec.scala | 3 +- .../KeyGroupedPartitioningSuite.scala | 69 ++++++++++++++++++- 2 files changed, 69 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index a2a85a1b50e55..bdecea2d4d085 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -59,7 +59,8 @@ case class BatchScanExec( @transient override lazy val inputPartitions: Seq[InputPartition] = batch.planInputPartitions().toImmutableArraySeq - @transient private lazy val filteredPartitions: Seq[Option[InputPartition]] = { + // Visible for testing + @transient private[sql] lazy val filteredPartitions: Seq[Option[InputPartition]] = { val dataSourceFilters = runtimeFilters.flatMap { case DynamicPruningExpression(e) => DataSourceV2Strategy.translateRuntimeFilterV2(e) case _ => None diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 701128243fc77..e168eae84907b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -336,6 +336,35 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { collect(plan) { case s: BatchScanExec => s } } + /** + * Helper method to verify that filteredPartitions contains the expected number of + * Some and None values. This is used to verify that dynamic partition filtering + * properly fills filtered-out partitions with None. + */ + private def assertFilteredPartitions( + scans: Seq[BatchScanExec], + expectedTotalPartitions: Seq[Int], + expectedFilteredOutPartitions: Seq[Int]): Unit = { + assert(scans.size === expectedTotalPartitions.size, + s"Expected ${expectedTotalPartitions.size} scans but got ${scans.size}") + + scans.zip(expectedTotalPartitions).zip(expectedFilteredOutPartitions).foreach { + case ((scan, expectedTotal), expectedFiltered) => + val filtered = scan.filteredPartitions + assert(filtered.size === expectedTotal, + s"Expected $expectedTotal total partitions but got ${filtered.size}") + + val noneCount = filtered.count(_.isEmpty) + assert(noneCount === expectedFiltered, + s"Expected $expectedFiltered None values but got $noneCount") + + val someCount = filtered.count(_.isDefined) + assert(someCount === (expectedTotal - expectedFiltered), + s"Expected ${expectedTotal - expectedFiltered} Some values but got $someCount") + } + } + + test("partitioned join: exact distribution (same number of buckets) from both sides") { val customers_partitions = Array(bucket(4, "customer_id")) val orders_partitions = Array(bucket(4, "customer_id")) @@ -1212,15 +1241,40 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { // with empty partitions and the job should still succeed var df = sql(s"SELECT sum(p.price) from testcat.ns.$items i, testcat.ns.$purchases p " + "WHERE i.id = p.item_id AND i.price > 40.0") + + var shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, "should not add shuffle for both sides of the join") + var scans = collectScans(df.queryExecution.executedPlan) + assert(scans.forall(_.outputPartitioning.numPartitions === 5)) + var groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.forall(_.outputPartitioning.numPartitions === 3)) + checkAnswer(df, Seq(Row(131))) + // Verify that filteredPartitions contains None for filtered-out partitions. + // After DPF with filter i.price > 40.0, only id=1 survives on items side. + // The purchases side should be pruned to only item_id=1. + // purchases: 5 total partitions (3 for id=1, 1 for id=2, 1 for id=3) + // After DPF: 3 Some (id=1), 2 None (id=2, id=3) + assertFilteredPartitions(scans, Seq(5, 5), Seq(0, 2)) + // dynamic filtering doesn't change partitioning so storage-partitioned join should kick // in df = sql(s"SELECT sum(p.price) from testcat.ns.$items i, testcat.ns.$purchases p " + "WHERE i.id = p.item_id AND i.price >= 10.0") - val shuffles = collectShuffles(df.queryExecution.executedPlan) + + shuffles = collectShuffles(df.queryExecution.executedPlan) assert(shuffles.isEmpty, "should not add shuffle for both sides of the join") + scans = collectScans(df.queryExecution.executedPlan) + assert(scans.forall(_.outputPartitioning.numPartitions === 5)) + groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.forall(_.outputPartitioning.numPartitions === 3)) + checkAnswer(df, Seq(Row(303.5))) + + // With filter i.price >= 10.0, all ids (1, 2, 3) survive, + // so no partitions should be filtered out + assertFilteredPartitions(scans, Seq(5, 5), Seq(0, 0)) } } } @@ -1275,14 +1329,25 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { checkAnswer(df, Seq(Row(213.5))) val shuffles = collectShuffles(df.queryExecution.executedPlan) + val scans = collectScans(df.queryExecution.executedPlan) + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(scans.map(_.outputPartitioning.numPartitions) === Seq(14, 6)) if (pushDownValues) { assert(shuffles.isEmpty, "should not add shuffle for both sides of the join") - val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) assert(groupPartitions.forall(_.outputPartitioning.numPartitions === expected)) } else { assert(shuffles.nonEmpty, "should contain shuffle when not pushing down partition values") + assert(groupPartitions.isEmpty) } + + // Verify filteredPartitions for DPF. + // After filter p.price < 45.0, purchases has item_ids {1, 2, 3, 5}. + // Items side should be pruned to these ids. Since items has {1, 2, 3, 4}, + // id=4 should be filtered out. + // purchases: 14 total, all kept (0 None) - no DPF on probe side + // items: 6 total, id=4 filtered (1 None) + assertFilteredPartitions(scans, Seq(14, 6), Seq(0, 1)) } } } From b3d34efa0560fd4617f3969d91a40bd535e7a545 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Wed, 4 Mar 2026 19:51:09 +0100 Subject: [PATCH 25/29] minor improvements --- .../sql/execution/datasources/v2/GroupPartitionsExec.scala | 4 ++-- .../spark/sql/execution/exchange/ShuffleExchangeExec.scala | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala index cc40dfc49823e..997a723cda1c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala @@ -148,10 +148,10 @@ case class GroupPartitionsExec( } override protected def doExecute(): RDD[InternalRow] = { - val partitionCoalescer = new GroupedPartitionCoalescer(groupedPartitions.map(_._2)) if (groupedPartitions.isEmpty) { sparkContext.emptyRDD } else { + val partitionCoalescer = new GroupedPartitionCoalescer(groupedPartitions.map(_._2)) new CoalescedRDD(child.execute(), groupedPartitions.size, Some(partitionCoalescer)) } } @@ -159,10 +159,10 @@ case class GroupPartitionsExec( override def supportsColumnar: Boolean = child.supportsColumnar override protected def doExecuteColumnar(): RDD[ColumnarBatch] = { - val partitionCoalescer = new GroupedPartitionCoalescer(groupedPartitions.map(_._2)) if (groupedPartitions.isEmpty) { sparkContext.emptyRDD } else { + val partitionCoalescer = new GroupedPartitionCoalescer(groupedPartitions.map(_._2)) new CoalescedRDD(child.executeColumnar(), groupedPartitions.size, Some(partitionCoalescer)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 8849da4f70224..7dcbf3779b93d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -370,10 +370,10 @@ object ShuffleExchangeExec { ascending = true, samplePointsPerPartitionHint = SQLConf.get.rangeExchangeSampleSizePerPartition) case SinglePartition => new ConstantPartitioner - case k @ KeyedPartitioning(expressions, _, _) => + case k: KeyedPartitioning => val keyGroupedPartitioning = k.toGrouped val valueMap = keyGroupedPartitioning.partitionKeys.zipWithIndex.map { - case (key, index) => (key.row.toSeq(expressions.map(_.dataType)), index) + case (key, index) => (key.row.toSeq(keyGroupedPartitioning.expressionDataTypes), index) }.toMap new KeyGroupedPartitioner(mutable.Map.from(valueMap), keyGroupedPartitioning.numPartitions) case _ => throw SparkException.internalError(s"Exchange not implemented for $newPartitioning") From 326915bda8917f0b9f81a9fe9e1e8426edbfe95d Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Thu, 5 Mar 2026 13:51:58 +0100 Subject: [PATCH 26/29] refactor `GroupedPartitions`, document what `KeyedPartitioning.satisfies()` means, partitioning of `KeyGroupedShuffleSpec` don't need to be grouped --- .../plans/physical/partitioning.scala | 75 +++++--- .../exchange/EnsureRequirements.scala | 161 ++++++++++-------- 2 files changed, 143 insertions(+), 93 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 2b48bcdaafda8..99ef23e54c74b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -382,6 +382,27 @@ case class CoalescedHashPartitioning(from: HashPartitioning, partitions: Seq[Coa * each partition having a distinct key. This occurs when: (1) a data source natively produces * unique partition keys, or (2) `GroupPartitionsExec` coalesces partitions with duplicate keys. * + * == Distribution Satisfaction and Grouping == + * The `satisfies()` method returns true if this partitioning can satisfy a distribution, + * regardless of whether the partitioning is actually grouped. The method delegates to: + * - `nonGroupedSatisfies()`: Returns true for basic distributions (UnspecifiedDistribution, + * AllTuples when single partition) + * - `groupedSatisfies()`: Returns true for distributions requiring grouped partitioning + * (ClusteredDistribution, OrderedDistribution) + * + * If `satisfies()` returns true but `isGrouped == false`, the partitioning does NOT actually + * satisfy the distribution yet. The `EnsureRequirements` rule must insert `GroupPartitionsExec` to + * coalesce duplicate partition keys before the distribution requirement is truly satisfied. + * + * For example, an ungrouped KeyedPartitioning with keys `[1, 2, 2, 3]` will return + * `satisfies(ClusteredDistribution(...)) == true` because it can satisfy the distribution after + * grouping. However, `EnsureRequirements` must add `GroupPartitionsExec` to produce grouped keys + * `[1, 2, 3]` before the distribution is actually satisfied. + * + * Similarly, for `OrderedDistribution`, even if `satisfies()` returns true, `GroupPartitionsExec` + * must be added to both group the partitions AND sort the partition keys according to the + * ordering requirement. + * * == Example == * Consider a data source with partition transform `[years(ts_col)]` and 4 input splits: * @@ -391,6 +412,7 @@ case class CoalescedHashPartitioning(from: HashPartitioning, partitions: Seq[Coa * partitionKeys: [0, 1, 2, 2] // partitions 2 and 3 have the same key * numPartitions: 4 * isGrouped: false + * satisfies(ClusteredDistribution(...)) == true // CAN satisfy after grouping * }}} * * '''After GroupPartitionsExec''' (grouped): @@ -399,6 +421,7 @@ case class CoalescedHashPartitioning(from: HashPartitioning, partitions: Seq[Coa * partitionKeys: [0, 1, 2] // duplicates removed, partitions coalesced * numPartitions: 3 * isGrouped: true + * satisfies(ClusteredDistribution(...)) == true // ACTUALLY satisfies now * }}} * * @param expressions Partition transform expressions (e.g., `years(col)`, `bucket(10, col)`). @@ -451,33 +474,37 @@ case class KeyedPartitioning( KeyedPartitioning.reduceKeys(partitionKeys, expressionDataTypes, reducers).distinct override def satisfies0(required: Distribution): Boolean = { - super.satisfies0(required) || isGrouped && { - required match { - case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) => - if (requireAllClusterKeys) { - // Checks whether this partitioning is partitioned on exactly same clustering keys of - // `ClusteredDistribution`. - c.areAllClusterKeysMatched(expressions) + nonGroupedSatisfies(required) || groupedSatisfies(required) + } + + def nonGroupedSatisfies(required: Distribution): Boolean = super.satisfies0(required) + + def groupedSatisfies(required: Distribution): Boolean = { + required match { + case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) => + if (requireAllClusterKeys) { + // Checks whether this partitioning is partitioned on exactly same clustering keys of + // `ClusteredDistribution`. + c.areAllClusterKeysMatched(expressions) + } else { + // We'll need to find leaf attributes from the partition expressions first. + val attributes = expressions.flatMap(_.collectLeaves()) + + if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) { + // check that join keys (required clustering keys) + // overlap with partition keys (KeyedPartitioning attributes) + requiredClustering.exists(x => attributes.exists(_.semanticEquals(x))) && + expressions.forall(_.collectLeaves().size == 1) } else { - // We'll need to find leaf attributes from the partition expressions first. - val attributes = expressions.flatMap(_.collectLeaves()) - - if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) { - // check that join keys (required clustering keys) - // overlap with partition keys (KeyedPartitioning attributes) - requiredClustering.exists(x => attributes.exists(_.semanticEquals(x))) && - expressions.forall(_.collectLeaves().size == 1) - } else { - attributes.forall(x => requiredClustering.exists(_.semanticEquals(x))) - } + attributes.forall(x => requiredClustering.exists(_.semanticEquals(x))) } + } - case o @ OrderedDistribution(_) if SQLConf.get.v2BucketingAllowSorting => - o.areAllClusterKeysMatched(expressions) + case o @ OrderedDistribution(_) if SQLConf.get.v2BucketingAllowSorting => + o.areAllClusterKeysMatched(expressions) - case _ => - false - } + case _ => + false } } @@ -932,8 +959,6 @@ case class KeyGroupedShuffleSpec( distribution: ClusteredDistribution, joinKeyPositions: Option[Seq[Int]] = None) extends ShuffleSpec { - assert(partitioning.isGrouped) - /** * A sequence where each element is a set of positions of the partition expression to the cluster * keys. For instance, if cluster keys are [a, b, b] and partition expressions are diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 011cbaf24ab40..6592b6ba60c2f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -62,32 +62,63 @@ case class EnsureRequirements( assert(requiredChildOrderings.length == originalChildren.length) // Ensure that the operator's children satisfy their output distribution requirements. var children = originalChildren.zip(requiredChildDistributions).map { - case (child, distribution) if child.outputPartitioning.satisfies(distribution) => - distribution match { - case o: OrderedDistribution => - ensureOrdering(child, child.outputPartitioning, o) - case _ => child - } - case (c @ GroupedPartitions(p), distribution) if p.satisfies(distribution) => - distribution match { - case o: OrderedDistribution => - ensureOrdering(c, p, o) - case _ => GroupPartitionsExec(c) - } case (child, BroadcastDistribution(mode)) => BroadcastExchangeExec(mode, child) + case (child, distribution) => - val numPartitions = distribution.requiredNumPartitions - .getOrElse(conf.numShufflePartitions) - distribution match { - case _: StatefulOpClusteredDistribution => - ShuffleExchangeExec( - distribution.createPartitioning(numPartitions), child, - REQUIRED_BY_STATEFUL_OPERATOR) - - case _ => - ShuffleExchangeExec( - distribution.createPartitioning(numPartitions), child, shuffleOrigin) + // Split child's partitioning into categories + val (other, grouped, nonGrouped) = splitKeyedPartitionings(child.outputPartitioning) + + // If non-KeyedPartitioning already satisfies, no changes needed + if (other.exists(_.satisfies(distribution))) { + child + } else { + // Check KeyedPartitioning satisfaction conditions + val groupedSatisfies = grouped.exists(_.satisfies(distribution)) + val nonGroupedSatisfiesAsIs = nonGrouped.exists(_.nonGroupedSatisfies(distribution)) + val nonGroupedSatisfiesWhenGrouped = nonGrouped.exists(_.groupedSatisfies(distribution)) + + // Check if any KeyedPartitioning satisfies the distribution + if (groupedSatisfies || nonGroupedSatisfiesAsIs || nonGroupedSatisfiesWhenGrouped) { + distribution match { + case o: OrderedDistribution => + // OrderedDistribution requires grouped KeyedPartitioning with sorted keys. + // Find any KeyedPartitioning that satisfies via groupedSatisfies. + val satisfyingKeyedPartitioning = + (grouped ++ nonGrouped).find(_.groupedSatisfies(distribution)).get + val attrs = satisfyingKeyedPartitioning.expressions.flatMap(_.collectLeaves()) + .map(_.asInstanceOf[Attribute]) + val keyRowOrdering = RowOrdering.create(o.ordering, attrs) + val keyOrdering = keyRowOrdering.on((t: InternalRowComparableWrapper) => t.row) + val sorted = satisfyingKeyedPartitioning.partitionKeys.sorted(keyOrdering) + GroupPartitionsExec(child, expectedPartitionKeys = Some(sorted.map((_, 1)))) + + case _ if groupedSatisfies => + // Grouped KeyedPartitioning already satisfies + child + + case _ if nonGroupedSatisfiesAsIs => + // Non-grouped KeyedPartitioning satisfies without grouping + child + + case _ => + // Non-grouped KeyedPartitioning satisfies only after grouping + GroupPartitionsExec(child) + } + } else { + // No partitioning satisfies - need shuffle + val numPartitions = distribution.requiredNumPartitions + .getOrElse(conf.numShufflePartitions) + distribution match { + case _: StatefulOpClusteredDistribution => + ShuffleExchangeExec( + distribution.createPartitioning(numPartitions), child, + REQUIRED_BY_STATEFUL_OPERATOR) + case _ => + ShuffleExchangeExec( + distribution.createPartitioning(numPartitions), child, shuffleOrigin) + } + } } } @@ -309,31 +340,6 @@ case class EnsureRequirements( } } - private def ensureOrdering( - plan: SparkPlan, - partitioning: Partitioning, - distribution: OrderedDistribution) = { - partitioning match { - case p: Partitioning with Expression => - val satisfyingKeyedPartitioning = - p.collectFirst { case k: KeyedPartitioning if k.satisfies(distribution) => k } - satisfyingKeyedPartitioning match { - case Some(k) => - val attrs = k.expressions.flatMap(_.collectLeaves()).map(_.asInstanceOf[Attribute]) - val keyRowOrdering = RowOrdering.create(distribution.ordering, attrs) - val keyOrdering = keyRowOrdering.on((t: InternalRowComparableWrapper) => t.row) - // Sort 'expectedPartitionKeys' and use this mechanism to ensure BatchScan's output - // partitions are ordered - val sorted = k.partitionKeys.sorted(keyOrdering) - GroupPartitionsExec(plan, expectedPartitionKeys = Some(sorted.map((_, 1)))) - - case _ => plan - } - - case _ => plan - } - } - /** * Recursively reorders the join keys based on partitioning. It starts reordering the * join keys to match HashPartitioning on either side, followed by PartitioningCollection. @@ -778,6 +784,45 @@ case class EnsureRequirements( result.toSeq } + /** + * Splits a partitioning into three categories: + * 1. Non-KeyedPartitioning (HashPartitioning, RangePartitioning, etc.) + * 2. Grouped KeyedPartitioning (isGrouped = true) + * 3. Non-grouped KeyedPartitioning (isGrouped = false) + * + * @param partitioning The partitioning to split + * @return A tuple of (other, grouped, nonGrouped) where: + * - other: Option containing non-KeyedPartitioning(s) + * - grouped: Seq of grouped KeyedPartitionings + * - nonGrouped: Seq of non-grouped KeyedPartitionings + */ + private def splitKeyedPartitionings(partitioning: Partitioning) = { + val otherPartitionings = ArrayBuffer.empty[Partitioning] + val groupedKeyedPartitionings = ArrayBuffer.empty[KeyedPartitioning] + val nonGroupedKeyedPartitionings = ArrayBuffer.empty[KeyedPartitioning] + + def split(p: Partitioning): Unit = p match { + case c: PartitioningCollection => c.partitionings.foreach(split) + case k: KeyedPartitioning => + if (k.isGrouped) { + groupedKeyedPartitionings += k + } else { + nonGroupedKeyedPartitionings += k + } + case o => otherPartitionings += o + } + + split(partitioning) + + val other = otherPartitionings.length match { + case 0 => None + case 1 => Some(otherPartitionings.head) + case _ => Some(PartitioningCollection(otherPartitionings.toSeq)) + } + + (other, groupedKeyedPartitionings.toSeq, nonGroupedKeyedPartitionings.toSeq) + } + def apply(plan: SparkPlan): SparkPlan = { val newPlan = plan.transformUp { case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, shuffleOrigin, _) @@ -827,23 +872,3 @@ case class EnsureRequirements( } } } - -object GroupedPartitions { - def unapply(plan: SparkPlan): Option[Partitioning with Expression] = { - groupPartitions(plan.outputPartitioning) - } - - private def groupPartitions(p: Partitioning): Option[Partitioning with Expression] = { - p match { - case c: PartitioningCollection => - c.partitionings.flatMap(groupPartitions) match { - case Nil => None - case p :: Nil => Some(p) - case ps => Some(PartitioningCollection(ps)) - } - case k: KeyedPartitioning => Some(k.toGrouped) - case _ => None - } - } -} - From 8526dc10e24f99ef7292c707d540574bccff9626 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Thu, 5 Mar 2026 15:25:48 +0100 Subject: [PATCH 27/29] fix `KeyedPartitioning.isGrouped` when `expectedPartitionKeys` is set --- .../datasources/v2/GroupPartitionsExec.scala | 23 ++++++++++++++----- 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala index 997a723cda1c8..9910c4eb788cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala @@ -76,7 +76,8 @@ case class GroupPartitionsExec( p.transform { case k: KeyedPartitioning => val projectedExpressions = joinKeyPositions.fold(k.expressions)(_.map(k.expressions)) - KeyedPartitioning(projectedExpressions, groupedPartitions.map(_._1), isGrouped = true) + KeyedPartitioning(projectedExpressions, groupedPartitions.map(_._1), + isGrouped = isGrouped) }.asInstanceOf[Partitioning] case o => o } @@ -86,7 +87,9 @@ case class GroupPartitionsExec( * Aligns partitions based on `expectedPartitionKeys` and clustering mode. */ private def alignToExpectedKeys(keyMap: Map[InternalRowComparableWrapper, Seq[Int]]) = { - expectedPartitionKeys.get.flatMap { case (key, numSplits) => + var isGrouped = true + val alignedPartitions = expectedPartitionKeys.get.flatMap { case (key, numSplits) => + if (numSplits > 1) isGrouped = false val splits = keyMap.getOrElse(key, Seq.empty) if (applyPartialClustering && !replicatePartitions) { // Distribute splits across expected partitions, padding with empty sequences @@ -97,6 +100,7 @@ case class GroupPartitionsExec( Seq.fill(numSplits)((key, splits)) } } + (alignedPartitions, isGrouped) } /** @@ -116,10 +120,12 @@ case class GroupPartitionsExec( * 3. Grouping input partition indices by their (possibly projected/reduced) keys * 4. Sorting or distributing based on whether partial clustering is enabled * - * Returns a sequence of (partitionKey, inputPartitionIndices) pairs representing - * how input partitions should be grouped together. + * Returns a tuple of (partitions, isGrouped) where: + * - partitions: sequence of (partitionKey, inputPartitionIndices) pairs representing + * how input partitions should be grouped together + * - isGrouped: whether the output partitioning is grouped (no duplicates in partition keys) */ - @transient lazy val groupedPartitions = { + @transient private lazy val groupedPartitionsTuple = { // There must be a `KeyedPartitioning` in child's output partitioning as a // `GroupPartitionsExec` node is added to a plan only in that case. val keyedPartitioning = child.outputPartitioning @@ -143,10 +149,15 @@ case class GroupPartitionsExec( if (expectedPartitionKeys.isDefined) { alignToExpectedKeys(keyToPartitionIndices) } else { - groupAndSortByKeys(keyToPartitionIndices, projectedDataTypes) + (groupAndSortByKeys(keyToPartitionIndices, projectedDataTypes), true) } } + @transient lazy val groupedPartitions: Seq[(InternalRowComparableWrapper, Seq[Int])] = + groupedPartitionsTuple._1 + + @transient lazy val isGrouped: Boolean = groupedPartitionsTuple._2 + override protected def doExecute(): RDD[InternalRow] = { if (groupedPartitions.isEmpty) { sparkContext.emptyRDD From 32b563fcd4f93652c227ce6ec9c6d31ccf9aee0b Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Thu, 5 Mar 2026 20:35:01 +0100 Subject: [PATCH 28/29] add empty groupPartitions test case, fix test spark tags, cleanup SPARK-55092 tests --- .../KeyGroupedPartitioningSuite.scala | 96 ++++++++++--------- 1 file changed, 51 insertions(+), 45 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index e168eae84907b..61384bf9f1fca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -2985,7 +2985,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } } - test("SPARK-55092: Don't group partitions for join when not needed") { + test("SPARK-55092: Scans should not group partitions") { val items_partitions = Array(identity("id")) createTable(items, itemsColumns, items_partitions) @@ -2996,57 +2996,35 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val purchases_partitions = Array(years("time")) createTable(purchases, purchasesColumns, purchases_partitions) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + "(1, 42.0, cast('2020-01-01' as timestamp)), " + "(3, 19.5, cast('2020-02-01' as timestamp))") - withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true") { - val df = createJoinTestDF(Seq("id" -> "item_id"), extraColumns = Seq("year(p.time)")) - - val shuffles = collectShuffles(df.queryExecution.executedPlan) - assert(shuffles.size == 1, "only shuffle one side not report partitioning") - - val scans = collectScans(df.queryExecution.executedPlan) - assert(scans(0).inputRDD.partitions.length === 3, - "items scan should not group") - assert(scans(1).inputRDD.partitions.length === 2, - "purchases scan should not group") - - checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0, 2020))) - } + val df = sql(s"SELECT * FROM testcat.ns.$items") + val scans = collectScans(df.queryExecution.executedPlan) + assert(scans(0).inputRDD.partitions.length === 3, + "items scan should not group partitions") - withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "false") { - val df = createJoinTestDF(Seq("id" -> "item_id"), extraColumns = Seq("year(p.time)")) + Seq((true, 1), (false, 2)).foreach { case (bucketingShuffle, expectedShuffleCount) => + withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> bucketingShuffle.toString) { + val df = createJoinTestDF(Seq("id" -> "item_id")) - val shuffles = collectShuffles(df.queryExecution.executedPlan) - assert(shuffles.size == 2, "only shuffle one side not report partitioning") + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.size == expectedShuffleCount) - val scans = collectScans(df.queryExecution.executedPlan) - assert(scans(0).inputRDD.partitions.length === 3, - "items scan should not group as it is shuffled") - assert(scans(1).inputRDD.partitions.length === 2, - "purchases scan should not group as it is shuffled") + val scans = collectScans(df.queryExecution.executedPlan) + assert(scans(0).inputRDD.partitions.length === 3, + "items scan should not group partitions") + assert(scans(1).inputRDD.partitions.length === 2, + "purchases scan should not group partitions") - checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0, 2020))) + checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0))) + } } } - test("SPARK-55092: Don't group partitions for aggregate when not needed") { - val items_partitions = Array(identity("id")) - createTable(items, itemsColumns, items_partitions) - - sql(s"INSERT INTO testcat.ns.$items VALUES " + - "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + - "(4, 'bb', 10.0, cast('2021-01-01' as timestamp)), " + - "(4, 'cc', 15.5, cast('2021-02-01' as timestamp))") - - val df = sql(s"SELECT * FROM testcat.ns.$items") - val scans = collectScans(df.queryExecution.executedPlan) - assert(scans(0).inputRDD.partitions.length === 3, - "items scan should not group") - } - - test("SPARK-55092: Multi table join granular partition grouping") { + test("SPARK-55535: Multi table join granular partition grouping") { withSQLConf( SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true", @@ -3096,7 +3074,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } } - test("SPARK-55092: Multi table join partial clustering") { + test("SPARK-55535: Multi table join partial clustering") { withSQLConf(SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> "true") { val items_partitions = Array(identity("id")) createTable(items, itemsColumns, items_partitions) @@ -3143,7 +3121,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } } - test("SPARK-55092: Empty partitioned table") { + test("SPARK-55535: Empty partitioned table") { withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { val items_partitions = Array(identity("id")) createTable(items, itemsColumns, items_partitions) @@ -3155,10 +3133,38 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { checkAnswer(df, Seq.empty) val shuffles = collectShuffles(df.queryExecution.executedPlan) - assert(shuffles.size === 2, "empty tables should not report KeyedPartitioning") + assert(shuffles.size === 2, + "both legs should be shuffled as empty tables should not report KeyedPartitioning") + + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.isEmpty, + "no legs should be grouped as empty tables should not report KeyedPartitioning") + } + } + + test("SPARK-55535: Empty group partitions due filtered partitions") { + val items_partitions = Array(identity("id")) + createTable(items, itemsColumns, items_partitions) + + sql(s"INSERT INTO testcat.ns.$items VALUES " + + s"(1, 'aa', 39.0, cast('2020-01-01' as timestamp))") + + val purchases_partitions = Array(identity("item_id")) + createTable(purchases, purchasesColumns, purchases_partitions) + + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + s"(2, 42.0, cast('2020-01-01' as timestamp))") + + withSQLConf(SQLConf.V2_BUCKETING_PARTITION_FILTER_ENABLED.key -> "true") { + val df = createJoinTestDF(Seq("id" -> "item_id")) + checkAnswer(df, Seq.empty) + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, "no legs should be shuffled") val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) - assert(groupPartitions.isEmpty, "empty tables should not report KeyedPartitioning") + assert(groupPartitions.forall(_.outputPartitioning.numPartitions == 0), + "group partitions should not have any (common) partitions") } } } From 7951dc60a223fcbcb6f6e1a99330e08bde9b0d45 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Fri, 6 Mar 2026 10:32:34 +0100 Subject: [PATCH 29/29] fix BroadcastDistribution in EnsureRequirements --- .../spark/sql/execution/exchange/EnsureRequirements.scala | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 6592b6ba60c2f..39da546256132 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -62,9 +62,6 @@ case class EnsureRequirements( assert(requiredChildOrderings.length == originalChildren.length) // Ensure that the operator's children satisfy their output distribution requirements. var children = originalChildren.zip(requiredChildDistributions).map { - case (child, BroadcastDistribution(mode)) => - BroadcastExchangeExec(mode, child) - case (child, distribution) => // Split child's partitioning into categories val (other, grouped, nonGrouped) = splitKeyedPartitionings(child.outputPartitioning) @@ -106,10 +103,12 @@ case class EnsureRequirements( GroupPartitionsExec(child) } } else { - // No partitioning satisfies - need shuffle + // No partitioning satisfies - need broadcast or shuffle val numPartitions = distribution.requiredNumPartitions .getOrElse(conf.numShufflePartitions) distribution match { + case BroadcastDistribution(mode) => + BroadcastExchangeExec(mode, child) case _: StatefulOpClusteredDistribution => ShuffleExchangeExec( distribution.createPartitioning(numPartitions), child,