diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RowLevelOperation.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RowLevelOperation.java index 844734ff7ccb7..26a21cb50f92d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RowLevelOperation.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RowLevelOperation.java @@ -105,4 +105,48 @@ default String description() { default NamedReference[] requiredMetadataAttributes() { return new NamedReference[0]; } + + /** + * Controls whether to send only the required data columns to the connector rather than the + * full row. + *
+ * When true, Spark narrows the data column schema ({@link LogicalWriteInfo#schema()}) to only + * the columns declared via {@link #requiredDataAttributes()}. Metadata columns (from + * {@link #requiredMetadataAttributes()}) and row ID columns (from + * {@link SupportsDelta#rowId()}) are unaffected and always projected separately. + *
+ * If {@link #requiredDataAttributes()} returns a non-empty array, the write schema is exactly + * those columns in declared order. The connector must include all columns it wants to receive, + * including the columns being updated. If {@link #requiredDataAttributes()} returns an empty + * array, Spark sends only the non-identity assigned columns (heuristic path). + *
+ * Currently only consulted for UPDATE operations. + * + * @since 4.3.0 + */ + default boolean supportsColumnUpdates() { + return false; + } + + /** + * Returns data column references required to perform this row-level operation. + *
+ * This method is only consulted by Spark when {@link #supportsColumnUpdates()} returns + * {@code true}. If {@code supportsColumnUpdates()} returns {@code false}, the returned array + * is ignored and the full table row is sent (the default behavior). + *
+ * When non-empty, the returned columns become the write schema in declared order. + * The connector must declare all columns it wants to receive, including the columns being + * updated. Use {@link RowLevelOperationInfo#updatedColumns()} to learn which columns are being + * assigned, then add any extra columns needed for row lookup or routing (e.g., primary key). + *
+ * When empty (the default), Spark falls back to sending only the non-identity assigned columns. + *
+ * Currently only consulted for UPDATE operations. + * + * @since 4.3.0 + */ + default NamedReference[] requiredDataAttributes() { + return new NamedReference[0]; + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RowLevelOperationInfo.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RowLevelOperationInfo.java index e3d7397aed91b..e0211661170cf 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RowLevelOperationInfo.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RowLevelOperationInfo.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.connector.write; import org.apache.spark.annotation.Experimental; +import org.apache.spark.sql.connector.expressions.NamedReference; import org.apache.spark.sql.connector.write.RowLevelOperation.Command; import org.apache.spark.sql.util.CaseInsensitiveStringMap; @@ -37,4 +38,20 @@ public interface RowLevelOperationInfo { * Returns the row-level SQL command (e.g. DELETE, UPDATE, MERGE). */ Command command(); + + /** + * Returns the columns being updated in an UPDATE statement, as non-identity assignments. + * + *
For DELETE and MERGE, returns an empty array. + * + *
Connectors can use this to decide what {@link RowLevelOperation#requiredDataAttributes()}
+ * to declare. For instance, a connector that needs its primary key for row lookup can check
+ * whether pk is already in the updated columns list and, if not, add it to
+ * requiredDataAttributes().
+ *
+ * @since 4.3.0
+ */
+ default NamedReference[] updatedColumns() {
+ return new NamedReference[0];
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala
index 48c48eb323bd7..4c08e6fc3856c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala
@@ -22,12 +22,13 @@ import scala.collection.mutable
import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.ProjectingInternalRow
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, AttributeSet, Expression, ExprId, If, Literal, MetadataAttribute, NamedExpression, V2ExpressionUtils}
+import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
import org.apache.spark.sql.catalyst.plans.logical.{Assignment, Expand, LogicalPlan, MergeRows, Project, Union}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.{ReplaceDataProjections, WriteDeltaProjections}
import org.apache.spark.sql.catalyst.util.RowDeltaUtils._
import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations
-import org.apache.spark.sql.connector.expressions.FieldReference
+import org.apache.spark.sql.connector.expressions.{FieldReference, NamedReference}
import org.apache.spark.sql.connector.write.{RowLevelOperation, RowLevelOperationInfoImpl, RowLevelOperationTable, SupportsDelta}
import org.apache.spark.sql.connector.write.RowLevelOperation.Command
import org.apache.spark.sql.errors.QueryCompilationErrors
@@ -50,20 +51,36 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] {
protected def buildOperationTable(
table: SupportsRowLevelOperations,
command: Command,
- options: CaseInsensitiveStringMap): RowLevelOperationTable = {
- val info = RowLevelOperationInfoImpl(command, options)
+ options: CaseInsensitiveStringMap,
+ updatedColumns: Seq[NamedReference] = Nil): RowLevelOperationTable = {
+ val info = RowLevelOperationInfoImpl(command, options, updatedColumns)
val operation = table.newRowLevelOperationBuilder(info).build()
RowLevelOperationTable(table, operation)
}
+ /**
+ * Builds a DataSourceV2Relation for a row-level operation, optionally narrowing its output.
+ *
+ * When dataAttrs is non-empty, the relation output is narrowed to include only columns
+ * required for a column-update write. When dataAttrs is empty, the full relation.output is
+ * preserved.
+ */
protected def buildRelationWithAttrs(
relation: DataSourceV2Relation,
table: RowLevelOperationTable,
metadataAttrs: Seq[AttributeReference],
- rowIdAttrs: Seq[AttributeReference] = Nil): DataSourceV2Relation = {
-
- val attrs = dedupAttrs(relation.output ++ rowIdAttrs ++ metadataAttrs)
- relation.copy(table = table, output = attrs)
+ rowIdAttrs: Seq[AttributeReference] = Nil,
+ dataAttrs: Seq[AttributeReference] = Nil,
+ cond: Expression = TrueLiteral): DataSourceV2Relation = {
+
+ if (dataAttrs.nonEmpty) {
+ val required = (dataAttrs ++ cond.references.toSeq).map(_.exprId).toSet
+ val narrowOutput = relation.output.filter(a => required.contains(a.exprId))
+ relation.copy(table = table, output = dedupAttrs(narrowOutput ++ rowIdAttrs ++ metadataAttrs))
+ } else {
+ val attrs = dedupAttrs(relation.output ++ rowIdAttrs ++ metadataAttrs)
+ relation.copy(table = table, output = attrs)
+ }
}
protected def dedupAttrs(attrs: Seq[AttributeReference]): Seq[AttributeReference] = {
@@ -87,6 +104,14 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] {
relation)
}
+ protected def resolveRequiredDataAttrs(
+ relation: DataSourceV2Relation,
+ operation: RowLevelOperation): Seq[AttributeReference] = {
+ V2ExpressionUtils.resolveRefs[AttributeReference](
+ operation.requiredDataAttributes.toImmutableArraySeq,
+ relation)
+ }
+
protected def resolveRowIdAttrs(
relation: DataSourceV2Relation,
operation: SupportsDelta): Seq[AttributeReference] = {
@@ -211,11 +236,13 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] {
metadataAttrs: Seq[Attribute]): WriteDeltaProjections = {
val outputs = extractOutputs(plan)
+ // Always produce Some(rowProjection) even for empty rowAttrs (identity-only column updates).
+ // Physical execution calls rowProjection.project(row) unconditionally; None causes NPE.
val rowProjection = if (rowAttrs.nonEmpty) {
val outputsWithRow = filterOutputs(outputs, OPERATIONS_WITH_ROW)
Some(newLazyProjection(plan, outputsWithRow, rowAttrs))
} else {
- None
+ Some(ProjectingInternalRow(StructType(Nil), Nil))
}
val outputsWithRowId = filterOutputs(outputs, OPERATIONS_WITH_ROW_ID)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala
index f235374bd5d6f..8e91316281c3f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala
@@ -17,11 +17,12 @@
package org.apache.spark.sql.catalyst.analysis
-import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, EqualNullSafe, Expression, If, Literal, MetadataAttribute, Not, SubqueryExpression}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, AttributeReference, AttributeSet, EqualNullSafe, Expression, If, Literal, MetadataAttribute, Not, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
import org.apache.spark.sql.catalyst.plans.logical.{Assignment, Expand, Filter, LogicalPlan, Project, ReplaceData, Union, UpdateTable, WriteDelta}
import org.apache.spark.sql.catalyst.util.RowDeltaUtils._
import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations
+import org.apache.spark.sql.connector.expressions.FieldReference
import org.apache.spark.sql.connector.write.{RowLevelOperationTable, SupportsDelta}
import org.apache.spark.sql.connector.write.RowLevelOperation.Command.UPDATE
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, ExtractV2Table}
@@ -41,7 +42,12 @@ object RewriteUpdateTable extends RewriteRowLevelCommand {
EliminateSubqueryAliases(aliasedTable) match {
case r @ ExtractV2Table(tbl: SupportsRowLevelOperations) =>
- val table = buildOperationTable(tbl, UPDATE, CaseInsensitiveStringMap.empty())
+ val updatedCols = assignments.collect {
+ case Assignment(key: AttributeReference, value) if !isIdentityAssignment(key, value) =>
+ FieldReference(key.name)
+ }
+ val table = buildOperationTable(tbl, UPDATE, CaseInsensitiveStringMap.empty(),
+ updatedCols)
val updateCond = cond.getOrElse(TrueLiteral)
table.operation match {
case _: SupportsDelta =>
@@ -65,20 +71,18 @@ object RewriteUpdateTable extends RewriteRowLevelCommand {
assignments: Seq[Assignment],
cond: Expression): ReplaceData = {
- // resolve all required metadata attrs that may be used for grouping data on write
- val metadataAttrs = resolveRequiredMetadataAttrs(relation, operationTable.operation)
-
- // construct a read relation and include all required metadata columns
- val readRelation = buildRelationWithAttrs(relation, operationTable, metadataAttrs)
+ val (readRelation, rowAttrs, metadataAttrs) =
+ buildReplaceDataReadRelation(relation, operationTable, assignments, cond)
- // build a plan with updated and copied over records
- val query = buildReplaceDataUpdateProjection(readRelation, assignments, cond)
+ val updatedAndRemainingRowsPlan = buildReplaceDataUpdateProjection(
+ readRelation, assignments, cond)
- // build a plan to replace read groups in the table
val writeRelation = relation.copy(table = operationTable)
- val projections = buildReplaceDataProjections(query, relation.output, metadataAttrs)
+ val projections = buildReplaceDataProjections(updatedAndRemainingRowsPlan, rowAttrs,
+ metadataAttrs)
val groupFilterCond = if (groupFilterEnabled) Some(cond) else None
- ReplaceData(writeRelation, cond, query, relation, projections, groupFilterCond)
+ ReplaceData(writeRelation, cond, updatedAndRemainingRowsPlan, relation, projections,
+ groupFilterCond)
}
// build a rewrite plan for sources that support replacing groups of data (e.g. files, partitions)
@@ -89,13 +93,8 @@ object RewriteUpdateTable extends RewriteRowLevelCommand {
assignments: Seq[Assignment],
cond: Expression): ReplaceData = {
- // resolve all required metadata attrs that may be used for grouping data on write
- val metadataAttrs = resolveRequiredMetadataAttrs(relation, operationTable.operation)
-
- // construct a read relation and include all required metadata columns
- // the same read relation will be used to read records that must be updated and copied over
- // the analyzer will take care of duplicated attr IDs
- val readRelation = buildRelationWithAttrs(relation, operationTable, metadataAttrs)
+ val (readRelation, rowAttrs, metadataAttrs) =
+ buildReplaceDataReadRelation(relation, operationTable, assignments, cond)
// build a plan for updated records that match the condition
val matchedRowsPlan = Filter(cond, readRelation)
@@ -106,38 +105,78 @@ object RewriteUpdateTable extends RewriteRowLevelCommand {
val remainingRowsPlan = addOperationColumn(COPY_OPERATION,
Filter(remainingRowFilter, readRelation))
- // the new state is a union of updated and copied over records
- val query = Union(updatedRowsPlan, remainingRowsPlan)
+ val updatedAndRemainingRowsPlan = Union(updatedRowsPlan, remainingRowsPlan)
- // build a plan to replace read groups in the table
val writeRelation = relation.copy(table = operationTable)
- val projections = buildReplaceDataProjections(query, relation.output, metadataAttrs)
+ val projections = buildReplaceDataProjections(updatedAndRemainingRowsPlan, rowAttrs,
+ metadataAttrs)
val groupFilterCond = if (groupFilterEnabled) Some(cond) else None
- ReplaceData(writeRelation, cond, query, relation, projections, groupFilterCond)
+ ReplaceData(writeRelation, cond, updatedAndRemainingRowsPlan, relation, projections,
+ groupFilterCond)
}
- // this method assumes the assignments have been already aligned before
+ /**
+ * When the connector supports column updates and declares required data attributes,
+ * the read relation is narrowed at analysis time so that GroupBasedRowLevelOperationScanPlanning
+ * uses only the needed columns for the scan. Otherwise, the full relation output is used.
+ */
+ private def buildReplaceDataReadRelation(
+ relation: DataSourceV2Relation,
+ operationTable: RowLevelOperationTable,
+ assignments: Seq[Assignment],
+ cond: Expression): (DataSourceV2Relation, Seq[Attribute], Seq[AttributeReference]) = {
+
+ val operation = operationTable.operation
+ val metadataAttrs = resolveRequiredMetadataAttrs(relation, operation)
+ val connectorDataAttrs = resolveRequiredDataAttrs(relation, operation)
+ val isNarrow = operation.supportsColumnUpdates() && connectorDataAttrs.nonEmpty
+
+ val readRelation = if (isNarrow) {
+ val allRequired = (connectorDataAttrs ++ computeAssignedAttrs(assignments)).distinct
+ buildRelationWithAttrs(relation, operationTable, metadataAttrs, dataAttrs = allRequired,
+ cond = cond)
+ } else {
+ buildRelationWithAttrs(relation, operationTable, metadataAttrs)
+ }
+
+ val rowAttrs: Seq[Attribute] = if (isNarrow) connectorDataAttrs else relation.output
+
+ (readRelation, rowAttrs, metadataAttrs)
+ }
+
+ /**
+ * Builds the update projection for ReplaceData plans. Assumes assignments are already aligned.
+ *
+ * plan.output may be narrowed by buildReplaceDataReadRelation, so only columns present in the
+ * plan are projected.
+ */
private def buildReplaceDataUpdateProjection(
plan: LogicalPlan,
assignments: Seq[Assignment],
cond: Expression = TrueLiteral): LogicalPlan = {
- // the plan output may include metadata columns at the end
- // that's why the number of assignments may not match the number of plan output columns
- val assignedValues = assignments.map(_.value)
- val updatedValues = plan.output.zipWithIndex.map { case (attr, index) =>
- if (index < assignments.size) {
- val assignedExpr = assignedValues(index)
- val updatedValue = If(cond, assignedExpr, attr)
- Alias(updatedValue, attr.name)()
- } else {
- assert(MetadataAttribute.isValid(attr.metadata))
+ // plan.output may be narrowed (fewer columns than assignments) or reordered,
+ // so assignments are matched by exprId.
+ val assignmentMap = AttributeMap(assignments.collect {
+ case Assignment(key: Attribute, value) => key -> value
+ })
+
+ val updatedValues = plan.output.map { attr =>
+ if (MetadataAttribute.isValid(attr.metadata)) {
if (MetadataAttribute.isPreservedOnUpdate(attr)) {
attr
} else {
val updatedValue = If(cond, Literal(null, attr.dataType), attr)
Alias(updatedValue, attr.name)(explicitMetadata = Some(attr.metadata))
}
+ } else {
+ assignmentMap.get(attr) match {
+ case Some(assignedExpr) =>
+ Alias(If(cond, assignedExpr, attr), attr.name)()
+ case None =>
+ // Column in relation.output with no matching assignment; pass through unchanged.
+ attr
+ }
}
}
@@ -154,30 +193,114 @@ object RewriteUpdateTable extends RewriteRowLevelCommand {
cond: Expression): WriteDelta = {
val operation = operationTable.operation.asInstanceOf[SupportsDelta]
+ val supportsColumnUpdate = operation.supportsColumnUpdates()
- // resolve all needed attrs (e.g. row ID and any required metadata attrs)
- val rowAttrs = relation.output
val rowIdAttrs = resolveRowIdAttrs(relation, operation)
val metadataAttrs = resolveRequiredMetadataAttrs(relation, operation)
- // construct a read relation and include all required metadata columns
+ val connectorDataAttrs = if (supportsColumnUpdate) {
+ resolveRequiredDataAttrs(relation, operation)
+ } else Nil
+
val readRelation = buildRelationWithAttrs(relation, operationTable, metadataAttrs, rowIdAttrs)
+ // Connector-declared attrs not being assigned are passed through so ColumnPruning
+ // keeps them in the scan and the connector receives their current values.
+ val assignedAttrs = if (supportsColumnUpdate) computeAssignedAttrs(assignments)
+ else relation.output
+ val connectorExtraAttrs: Seq[AttributeReference] = if (connectorDataAttrs.nonEmpty) {
+ val assignedAttrSet = AttributeSet(assignedAttrs)
+ connectorDataAttrs.filterNot(assignedAttrSet.contains)
+ } else Nil
+
// build a plan for updated records that match the condition
val matchedRowsPlan = Filter(cond, readRelation)
val rowDeltaPlan = if (operation.representUpdateAsDeleteAndInsert) {
buildDeletesAndInserts(matchedRowsPlan, assignments, rowIdAttrs)
+ } else if (supportsColumnUpdate) {
+ buildColumnUpdateProjection(
+ matchedRowsPlan, assignments, rowIdAttrs, metadataAttrs, connectorExtraAttrs)
} else {
buildWriteDeltaUpdateProjection(matchedRowsPlan, assignments, rowIdAttrs)
}
+ val effectiveRowAttrs = if (supportsColumnUpdate && connectorDataAttrs.nonEmpty) {
+ connectorDataAttrs
+ } else if (supportsColumnUpdate) {
+ assignedAttrs
+ } else {
+ relation.output
+ }
+
// build a plan to write the row delta to the table
val writeRelation = relation.copy(table = operationTable)
- val projections = buildWriteDeltaProjections(rowDeltaPlan, rowAttrs, rowIdAttrs, metadataAttrs)
+ val projections = buildWriteDeltaProjections(
+ rowDeltaPlan, effectiveRowAttrs, rowIdAttrs, metadataAttrs)
val groupFilterCond = if (groupFilterEnabled) Some(cond) else None
WriteDelta(writeRelation, cond, rowDeltaPlan, relation, projections, groupFilterCond)
}
+ /**
+ * Builds the WriteDelta projection for the column update path. The resulting Project
+ * references only the columns needed for the write, so ColumnPruning narrows the scan
+ * to match.
+ */
+ private def buildColumnUpdateProjection(
+ plan: LogicalPlan,
+ assignments: Seq[Assignment],
+ rowIdAttrs: Seq[Attribute],
+ metadataAttrs: Seq[Attribute],
+ connectorExtraAttrs: Seq[AttributeReference] = Nil): LogicalPlan = {
+
+ val assignedValues = assignments.collect {
+ case Assignment(key: Attribute, value) if !isIdentityAssignment(key, value) =>
+ Alias(value, key.name)()
+ }
+
+ val connectorExtraAttrSet = AttributeSet(connectorExtraAttrs)
+ val connectorPassThroughValues = plan.output.filter { a =>
+ connectorExtraAttrSet.contains(a) && !MetadataAttribute.isValid(a.metadata)
+ }
+
+ val metadataAttrSet = AttributeSet(metadataAttrs)
+ val metadataValues = plan.output.filter(metadataAttrSet.contains).map { attr =>
+ if (MetadataAttribute.isPreservedOnUpdate(attr)) {
+ attr
+ } else {
+ Alias(Literal(null, attr.dataType), attr.name)(explicitMetadata = Some(attr.metadata))
+ }
+ }
+
+ val rowIdAttrSet = AttributeSet(rowIdAttrs)
+ val rowIdValues = plan.output.filter(rowIdAttrSet.contains)
+
+ val originalRowIdValues = buildOriginalRowIdValues(rowIdAttrs, assignments)
+ val operationType = Alias(Literal(UPDATE_OPERATION), OPERATION_COLUMN)()
+
+ Project(
+ Seq(operationType) ++ assignedValues ++ connectorPassThroughValues ++
+ metadataValues ++ rowIdValues ++ originalRowIdValues,
+ plan)
+ }
+
+ // Returns the table attributes that are genuinely updated (non-identity) in this UPDATE.
+ private def computeAssignedAttrs(assignments: Seq[Assignment]): Seq[AttributeReference] = {
+ assignments.collect {
+ case Assignment(key: AttributeReference, value) if !isIdentityAssignment(key, value) => key
+ }
+ }
+
+ private def isIdentityAssignment(key: Attribute, value: Expression): Boolean = {
+ val unwrapped = value match {
+ case Alias(child, _) => child
+ case other => other
+ }
+ unwrapped match {
+ case attr: Attribute => AttributeSet(Seq(key)).contains(attr)
+ case _ => false
+ }
+ }
+
// this method assumes the assignments have been already aligned before
private def buildWriteDeltaUpdateProjection(
plan: LogicalPlan,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
index 40cf5009b97dc..d9a9e586f32c4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
@@ -364,6 +364,14 @@ trait RowLevelWrite extends V2WriteCommand with SupportsSubquery {
operation.requiredMetadataAttributes.toImmutableArraySeq,
originalTable)
}
+
+ // Resolves the connector-declared data attributes against the original table.
+ // Symmetric with projectedMetadataAttrs; used in narrow-schema validation.
+ protected def projectedDataAttrs: Seq[Attribute] = {
+ V2ExpressionUtils.resolveRefs[AttributeReference](
+ operation.requiredDataAttributes.toImmutableArraySeq,
+ originalTable)
+ }
}
/**
@@ -417,7 +425,35 @@ case class ReplaceData(
// validates row projection output is compatible with table attributes
private def rowAttrsResolved: Boolean = {
val inRowAttrs = DataTypeUtils.toAttributes(projections.rowProjection.schema)
- table.skipSchemaResolution || areCompatible(inRowAttrs, table.output)
+ table.skipSchemaResolution || areCompatible(inRowAttrs, table.output) ||
+ dataAttrsResolved(inRowAttrs)
+ }
+
+ /**
+ * Validates the narrow-write-schema row projection output.
+ *
+ * When the connector declares specific data attributes via requiredDataAttributes(), the
+ * write schema must exactly match projectedDataAttrs (same columns, same order).
+ *
+ * When requiredDataAttributes() is empty, the write schema contains only
+ * the assigned columns. We validate each one exists in the table with a compatible type.
+ */
+ private def dataAttrsResolved(inRowAttrs: Seq[Attribute]): Boolean = {
+ if (!operation.supportsColumnUpdates()) { return false }
+ val outDataAttrs = projectedDataAttrs
+ if (outDataAttrs.nonEmpty) {
+ areCompatible(inRowAttrs, outDataAttrs)
+ } else {
+ inRowAttrs.forall { inAttr =>
+ table.output.exists { outAttr =>
+ val inType = CharVarcharUtils.getRawType(inAttr.metadata).getOrElse(inAttr.dataType)
+ val outType = CharVarcharUtils.getRawType(outAttr.metadata).getOrElse(outAttr.dataType)
+ inAttr.name == outAttr.name &&
+ DataType.equalsIgnoreCompatibleNullability(inType, outType) &&
+ (outAttr.nullable || !inAttr.nullable)
+ }
+ }
+ }
}
// validates metadata projection output is compatible with metadata attributes
@@ -509,7 +545,35 @@ case class WriteDelta(
case Some(projection) => DataTypeUtils.toAttributes(projection.schema)
case None => Nil
}
- table.skipSchemaResolution || areCompatible(inRowAttrs, outRowAttrs)
+ table.skipSchemaResolution || areCompatible(inRowAttrs, outRowAttrs) ||
+ dataAttrsResolved(inRowAttrs)
+ }
+
+ /**
+ * Validates the narrow-write-schema row projection output.
+ *
+ * When the connector declares specific data attributes via requiredDataAttributes(), the
+ * write schema must exactly match projectedDataAttrs (same columns, same order).
+ *
+ * When requiredDataAttributes() is empty, the write schema contains only
+ * the assigned columns. We validate each one exists in the table with a compatible type.
+ */
+ private def dataAttrsResolved(inRowAttrs: Seq[Attribute]): Boolean = {
+ if (!operation.supportsColumnUpdates()) { return false }
+ val outDataAttrs = projectedDataAttrs
+ if (outDataAttrs.nonEmpty) {
+ areCompatible(inRowAttrs, outDataAttrs)
+ } else {
+ inRowAttrs.forall { inAttr =>
+ table.output.exists { outAttr =>
+ val inType = CharVarcharUtils.getRawType(inAttr.metadata).getOrElse(inAttr.dataType)
+ val outType = CharVarcharUtils.getRawType(outAttr.metadata).getOrElse(outAttr.dataType)
+ inAttr.name == outAttr.name &&
+ DataType.equalsIgnoreCompatibleNullability(inType, outType) &&
+ (outAttr.nullable || !inAttr.nullable)
+ }
+ }
+ }
}
// validates row ID projection output is compatible with row ID attributes
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/RowLevelOperationInfoImpl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/RowLevelOperationInfoImpl.scala
index 9d499cdef361b..a84e0230cd8d7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/RowLevelOperationInfoImpl.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/RowLevelOperationInfoImpl.scala
@@ -17,9 +17,15 @@
package org.apache.spark.sql.connector.write
+import org.apache.spark.sql.connector.expressions.NamedReference
import org.apache.spark.sql.connector.write.RowLevelOperation.Command
import org.apache.spark.sql.util.CaseInsensitiveStringMap
private[sql] case class RowLevelOperationInfoImpl(
command: Command,
- options: CaseInsensitiveStringMap) extends RowLevelOperationInfo
+ options: CaseInsensitiveStringMap,
+ private val updatedCols: Seq[NamedReference] = Nil)
+ extends RowLevelOperationInfo {
+
+ override def updatedColumns(): Array[NamedReference] = updatedCols.toArray
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala
index 5c0bc0b143f3d..69b2be68e7e10 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala
@@ -70,18 +70,40 @@ class InMemoryRowLevelOperationTable private (
private final val SPLIT_UPDATES = "split-updates"
private final val NO_METADATA = "no-metadata"
private final val noMetadata = properties.getOrDefault(NO_METADATA, "false") == "true"
+ private final val COLUMN_UPDATE = "column-update"
+ private final val COLUMN_UPDATE_REQ_ATTRS = "column-update-req-attrs"
+ private final val COLUMN_UPDATE_COW = "column-update-cow"
+ private final val COLUMN_UPDATE_FROM_INFO = "column-update-from-info"
+ private final val COLUMN_UPDATE_SPLIT = "column-update-split"
// used in row-level operation tests to verify replaced partitions
var replacedPartitions: Seq[Seq[Any]] = Seq.empty
// used in row-level operation tests to verify reported write schema
var lastWriteInfo: LogicalWriteInfo = _
+ // used in column-update tests to verify the scan projection was narrowed correctly
+ var lastScanSchema: StructType = _
+ // used in column-update tests to verify that Spark passed the correct updated column list
+ // to the connector via RowLevelOperationInfo.updatedColumns()
+ var lastUpdatedColumns: Array[NamedReference] = Array.empty
// used in row-level operation tests to verify passed records
// (operation, id, metadata, row)
var lastWriteLog: Seq[InternalRow] = Seq.empty
override def newRowLevelOperationBuilder(
info: RowLevelOperationInfo): RowLevelOperationBuilder = {
- if (properties.getOrDefault(SUPPORTS_DELTAS, "false") == "true") {
+ lastUpdatedColumns = info.updatedColumns()
+ if (properties.getOrDefault(COLUMN_UPDATE, "false") == "true") {
+ () => new DeltaBasedColumnUpdateOperation(info.command)
+ } else if (properties.containsKey(COLUMN_UPDATE_REQ_ATTRS)) {
+ val reqCols = properties.get(COLUMN_UPDATE_REQ_ATTRS).split(",").map(_.trim)
+ () => new DeltaBasedColumnUpdateOperationWithReqAttrs(info.command, reqCols)
+ } else if (properties.getOrDefault(COLUMN_UPDATE_FROM_INFO, "false") == "true") {
+ () => new DeltaBasedColumnUpdateOperationFromInfo(info.command, info.updatedColumns().toSeq)
+ } else if (properties.getOrDefault(COLUMN_UPDATE_COW, "false") == "true") {
+ () => new PartitionBasedColumnUpdateOperation(info.command, info.updatedColumns().toSeq)
+ } else if (properties.getOrDefault(COLUMN_UPDATE_SPLIT, "false") == "true") {
+ () => new DeltaBasedColumnUpdateSplitOperation(info.command, info.updatedColumns().toSeq)
+ } else if (properties.getOrDefault(SUPPORTS_DELTAS, "false") == "true") {
() => DeltaBasedOperation(info.command)
} else {
() => PartitionBasedOperation(info.command)
@@ -208,6 +230,302 @@ class InMemoryRowLevelOperationTable private (
}
}
+ // A delta-based operation that supports column-level updates: Spark sends only the
+ // assigned/changed columns in the row projection instead of the full row schema.
+ class DeltaBasedColumnUpdateOperation(command: Command)
+ extends DeltaBasedOperation(command) {
+ override def representUpdateAsDeleteAndInsert(): Boolean = false
+ override def supportsColumnUpdates(): Boolean = true
+
+ // Override newScanBuilder to record the schema that Spark actually requests from the
+ // connector after column pruning, so tests can assert on scan narrowing.
+ override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
+ new InMemoryScanBuilder(schema, options) {
+ override def build(): Scan = {
+ val scan = super.build()
+ lastScanSchema = scan.readSchema()
+ scan
+ }
+ }
+ }
+
+ override def newWriteBuilder(info: LogicalWriteInfo): DeltaWriteBuilder = {
+ lastWriteInfo = info
+ new DeltaWriteBuilder {
+ override def build(): DeltaWrite =
+ new DeltaWrite with RequiresDistributionAndOrdering {
+
+ override def requiredDistribution(): Distribution = {
+ Distributions.clustered(Array(PARTITION_COLUMN_REF))
+ }
+
+ override def requiredOrdering(): Array[SortOrder] = {
+ Array[SortOrder](
+ LogicalExpressions.sort(
+ PARTITION_COLUMN_REF,
+ SortDirection.ASCENDING,
+ SortDirection.ASCENDING.defaultNullOrdering())
+ )
+ }
+
+ override def toBatch: DeltaBatchWrite =
+ new TestBatchWrite with DeltaBatchWrite {
+ override def createBatchWriterFactory(
+ info: PhysicalWriteInfo): DeltaWriterFactory = {
+ new DeltaBufferedRowsWriterFactory(lastWriteInfo.schema())
+ }
+
+ // For column-update writes, rows contain only the assigned columns
+ // (narrow schema from LogicalWriteInfo). We expand each row to the full table
+ // schema by overlaying write-schema columns on the base row found by pk.
+ override protected def doCommit(messages: Array[WriterCommitMessage]): Unit =
+ dataMap.synchronized {
+ val newData = messages.map(_.asInstanceOf[BufferedRows])
+ val writeSchema = lastWriteInfo.schema()
+ val writeFieldIdx = writeSchema.fieldNames.zipWithIndex.toMap
+
+ val mergedData = newData.map { buf =>
+ val merged = new BufferedRows(buf.key, schema)
+ val updateOpName = UTF8String.fromString(Update.toString)
+ buf.log.foreach { logRow =>
+ val opName = logRow.getUTF8String(0)
+ if (opName == updateOpName) {
+ val pk = logRow.getInt(1)
+ val narrowRow = logRow.get(3, writeSchema).asInstanceOf[InternalRow]
+ val baseRow = dataMap.values.iterator.flatten
+ .flatMap(_.rows)
+ .find(r => r.getInt(schema.fieldIndex("pk")) == pk)
+ val fullRow = new GenericInternalRow(schema.length)
+ baseRow.foreach { base =>
+ for (i <- schema.fields.indices) {
+ fullRow.update(i, base.get(i, schema.fields(i).dataType))
+ }
+ }
+ schema.fields.zipWithIndex.foreach { case (field, i) =>
+ writeFieldIdx.get(field.name).foreach { j =>
+ fullRow.update(i, narrowRow.get(j, field.dataType))
+ }
+ }
+ merged.rows.append(fullRow)
+ }
+ }
+ merged
+ }
+
+ withDeletes(newData)
+ withData(mergedData)
+ lastWriteLog = newData.flatMap(buffer => buffer.log).toIndexedSeq
+ }
+
+ override def abort(messages: Array[WriterCommitMessage]): Unit = {}
+ }
+ }
+ }
+ }
+ }
+
+ class DeltaBasedColumnUpdateOperationWithReqAttrs(command: Command, reqCols: Array[String])
+ extends DeltaBasedColumnUpdateOperation(command) {
+ override def requiredDataAttributes(): Array[NamedReference] = reqCols.map(FieldReference(_))
+ }
+
+ class DeltaBasedColumnUpdateOperationFromInfo(
+ command: Command,
+ updatedCols: Seq[NamedReference])
+ extends DeltaBasedColumnUpdateOperation(command) {
+
+ private val PK_REF: NamedReference = FieldReference("pk")
+
+ override def requiredDataAttributes(): Array[NamedReference] = {
+ val updatedNames = updatedCols.map(_.describe()).toSet
+ if (updatedNames.contains("pk")) {
+ updatedCols.toArray
+ } else {
+ (Array(PK_REF) ++ updatedCols).toArray
+ }
+ }
+ }
+
+ class DeltaBasedColumnUpdateSplitOperation(
+ command: Command,
+ updatedCols: Seq[NamedReference] = Nil)
+ extends DeltaBasedColumnUpdateOperation(command) {
+ override def representUpdateAsDeleteAndInsert(): Boolean = true
+
+ private val PK_REF: NamedReference = FieldReference("pk")
+ override def requiredDataAttributes(): Array[NamedReference] = {
+ val updatedNames = updatedCols.map(_.describe()).toSet
+ if (updatedNames.contains("pk")) updatedCols.toArray
+ else (Array(PK_REF) ++ updatedCols).toArray
+ }
+
+ override def newWriteBuilder(info: LogicalWriteInfo): DeltaWriteBuilder = {
+ lastWriteInfo = info
+ new DeltaWriteBuilder {
+ override def build(): DeltaWrite =
+ new DeltaWrite with RequiresDistributionAndOrdering {
+ override def requiredDistribution(): Distribution =
+ Distributions.clustered(Array(PARTITION_COLUMN_REF))
+ override def requiredOrdering(): Array[SortOrder] = Array[SortOrder](
+ LogicalExpressions.sort(
+ PARTITION_COLUMN_REF,
+ SortDirection.ASCENDING,
+ SortDirection.ASCENDING.defaultNullOrdering()))
+ override def toBatch: DeltaBatchWrite =
+ new TestBatchWrite with DeltaBatchWrite {
+ override def createBatchWriterFactory(
+ info: PhysicalWriteInfo): DeltaWriterFactory =
+ new DeltaBufferedRowsWriterFactory(lastWriteInfo.schema())
+
+ // For delete+reinsert with narrow writes, the REINSERT row has only the
+ // connector-declared columns (requiredDataAttributes order).
+ // pk is the first field in the write schema (declared before updatedCols).
+ // Reconstruct the full row by overlaying the narrow row onto the original.
+ override protected def doCommit(messages: Array[WriterCommitMessage]): Unit =
+ dataMap.synchronized {
+ val newData = messages.map(_.asInstanceOf[BufferedRows])
+ val writeSchema = lastWriteInfo.schema()
+ val writeFieldIdx = writeSchema.fieldNames.zipWithIndex.toMap
+ val reinsertOpName = UTF8String.fromString(Reinsert.toString)
+ val pkIdx = writeFieldIdx("pk")
+
+ val expandedData = newData.map { buf =>
+ val expanded = new BufferedRows(buf.key, schema)
+ buf.log.foreach { logRow =>
+ val opName = logRow.getUTF8String(0)
+ if (opName == reinsertOpName) {
+ val narrowRow = logRow.get(3, writeSchema).asInstanceOf[InternalRow]
+ val pk = narrowRow.getInt(pkIdx)
+ val baseRow = dataMap.values.iterator.flatten
+ .flatMap(_.rows)
+ .find(r => r.getInt(schema.fieldIndex("pk")) == pk)
+ val fullRow = new GenericInternalRow(schema.length)
+ baseRow.foreach { base =>
+ for (i <- schema.fields.indices) {
+ fullRow.update(i, base.get(i, schema.fields(i).dataType))
+ }
+ }
+ schema.fields.zipWithIndex.foreach { case (field, i) =>
+ writeFieldIdx.get(field.name).foreach { j =>
+ fullRow.update(i, narrowRow.get(j, field.dataType))
+ }
+ }
+ expanded.rows.append(fullRow)
+ }
+ }
+ expanded
+ }
+
+ withDeletes(newData)
+ withData(expandedData)
+ lastWriteLog = newData.flatMap(buffer => buffer.log).toIndexedSeq
+ }
+
+ override def abort(messages: Array[WriterCommitMessage]): Unit = {}
+ }
+ }
+ }
+ }
+ }
+
+ class PartitionBasedColumnUpdateOperation(
+ command: Command,
+ updatedCols: Seq[NamedReference] = Nil) extends RowLevelOperation {
+ var configuredScan: InMemoryBatchScan = _
+
+ override def command(): Command = command
+
+ override def supportsColumnUpdates(): Boolean = true
+
+ override def requiredDataAttributes(): Array[NamedReference] = {
+ val base = Seq(FieldReference("pk"), FieldReference("dep"))
+ val baseNames = base.map(_.describe()).toSet
+ (base ++ updatedCols.filterNot(r => baseNames.contains(r.describe()))).toArray
+ }
+
+ override def requiredMetadataAttributes(): Array[NamedReference] =
+ Array(PARTITION_COLUMN_REF, INDEX_COLUMN_REF)
+
+ override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
+ new InMemoryScanBuilder(schema, options) {
+ override def build(): Scan = {
+ val scan = super.build()
+ configuredScan = scan.asInstanceOf[InMemoryBatchScan]
+ lastScanSchema = scan.readSchema()
+ scan
+ }
+ }
+ }
+
+ override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = {
+ lastWriteInfo = info
+ new WriteBuilder {
+ override def build(): Write = new Write with RequiresDistributionAndOrdering {
+ override def requiredDistribution: Distribution =
+ Distributions.clustered(Array(PARTITION_COLUMN_REF))
+
+ override def requiredOrdering: Array[SortOrder] = Array[SortOrder](
+ LogicalExpressions.sort(
+ PARTITION_COLUMN_REF,
+ SortDirection.ASCENDING,
+ SortDirection.ASCENDING.defaultNullOrdering()))
+
+ override def toBatch: BatchWrite =
+ PartitionBasedNarrowReplaceData(configuredScan, info.schema())
+
+ override def description: String = "InMemoryNarrowCoWWrite"
+ }
+ }
+ }
+
+ override def description(): String = "InMemoryPartitionColumnUpdateOperation"
+ }
+
+ // CoW write handler for narrow column-update writes.
+ // Receives rows with only the connector-declared + assigned columns.
+ // Reconstructs full rows by looking up the original row by pk and overlaying received columns.
+ private case class PartitionBasedNarrowReplaceData(
+ scan: InMemoryBatchScan,
+ writeSchema: StructType) extends TestBatchWrite {
+
+ override protected def doCommit(
+ messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized {
+ val newData = messages.map(_.asInstanceOf[BufferedRows])
+ val readRows = scan.data.flatMap(_.asInstanceOf[BufferedRows].rows)
+ val readPartitions = readRows.map(r => getKey(r, schema)).distinct
+ dataMap --= readPartitions
+ replacedPartitions = readPartitions
+
+ val writeFieldIdx = writeSchema.fieldNames.zipWithIndex.toMap
+ val pkIdxInWrite = writeFieldIdx("pk")
+ val pkIdxInFull = schema.fieldIndex("pk")
+
+ val expandedData = newData.map { buf =>
+ val expanded = new BufferedRows(buf.key, schema)
+ buf.rows.foreach { narrowRow =>
+ val pk = narrowRow.getInt(pkIdxInWrite)
+ val origRow = readRows.find(r => r.getInt(pkIdxInFull) == pk)
+ val fullRow = new GenericInternalRow(schema.length)
+ origRow.foreach { base =>
+ for (i <- schema.fields.indices) {
+ fullRow.update(i, base.get(i, schema.fields(i).dataType))
+ }
+ }
+ schema.fields.zipWithIndex.foreach { case (field, i) =>
+ writeFieldIdx.get(field.name).foreach { j =>
+ fullRow.update(i, narrowRow.get(j, field.dataType))
+ }
+ }
+ expanded.rows.append(fullRow)
+ }
+ expanded
+ }
+
+ withData(expandedData, schema)
+ lastWriteLog = newData.flatMap(buffer => buffer.log).toImmutableArraySeq
+ }
+ }
+
private object TestDeltaBatchWrite extends TestBatchWrite with DeltaBatchWrite {
override def createBatchWriterFactory(info: PhysicalWriteInfo): DeltaWriterFactory = {
new DeltaBufferedRowsWriterFactory(CatalogV2Util.v2ColumnsToStructType(columns()))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala
index 4b9dff5c3d780..c204146cdc388 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala
@@ -112,6 +112,7 @@ class TxnTable(
override def newRowLevelOperationBuilder(
info: RowLevelOperationInfo): RowLevelOperationBuilder = {
catalog.writeTarget = this
+ delegate.lastUpdatedColumns = info.updatedColumns()
super.newRowLevelOperationBuilder(info)
}
@@ -128,6 +129,8 @@ class TxnTable(
delegate.replacedPartitions = replacedPartitions
delegate.lastWriteInfo = lastWriteInfo
delegate.lastWriteLog = lastWriteLog
+ delegate.lastUpdatedColumns = lastUpdatedColumns
+ delegate.lastScanSchema = lastScanSchema
delegate.commits ++= commits
delegate.increaseVersion()
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala
index 9f8409efa360e..16e68fc7320cd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.execution.dynamicpruning
-import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, DynamicPruningExpression, Expression, InSubquery, ListQuery, PredicateHelper, V2ExpressionUtils}
import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
import org.apache.spark.sql.catalyst.optimizer.RewritePredicateSubquery
@@ -139,17 +138,13 @@ class RowLevelOperationRuntimeGroupFiltering(optimizeSubqueries: Rule[LogicalPla
tableAttrs: Seq[Attribute],
scanAttrs: Seq[Attribute]): AttributeMap[Attribute] = {
- val attrMapping = tableAttrs.map { tableAttr =>
+ // The scan may be narrowed to exclude columns not needed by the connector.
+ // Attributes absent from the scan are skipped here; the caller must ensure
+ // that any attribute referenced in the condition is present in the scan.
+ val attrMapping = tableAttrs.flatMap { tableAttr =>
scanAttrs
.find(scanAttr => conf.resolver(scanAttr.name, tableAttr.name))
.map(scanAttr => tableAttr -> scanAttr)
- .getOrElse {
- throw new AnalysisException(
- errorClass = "_LEGACY_ERROR_TEMP_3075",
- messageParameters = Map(
- "tableAttr" -> tableAttr.toString,
- "scanAttrs" -> scanAttrs.mkString(",")))
- }
}
AttributeMap(attrMapping)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedColumnUpdateTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedColumnUpdateTableSuite.scala
new file mode 100644
index 0000000000000..1440f481d982e
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedColumnUpdateTableSuite.scala
@@ -0,0 +1,514 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.connector.catalog.{CatalogV2Util, TableInfo}
+import org.apache.spark.sql.connector.expressions.LogicalExpressions.{identity, reference}
+import org.apache.spark.sql.connector.expressions.Transform
+import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
+
+/**
+ * Tests for UPDATE statements targeting connectors that return true from
+ * [[org.apache.spark.sql.connector.write.RowLevelOperation#supportsColumnUpdates]].
+ *
+ * When a connector supports column updates, Spark narrows the row projection
+ * (LogicalWriteInfo.schema()) to contain only the assigned/changed columns rather than
+ * the full table row.
+ */
+class DeltaBasedColumnUpdateTableSuite extends RowLevelOperationSuiteBase {
+
+ override protected lazy val extraTableProps: java.util.Map[String, String] = {
+ val props = new java.util.HashMap[String, String]()
+ props.put("column-update", "true")
+ props
+ }
+
+ test("column-update: rowSchema contains only the single assigned column") {
+ createAndInitTable("pk INT NOT NULL, id INT, dep STRING",
+ """{ "pk": 1, "id": 1, "dep": "hr" }
+ |{ "pk": 2, "id": 2, "dep": "software" }
+ |{ "pk": 3, "id": 3, "dep": "hr" }
+ |""".stripMargin)
+
+ sql(s"UPDATE $tableNameAsString SET id = -1 WHERE pk = 1")
+
+ checkLastWriteInfo(
+ expectedRowSchema = StructType(Seq(
+ StructField("id", IntegerType, nullable = false)
+ )),
+ expectedRowIdSchema = Some(StructType(Array(PK_FIELD))),
+ expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE))))
+ }
+
+ test("column-update: rowSchema contains multiple assigned columns") {
+ createAndInitTable("pk INT NOT NULL, id INT, dep STRING",
+ """{ "pk": 1, "id": 1, "dep": "hr" }
+ |{ "pk": 2, "id": 2, "dep": "software" }
+ |""".stripMargin)
+
+ sql(s"UPDATE $tableNameAsString SET id = -1, dep = 'engineering' WHERE pk = 1")
+
+ checkLastWriteInfo(
+ expectedRowSchema = StructType(Seq(
+ StructField("id", IntegerType, nullable = false),
+ StructField("dep", StringType, nullable = false)
+ )),
+ expectedRowIdSchema = Some(StructType(Array(PK_FIELD))),
+ expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE))))
+ }
+
+ test("column-update: rowSchema is empty for a full identity update") {
+ createAndInitTable("pk INT NOT NULL, id INT, dep STRING",
+ """{ "pk": 1, "id": 1, "dep": "hr" }
+ |""".stripMargin)
+
+ sql(s"UPDATE $tableNameAsString SET id = id, dep = dep WHERE pk = 1")
+
+ checkLastWriteInfo(
+ expectedRowSchema = new StructType(),
+ expectedRowIdSchema = Some(StructType(Array(PK_FIELD))),
+ expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE))))
+ }
+
+ test("column-update: row filter condition is orthogonal to column narrowing") {
+ createAndInitTable("pk INT NOT NULL, id INT, dep STRING",
+ """{ "pk": 1, "id": 1, "dep": "hr" }
+ |{ "pk": 2, "id": 2, "dep": "software" }
+ |{ "pk": 3, "id": 3, "dep": "hr" }
+ |""".stripMargin)
+
+ sql(s"UPDATE $tableNameAsString SET dep = 'engineering' WHERE pk IN (1, 3)")
+
+ checkLastWriteInfo(
+ expectedRowSchema = StructType(Seq(
+ StructField("dep", StringType, nullable = false)
+ )),
+ expectedRowIdSchema = Some(StructType(Array(PK_FIELD))),
+ expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE))))
+ }
+
+ test("column-update: update all rows (no WHERE clause)") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ sql(s"UPDATE $tableNameAsString SET salary = salary * 2")
+
+ checkLastWriteInfo(
+ expectedRowSchema = StructType(Seq(
+ StructField("salary", IntegerType, nullable = true)
+ )),
+ expectedRowIdSchema = Some(StructType(Array(PK_FIELD))),
+ expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE))))
+ }
+
+ test("column-update: rowSchema excludes identity assignments in a mixed UPDATE") {
+ createAndInitTable("pk INT NOT NULL, id INT, dep STRING",
+ """{ "pk": 1, "id": 1, "dep": "hr" }
+ |{ "pk": 2, "id": 2, "dep": "software" }
+ |""".stripMargin)
+
+ sql(s"UPDATE $tableNameAsString SET id = id, dep = 'engineering' WHERE pk = 1")
+
+ checkLastWriteInfo(
+ expectedRowSchema = StructType(Seq(
+ StructField("dep", StringType, nullable = false)
+ )),
+ expectedRowIdSchema = Some(StructType(Array(PK_FIELD))),
+ expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE))))
+ }
+
+ test("column-update: cross-column assignment is not treated as identity") {
+ createAndInitTable("pk INT NOT NULL, id INT, dep STRING",
+ """{ "pk": 1, "id": 1, "dep": "hr" }
+ |{ "pk": 2, "id": 2, "dep": "software" }
+ |""".stripMargin)
+
+ sql(s"UPDATE $tableNameAsString SET dep = dep, id = -1 WHERE pk = 1")
+
+ checkLastWriteInfo(
+ expectedRowSchema = StructType(Seq(
+ StructField("id", IntegerType, nullable = false)
+ )),
+ expectedRowIdSchema = Some(StructType(Array(PK_FIELD))),
+ expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE))))
+ }
+
+ test("column-update: nested struct field update narrows to the root struct column") {
+ createAndInitTable("pk INT NOT NULL, s STRUCT