Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -387,8 +387,7 @@ case class KeyGroupedPartitioning(
if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) {
// check that join keys (required clustering keys)
// overlap with partition keys (KeyGroupedPartitioning attributes)
requiredClustering.exists(x => attributes.exists(_.semanticEquals(x))) &&
expressions.forall(_.collectLeaves().size == 1)
requiredClustering.exists(x => attributes.exists(_.semanticEquals(x)))
} else {
attributes.forall(x => requiredClustering.exists(_.semanticEquals(x)))
}
Expand Down Expand Up @@ -457,14 +456,7 @@ object KeyGroupedPartitioning {

def supportsExpressions(expressions: Seq[Expression]): Boolean = {
def isSupportedTransform(transform: TransformExpression): Boolean = {
transform.children.size == 1 && isReference(transform.children.head)
}

@tailrec
def isReference(e: Expression): Boolean = e match {
case _: Attribute => true
case g: GetStructField => isReference(g.child)
case _ => false
transform.children.count(isReference) == 1
}

expressions.forall {
Expand All @@ -473,6 +465,13 @@ object KeyGroupedPartitioning {
case _ => false
}
}

@tailrec
def isReference(e: Expression): Boolean = e match {
case _: Attribute => true
case g: GetStructField => isReference(g.child)
case _ => false
}
}

/**
Expand Down Expand Up @@ -792,8 +791,15 @@ case class KeyGroupedShuffleSpec(
}
partitioning.expressions.map { e =>
val leaves = e.collectLeaves()
assert(leaves.size == 1, s"Expected exactly one child from $e, but found ${leaves.size}")
distKeyToPos.getOrElse(leaves.head.canonicalized, mutable.BitSet.empty)
val attrs = leaves.filter(KeyGroupedPartitioning.isReference)
assert(leaves.size == 1 || attrs.size == 1,
s"Expected exactly one reference or child from $e, but found ${leaves.size}")
val head = if (attrs.size == 1) {
attrs.head
} else {
leaves.head
}
distKeyToPos.getOrElse(head.canonicalized, mutable.BitSet.empty)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -647,8 +647,7 @@ case class PartitionInternalRow(keys: Array[Any])
if (!other.isInstanceOf[PartitionInternalRow]) {
return false
}
// Just compare by reference, not by value
this.keys == other.asInstanceOf[PartitionInternalRow].keys
this.keys sameElements other.asInstanceOf[PartitionInternalRow].keys

@szehon-ho szehon-ho Dec 17, 2024

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to fix the 'non-clustered distribution: V2 function with multiple args' test below, which now needs to compare the actual partitionRow by values (and not just reference).

}
override def hashCode: Int = {
Objects.hashCode(keys)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Literal, TransformExpression}
import org.apache.spark.sql.catalyst.plans.physical
import org.apache.spark.sql.connector.catalog.{Column, Identifier, InMemoryTableCatalog}
import org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning
import org.apache.spark.sql.connector.catalog.{Column, Identifier, InMemoryTableCatalog, PartitionInternalRow}
import org.apache.spark.sql.connector.catalog.functions._
import org.apache.spark.sql.connector.distributions.Distributions
import org.apache.spark.sql.connector.expressions._
Expand All @@ -37,6 +38,7 @@ import org.apache.spark.sql.execution.joins.SortMergeJoinExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
private val functions = Seq(
Expand Down Expand Up @@ -195,10 +197,16 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
s"(2, 'ccc', CAST('2020-01-01' AS timestamp))")

val df = sql(s"SELECT * FROM testcat.ns.$table")
val distribution = physical.ClusteredDistribution(
Seq(TransformExpression(TruncateFunction, Seq(attr("data"), Literal(2)))))
val transformExpression = Seq(TransformExpression(
TruncateFunction, Seq(attr("data"), Literal(2))))
val distribution = physical.ClusteredDistribution(transformExpression)
val partValues = Seq(
PartitionInternalRow(Array(UTF8String.fromString("aa"))),
PartitionInternalRow(Array(UTF8String.fromString("bb"))),
PartitionInternalRow(Array(UTF8String.fromString("cc"))))
val partitioning = new KeyGroupedPartitioning(transformExpression, 3, partValues, partValues)

checkQueryPlan(df, distribution, physical.UnknownPartitioning(0))
checkQueryPlan(df, distribution, partitioning)
}

/**
Expand Down Expand Up @@ -2504,4 +2512,45 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
assert(scans.forall(_.inputRDD.partitions.length == 2))
}
}

test("SPARK-50593: Support truncate transform") {
val partitions: Array[Transform] = Array(
Expressions.apply("truncate", Expressions.column("data"), Expressions.literal(2))
)

// create a table with 3 partitions, partitioned by `truncate` transform
createTable("table", columns, partitions)
sql(s"INSERT INTO testcat.ns.table VALUES " +
s"(0, 'aaa', CAST('2022-01-01' AS timestamp)), " +
s"(1, 'bbb', CAST('2021-01-01' AS timestamp)), " +
s"(2, 'ccc', CAST('2020-01-01' AS timestamp))")

createTable("table2", columns2, partitions)
sql(s"INSERT INTO testcat.ns.table2 VALUES " +
s"(1, 5, 'aaa')," +
s"(5, 10, 'bbb')," +
s"(20, 40, 'bbb')," +
s"(40, 80, 'ddd')")

withSQLConf(
SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false",
SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true",
SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true") {

val df =
sql(
selectWithMergeJoinHint("table", "table2") +
"id, store_id, dept_id " +
"FROM testcat.ns.table JOIN testcat.ns.table2 " +
"ON table.data = table2.data " +
"SORT BY id, store_id, dept_id")
val shuffles = collectShuffles(df.queryExecution.executedPlan)
assert(shuffles.isEmpty, "should not add shuffle for both sides of the join")
checkAnswer(df,
Seq(Row(0, 1, 5), Row(1, 5, 10), Row(1, 20, 40))
)
val scans = collectScans(df.queryExecution.executedPlan)
assert(scans.forall(_.inputRDD.partitions.length == 4))
}
}
}