From 4df62792822944dc4f31d53323fef34514ce51e2 Mon Sep 17 00:00:00 2001 From: Anurag Mantripragada Date: Fri, 17 Apr 2026 09:26:10 -0700 Subject: [PATCH 1/5] [SPARK-56599][SQL] Add scan narrowing for column-level UPDATEs in DSv2 --- .../connector/write/RowLevelOperation.java | 41 ++ .../write/RowLevelOperationInfo.java | 17 + .../analysis/RewriteRowLevelCommand.scala | 42 +- .../analysis/RewriteUpdateTable.scala | 242 ++++++- .../catalyst/plans/logical/v2Commands.scala | 61 +- .../write/RowLevelOperationInfoImpl.scala | 8 +- .../InMemoryRowLevelOperationTable.scala | 355 ++++++++- ...wLevelOperationRuntimeGroupFiltering.scala | 13 +- .../DeltaBasedColumnUpdateTableSuite.scala | 677 ++++++++++++++++++ .../DeltaBasedUpdateTableSuiteBase.scala | 68 ++ .../RowLevelOperationSuiteBase.scala | 12 + 11 files changed, 1480 insertions(+), 56 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedColumnUpdateTableSuite.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RowLevelOperation.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RowLevelOperation.java index 844734ff7ccb7..8c8affdd7c098 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RowLevelOperation.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RowLevelOperation.java @@ -105,4 +105,45 @@ default String description() { default NamedReference[] requiredMetadataAttributes() { return new NamedReference[0]; } + + + /** + * Controls whether to send only the required data columns to the connector rather than the + * full row. + *

+ * When true, Spark narrows the data column schema ({@link LogicalWriteInfo#schema()}) to only + * the columns declared via {@link #requiredDataAttributes()}. Metadata columns (from + * {@link #requiredMetadataAttributes()}) and row ID columns (from + * {@link SupportsDelta#rowId()}) are unaffected and always projected separately. + *

+ * If {@link #requiredDataAttributes()} returns a non-empty array, the write schema is exactly + * those columns in declared order. The connector must include all columns it wants to receive, + * including the columns being updated. If {@link #requiredDataAttributes()} returns an empty + * array, Spark sends only the non-identity assigned columns (heuristic path). + * + * @since 4.2.0 + */ + default boolean supportsColumnUpdates() { + return false; + } + + /** + * Returns data column references required to perform this row-level operation. + *

+ * This method is only consulted by Spark when {@link #supportsColumnUpdates()} returns + * {@code true}. If {@code supportsColumnUpdates()} returns {@code false}, the returned array + * is ignored and the full table row is sent (the default behavior). + *

+ * When non-empty, the returned columns become the write schema in declared order. + * The connector must declare all columns it wants to receive, including the columns being + * updated. Use {@link RowLevelOperationInfo#updatedColumns()} to learn which columns are being + * assigned, then add any extra columns needed for row lookup or routing (e.g., primary key). + *

+ * When empty (the default), Spark falls back to sending only the non-identity assigned columns. + * + * @since 4.2.0 + */ + default NamedReference[] requiredDataAttributes() { + return new NamedReference[0]; + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RowLevelOperationInfo.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RowLevelOperationInfo.java index e3d7397aed91b..77bb5b31e28bc 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RowLevelOperationInfo.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RowLevelOperationInfo.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.connector.write; import org.apache.spark.annotation.Experimental; +import org.apache.spark.sql.connector.expressions.NamedReference; import org.apache.spark.sql.connector.write.RowLevelOperation.Command; import org.apache.spark.sql.util.CaseInsensitiveStringMap; @@ -37,4 +38,20 @@ public interface RowLevelOperationInfo { * Returns the row-level SQL command (e.g. DELETE, UPDATE, MERGE). */ Command command(); + + /** + * Returns the columns being updated in an UPDATE statement, as non-identity assignments. + * + *

For DELETE and MERGE, returns an empty array. + * + *

Connectors can use this to decide what {@link RowLevelOperation#requiredDataAttributes()} + * to declare. For instance, a connector that needs its primary key for row lookup can check + * whether pk is already in the updated columns list and, if not, add it to + * requiredDataAttributes(). + * + * @since 4.2.0 + */ + default NamedReference[] updatedColumns() { + return new NamedReference[0]; + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala index 48c48eb323bd7..98d73225515a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala @@ -22,12 +22,13 @@ import scala.collection.mutable import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.ProjectingInternalRow import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, AttributeSet, Expression, ExprId, If, Literal, MetadataAttribute, NamedExpression, V2ExpressionUtils} +import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral import org.apache.spark.sql.catalyst.plans.logical.{Assignment, Expand, LogicalPlan, MergeRows, Project, Union} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.{ReplaceDataProjections, WriteDeltaProjections} import org.apache.spark.sql.catalyst.util.RowDeltaUtils._ import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations -import org.apache.spark.sql.connector.expressions.FieldReference +import org.apache.spark.sql.connector.expressions.{FieldReference, NamedReference} import org.apache.spark.sql.connector.write.{RowLevelOperation, RowLevelOperationInfoImpl, RowLevelOperationTable, SupportsDelta} import org.apache.spark.sql.connector.write.RowLevelOperation.Command import org.apache.spark.sql.errors.QueryCompilationErrors @@ -50,20 +51,35 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] { protected def buildOperationTable( table: SupportsRowLevelOperations, command: Command, - options: CaseInsensitiveStringMap): RowLevelOperationTable = { - val info = RowLevelOperationInfoImpl(command, options) + options: CaseInsensitiveStringMap, + updatedColumns: Seq[NamedReference] = Nil): RowLevelOperationTable = { + val info = RowLevelOperationInfoImpl(command, options, updatedColumns) val operation = table.newRowLevelOperationBuilder(info).build() RowLevelOperationTable(table, operation) } + // Builds a DataSourceV2Relation for a row-level operation, optionally narrowing its output. + // + // When dataAttrs is non-empty, the relation output is narrowed to include only columns + // required for a column-update write. When dataAttrs is empty, the full relation.output is + // preserved. protected def buildRelationWithAttrs( relation: DataSourceV2Relation, table: RowLevelOperationTable, metadataAttrs: Seq[AttributeReference], - rowIdAttrs: Seq[AttributeReference] = Nil): DataSourceV2Relation = { - - val attrs = dedupAttrs(relation.output ++ rowIdAttrs ++ metadataAttrs) - relation.copy(table = table, output = attrs) + rowIdAttrs: Seq[AttributeReference] = Nil, + dataAttrs: Seq[AttributeReference] = Nil, + cond: Expression = TrueLiteral): DataSourceV2Relation = { + + if (dataAttrs.nonEmpty) { + val required = + AttributeSet(dataAttrs) ++ AttributeSet(Seq(cond)) ++ AttributeSet(rowIdAttrs) + val narrowOutput = relation.output.filter(required.contains) + relation.copy(table = table, output = dedupAttrs(narrowOutput ++ rowIdAttrs ++ metadataAttrs)) + } else { + val attrs = dedupAttrs(relation.output ++ rowIdAttrs ++ metadataAttrs) + relation.copy(table = table, output = attrs) + } } protected def dedupAttrs(attrs: Seq[AttributeReference]): Seq[AttributeReference] = { @@ -87,6 +103,14 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] { relation) } + protected def resolveRequiredDataAttrs( + relation: DataSourceV2Relation, + operation: RowLevelOperation): Seq[AttributeReference] = { + V2ExpressionUtils.resolveRefs[AttributeReference]( + operation.requiredDataAttributes.toImmutableArraySeq, + relation) + } + protected def resolveRowIdAttrs( relation: DataSourceV2Relation, operation: SupportsDelta): Seq[AttributeReference] = { @@ -211,11 +235,13 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] { metadataAttrs: Seq[Attribute]): WriteDeltaProjections = { val outputs = extractOutputs(plan) + // Always produce Some(rowProjection) even for empty rowAttrs (identity-only column updates). + // Physical execution calls rowProjection.project(row) unconditionally; None causes NPE. val rowProjection = if (rowAttrs.nonEmpty) { val outputsWithRow = filterOutputs(outputs, OPERATIONS_WITH_ROW) Some(newLazyProjection(plan, outputsWithRow, rowAttrs)) } else { - None + Some(ProjectingInternalRow(StructType(Nil), Nil)) } val outputsWithRowId = filterOutputs(outputs, OPERATIONS_WITH_ROW_ID) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala index f235374bd5d6f..b4ef89617c82d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala @@ -17,11 +17,12 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, EqualNullSafe, Expression, If, Literal, MetadataAttribute, Not, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, AttributeReference, AttributeSet, Cast, EqualNullSafe, Expression, If, Literal, MetadataAttribute, Not, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral import org.apache.spark.sql.catalyst.plans.logical.{Assignment, Expand, Filter, LogicalPlan, Project, ReplaceData, Union, UpdateTable, WriteDelta} import org.apache.spark.sql.catalyst.util.RowDeltaUtils._ import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations +import org.apache.spark.sql.connector.expressions.FieldReference import org.apache.spark.sql.connector.write.{RowLevelOperationTable, SupportsDelta} import org.apache.spark.sql.connector.write.RowLevelOperation.Command.UPDATE import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, ExtractV2Table} @@ -41,7 +42,13 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { EliminateSubqueryAliases(aliasedTable) match { case r @ ExtractV2Table(tbl: SupportsRowLevelOperations) => - val table = buildOperationTable(tbl, UPDATE, CaseInsensitiveStringMap.empty()) + val updatedCols = assignments.collect { + case Assignment(key: AttributeReference, value) + if !isIdentityAssignment(key, value) => + FieldReference(key.name) + } + val table = buildOperationTable(tbl, UPDATE, CaseInsensitiveStringMap.empty(), + updatedCols) val updateCond = cond.getOrElse(TrueLiteral) table.operation match { case _: SupportsDelta => @@ -65,18 +72,15 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { assignments: Seq[Assignment], cond: Expression): ReplaceData = { - // resolve all required metadata attrs that may be used for grouping data on write - val metadataAttrs = resolveRequiredMetadataAttrs(relation, operationTable.operation) - - // construct a read relation and include all required metadata columns - val readRelation = buildRelationWithAttrs(relation, operationTable, metadataAttrs) + val (readRelation, rowAttrs) = buildCoWReadSetup(relation, operationTable, assignments, cond) - // build a plan with updated and copied over records - val query = buildReplaceDataUpdateProjection(readRelation, assignments, cond) + val updatedAndRemainingRowsPlan = buildReplaceDataUpdateProjection( + readRelation, assignments, cond) - // build a plan to replace read groups in the table val writeRelation = relation.copy(table = operationTable) - val projections = buildReplaceDataProjections(query, relation.output, metadataAttrs) + val query = updatedAndRemainingRowsPlan + val metadataAttrs = resolveRequiredMetadataAttrs(relation, operationTable.operation) + val projections = buildReplaceDataProjections(query, rowAttrs, metadataAttrs) val groupFilterCond = if (groupFilterEnabled) Some(cond) else None ReplaceData(writeRelation, cond, query, relation, projections, groupFilterCond) } @@ -89,13 +93,7 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { assignments: Seq[Assignment], cond: Expression): ReplaceData = { - // resolve all required metadata attrs that may be used for grouping data on write - val metadataAttrs = resolveRequiredMetadataAttrs(relation, operationTable.operation) - - // construct a read relation and include all required metadata columns - // the same read relation will be used to read records that must be updated and copied over - // the analyzer will take care of duplicated attr IDs - val readRelation = buildRelationWithAttrs(relation, operationTable, metadataAttrs) + val (readRelation, rowAttrs) = buildCoWReadSetup(relation, operationTable, assignments, cond) // build a plan for updated records that match the condition val matchedRowsPlan = Filter(cond, readRelation) @@ -106,38 +104,92 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { val remainingRowsPlan = addOperationColumn(COPY_OPERATION, Filter(remainingRowFilter, readRelation)) - // the new state is a union of updated and copied over records - val query = Union(updatedRowsPlan, remainingRowsPlan) + val updatedAndRemainingRowsPlan = Union(updatedRowsPlan, remainingRowsPlan) - // build a plan to replace read groups in the table val writeRelation = relation.copy(table = operationTable) - val projections = buildReplaceDataProjections(query, relation.output, metadataAttrs) + val query = updatedAndRemainingRowsPlan + val metadataAttrs = resolveRequiredMetadataAttrs(relation, operationTable.operation) + val projections = buildReplaceDataProjections(query, rowAttrs, metadataAttrs) val groupFilterCond = if (groupFilterEnabled) Some(cond) else None ReplaceData(writeRelation, cond, query, relation, projections, groupFilterCond) } + // Common read-relation setup shared by both CoW plan builders. + // + // When the connector supports column updates and declares required data attributes, + // the read relation is narrowed at analysis time so that + // GroupBasedRowLevelOperationScanPlanning uses only the needed columns for the scan. + // Otherwise the full relation output is used. + private def buildCoWReadSetup( + relation: DataSourceV2Relation, + operationTable: RowLevelOperationTable, + assignments: Seq[Assignment], + cond: Expression): (DataSourceV2Relation, Seq[Attribute]) = { + + val operation = operationTable.operation + val metadataAttrs = resolveRequiredMetadataAttrs(relation, operation) + val connectorDataAttrs = resolveRequiredDataAttrs(relation, operation) + val isNarrow = operation.supportsColumnUpdates() && connectorDataAttrs.nonEmpty + + // CoW scan narrowing must be done manually at analysis time. + // GroupBasedRowLevelOperationScanPlanning (an optimizer rule that fires after analysis) + // always reads relation.output directly when building the physical scan -- it does not + // observe Project nodes above the relation, so optimizer-driven column pruning has no + // effect on CoW scans. We narrow DataSourceV2Relation.output here so that rule picks + // up the narrow set. + val readRelation = if (isNarrow) { + val allRequired = (connectorDataAttrs ++ computeAssignedAttrs(assignments)).distinct + buildRelationWithAttrs(relation, operationTable, metadataAttrs, dataAttrs = allRequired, + cond = cond) + } else { + buildRelationWithAttrs(relation, operationTable, metadataAttrs) + } + + // CoW write schema (two paths only, no heuristic for CoW): + // - Narrow path (connectorDataAttrs declared): exactly connector-declared cols in declared + // order. The connector must declare ALL columns it wants to receive. + // - Full path (connectorDataAttrs empty OR supportsColumnUpdates=false): full table output. + // Unlike MOR, CoW does not have a heuristic assigned-only path because + // GroupBasedRowLevelOperationScanPlanning needs explicit column declarations to narrow. + val rowAttrs: Seq[Attribute] = if (isNarrow) connectorDataAttrs else relation.output + + (readRelation, rowAttrs) + } + // this method assumes the assignments have been already aligned before + // + // Works for both the full-scan and narrow-scan CoW paths. In the narrow case, + // readRelation.output is already restricted by buildCoWReadSetup, so projecting + // all plan.output gives the correct narrow write schema. private def buildReplaceDataUpdateProjection( plan: LogicalPlan, assignments: Seq[Assignment], cond: Expression = TrueLiteral): LogicalPlan = { - // the plan output may include metadata columns at the end - // that's why the number of assignments may not match the number of plan output columns - val assignedValues = assignments.map(_.value) - val updatedValues = plan.output.zipWithIndex.map { case (attr, index) => - if (index < assignments.size) { - val assignedExpr = assignedValues(index) - val updatedValue = If(cond, assignedExpr, attr) - Alias(updatedValue, attr.name)() - } else { - assert(MetadataAttribute.isValid(attr.metadata)) + // Build a name-keyed map via AttributeMap (compares by exprId internally) so we can look + // up each plan column's assignment without relying on positional ordering. This is more + // robust than position-based indexing and works correctly for any plan output layout. + val assignmentMap = AttributeMap(assignments.collect { + case Assignment(key: Attribute, value) => key -> value + }) + + val updatedValues = plan.output.map { attr => + if (MetadataAttribute.isValid(attr.metadata)) { if (MetadataAttribute.isPreservedOnUpdate(attr)) { attr } else { val updatedValue = If(cond, Literal(null, attr.dataType), attr) Alias(updatedValue, attr.name)(explicitMetadata = Some(attr.metadata)) } + } else { + assignmentMap.get(attr) match { + case Some(assignedExpr) => + Alias(If(cond, assignedExpr, attr), attr.name)() + case None => + // Column is present in the scan but has no assignment -- pass through unchanged. + // In the narrow CoW path these are connector-declared columns not being updated. + attr + } } } @@ -154,30 +206,150 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { cond: Expression): WriteDelta = { val operation = operationTable.operation.asInstanceOf[SupportsDelta] + // Column-update support applies to the standard delta path and the delete+reinsert path. + // When representUpdateAsDeleteAndInsert is true, the REINSERT leg of the Expand already + // uses only assigned values, so the narrow effectiveRowAttrs applies correctly. + val supportsColumnUpdate = operation.supportsColumnUpdates() // resolve all needed attrs (e.g. row ID and any required metadata attrs) - val rowAttrs = relation.output val rowIdAttrs = resolveRowIdAttrs(relation, operation) val metadataAttrs = resolveRequiredMetadataAttrs(relation, operation) - // construct a read relation and include all required metadata columns + // Connector-declared data attrs used to determine pass-through columns in the write plan. + val connectorDataAttrs = if (supportsColumnUpdate) { + resolveRequiredDataAttrs(relation, operation) + } else Nil + + // MOR uses a full-schema scan; ColumnPruning narrows it via Project references. val readRelation = buildRelationWithAttrs(relation, operationTable, metadataAttrs, rowIdAttrs) + // Connector-required attrs that are NOT being assigned are added as pass-throughs in the + // plan so that ColumnPruning keeps them in the physical scan AND the connector receives + // their current values via DeltaWriter.update's row argument. + val assignedAttrs = if (supportsColumnUpdate) computeAssignedAttrs(assignments) + else relation.output + val connectorExtraAttrs: Seq[AttributeReference] = if (connectorDataAttrs.nonEmpty) { + val assignedAttrSet = AttributeSet(assignedAttrs) + connectorDataAttrs.filterNot(assignedAttrSet.contains) + } else Nil + // build a plan for updated records that match the condition val matchedRowsPlan = Filter(cond, readRelation) val rowDeltaPlan = if (operation.representUpdateAsDeleteAndInsert) { buildDeletesAndInserts(matchedRowsPlan, assignments, rowIdAttrs) + } else if (supportsColumnUpdate) { + buildColumnUpdateProjection( + matchedRowsPlan, assignments, rowIdAttrs, metadataAttrs, connectorExtraAttrs) } else { buildWriteDeltaUpdateProjection(matchedRowsPlan, assignments, rowIdAttrs) } + // Effective row write schema: + // - Narrow path (connectorDataAttrs declared): exactly connector-declared cols in declared + // order. The connector must declare ALL columns it wants to receive (including updated + // ones). This mirrors the metadata pattern and enables strict areCompatible validation. + // - Heuristic path (connectorDataAttrs empty): only the assigned (changed) columns. + // - Full path (no column-update support): full table output. + val effectiveRowAttrs = if (supportsColumnUpdate && connectorDataAttrs.nonEmpty) { + connectorDataAttrs + } else if (supportsColumnUpdate) { + assignedAttrs + } else { + relation.output + } + // build a plan to write the row delta to the table val writeRelation = relation.copy(table = operationTable) - val projections = buildWriteDeltaProjections(rowDeltaPlan, rowAttrs, rowIdAttrs, metadataAttrs) + val projections = buildWriteDeltaProjections( + rowDeltaPlan, effectiveRowAttrs, rowIdAttrs, metadataAttrs) val groupFilterCond = if (groupFilterEnabled) Some(cond) else None WriteDelta(writeRelation, cond, rowDeltaPlan, relation, projections, groupFilterCond) } + // Builds the row delta projection for the column update path. + // + // The resulting Project references only: + // - assigned column values (new values being written) + // - connector pass-through values (connector declared but not assigned) + // - metadata columns (nulled or preserved) + // - row ID columns (for delta identification) + // - original row ID values (only when a row ID column is being reassigned) + // + // ColumnPruning observes exactly these references and narrows the physical scan accordingly. + // Connectors that need additional columns in the scan (e.g., partition columns for + // distribution) should declare them in requiredDataAttributes(). + // + // Note: AlignUpdateAssignments guarantees all assignment keys are top-level + // AttributeReferences even for nested field updates (e.g., SET col1.field = 'x' becomes + // Assignment(col1: AttributeReference, CreateNamedStruct(...))), so isIdentityAssignment + // correctly identifies non-updating assignments. + private def buildColumnUpdateProjection( + plan: LogicalPlan, + assignments: Seq[Assignment], + rowIdAttrs: Seq[Attribute], + metadataAttrs: Seq[Attribute], + connectorExtraAttrs: Seq[AttributeReference] = Nil): LogicalPlan = { + + // only emit values for non-identity assignments (the narrow write schema) + val assignedValues = assignments.collect { + case Assignment(key: Attribute, value) if !isIdentityAssignment(key, value) => + Alias(value, key.name)() + } + + // Connector-required data attrs that are not being assigned are passed through as-is + // so that (a) ColumnPruning keeps them in the physical scan, and (b) the connector + // receives their current values via DeltaWriter.update's row argument. + val connectorExtraAttrSet = AttributeSet(connectorExtraAttrs) + val connectorPassThroughValues = plan.output.filter { a => + connectorExtraAttrSet.contains(a) && !MetadataAttribute.isValid(a.metadata) + } + + // pass through or null out metadata columns present in the scan + val metadataAttrSet = AttributeSet(metadataAttrs) + val metadataValues = plan.output.filter(metadataAttrSet.contains).map { attr => + if (MetadataAttribute.isPreservedOnUpdate(attr)) { + attr + } else { + Alias(Literal(null, attr.dataType), attr.name)(explicitMetadata = Some(attr.metadata)) + } + } + + // pass through row ID columns from the scan + val rowIdAttrSet = AttributeSet(rowIdAttrs) + val rowIdValues = plan.output.filter(rowIdAttrSet.contains) + + val originalRowIdValues = buildOriginalRowIdValues(rowIdAttrs, assignments) + val operationType = Alias(Literal(UPDATE_OPERATION), OPERATION_COLUMN)() + + Project( + Seq(operationType) ++ assignedValues ++ connectorPassThroughValues ++ + metadataValues ++ rowIdValues ++ originalRowIdValues, + plan) + } + + // Returns the table attributes that are genuinely updated (non-identity) in this UPDATE. + // Strips Alias/Cast wrappers introduced during assignment alignment before doing the + // AttributeSet membership check (which uses exprId equality internally). + private def computeAssignedAttrs(assignments: Seq[Assignment]): Seq[AttributeReference] = { + assignments.collect { + case Assignment(key: AttributeReference, value) if !isIdentityAssignment(key, value) => key + } + } + + private def isIdentityAssignment(key: Attribute, value: Expression): Boolean = { + stripAliasesAndCasts(value) match { + case attr: Attribute => AttributeSet(Seq(key)).contains(attr) + case _ => false + } + } + + // Recursively strips Alias and Cast wrappers introduced during assignment alignment. + private def stripAliasesAndCasts(expr: Expression): Expression = expr match { + case Alias(child, _) => stripAliasesAndCasts(child) + case Cast(child, _, _, _) => stripAliasesAndCasts(child) + case other => other + } + // this method assumes the assignments have been already aligned before private def buildWriteDeltaUpdateProjection( plan: LogicalPlan, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 40cf5009b97dc..9cc3d67b73753 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -364,6 +364,14 @@ trait RowLevelWrite extends V2WriteCommand with SupportsSubquery { operation.requiredMetadataAttributes.toImmutableArraySeq, originalTable) } + + // Resolves the connector-declared data attributes against the original table. + // Symmetric with projectedMetadataAttrs; used in narrow-schema validation. + protected def projectedDataAttrs: Seq[Attribute] = { + V2ExpressionUtils.resolveRefs[AttributeReference]( + operation.requiredDataAttributes.toImmutableArraySeq, + originalTable) + } } /** @@ -417,7 +425,35 @@ case class ReplaceData( // validates row projection output is compatible with table attributes private def rowAttrsResolved: Boolean = { val inRowAttrs = DataTypeUtils.toAttributes(projections.rowProjection.schema) - table.skipSchemaResolution || areCompatible(inRowAttrs, table.output) + table.skipSchemaResolution || + areCompatible(inRowAttrs, table.output) || + dataAttrsResolved(inRowAttrs) + } + + // Validates the narrow-write-schema row projection output. + // + // When the connector declares specific data attributes via requiredDataAttributes(), the + // write schema must exactly match projectedDataAttrs (same columns, same order). This is + // symmetric with metadataAttrsResolved: the connector's declared attrs define the write schema. + // + // When requiredDataAttributes() is empty (heuristic path), the write schema contains only + // the assigned columns. We validate each one exists in the table with a compatible type. + private def dataAttrsResolved(inRowAttrs: Seq[Attribute]): Boolean = { + if (!operation.supportsColumnUpdates()) { return false } + val outDataAttrs = projectedDataAttrs + if (outDataAttrs.nonEmpty) { + areCompatible(inRowAttrs, outDataAttrs) + } else { + inRowAttrs.forall { inAttr => + table.output.exists { outAttr => + val inType = CharVarcharUtils.getRawType(inAttr.metadata).getOrElse(inAttr.dataType) + val outType = CharVarcharUtils.getRawType(outAttr.metadata).getOrElse(outAttr.dataType) + inAttr.name == outAttr.name && + DataType.equalsIgnoreCompatibleNullability(inType, outType) && + (outAttr.nullable || !inAttr.nullable) + } + } + } } // validates metadata projection output is compatible with metadata attributes @@ -509,7 +545,28 @@ case class WriteDelta( case Some(projection) => DataTypeUtils.toAttributes(projection.schema) case None => Nil } - table.skipSchemaResolution || areCompatible(inRowAttrs, outRowAttrs) + table.skipSchemaResolution || + areCompatible(inRowAttrs, outRowAttrs) || + dataAttrsResolved(inRowAttrs) + } + + // Validates the narrow-write-schema row projection. Symmetric with ReplaceData. + private def dataAttrsResolved(inRowAttrs: Seq[Attribute]): Boolean = { + if (!operation.supportsColumnUpdates()) { return false } + val outDataAttrs = projectedDataAttrs + if (outDataAttrs.nonEmpty) { + areCompatible(inRowAttrs, outDataAttrs) + } else { + inRowAttrs.forall { inAttr => + table.output.exists { outAttr => + val inType = CharVarcharUtils.getRawType(inAttr.metadata).getOrElse(inAttr.dataType) + val outType = CharVarcharUtils.getRawType(outAttr.metadata).getOrElse(outAttr.dataType) + inAttr.name == outAttr.name && + DataType.equalsIgnoreCompatibleNullability(inType, outType) && + (outAttr.nullable || !inAttr.nullable) + } + } + } } // validates row ID projection output is compatible with row ID attributes diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/RowLevelOperationInfoImpl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/RowLevelOperationInfoImpl.scala index 9d499cdef361b..a84e0230cd8d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/RowLevelOperationInfoImpl.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/RowLevelOperationInfoImpl.scala @@ -17,9 +17,15 @@ package org.apache.spark.sql.connector.write +import org.apache.spark.sql.connector.expressions.NamedReference import org.apache.spark.sql.connector.write.RowLevelOperation.Command import org.apache.spark.sql.util.CaseInsensitiveStringMap private[sql] case class RowLevelOperationInfoImpl( command: Command, - options: CaseInsensitiveStringMap) extends RowLevelOperationInfo + options: CaseInsensitiveStringMap, + private val updatedCols: Seq[NamedReference] = Nil) + extends RowLevelOperationInfo { + + override def updatedColumns(): Array[NamedReference] = updatedCols.toArray +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala index 5c0bc0b143f3d..8635ac92173ed 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala @@ -70,18 +70,49 @@ class InMemoryRowLevelOperationTable private ( private final val SPLIT_UPDATES = "split-updates" private final val NO_METADATA = "no-metadata" private final val noMetadata = properties.getOrDefault(NO_METADATA, "false") == "true" + private final val COLUMN_UPDATE = "column-update" + private final val COLUMN_UPDATE_REQ_ATTRS = "column-update-req-attrs" + // Selects PartitionBasedColumnUpdateOperation: CoW connector with supportsColumnUpdates=true + // and requiredDataAttributes=[pk,dep]. + private final val COLUMN_UPDATE_COW = "column-update-cow" + // Selects DeltaBasedColumnUpdateOperationFromInfo: connector that derives + // requiredDataAttributes() dynamically from RowLevelOperationInfo.updatedColumns(). + // Always adds "pk" for row lookup plus whatever Spark reports as updated. + private final val COLUMN_UPDATE_FROM_INFO = "column-update-from-info" + // Selects DeltaBasedColumnUpdateSplitOperation: delta connector with + // representUpdateAsDeleteAndInsert=true AND supportsColumnUpdates=true. + // Used to verify Point 7: the restriction on column updates for the delete+reinsert path + // has been lifted. + private final val COLUMN_UPDATE_SPLIT = "column-update-split" // used in row-level operation tests to verify replaced partitions var replacedPartitions: Seq[Seq[Any]] = Seq.empty // used in row-level operation tests to verify reported write schema var lastWriteInfo: LogicalWriteInfo = _ + // used in column-update tests to verify the scan projection was narrowed correctly + var lastScanSchema: StructType = _ + // used in column-update tests to verify that Spark passed the correct updated column list + // to the connector via RowLevelOperationInfo.updatedColumns() + var lastUpdatedColumns: Array[NamedReference] = Array.empty // used in row-level operation tests to verify passed records // (operation, id, metadata, row) var lastWriteLog: Seq[InternalRow] = Seq.empty override def newRowLevelOperationBuilder( info: RowLevelOperationInfo): RowLevelOperationBuilder = { - if (properties.getOrDefault(SUPPORTS_DELTAS, "false") == "true") { + lastUpdatedColumns = info.updatedColumns() + if (properties.getOrDefault(COLUMN_UPDATE, "false") == "true") { + () => new DeltaBasedColumnUpdateOperation(info.command) + } else if (properties.containsKey(COLUMN_UPDATE_REQ_ATTRS)) { + val reqCols = properties.get(COLUMN_UPDATE_REQ_ATTRS).split(",").map(_.trim) + () => new DeltaBasedColumnUpdateOperationWithReqAttrs(info.command, reqCols) + } else if (properties.getOrDefault(COLUMN_UPDATE_FROM_INFO, "false") == "true") { + () => new DeltaBasedColumnUpdateOperationFromInfo(info.command, info.updatedColumns().toSeq) + } else if (properties.getOrDefault(COLUMN_UPDATE_COW, "false") == "true") { + () => new PartitionBasedColumnUpdateOperation(info.command, info.updatedColumns().toSeq) + } else if (properties.getOrDefault(COLUMN_UPDATE_SPLIT, "false") == "true") { + () => new DeltaBasedColumnUpdateSplitOperation(info.command, info.updatedColumns().toSeq) + } else if (properties.getOrDefault(SUPPORTS_DELTAS, "false") == "true") { () => DeltaBasedOperation(info.command) } else { () => PartitionBasedOperation(info.command) @@ -208,6 +239,328 @@ class InMemoryRowLevelOperationTable private ( } } + // A delta-based operation that supports column-level updates: Spark sends only the + // assigned/changed columns in the row projection instead of the full row schema. + class DeltaBasedColumnUpdateOperation(command: Command) + extends DeltaBasedOperation(command) { + override def representUpdateAsDeleteAndInsert(): Boolean = false + override def supportsColumnUpdates(): Boolean = true + + // Override newScanBuilder to record the schema that Spark actually requests from the + // connector after column pruning, so tests can assert on scan narrowing. + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + new InMemoryScanBuilder(schema, options) { + override def build(): Scan = { + val scan = super.build() + lastScanSchema = scan.readSchema() + scan + } + } + } + + override def newWriteBuilder(info: LogicalWriteInfo): DeltaWriteBuilder = { + lastWriteInfo = info + new DeltaWriteBuilder { + override def build(): DeltaWrite = + new DeltaWrite with RequiresDistributionAndOrdering { + + override def requiredDistribution(): Distribution = { + Distributions.clustered(Array(PARTITION_COLUMN_REF)) + } + + override def requiredOrdering(): Array[SortOrder] = { + Array[SortOrder]( + LogicalExpressions.sort( + PARTITION_COLUMN_REF, + SortDirection.ASCENDING, + SortDirection.ASCENDING.defaultNullOrdering()) + ) + } + + override def toBatch: DeltaBatchWrite = + new TestBatchWrite with DeltaBatchWrite { + override def createBatchWriterFactory( + info: PhysicalWriteInfo): DeltaWriterFactory = { + new DeltaBufferedRowsWriterFactory(lastWriteInfo.schema()) + } + + // For column-update writes, rows contain only the assigned columns + // (narrow schema from LogicalWriteInfo). We expand each row to the full table + // schema by overlaying write-schema columns on the base row found by pk. + override protected def doCommit(messages: Array[WriterCommitMessage]): Unit = + dataMap.synchronized { + val newData = messages.map(_.asInstanceOf[BufferedRows]) + val writeSchema = lastWriteInfo.schema() + val writeFieldIdx = writeSchema.fieldNames.zipWithIndex.toMap + + val mergedData = newData.map { buf => + val merged = new BufferedRows(buf.key, schema) + val updateOpName = UTF8String.fromString(Update.toString) + buf.log.foreach { logRow => + val opName = logRow.getUTF8String(0) + if (opName == updateOpName) { + val pk = logRow.getInt(1) + val narrowRow = logRow.get(3, writeSchema).asInstanceOf[InternalRow] + val baseRow = dataMap.values.iterator.flatten + .flatMap(_.rows) + .find(r => r.getInt(schema.fieldIndex("pk")) == pk) + val fullRow = new GenericInternalRow(schema.length) + baseRow.foreach { base => + for (i <- schema.fields.indices) { + fullRow.update(i, base.get(i, schema(i).dataType)) + } + } + schema.fields.zipWithIndex.foreach { case (field, i) => + writeFieldIdx.get(field.name).foreach { j => + fullRow.update(i, narrowRow.get(j, field.dataType)) + } + } + merged.rows.append(fullRow) + } + } + merged + } + + withDeletes(newData) + withData(mergedData) + lastWriteLog = newData.flatMap(buffer => buffer.log).toIndexedSeq + } + + override def abort(messages: Array[WriterCommitMessage]): Unit = {} + } + } + } + } + } + + // A variant of DeltaBasedColumnUpdateOperation that overrides requiredDataAttributes() + // to declare a fixed set of data columns the connector needs in the scan. This exercises + // the connector-driven scan-narrowing path (as opposed to the heuristic path). + class DeltaBasedColumnUpdateOperationWithReqAttrs(command: Command, reqCols: Array[String]) + extends DeltaBasedColumnUpdateOperation(command) { + override def requiredDataAttributes(): Array[NamedReference] = reqCols.map(FieldReference(_)) + } + + // A delta-based column-update connector that derives requiredDataAttributes() dynamically + // from RowLevelOperationInfo.updatedColumns(). + // + // This models the common connector pattern: + // 1. Spark tells the connector which columns are being updated via updatedColumns(). + // 2. The connector adds any extra columns it always needs (here: "pk" for row lookup). + // 3. The combined set is returned from requiredDataAttributes() so Spark narrows the scan. + // + // If "pk" is already in updatedColumns (the user is updating pk itself), it is not duplicated. + class DeltaBasedColumnUpdateOperationFromInfo( + command: Command, + updatedCols: Seq[NamedReference]) + extends DeltaBasedColumnUpdateOperation(command) { + + private val PK_REF: NamedReference = FieldReference("pk") + + override def requiredDataAttributes(): Array[NamedReference] = { + val updatedNames = updatedCols.map(_.describe()).toSet + if (updatedNames.contains("pk")) { + updatedCols.toArray + } else { + (Array(PK_REF) ++ updatedCols).toArray + } + } + } + + // A delta-based operation that combines representUpdateAsDeleteAndInsert=true with + // supportsColumnUpdates()=true. This verifies that the restriction which previously + // blocked column-level updates on the delete+reinsert path has been lifted. + // + // The connector declares "pk" plus any columns being updated (via updatedCols). + // The write schema = requiredDataAttributes() in declared order. + // The REINSERT leg receives the narrow write row; the DELETE leg uses row ID only. + class DeltaBasedColumnUpdateSplitOperation( + command: Command, + updatedCols: Seq[NamedReference] = Nil) + extends DeltaBasedColumnUpdateOperation(command) { + override def representUpdateAsDeleteAndInsert(): Boolean = true + + private val PK_REF: NamedReference = FieldReference("pk") + override def requiredDataAttributes(): Array[NamedReference] = { + val updatedNames = updatedCols.map(_.describe()).toSet + if (updatedNames.contains("pk")) updatedCols.toArray + else (Array(PK_REF) ++ updatedCols).toArray + } + + override def newWriteBuilder(info: LogicalWriteInfo): DeltaWriteBuilder = { + lastWriteInfo = info + new DeltaWriteBuilder { + override def build(): DeltaWrite = + new DeltaWrite with RequiresDistributionAndOrdering { + override def requiredDistribution(): Distribution = + Distributions.clustered(Array(PARTITION_COLUMN_REF)) + override def requiredOrdering(): Array[SortOrder] = Array[SortOrder]( + LogicalExpressions.sort( + PARTITION_COLUMN_REF, + SortDirection.ASCENDING, + SortDirection.ASCENDING.defaultNullOrdering())) + override def toBatch: DeltaBatchWrite = + new TestBatchWrite with DeltaBatchWrite { + override def createBatchWriterFactory( + info: PhysicalWriteInfo): DeltaWriterFactory = + new DeltaBufferedRowsWriterFactory(lastWriteInfo.schema()) + + // For delete+reinsert with narrow writes, the REINSERT row has only the + // connector-declared columns (requiredDataAttributes order). + // pk is the first field in the write schema (declared before updatedCols). + // Reconstruct the full row by overlaying the narrow row onto the original. + override protected def doCommit(messages: Array[WriterCommitMessage]): Unit = + dataMap.synchronized { + val newData = messages.map(_.asInstanceOf[BufferedRows]) + val writeSchema = lastWriteInfo.schema() + val writeFieldIdx = writeSchema.fieldNames.zipWithIndex.toMap + val reinsertOpName = UTF8String.fromString(Reinsert.toString) + val pkIdx = writeFieldIdx("pk") + + val expandedData = newData.map { buf => + val expanded = new BufferedRows(buf.key, schema) + buf.log.foreach { logRow => + val opName = logRow.getUTF8String(0) + if (opName == reinsertOpName) { + val narrowRow = logRow.get(3, writeSchema).asInstanceOf[InternalRow] + val pk = narrowRow.getInt(pkIdx) + val baseRow = dataMap.values.iterator.flatten + .flatMap(_.rows) + .find(r => r.getInt(schema.fieldIndex("pk")) == pk) + val fullRow = new GenericInternalRow(schema.length) + baseRow.foreach { base => + for (i <- schema.fields.indices) { + fullRow.update(i, base.get(i, schema(i).dataType)) + } + } + schema.fields.zipWithIndex.foreach { case (field, i) => + writeFieldIdx.get(field.name).foreach { j => + fullRow.update(i, narrowRow.get(j, field.dataType)) + } + } + expanded.rows.append(fullRow) + } + } + expanded + } + + withDeletes(newData) + withData(expandedData) + lastWriteLog = newData.flatMap(buffer => buffer.log).toIndexedSeq + } + + override def abort(messages: Array[WriterCommitMessage]): Unit = {} + } + } + } + } + } + + // A CoW operation that supports column-level updates. The connector declares it needs + // "pk" and "dep" for partition routing, plus any columns the user is updating (via + // updatedCols from RowLevelOperationInfo). supportsColumnUpdates()=true so Spark narrows + // the scan and write schema to exactly requiredDataAttributes(). + // The commit logic reconstructs full rows from the original scan data using pk as a key. + class PartitionBasedColumnUpdateOperation( + command: Command, + updatedCols: Seq[NamedReference] = Nil) extends RowLevelOperation { + var configuredScan: InMemoryBatchScan = _ + + override def command(): Command = command + + override def supportsColumnUpdates(): Boolean = true + + override def requiredDataAttributes(): Array[NamedReference] = { + // Always need pk (for row lookup) and dep (partition key). + // Also include any columns being updated so Spark sends their new values. + val base = Seq(FieldReference("pk"), FieldReference("dep")) + val baseNames = base.map(_.describe()).toSet + (base ++ updatedCols.filterNot(r => baseNames.contains(r.describe()))).toArray + } + + override def requiredMetadataAttributes(): Array[NamedReference] = + Array(PARTITION_COLUMN_REF, INDEX_COLUMN_REF) + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + new InMemoryScanBuilder(schema, options) { + override def build(): Scan = { + val scan = super.build() + configuredScan = scan.asInstanceOf[InMemoryBatchScan] + lastScanSchema = scan.readSchema() + scan + } + } + } + + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { + lastWriteInfo = info + new WriteBuilder { + override def build(): Write = new Write with RequiresDistributionAndOrdering { + override def requiredDistribution: Distribution = + Distributions.clustered(Array(PARTITION_COLUMN_REF)) + + override def requiredOrdering: Array[SortOrder] = Array[SortOrder]( + LogicalExpressions.sort( + PARTITION_COLUMN_REF, + SortDirection.ASCENDING, + SortDirection.ASCENDING.defaultNullOrdering())) + + override def toBatch: BatchWrite = + PartitionBasedNarrowReplaceData(configuredScan, info.schema()) + + override def description: String = "InMemoryNarrowCoWWrite" + } + } + } + + override def description(): String = "InMemoryPartitionColumnUpdateOperation" + } + + // CoW write handler for narrow column-update writes. + // Receives rows with only the connector-declared + assigned columns. + // Reconstructs full rows by looking up the original row by pk and overlaying received columns. + private case class PartitionBasedNarrowReplaceData( + scan: InMemoryBatchScan, + writeSchema: StructType) extends TestBatchWrite { + + override protected def doCommit( + messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { + val newData = messages.map(_.asInstanceOf[BufferedRows]) + val readRows = scan.data.flatMap(_.asInstanceOf[BufferedRows].rows) + val readPartitions = readRows.map(r => getKey(r, schema)).distinct + dataMap --= readPartitions + replacedPartitions = readPartitions + + val writeFieldIdx = writeSchema.fieldNames.zipWithIndex.toMap + val pkIdxInWrite = writeFieldIdx("pk") + val pkIdxInFull = schema.fieldIndex("pk") + + val expandedData = newData.map { buf => + val expanded = new BufferedRows(buf.key, schema) + buf.rows.foreach { narrowRow => + val pk = narrowRow.getInt(pkIdxInWrite) + val origRow = readRows.find(r => r.getInt(pkIdxInFull) == pk) + val fullRow = new GenericInternalRow(schema.length) + origRow.foreach { base => + for (i <- schema.fields.indices) { + fullRow.update(i, base.get(i, schema(i).dataType)) + } + } + schema.fields.zipWithIndex.foreach { case (field, i) => + writeFieldIdx.get(field.name).foreach { j => + fullRow.update(i, narrowRow.get(j, field.dataType)) + } + } + expanded.rows.append(fullRow) + } + expanded + } + + withData(expandedData, schema) + lastWriteLog = newData.flatMap(buffer => buffer.log).toImmutableArraySeq + } + } + private object TestDeltaBatchWrite extends TestBatchWrite with DeltaBatchWrite { override def createBatchWriterFactory(info: PhysicalWriteInfo): DeltaWriterFactory = { new DeltaBufferedRowsWriterFactory(CatalogV2Util.v2ColumnsToStructType(columns())) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala index 9f8409efa360e..e00a1b7d9b219 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.execution.dynamicpruning -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, DynamicPruningExpression, Expression, InSubquery, ListQuery, PredicateHelper, V2ExpressionUtils} import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral import org.apache.spark.sql.catalyst.optimizer.RewritePredicateSubquery @@ -139,17 +138,13 @@ class RowLevelOperationRuntimeGroupFiltering(optimizeSubqueries: Rule[LogicalPla tableAttrs: Seq[Attribute], scanAttrs: Seq[Attribute]): AttributeMap[Attribute] = { - val attrMapping = tableAttrs.map { tableAttr => + // For column-level updates, the scan may be narrowed to exclude columns that the + // connector does not need. Skip table attributes that are absent from the scan + // instead of throwing -- they cannot appear in the condition if they were pruned. + val attrMapping = tableAttrs.flatMap { tableAttr => scanAttrs .find(scanAttr => conf.resolver(scanAttr.name, tableAttr.name)) .map(scanAttr => tableAttr -> scanAttr) - .getOrElse { - throw new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_3075", - messageParameters = Map( - "tableAttr" -> tableAttr.toString, - "scanAttrs" -> scanAttrs.mkString(","))) - } } AttributeMap(attrMapping) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedColumnUpdateTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedColumnUpdateTableSuite.scala new file mode 100644 index 0000000000000..b22f2743a87ee --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedColumnUpdateTableSuite.scala @@ -0,0 +1,677 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +import org.apache.spark.sql.Row +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, TableInfo} +import org.apache.spark.sql.connector.expressions.LogicalExpressions.{identity, reference} +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} + +/** + * Tests for UPDATE statements targeting connectors that return true from + * [[org.apache.spark.sql.connector.write.RowLevelOperation#supportsColumnUpdates]]. + * + * When a connector supports column updates, Spark narrows the row projection + * (LogicalWriteInfo.schema()) to contain only the assigned/changed columns rather than + * the full table row. + */ +class DeltaBasedColumnUpdateTableSuite extends RowLevelOperationSuiteBase { + + override protected lazy val extraTableProps: java.util.Map[String, String] = { + val props = new java.util.HashMap[String, String]() + props.put("column-update", "true") + props + } + + // --- Schema narrowing: verify LogicalWriteInfo.schema() is narrow --- + + test("column-update: rowSchema contains only the single assigned column") { + createAndInitTable("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |{ "pk": 2, "id": 2, "dep": "software" } + |{ "pk": 3, "id": 3, "dep": "hr" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET id = -1 WHERE pk = 1") + + // Only the assigned column (id) should appear in the row schema -- not pk or dep + checkLastWriteInfo( + expectedRowSchema = StructType(Seq( + StructField("id", IntegerType, nullable = false) + )), + expectedRowIdSchema = Some(StructType(Array(PK_FIELD))), + expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE)))) + } + + test("column-update: rowSchema contains multiple assigned columns") { + createAndInitTable("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |{ "pk": 2, "id": 2, "dep": "software" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET id = -1, dep = 'engineering' WHERE pk = 1") + + // Both assigned columns (id, dep) should appear -- but NOT pk (unassigned) + checkLastWriteInfo( + expectedRowSchema = StructType(Seq( + StructField("id", IntegerType, nullable = false), + StructField("dep", StringType, nullable = false) + )), + expectedRowIdSchema = Some(StructType(Array(PK_FIELD))), + expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE)))) + } + + test("column-update: rowSchema is empty for a full identity update") { + createAndInitTable("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET id = id, dep = dep WHERE pk = 1") + + checkLastWriteInfo( + expectedRowSchema = new StructType(), + expectedRowIdSchema = Some(StructType(Array(PK_FIELD))), + expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE)))) + } + + test("column-update: row filter condition is orthogonal to column narrowing") { + createAndInitTable("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |{ "pk": 2, "id": 2, "dep": "software" } + |{ "pk": 3, "id": 3, "dep": "hr" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET dep = 'engineering' WHERE pk IN (1, 3)") + + checkLastWriteInfo( + expectedRowSchema = StructType(Seq( + StructField("dep", StringType, nullable = false) + )), + expectedRowIdSchema = Some(StructType(Array(PK_FIELD))), + expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE)))) + } + + test("column-update: update all rows (no WHERE clause)") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET salary = salary * 2") + + checkLastWriteInfo( + expectedRowSchema = StructType(Seq( + StructField("salary", IntegerType, nullable = true) + )), + expectedRowIdSchema = Some(StructType(Array(PK_FIELD))), + expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE)))) + } + + // --- Identity assignment filtering --- + + test("column-update: rowSchema excludes identity assignments in a mixed UPDATE") { + createAndInitTable("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |{ "pk": 2, "id": 2, "dep": "software" } + |""".stripMargin) + + // id = id is identity -- should be excluded from rowSchema + // dep = 'engineering' is a real assignment -- should be included + sql(s"UPDATE $tableNameAsString SET id = id, dep = 'engineering' WHERE pk = 1") + + checkLastWriteInfo( + expectedRowSchema = StructType(Seq( + StructField("dep", StringType, nullable = false) + )), + expectedRowIdSchema = Some(StructType(Array(PK_FIELD))), + expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE)))) + } + + test("column-update: cross-column assignment is not treated as identity") { + createAndInitTable("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |{ "pk": 2, "id": 2, "dep": "software" } + |""".stripMargin) + + // dep = dep is identity; id = -1 is a real assignment -- only id should appear + sql(s"UPDATE $tableNameAsString SET dep = dep, id = -1 WHERE pk = 1") + + checkLastWriteInfo( + expectedRowSchema = StructType(Seq( + StructField("id", IntegerType, nullable = false) + )), + expectedRowIdSchema = Some(StructType(Array(PK_FIELD))), + expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE)))) + } + + // --- updatedColumns in RowLevelOperationInfo --- + + test("column-update: updatedColumns contains non-identity assigned columns") { + createAndInitTable("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET id = -1, dep = 'eng' WHERE pk = 1") + + val updatedNames = table.lastUpdatedColumns.map(_.describe()).toSet + assert(updatedNames == Set("id", "dep"), + s"expected [id, dep] in updatedColumns but got: $updatedNames") + } + + test("column-update: updatedColumns excludes identity assignments") { + createAndInitTable("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |""".stripMargin) + + // dep = dep is identity; only id should appear in updatedColumns + sql(s"UPDATE $tableNameAsString SET id = -1, dep = dep WHERE pk = 1") + + val updatedNames = table.lastUpdatedColumns.map(_.describe()).toSet + assert(updatedNames == Set("id"), + s"expected only [id] in updatedColumns but got: $updatedNames") + } + + test("column-update: updatedColumns is empty for a full identity update") { + createAndInitTable("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET id = id, dep = dep WHERE pk = 1") + + assert(table.lastUpdatedColumns.isEmpty, + s"expected empty updatedColumns but got: ${table.lastUpdatedColumns.mkString(", ")}") + } + + test("column-update: updatedColumns is empty for DELETE (Javadoc contract)") { + // DELETE never has updated columns -- verify that the default empty array is passed + // through RowLevelOperationInfo even when a column-update connector handles the DELETE. + // Use a partition-column condition so the InMemory table can process the filter. + createAndInitTable("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |{ "pk": 2, "id": 2, "dep": "software" } + |""".stripMargin) + + sql(s"DELETE FROM $tableNameAsString WHERE dep = 'hr'") + + assert(table.lastUpdatedColumns.isEmpty, + s"DELETE must pass empty updatedColumns but got: ${table.lastUpdatedColumns.mkString(", ")}") + } + + // --- Data correctness --- + + test("column-update: data correctness -- single column update") { + createAndInitTable("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |{ "pk": 2, "id": 2, "dep": "software" } + |{ "pk": 3, "id": 3, "dep": "hr" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET id = -1 WHERE pk = 1") + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString ORDER BY pk"), + Row(1, -1, "hr") :: Row(2, 2, "software") :: Row(3, 3, "hr") :: Nil) + } + + test("column-update: data correctness -- update all rows") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET salary = salary * 2") + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString ORDER BY pk"), + Row(1, 200, "hr") :: Row(2, 400, "software") :: Row(3, 600, "hr") :: Nil) + } + + test("column-update: data correctness -- mixed identity and real assignments") { + createAndInitTable("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |{ "pk": 2, "id": 2, "dep": "software" } + |{ "pk": 3, "id": 3, "dep": "hr" } + |""".stripMargin) + + // Only dep changes; id stays as-is even though id = id is in the SET list. + sql(s"UPDATE $tableNameAsString SET id = id, dep = 'engineering' WHERE pk = 1") + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString ORDER BY pk"), + Row(1, 1, "engineering") :: Row(2, 2, "software") :: Row(3, 3, "hr") :: Nil) + } + + // --- Scan narrowing: verify the connector only receives the columns it needs --- + + test("column-update: scan excludes the assigned column when SET to a literal") { + createAndInitTable("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |{ "pk": 2, "id": 2, "dep": "software" } + |""".stripMargin) + + // id is the target of a literal assignment -- its current value is not needed. + // pk is needed for the WHERE condition and as rowId. + sql(s"UPDATE $tableNameAsString SET id = -1 WHERE pk = 1") + + val scanSchema = table.lastScanSchema + assert(!scanSchema.fieldNames.contains("id"), s"id should be excluded from scan: $scanSchema") + assert(scanSchema.fieldNames.contains("pk"), s"pk must be in scan: $scanSchema") + } + + test("column-update: scan includes the assigned column when its current value is the RHS") { + createAndInitTable("pk INT NOT NULL, salary INT, bonus INT, dep STRING", + """{ "pk": 1, "salary": 100, "bonus": 10, "dep": "hr" } + |{ "pk": 2, "salary": 200, "bonus": 20, "dep": "software" } + |""".stripMargin) + + // salary appears on the RHS (salary * 2) so it must be scanned. + // bonus is not referenced anywhere -- excluded. + sql(s"UPDATE $tableNameAsString SET salary = salary * 2") + + val scanSchema = table.lastScanSchema + assert(scanSchema.fieldNames.contains("salary"), s"salary must be in scan: $scanSchema") + assert(!scanSchema.fieldNames.contains("bonus"), s"bonus should be excluded: $scanSchema") + } + + test("column-update: scan excludes non-referenced columns for literal assignment") { + createAndInitTable("pk INT NOT NULL, id INT, salary INT, dep STRING", + """{ "pk": 1, "id": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "id": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + // dep is a literal assignment; id and salary are not referenced -- only pk needed. + sql(s"UPDATE $tableNameAsString SET dep = 'engineering' WHERE pk = 1") + + val scanSchema = table.lastScanSchema + assert(scanSchema.fieldNames.contains("pk"), s"pk must be in scan: $scanSchema") + assert(!scanSchema.fieldNames.contains("id"), s"id should be excluded: $scanSchema") + assert(!scanSchema.fieldNames.contains("salary"), s"salary should be excluded: $scanSchema") + } + + test("column-update: scan includes condition columns even when not assigned") { + createAndInitTable("pk INT NOT NULL, salary INT, bonus INT, dep STRING", + """{ "pk": 1, "salary": 100, "bonus": 10, "dep": "hr" } + |{ "pk": 2, "salary": 200, "bonus": 20, "dep": "software" } + |""".stripMargin) + + // dep appears in the WHERE clause -- must be scanned even though it is not assigned. + // bonus is neither assigned nor in the condition -- excluded. + // salary is set to a literal -- current value not needed. + sql(s"UPDATE $tableNameAsString SET salary = -1 WHERE dep = 'hr'") + + val scanSchema = table.lastScanSchema + assert(scanSchema.fieldNames.contains("dep"), + s"dep must be in scan (WHERE clause): $scanSchema") + assert(!scanSchema.fieldNames.contains("bonus"), s"bonus should be excluded: $scanSchema") + assert(!scanSchema.fieldNames.contains("salary"), + s"salary should be excluded (literal assignment): $scanSchema") + } + + // --------------------------------------------------------------------------- + // Connector-driven scan narrowing via requiredDataAttributes() + // --------------------------------------------------------------------------- + + // Creates a table backed by DeltaBasedColumnUpdateOperationWithReqAttrs, which overrides + // requiredDataAttributes() to return the given comma-separated column names. + private def createAndInitTableWithReqAttrs( + reqAttrs: String, + schemaString: String, + jsonData: String): Unit = { + val props = new java.util.HashMap[String, String]() + props.put("column-update-req-attrs", reqAttrs) + val columns = CatalogV2Util.structTypeToV2Columns(StructType.fromDDL(schemaString)) + val transforms = Array[Transform](identity(reference(Seq("dep")))) + val tableInfo = new TableInfo.Builder() + .withColumns(columns) + .withPartitions(transforms) + .withProperties(props) + .build() + catalog.createTable(ident, tableInfo) + append(schemaString, jsonData) + } + + test("column-update: requiredDataAttributes forces connector-declared column into scan") { + // Connector declares it always needs "dep". + // SQL assigns "id" (literal) with condition on "pk". + // Connector-driven scan = {pk, dep} (dep from connector declaration; pk from condition). + // id is NOT in scan: literal assignment + not declared by connector. + createAndInitTableWithReqAttrs("dep", "pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |{ "pk": 2, "id": 2, "dep": "software" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET id = -1 WHERE pk = 1") + + val scanSchema = table.lastScanSchema + assert(scanSchema.fieldNames.contains("dep"), + s"dep must be in scan (connector required): $scanSchema") + assert(scanSchema.fieldNames.contains("pk"), s"pk must be in scan: $scanSchema") + assert(!scanSchema.fieldNames.contains("id"), + s"id should be excluded (literal assignment, not declared): $scanSchema") + } + + test("column-update: requiredDataAttributes - data correctness") { + // Connector declares "dep,id" so it receives both the new id value and dep for routing. + // The write schema is exactly requiredDataAttributes = {dep, id} (declared order). + createAndInitTableWithReqAttrs("dep,id", "pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |{ "pk": 2, "id": 2, "dep": "software" } + |{ "pk": 3, "id": 3, "dep": "hr" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET id = -1 WHERE pk = 1") + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString ORDER BY pk"), + Row(1, -1, "hr") :: Row(2, 2, "software") :: Row(3, 3, "hr") :: Nil) + } + + test("column-update: empty requiredDataAttributes falls back to heuristic") { + // "column-update" uses DeltaBasedColumnUpdateOperation whose requiredDataAttributes() + // returns the default empty array. + // With the optimizer-driven approach for MOR, the scan is narrowed by V2ScanRelationPushDown + // which observes what columns the write plan actually references. + // SET id = -1 (literal assignment): id is not referenced from the scan, so it is pruned. + // dep is the partitioning column; since it is not declared in requiredDataAttributes() + // and is not referenced by the WHERE condition (pk = 1), it may be pruned from the scan. + createAndInitTable("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |{ "pk": 2, "id": 2, "dep": "software" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET id = -1 WHERE pk = 1") + + val scanSchema = table.lastScanSchema + assert(!scanSchema.fieldNames.contains("id"), + s"id must NOT be in scan (literal assignment, no scan reference): $scanSchema") + assert(scanSchema.fieldNames.contains("pk"), s"pk must be in scan (condition): $scanSchema") + } + + // --------------------------------------------------------------------------- + // Connector uses RowLevelOperationInfo.updatedColumns() to derive its own + // requiredDataAttributes() dynamically. + // DeltaBasedColumnUpdateOperationFromInfo always adds "pk" (for row lookup) to + // whatever Spark reports as updated columns. + // --------------------------------------------------------------------------- + + private def createAndInitTableFromInfo(schemaString: String, jsonData: String): Unit = { + val props = new java.util.HashMap[String, String]() + props.put("column-update-from-info", "true") + val columns = CatalogV2Util.structTypeToV2Columns(StructType.fromDDL(schemaString)) + val transforms = Array[Transform](identity(reference(Seq("dep")))) + val tableInfo = new TableInfo.Builder() + .withColumns(columns) + .withPartitions(transforms) + .withProperties(props) + .build() + catalog.createTable(ident, tableInfo) + append(schemaString, jsonData) + } + + test("column-update from-info: connector adds pk to updatedColumns for requiredDataAttributes") { + // Connector receives updatedColumns=[salary], adds pk for row lookup. + // requiredDataAttributes() = [pk, salary]. + // + // salary = -1 is a LITERAL assignment: the write plan references Literal(-1) not the + // scan's salary column. Since salary is in assignedAttrs, it is not a connectorExtraAttr + // pass-through either. V2ScanRelationPushDown therefore does not see salary referenced + // and prunes it from the scan. + // + // The scan contains: pk (connector pass-through), dep (partitioning + WHERE condition). + // The scan excludes: salary (literal assignment), id and bonus (not declared, not in cond). + createAndInitTableFromInfo("pk INT NOT NULL, salary INT, id INT, bonus INT, dep STRING", + """{ "pk": 1, "salary": 100, "id": 10, "bonus": 5, "dep": "hr" } + |{ "pk": 2, "salary": 200, "id": 20, "bonus": 6, "dep": "software" } + |{ "pk": 3, "salary": 300, "id": 30, "bonus": 7, "dep": "hr" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET salary = -1 WHERE dep = 'hr'") + + val scanSchema = table.lastScanSchema + assert(scanSchema.fieldNames.contains("pk"), + s"pk must be in scan (connector pass-through via connectorExtraAttrs): $scanSchema") + assert(scanSchema.fieldNames.contains("dep"), + s"dep must be in scan (partitioning + WHERE): $scanSchema") + assert(!scanSchema.fieldNames.contains("id"), + s"id must be excluded (not declared, not assigned, not in condition): $scanSchema") + assert(!scanSchema.fieldNames.contains("bonus"), + s"bonus must be excluded (not declared, not assigned, not in condition): $scanSchema") + } + + test("column-update from-info: write schema is updatedColumns + pk pass-through") { + // requiredDataAttributes = [pk, salary] (pk always added; salary because it's assigned). + // Write schema = requiredDataAttributes in declared order = {pk, salary}. + createAndInitTableFromInfo("pk INT NOT NULL, salary INT, id INT, dep STRING", + """{ "pk": 1, "salary": 100, "id": 10, "dep": "hr" } + |{ "pk": 2, "salary": 200, "id": 20, "dep": "software" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET salary = -1 WHERE pk = 1") + + val writeSchema = table.lastWriteInfo.schema() + assert(writeSchema.fieldNames.contains("salary"), + s"salary must be in write schema (assigned): $writeSchema") + assert(writeSchema.fieldNames.contains("pk"), + s"pk must be in write schema " + + s"(connector pass-through via requiredDataAttributes): $writeSchema") + assert(!writeSchema.fieldNames.contains("id"), + s"id must not be in write schema: $writeSchema") + assert(!writeSchema.fieldNames.contains("dep"), + s"dep must not be in write schema (partitioning, not a data column to write): $writeSchema") + } + + test("column-update from-info: pk already in updatedColumns is not duplicated") { + // When the user updates pk itself, updatedColumns=[pk, salary]. + // Connector sees pk already present -> requiredDataAttributes=[pk, salary] (no dup). + createAndInitTableFromInfo("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET pk = pk + 10, salary = -1 WHERE dep = 'hr'") + + val writeSchema = table.lastWriteInfo.schema() + val pkCount = writeSchema.fieldNames.count(_ == "pk") + assert(pkCount == 1, s"pk must appear exactly once in write schema: $writeSchema") + } + + test("column-update from-info: data correctness") { + createAndInitTableFromInfo("pk INT NOT NULL, salary INT, id INT, dep STRING", + """{ "pk": 1, "salary": 100, "id": 10, "dep": "hr" } + |{ "pk": 2, "salary": 200, "id": 20, "dep": "software" } + |{ "pk": 3, "salary": 300, "id": 30, "dep": "hr" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET salary = -1 WHERE dep = 'hr'") + + // salary updated for hr rows; id preserved (not in write schema, connector uses pk lookup) + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString ORDER BY pk"), + Row(1, -1, 10, "hr") :: + Row(2, 200, 20, "software") :: + Row(3, -1, 30, "hr") :: Nil) + } + + // --------------------------------------------------------------------------- + // CoW connector with supportsColumnUpdates() on RowLevelOperation. + // PartitionBasedColumnUpdateOperation declares requiredDataAttributes() = [pk, dep] and + // supportsColumnUpdates() = true. Spark narrows the scan to connector-declared + assigned + // columns; bonus is excluded. The connector reconstructs full rows via pk lookup. + // --------------------------------------------------------------------------- + + private def createAndInitTableCoW(schemaString: String, jsonData: String): Unit = { + val props = new java.util.HashMap[String, String]() + props.put("column-update-cow", "true") + val columns = CatalogV2Util.structTypeToV2Columns(StructType.fromDDL(schemaString)) + val transforms = Array[Transform](identity(reference(Seq("dep")))) + val tableInfo = new TableInfo.Builder() + .withColumns(columns) + .withPartitions(transforms) + .withProperties(props) + .build() + catalog.createTable(ident, tableInfo) + append(schemaString, jsonData) + } + + test("column-update CoW: scan excludes columns not declared and not assigned") { + // Connector declares [pk, dep]. SET salary = -1. + // Narrow scan = pk (declared) + dep (declared + condition + partitioning) + // + salary (assigned LHS). bonus is neither declared nor assigned -> excluded. + createAndInitTableCoW("pk INT NOT NULL, salary INT, bonus INT, dep STRING", + """{ "pk": 1, "salary": 100, "bonus": 10, "dep": "hr" } + |{ "pk": 2, "salary": 200, "bonus": 20, "dep": "software" } + |{ "pk": 3, "salary": 300, "bonus": 30, "dep": "hr" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET salary = -1 WHERE dep = 'hr'") + + val scanSchema = table.lastScanSchema + assert(scanSchema.fieldNames.contains("pk"), s"pk must be in scan: $scanSchema") + assert(scanSchema.fieldNames.contains("dep"), s"dep must be in scan: $scanSchema") + assert(scanSchema.fieldNames.contains("salary"), + s"salary must be in scan (assigned LHS): $scanSchema") + assert(!scanSchema.fieldNames.contains("bonus"), s"bonus must be excluded: $scanSchema") + } + + test("column-update CoW: write schema contains only declared + assigned columns") { + createAndInitTableCoW("pk INT NOT NULL, salary INT, bonus INT, dep STRING", + """{ "pk": 1, "salary": 100, "bonus": 10, "dep": "hr" } + |{ "pk": 2, "salary": 200, "bonus": 20, "dep": "software" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET salary = -1 WHERE dep = 'hr'") + + val writeSchema = table.lastWriteInfo.schema() + assert(writeSchema.fieldNames.contains("pk"), s"pk must be in write schema: $writeSchema") + assert(writeSchema.fieldNames.contains("dep"), s"dep must be in write schema: $writeSchema") + assert(writeSchema.fieldNames.contains("salary"), + s"salary must be in write schema: $writeSchema") + assert(!writeSchema.fieldNames.contains("bonus"), + s"bonus must not be in write schema: $writeSchema") + } + + test("column-update CoW: data correctness -- bonus preserved, salary updated") { + // bonus is not in the write schema; the connector must preserve it from the original row. + createAndInitTableCoW("pk INT NOT NULL, salary INT, bonus INT, dep STRING", + """{ "pk": 1, "salary": 100, "bonus": 10, "dep": "hr" } + |{ "pk": 2, "salary": 200, "bonus": 20, "dep": "software" } + |{ "pk": 3, "salary": 300, "bonus": 30, "dep": "hr" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET salary = -1 WHERE dep = 'hr'") + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString ORDER BY pk"), + Row(1, -1, 10, "hr") :: + Row(2, 200, 20, "software") :: + Row(3, -1, 30, "hr") :: Nil) + } + + test("column-update CoW: narrow scan + subquery WHERE condition") { + // Exercises buildReplaceDataWithUnionPlan + narrow scan + the flatMap change in + // RowLevelOperationRuntimeGroupFiltering.buildTableToScanAttrMap. + // The subquery forces the UNION path (updated rows + remaining rows). + // bonus is not declared and not assigned, must be excluded from scan and write + // but the subquery-based filter must still work correctly with the narrow scan. + createAndInitTableCoW("pk INT NOT NULL, salary INT, bonus INT, dep STRING", + """{ "pk": 1, "salary": 100, "bonus": 10, "dep": "hr" } + |{ "pk": 2, "salary": 200, "bonus": 20, "dep": "software" } + |{ "pk": 3, "salary": 300, "bonus": 30, "dep": "hr" } + |""".stripMargin) + + import testImplicits._ + val subqueryDF = Seq("hr").toDF() + subqueryDF.createOrReplaceTempView("target_deps") + + sql( + s"""UPDATE $tableNameAsString + |SET salary = -1 + |WHERE dep IN (SELECT * FROM target_deps) + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString ORDER BY pk"), + Row(1, -1, 10, "hr") :: + Row(2, 200, 20, "software") :: + Row(3, -1, 30, "hr") :: Nil) + } + + // --------------------------------------------------------------------------- + // Delta connector with representUpdateAsDeleteAndInsert=true AND supportsColumnUpdates=true. + // + // Point 7: The restriction that blocked column-level updates on the delete+reinsert path + // has been removed. The REINSERT leg of the Expand uses only assigned values (the narrow + // write schema from effectiveRowAttrs), and the DELETE leg uses row ID only. + // --------------------------------------------------------------------------- + + private def createAndInitTableSplit(schemaString: String, jsonData: String): Unit = { + val props = new java.util.HashMap[String, String]() + props.put("column-update-split", "true") + val columns = CatalogV2Util.structTypeToV2Columns(StructType.fromDDL(schemaString)) + val transforms = Array[Transform](identity(reference(Seq("dep")))) + val tableInfo = new TableInfo.Builder() + .withColumns(columns) + .withPartitions(transforms) + .withProperties(props) + .build() + catalog.createTable(ident, tableInfo) + append(schemaString, jsonData) + } + + test("column-update split: write schema is narrow (assigned + pk pass-through)") { + // representUpdateAsDeleteAndInsert=true + supportsColumnUpdates=true. + // requiredDataAttributes() = [pk, id] (pk always declared; id because it's being updated). + // The write schema = requiredDataAttributes() in declared order = {pk, id}. + // dep is NOT in the write schema (not declared, not assigned). + createAndInitTableSplit("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |{ "pk": 2, "id": 2, "dep": "software" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET id = -1 WHERE pk = 1") + + // Write schema is exactly requiredDataAttributes = {pk, id} in declared order. + checkLastWriteInfo( + expectedRowSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("id", IntegerType, nullable = false) + )), + expectedRowIdSchema = Some(StructType(Array(PK_FIELD))), + expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE)))) + } + + test("column-update split: data correctness") { + // representUpdateAsDeleteAndInsert=true + supportsColumnUpdates=true. + // The connector receives narrow REINSERT rows and must reconstruct full rows. + createAndInitTableSplit("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |{ "pk": 2, "id": 2, "dep": "software" } + |{ "pk": 3, "id": 3, "dep": "hr" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET id = -1 WHERE dep = 'hr'") + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString ORDER BY pk"), + Row(1, -1, "hr") :: Row(2, 2, "software") :: Row(3, -1, "hr") :: Nil) + } +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuiteBase.scala index 49e586535a0d0..6c7dfc25d3a9a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuiteBase.scala @@ -23,6 +23,74 @@ abstract class DeltaBasedUpdateTableSuiteBase extends UpdateTableSuiteBase { override protected def deltaUpdate: Boolean = true + // --------------------------------------------------------------------------- + // RowLevelOperationInfo.updatedColumns() -- Spark informs the connector which + // columns are genuinely being updated (non-identity assignments only). + // --------------------------------------------------------------------------- + + test("updatedColumns: single non-identity assignment") { + createAndInitTable("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET id = -1 WHERE pk = 1") + + checkLastUpdatedColumns("id") + } + + test("updatedColumns: multiple non-identity assignments") { + createAndInitTable("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET id = -1, dep = 'eng' WHERE pk = 1") + + checkLastUpdatedColumns("id", "dep") + } + + test("updatedColumns: identity assignments are excluded") { + createAndInitTable("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |""".stripMargin) + + // dep = dep is an identity assignment and must NOT appear in updatedColumns + sql(s"UPDATE $tableNameAsString SET id = -1, dep = dep WHERE pk = 1") + + checkLastUpdatedColumns("id") + } + + test("updatedColumns: empty when all assignments are identity") { + createAndInitTable("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET id = id, dep = dep WHERE pk = 1") + + checkLastUpdatedColumns() // expects empty + } + + test("updatedColumns: no WHERE clause still reports assigned columns") { + createAndInitTable("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |{ "pk": 2, "id": 2, "dep": "software" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET dep = 'eng'") + + checkLastUpdatedColumns("dep") + } + + test("updatedColumns: cross-column assignment is not treated as identity") { + createAndInitTable("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |""".stripMargin) + + // SET id = dep assigns a different column's value to id -- not identity + sql(s"UPDATE $tableNameAsString SET id = 0, dep = dep WHERE pk = 1") + + checkLastUpdatedColumns("id") + } + test("nullable row ID attrs") { createAndInitTable("pk INT, salary INT, dep STRING", """{ "pk": 1, "salary": 300, "dep": 'hr' } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala index 379e7ba755d9d..a677945c48363 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala @@ -245,6 +245,18 @@ abstract class RowLevelOperationSuiteBase assert(actualMetadataSchema == expectedMetadataSchema, "metadata schema must match") } + /** + * Asserts that the column names in RowLevelOperationInfo.updatedColumns() received by the + * last operation match exactly the expected set. Order is ignored. + */ + protected def checkLastUpdatedColumns(expectedNames: String*): Unit = { + val actual = table.lastUpdatedColumns.map(_.describe()).toSet + val expected = expectedNames.toSet + assert(actual == expected, + s"updatedColumns mismatch: expected ${expected.mkString("[", ", ", "]")} " + + s"but got ${actual.mkString("[", ", ", "]")}") + } + protected def checkLastWriteLog(expectedEntries: WriteLogEntry*): Unit = { val entryType = new StructType() .add(StructField("operation", StringType)) From d4babf7a5466db4304a48d5a6c9fd576f969a11c Mon Sep 17 00:00:00 2001 From: Anurag Mantripragada Date: Tue, 5 May 2026 16:02:18 -0700 Subject: [PATCH 2/5] Address review comments and rebase --- .../sql/catalyst/analysis/RewriteRowLevelCommand.scala | 5 ++--- .../connector/catalog/InMemoryRowLevelOperationTable.scala | 6 +++--- .../scala/org/apache/spark/sql/connector/catalog/txns.scala | 3 +++ 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala index 98d73225515a6..b6909a6dfd670 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala @@ -72,9 +72,8 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] { cond: Expression = TrueLiteral): DataSourceV2Relation = { if (dataAttrs.nonEmpty) { - val required = - AttributeSet(dataAttrs) ++ AttributeSet(Seq(cond)) ++ AttributeSet(rowIdAttrs) - val narrowOutput = relation.output.filter(required.contains) + val required = (dataAttrs ++ cond.references.toSeq).map(_.exprId).toSet + val narrowOutput = relation.output.filter(a => required.contains(a.exprId)) relation.copy(table = table, output = dedupAttrs(narrowOutput ++ rowIdAttrs ++ metadataAttrs)) } else { val attrs = dedupAttrs(relation.output ++ rowIdAttrs ++ metadataAttrs) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala index 8635ac92173ed..22f28f8390472 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala @@ -307,7 +307,7 @@ class InMemoryRowLevelOperationTable private ( val fullRow = new GenericInternalRow(schema.length) baseRow.foreach { base => for (i <- schema.fields.indices) { - fullRow.update(i, base.get(i, schema(i).dataType)) + fullRow.update(i, base.get(i, schema.fields(i).dataType)) } } schema.fields.zipWithIndex.foreach { case (field, i) => @@ -430,7 +430,7 @@ class InMemoryRowLevelOperationTable private ( val fullRow = new GenericInternalRow(schema.length) baseRow.foreach { base => for (i <- schema.fields.indices) { - fullRow.update(i, base.get(i, schema(i).dataType)) + fullRow.update(i, base.get(i, schema.fields(i).dataType)) } } schema.fields.zipWithIndex.foreach { case (field, i) => @@ -543,7 +543,7 @@ class InMemoryRowLevelOperationTable private ( val fullRow = new GenericInternalRow(schema.length) origRow.foreach { base => for (i <- schema.fields.indices) { - fullRow.update(i, base.get(i, schema(i).dataType)) + fullRow.update(i, base.get(i, schema.fields(i).dataType)) } } schema.fields.zipWithIndex.foreach { case (field, i) => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala index 4b9dff5c3d780..c204146cdc388 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala @@ -112,6 +112,7 @@ class TxnTable( override def newRowLevelOperationBuilder( info: RowLevelOperationInfo): RowLevelOperationBuilder = { catalog.writeTarget = this + delegate.lastUpdatedColumns = info.updatedColumns() super.newRowLevelOperationBuilder(info) } @@ -128,6 +129,8 @@ class TxnTable( delegate.replacedPartitions = replacedPartitions delegate.lastWriteInfo = lastWriteInfo delegate.lastWriteLog = lastWriteLog + delegate.lastUpdatedColumns = lastUpdatedColumns + delegate.lastScanSchema = lastScanSchema delegate.commits ++= commits delegate.increaseVersion() } From f33797f33115a88d134b72184a5eb626a6ed24a4 Mon Sep 17 00:00:00 2001 From: Anurag Mantripragada Date: Tue, 5 May 2026 17:29:53 -0700 Subject: [PATCH 3/5] Remove narrowing of CAST() --- .../catalyst/analysis/RewriteUpdateTable.scala | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala index b4ef89617c82d..9ea5cfcbf9752 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, AttributeReference, AttributeSet, Cast, EqualNullSafe, Expression, If, Literal, MetadataAttribute, Not, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, AttributeReference, AttributeSet, EqualNullSafe, Expression, If, Literal, MetadataAttribute, Not, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral import org.apache.spark.sql.catalyst.plans.logical.{Assignment, Expand, Filter, LogicalPlan, Project, ReplaceData, Union, UpdateTable, WriteDelta} import org.apache.spark.sql.catalyst.util.RowDeltaUtils._ @@ -328,8 +328,6 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { } // Returns the table attributes that are genuinely updated (non-identity) in this UPDATE. - // Strips Alias/Cast wrappers introduced during assignment alignment before doing the - // AttributeSet membership check (which uses exprId equality internally). private def computeAssignedAttrs(assignments: Seq[Assignment]): Seq[AttributeReference] = { assignments.collect { case Assignment(key: AttributeReference, value) if !isIdentityAssignment(key, value) => key @@ -337,19 +335,16 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { } private def isIdentityAssignment(key: Attribute, value: Expression): Boolean = { - stripAliasesAndCasts(value) match { + val unwrapped = value match { + case Alias(child, _) => child + case other => other + } + unwrapped match { case attr: Attribute => AttributeSet(Seq(key)).contains(attr) case _ => false } } - // Recursively strips Alias and Cast wrappers introduced during assignment alignment. - private def stripAliasesAndCasts(expr: Expression): Expression = expr match { - case Alias(child, _) => stripAliasesAndCasts(child) - case Cast(child, _, _, _) => stripAliasesAndCasts(child) - case other => other - } - // this method assumes the assignments have been already aligned before private def buildWriteDeltaUpdateProjection( plan: LogicalPlan, From e3c253e6dafc543f3a6d23750342f5ff7afc82ee Mon Sep 17 00:00:00 2001 From: Anurag Mantripragada Date: Wed, 6 May 2026 18:06:03 -0700 Subject: [PATCH 4/5] Address review comments and clean up --- .../connector/write/RowLevelOperation.java | 9 +- .../write/RowLevelOperationInfo.java | 2 +- .../analysis/RewriteRowLevelCommand.scala | 12 +- .../analysis/RewriteUpdateTable.scala | 118 +++------ .../catalyst/plans/logical/v2Commands.scala | 33 ++- .../InMemoryRowLevelOperationTable.scala | 35 --- ...wLevelOperationRuntimeGroupFiltering.scala | 6 +- .../DeltaBasedColumnUpdateTableSuite.scala | 244 +++--------------- .../DeltaBasedUpdateTableSuiteBase.scala | 8 +- .../GroupBasedColumnUpdateTableSuite.scala | 115 +++++++++ 10 files changed, 229 insertions(+), 353 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedColumnUpdateTableSuite.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RowLevelOperation.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RowLevelOperation.java index 8c8affdd7c098..26a21cb50f92d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RowLevelOperation.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RowLevelOperation.java @@ -106,7 +106,6 @@ default NamedReference[] requiredMetadataAttributes() { return new NamedReference[0]; } - /** * Controls whether to send only the required data columns to the connector rather than the * full row. @@ -120,8 +119,10 @@ default NamedReference[] requiredMetadataAttributes() { * those columns in declared order. The connector must include all columns it wants to receive, * including the columns being updated. If {@link #requiredDataAttributes()} returns an empty * array, Spark sends only the non-identity assigned columns (heuristic path). + *

+ * Currently only consulted for UPDATE operations. * - * @since 4.2.0 + * @since 4.3.0 */ default boolean supportsColumnUpdates() { return false; @@ -140,8 +141,10 @@ default boolean supportsColumnUpdates() { * assigned, then add any extra columns needed for row lookup or routing (e.g., primary key). *

* When empty (the default), Spark falls back to sending only the non-identity assigned columns. + *

+ * Currently only consulted for UPDATE operations. * - * @since 4.2.0 + * @since 4.3.0 */ default NamedReference[] requiredDataAttributes() { return new NamedReference[0]; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RowLevelOperationInfo.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RowLevelOperationInfo.java index 77bb5b31e28bc..e0211661170cf 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RowLevelOperationInfo.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RowLevelOperationInfo.java @@ -49,7 +49,7 @@ public interface RowLevelOperationInfo { * whether pk is already in the updated columns list and, if not, add it to * requiredDataAttributes(). * - * @since 4.2.0 + * @since 4.3.0 */ default NamedReference[] updatedColumns() { return new NamedReference[0]; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala index b6909a6dfd670..4c08e6fc3856c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala @@ -58,11 +58,13 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] { RowLevelOperationTable(table, operation) } - // Builds a DataSourceV2Relation for a row-level operation, optionally narrowing its output. - // - // When dataAttrs is non-empty, the relation output is narrowed to include only columns - // required for a column-update write. When dataAttrs is empty, the full relation.output is - // preserved. + /** + * Builds a DataSourceV2Relation for a row-level operation, optionally narrowing its output. + * + * When dataAttrs is non-empty, the relation output is narrowed to include only columns + * required for a column-update write. When dataAttrs is empty, the full relation.output is + * preserved. + */ protected def buildRelationWithAttrs( relation: DataSourceV2Relation, table: RowLevelOperationTable, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala index 9ea5cfcbf9752..8e91316281c3f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala @@ -43,8 +43,7 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { EliminateSubqueryAliases(aliasedTable) match { case r @ ExtractV2Table(tbl: SupportsRowLevelOperations) => val updatedCols = assignments.collect { - case Assignment(key: AttributeReference, value) - if !isIdentityAssignment(key, value) => + case Assignment(key: AttributeReference, value) if !isIdentityAssignment(key, value) => FieldReference(key.name) } val table = buildOperationTable(tbl, UPDATE, CaseInsensitiveStringMap.empty(), @@ -72,17 +71,18 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { assignments: Seq[Assignment], cond: Expression): ReplaceData = { - val (readRelation, rowAttrs) = buildCoWReadSetup(relation, operationTable, assignments, cond) + val (readRelation, rowAttrs, metadataAttrs) = + buildReplaceDataReadRelation(relation, operationTable, assignments, cond) val updatedAndRemainingRowsPlan = buildReplaceDataUpdateProjection( readRelation, assignments, cond) val writeRelation = relation.copy(table = operationTable) - val query = updatedAndRemainingRowsPlan - val metadataAttrs = resolveRequiredMetadataAttrs(relation, operationTable.operation) - val projections = buildReplaceDataProjections(query, rowAttrs, metadataAttrs) + val projections = buildReplaceDataProjections(updatedAndRemainingRowsPlan, rowAttrs, + metadataAttrs) val groupFilterCond = if (groupFilterEnabled) Some(cond) else None - ReplaceData(writeRelation, cond, query, relation, projections, groupFilterCond) + ReplaceData(writeRelation, cond, updatedAndRemainingRowsPlan, relation, projections, + groupFilterCond) } // build a rewrite plan for sources that support replacing groups of data (e.g. files, partitions) @@ -93,7 +93,8 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { assignments: Seq[Assignment], cond: Expression): ReplaceData = { - val (readRelation, rowAttrs) = buildCoWReadSetup(relation, operationTable, assignments, cond) + val (readRelation, rowAttrs, metadataAttrs) = + buildReplaceDataReadRelation(relation, operationTable, assignments, cond) // build a plan for updated records that match the condition val matchedRowsPlan = Filter(cond, readRelation) @@ -107,36 +108,29 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { val updatedAndRemainingRowsPlan = Union(updatedRowsPlan, remainingRowsPlan) val writeRelation = relation.copy(table = operationTable) - val query = updatedAndRemainingRowsPlan - val metadataAttrs = resolveRequiredMetadataAttrs(relation, operationTable.operation) - val projections = buildReplaceDataProjections(query, rowAttrs, metadataAttrs) + val projections = buildReplaceDataProjections(updatedAndRemainingRowsPlan, rowAttrs, + metadataAttrs) val groupFilterCond = if (groupFilterEnabled) Some(cond) else None - ReplaceData(writeRelation, cond, query, relation, projections, groupFilterCond) + ReplaceData(writeRelation, cond, updatedAndRemainingRowsPlan, relation, projections, + groupFilterCond) } - // Common read-relation setup shared by both CoW plan builders. - // - // When the connector supports column updates and declares required data attributes, - // the read relation is narrowed at analysis time so that - // GroupBasedRowLevelOperationScanPlanning uses only the needed columns for the scan. - // Otherwise the full relation output is used. - private def buildCoWReadSetup( + /** + * When the connector supports column updates and declares required data attributes, + * the read relation is narrowed at analysis time so that GroupBasedRowLevelOperationScanPlanning + * uses only the needed columns for the scan. Otherwise, the full relation output is used. + */ + private def buildReplaceDataReadRelation( relation: DataSourceV2Relation, operationTable: RowLevelOperationTable, assignments: Seq[Assignment], - cond: Expression): (DataSourceV2Relation, Seq[Attribute]) = { + cond: Expression): (DataSourceV2Relation, Seq[Attribute], Seq[AttributeReference]) = { val operation = operationTable.operation val metadataAttrs = resolveRequiredMetadataAttrs(relation, operation) val connectorDataAttrs = resolveRequiredDataAttrs(relation, operation) val isNarrow = operation.supportsColumnUpdates() && connectorDataAttrs.nonEmpty - // CoW scan narrowing must be done manually at analysis time. - // GroupBasedRowLevelOperationScanPlanning (an optimizer rule that fires after analysis) - // always reads relation.output directly when building the physical scan -- it does not - // observe Project nodes above the relation, so optimizer-driven column pruning has no - // effect on CoW scans. We narrow DataSourceV2Relation.output here so that rule picks - // up the narrow set. val readRelation = if (isNarrow) { val allRequired = (connectorDataAttrs ++ computeAssignedAttrs(assignments)).distinct buildRelationWithAttrs(relation, operationTable, metadataAttrs, dataAttrs = allRequired, @@ -145,30 +139,24 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { buildRelationWithAttrs(relation, operationTable, metadataAttrs) } - // CoW write schema (two paths only, no heuristic for CoW): - // - Narrow path (connectorDataAttrs declared): exactly connector-declared cols in declared - // order. The connector must declare ALL columns it wants to receive. - // - Full path (connectorDataAttrs empty OR supportsColumnUpdates=false): full table output. - // Unlike MOR, CoW does not have a heuristic assigned-only path because - // GroupBasedRowLevelOperationScanPlanning needs explicit column declarations to narrow. val rowAttrs: Seq[Attribute] = if (isNarrow) connectorDataAttrs else relation.output - (readRelation, rowAttrs) + (readRelation, rowAttrs, metadataAttrs) } - // this method assumes the assignments have been already aligned before - // - // Works for both the full-scan and narrow-scan CoW paths. In the narrow case, - // readRelation.output is already restricted by buildCoWReadSetup, so projecting - // all plan.output gives the correct narrow write schema. + /** + * Builds the update projection for ReplaceData plans. Assumes assignments are already aligned. + * + * plan.output may be narrowed by buildReplaceDataReadRelation, so only columns present in the + * plan are projected. + */ private def buildReplaceDataUpdateProjection( plan: LogicalPlan, assignments: Seq[Assignment], cond: Expression = TrueLiteral): LogicalPlan = { - // Build a name-keyed map via AttributeMap (compares by exprId internally) so we can look - // up each plan column's assignment without relying on positional ordering. This is more - // robust than position-based indexing and works correctly for any plan output layout. + // plan.output may be narrowed (fewer columns than assignments) or reordered, + // so assignments are matched by exprId. val assignmentMap = AttributeMap(assignments.collect { case Assignment(key: Attribute, value) => key -> value }) @@ -186,8 +174,7 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { case Some(assignedExpr) => Alias(If(cond, assignedExpr, attr), attr.name)() case None => - // Column is present in the scan but has no assignment -- pass through unchanged. - // In the narrow CoW path these are connector-declared columns not being updated. + // Column in relation.output with no matching assignment; pass through unchanged. attr } } @@ -206,26 +193,19 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { cond: Expression): WriteDelta = { val operation = operationTable.operation.asInstanceOf[SupportsDelta] - // Column-update support applies to the standard delta path and the delete+reinsert path. - // When representUpdateAsDeleteAndInsert is true, the REINSERT leg of the Expand already - // uses only assigned values, so the narrow effectiveRowAttrs applies correctly. val supportsColumnUpdate = operation.supportsColumnUpdates() - // resolve all needed attrs (e.g. row ID and any required metadata attrs) val rowIdAttrs = resolveRowIdAttrs(relation, operation) val metadataAttrs = resolveRequiredMetadataAttrs(relation, operation) - // Connector-declared data attrs used to determine pass-through columns in the write plan. val connectorDataAttrs = if (supportsColumnUpdate) { resolveRequiredDataAttrs(relation, operation) } else Nil - // MOR uses a full-schema scan; ColumnPruning narrows it via Project references. val readRelation = buildRelationWithAttrs(relation, operationTable, metadataAttrs, rowIdAttrs) - // Connector-required attrs that are NOT being assigned are added as pass-throughs in the - // plan so that ColumnPruning keeps them in the physical scan AND the connector receives - // their current values via DeltaWriter.update's row argument. + // Connector-declared attrs not being assigned are passed through so ColumnPruning + // keeps them in the scan and the connector receives their current values. val assignedAttrs = if (supportsColumnUpdate) computeAssignedAttrs(assignments) else relation.output val connectorExtraAttrs: Seq[AttributeReference] = if (connectorDataAttrs.nonEmpty) { @@ -244,12 +224,6 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { buildWriteDeltaUpdateProjection(matchedRowsPlan, assignments, rowIdAttrs) } - // Effective row write schema: - // - Narrow path (connectorDataAttrs declared): exactly connector-declared cols in declared - // order. The connector must declare ALL columns it wants to receive (including updated - // ones). This mirrors the metadata pattern and enables strict areCompatible validation. - // - Heuristic path (connectorDataAttrs empty): only the assigned (changed) columns. - // - Full path (no column-update support): full table output. val effectiveRowAttrs = if (supportsColumnUpdate && connectorDataAttrs.nonEmpty) { connectorDataAttrs } else if (supportsColumnUpdate) { @@ -266,23 +240,11 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { WriteDelta(writeRelation, cond, rowDeltaPlan, relation, projections, groupFilterCond) } - // Builds the row delta projection for the column update path. - // - // The resulting Project references only: - // - assigned column values (new values being written) - // - connector pass-through values (connector declared but not assigned) - // - metadata columns (nulled or preserved) - // - row ID columns (for delta identification) - // - original row ID values (only when a row ID column is being reassigned) - // - // ColumnPruning observes exactly these references and narrows the physical scan accordingly. - // Connectors that need additional columns in the scan (e.g., partition columns for - // distribution) should declare them in requiredDataAttributes(). - // - // Note: AlignUpdateAssignments guarantees all assignment keys are top-level - // AttributeReferences even for nested field updates (e.g., SET col1.field = 'x' becomes - // Assignment(col1: AttributeReference, CreateNamedStruct(...))), so isIdentityAssignment - // correctly identifies non-updating assignments. + /** + * Builds the WriteDelta projection for the column update path. The resulting Project + * references only the columns needed for the write, so ColumnPruning narrows the scan + * to match. + */ private def buildColumnUpdateProjection( plan: LogicalPlan, assignments: Seq[Assignment], @@ -290,21 +252,16 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { metadataAttrs: Seq[Attribute], connectorExtraAttrs: Seq[AttributeReference] = Nil): LogicalPlan = { - // only emit values for non-identity assignments (the narrow write schema) val assignedValues = assignments.collect { case Assignment(key: Attribute, value) if !isIdentityAssignment(key, value) => Alias(value, key.name)() } - // Connector-required data attrs that are not being assigned are passed through as-is - // so that (a) ColumnPruning keeps them in the physical scan, and (b) the connector - // receives their current values via DeltaWriter.update's row argument. val connectorExtraAttrSet = AttributeSet(connectorExtraAttrs) val connectorPassThroughValues = plan.output.filter { a => connectorExtraAttrSet.contains(a) && !MetadataAttribute.isValid(a.metadata) } - // pass through or null out metadata columns present in the scan val metadataAttrSet = AttributeSet(metadataAttrs) val metadataValues = plan.output.filter(metadataAttrSet.contains).map { attr => if (MetadataAttribute.isPreservedOnUpdate(attr)) { @@ -314,7 +271,6 @@ object RewriteUpdateTable extends RewriteRowLevelCommand { } } - // pass through row ID columns from the scan val rowIdAttrSet = AttributeSet(rowIdAttrs) val rowIdValues = plan.output.filter(rowIdAttrSet.contains) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 9cc3d67b73753..d9a9e586f32c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -425,19 +425,19 @@ case class ReplaceData( // validates row projection output is compatible with table attributes private def rowAttrsResolved: Boolean = { val inRowAttrs = DataTypeUtils.toAttributes(projections.rowProjection.schema) - table.skipSchemaResolution || - areCompatible(inRowAttrs, table.output) || + table.skipSchemaResolution || areCompatible(inRowAttrs, table.output) || dataAttrsResolved(inRowAttrs) } - // Validates the narrow-write-schema row projection output. - // - // When the connector declares specific data attributes via requiredDataAttributes(), the - // write schema must exactly match projectedDataAttrs (same columns, same order). This is - // symmetric with metadataAttrsResolved: the connector's declared attrs define the write schema. - // - // When requiredDataAttributes() is empty (heuristic path), the write schema contains only - // the assigned columns. We validate each one exists in the table with a compatible type. + /** + * Validates the narrow-write-schema row projection output. + * + * When the connector declares specific data attributes via requiredDataAttributes(), the + * write schema must exactly match projectedDataAttrs (same columns, same order). + * + * When requiredDataAttributes() is empty, the write schema contains only + * the assigned columns. We validate each one exists in the table with a compatible type. + */ private def dataAttrsResolved(inRowAttrs: Seq[Attribute]): Boolean = { if (!operation.supportsColumnUpdates()) { return false } val outDataAttrs = projectedDataAttrs @@ -545,12 +545,19 @@ case class WriteDelta( case Some(projection) => DataTypeUtils.toAttributes(projection.schema) case None => Nil } - table.skipSchemaResolution || - areCompatible(inRowAttrs, outRowAttrs) || + table.skipSchemaResolution || areCompatible(inRowAttrs, outRowAttrs) || dataAttrsResolved(inRowAttrs) } - // Validates the narrow-write-schema row projection. Symmetric with ReplaceData. + /** + * Validates the narrow-write-schema row projection output. + * + * When the connector declares specific data attributes via requiredDataAttributes(), the + * write schema must exactly match projectedDataAttrs (same columns, same order). + * + * When requiredDataAttributes() is empty, the write schema contains only + * the assigned columns. We validate each one exists in the table with a compatible type. + */ private def dataAttrsResolved(inRowAttrs: Seq[Attribute]): Boolean = { if (!operation.supportsColumnUpdates()) { return false } val outDataAttrs = projectedDataAttrs diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala index 22f28f8390472..69b2be68e7e10 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala @@ -72,17 +72,8 @@ class InMemoryRowLevelOperationTable private ( private final val noMetadata = properties.getOrDefault(NO_METADATA, "false") == "true" private final val COLUMN_UPDATE = "column-update" private final val COLUMN_UPDATE_REQ_ATTRS = "column-update-req-attrs" - // Selects PartitionBasedColumnUpdateOperation: CoW connector with supportsColumnUpdates=true - // and requiredDataAttributes=[pk,dep]. private final val COLUMN_UPDATE_COW = "column-update-cow" - // Selects DeltaBasedColumnUpdateOperationFromInfo: connector that derives - // requiredDataAttributes() dynamically from RowLevelOperationInfo.updatedColumns(). - // Always adds "pk" for row lookup plus whatever Spark reports as updated. private final val COLUMN_UPDATE_FROM_INFO = "column-update-from-info" - // Selects DeltaBasedColumnUpdateSplitOperation: delta connector with - // representUpdateAsDeleteAndInsert=true AND supportsColumnUpdates=true. - // Used to verify Point 7: the restriction on column updates for the delete+reinsert path - // has been lifted. private final val COLUMN_UPDATE_SPLIT = "column-update-split" // used in row-level operation tests to verify replaced partitions @@ -333,23 +324,11 @@ class InMemoryRowLevelOperationTable private ( } } - // A variant of DeltaBasedColumnUpdateOperation that overrides requiredDataAttributes() - // to declare a fixed set of data columns the connector needs in the scan. This exercises - // the connector-driven scan-narrowing path (as opposed to the heuristic path). class DeltaBasedColumnUpdateOperationWithReqAttrs(command: Command, reqCols: Array[String]) extends DeltaBasedColumnUpdateOperation(command) { override def requiredDataAttributes(): Array[NamedReference] = reqCols.map(FieldReference(_)) } - // A delta-based column-update connector that derives requiredDataAttributes() dynamically - // from RowLevelOperationInfo.updatedColumns(). - // - // This models the common connector pattern: - // 1. Spark tells the connector which columns are being updated via updatedColumns(). - // 2. The connector adds any extra columns it always needs (here: "pk" for row lookup). - // 3. The combined set is returned from requiredDataAttributes() so Spark narrows the scan. - // - // If "pk" is already in updatedColumns (the user is updating pk itself), it is not duplicated. class DeltaBasedColumnUpdateOperationFromInfo( command: Command, updatedCols: Seq[NamedReference]) @@ -367,13 +346,6 @@ class InMemoryRowLevelOperationTable private ( } } - // A delta-based operation that combines representUpdateAsDeleteAndInsert=true with - // supportsColumnUpdates()=true. This verifies that the restriction which previously - // blocked column-level updates on the delete+reinsert path has been lifted. - // - // The connector declares "pk" plus any columns being updated (via updatedCols). - // The write schema = requiredDataAttributes() in declared order. - // The REINSERT leg receives the narrow write row; the DELETE leg uses row ID only. class DeltaBasedColumnUpdateSplitOperation( command: Command, updatedCols: Seq[NamedReference] = Nil) @@ -456,11 +428,6 @@ class InMemoryRowLevelOperationTable private ( } } - // A CoW operation that supports column-level updates. The connector declares it needs - // "pk" and "dep" for partition routing, plus any columns the user is updating (via - // updatedCols from RowLevelOperationInfo). supportsColumnUpdates()=true so Spark narrows - // the scan and write schema to exactly requiredDataAttributes(). - // The commit logic reconstructs full rows from the original scan data using pk as a key. class PartitionBasedColumnUpdateOperation( command: Command, updatedCols: Seq[NamedReference] = Nil) extends RowLevelOperation { @@ -471,8 +438,6 @@ class InMemoryRowLevelOperationTable private ( override def supportsColumnUpdates(): Boolean = true override def requiredDataAttributes(): Array[NamedReference] = { - // Always need pk (for row lookup) and dep (partition key). - // Also include any columns being updated so Spark sends their new values. val base = Seq(FieldReference("pk"), FieldReference("dep")) val baseNames = base.map(_.describe()).toSet (base ++ updatedCols.filterNot(r => baseNames.contains(r.describe()))).toArray diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala index e00a1b7d9b219..16e68fc7320cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala @@ -138,9 +138,9 @@ class RowLevelOperationRuntimeGroupFiltering(optimizeSubqueries: Rule[LogicalPla tableAttrs: Seq[Attribute], scanAttrs: Seq[Attribute]): AttributeMap[Attribute] = { - // For column-level updates, the scan may be narrowed to exclude columns that the - // connector does not need. Skip table attributes that are absent from the scan - // instead of throwing -- they cannot appear in the condition if they were pruned. + // The scan may be narrowed to exclude columns not needed by the connector. + // Attributes absent from the scan are skipped here; the caller must ensure + // that any attribute referenced in the condition is present in the scan. val attrMapping = tableAttrs.flatMap { tableAttr => scanAttrs .find(scanAttr => conf.resolver(scanAttr.name, tableAttr.name)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedColumnUpdateTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedColumnUpdateTableSuite.scala index b22f2743a87ee..15bf2bab08932 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedColumnUpdateTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedColumnUpdateTableSuite.scala @@ -39,8 +39,6 @@ class DeltaBasedColumnUpdateTableSuite extends RowLevelOperationSuiteBase { props } - // --- Schema narrowing: verify LogicalWriteInfo.schema() is narrow --- - test("column-update: rowSchema contains only the single assigned column") { createAndInitTable("pk INT NOT NULL, id INT, dep STRING", """{ "pk": 1, "id": 1, "dep": "hr" } @@ -50,7 +48,6 @@ class DeltaBasedColumnUpdateTableSuite extends RowLevelOperationSuiteBase { sql(s"UPDATE $tableNameAsString SET id = -1 WHERE pk = 1") - // Only the assigned column (id) should appear in the row schema -- not pk or dep checkLastWriteInfo( expectedRowSchema = StructType(Seq( StructField("id", IntegerType, nullable = false) @@ -67,7 +64,6 @@ class DeltaBasedColumnUpdateTableSuite extends RowLevelOperationSuiteBase { sql(s"UPDATE $tableNameAsString SET id = -1, dep = 'engineering' WHERE pk = 1") - // Both assigned columns (id, dep) should appear -- but NOT pk (unassigned) checkLastWriteInfo( expectedRowSchema = StructType(Seq( StructField("id", IntegerType, nullable = false), @@ -124,16 +120,12 @@ class DeltaBasedColumnUpdateTableSuite extends RowLevelOperationSuiteBase { expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE)))) } - // --- Identity assignment filtering --- - test("column-update: rowSchema excludes identity assignments in a mixed UPDATE") { createAndInitTable("pk INT NOT NULL, id INT, dep STRING", """{ "pk": 1, "id": 1, "dep": "hr" } |{ "pk": 2, "id": 2, "dep": "software" } |""".stripMargin) - // id = id is identity -- should be excluded from rowSchema - // dep = 'engineering' is a real assignment -- should be included sql(s"UPDATE $tableNameAsString SET id = id, dep = 'engineering' WHERE pk = 1") checkLastWriteInfo( @@ -150,7 +142,6 @@ class DeltaBasedColumnUpdateTableSuite extends RowLevelOperationSuiteBase { |{ "pk": 2, "id": 2, "dep": "software" } |""".stripMargin) - // dep = dep is identity; id = -1 is a real assignment -- only id should appear sql(s"UPDATE $tableNameAsString SET dep = dep, id = -1 WHERE pk = 1") checkLastWriteInfo( @@ -161,7 +152,23 @@ class DeltaBasedColumnUpdateTableSuite extends RowLevelOperationSuiteBase { expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE)))) } - // --- updatedColumns in RowLevelOperationInfo --- + test("column-update: nested struct field update narrows to the root struct column") { + createAndInitTable("pk INT NOT NULL, s STRUCT, dep STRING", + """{ "pk": 1, "s": { "c1": 1, "c2": 2 }, "dep": "hr" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET s.c1 = -1 WHERE pk = 1") + + val updatedNames = table.lastUpdatedColumns.map(_.describe()).toSet + assert(updatedNames == Set("s"), + s"expected [s] in updatedColumns (root struct) but got: $updatedNames") + + val writeSchema = table.lastWriteInfo.schema() + assert(writeSchema.fieldNames.contains("s"), + s"s must be in write schema: $writeSchema") + assert(!writeSchema.fieldNames.contains("dep"), + s"dep must not be in write schema: $writeSchema") + } test("column-update: updatedColumns contains non-identity assigned columns") { createAndInitTable("pk INT NOT NULL, id INT, dep STRING", @@ -180,7 +187,6 @@ class DeltaBasedColumnUpdateTableSuite extends RowLevelOperationSuiteBase { """{ "pk": 1, "id": 1, "dep": "hr" } |""".stripMargin) - // dep = dep is identity; only id should appear in updatedColumns sql(s"UPDATE $tableNameAsString SET id = -1, dep = dep WHERE pk = 1") val updatedNames = table.lastUpdatedColumns.map(_.describe()).toSet @@ -200,9 +206,6 @@ class DeltaBasedColumnUpdateTableSuite extends RowLevelOperationSuiteBase { } test("column-update: updatedColumns is empty for DELETE (Javadoc contract)") { - // DELETE never has updated columns -- verify that the default empty array is passed - // through RowLevelOperationInfo even when a column-update connector handles the DELETE. - // Use a partition-column condition so the InMemory table can process the filter. createAndInitTable("pk INT NOT NULL, id INT, dep STRING", """{ "pk": 1, "id": 1, "dep": "hr" } |{ "pk": 2, "id": 2, "dep": "software" } @@ -214,8 +217,6 @@ class DeltaBasedColumnUpdateTableSuite extends RowLevelOperationSuiteBase { s"DELETE must pass empty updatedColumns but got: ${table.lastUpdatedColumns.mkString(", ")}") } - // --- Data correctness --- - test("column-update: data correctness -- single column update") { createAndInitTable("pk INT NOT NULL, id INT, dep STRING", """{ "pk": 1, "id": 1, "dep": "hr" } @@ -251,7 +252,6 @@ class DeltaBasedColumnUpdateTableSuite extends RowLevelOperationSuiteBase { |{ "pk": 3, "id": 3, "dep": "hr" } |""".stripMargin) - // Only dep changes; id stays as-is even though id = id is in the SET list. sql(s"UPDATE $tableNameAsString SET id = id, dep = 'engineering' WHERE pk = 1") checkAnswer( @@ -259,16 +259,12 @@ class DeltaBasedColumnUpdateTableSuite extends RowLevelOperationSuiteBase { Row(1, 1, "engineering") :: Row(2, 2, "software") :: Row(3, 3, "hr") :: Nil) } - // --- Scan narrowing: verify the connector only receives the columns it needs --- - test("column-update: scan excludes the assigned column when SET to a literal") { createAndInitTable("pk INT NOT NULL, id INT, dep STRING", """{ "pk": 1, "id": 1, "dep": "hr" } |{ "pk": 2, "id": 2, "dep": "software" } |""".stripMargin) - // id is the target of a literal assignment -- its current value is not needed. - // pk is needed for the WHERE condition and as rowId. sql(s"UPDATE $tableNameAsString SET id = -1 WHERE pk = 1") val scanSchema = table.lastScanSchema @@ -282,8 +278,6 @@ class DeltaBasedColumnUpdateTableSuite extends RowLevelOperationSuiteBase { |{ "pk": 2, "salary": 200, "bonus": 20, "dep": "software" } |""".stripMargin) - // salary appears on the RHS (salary * 2) so it must be scanned. - // bonus is not referenced anywhere -- excluded. sql(s"UPDATE $tableNameAsString SET salary = salary * 2") val scanSchema = table.lastScanSchema @@ -297,7 +291,6 @@ class DeltaBasedColumnUpdateTableSuite extends RowLevelOperationSuiteBase { |{ "pk": 2, "id": 2, "salary": 200, "dep": "software" } |""".stripMargin) - // dep is a literal assignment; id and salary are not referenced -- only pk needed. sql(s"UPDATE $tableNameAsString SET dep = 'engineering' WHERE pk = 1") val scanSchema = table.lastScanSchema @@ -312,9 +305,7 @@ class DeltaBasedColumnUpdateTableSuite extends RowLevelOperationSuiteBase { |{ "pk": 2, "salary": 200, "bonus": 20, "dep": "software" } |""".stripMargin) - // dep appears in the WHERE clause -- must be scanned even though it is not assigned. - // bonus is neither assigned nor in the condition -- excluded. - // salary is set to a literal -- current value not needed. + // dep is in WHERE but not assigned -- must still be scanned sql(s"UPDATE $tableNameAsString SET salary = -1 WHERE dep = 'hr'") val scanSchema = table.lastScanSchema @@ -325,12 +316,6 @@ class DeltaBasedColumnUpdateTableSuite extends RowLevelOperationSuiteBase { s"salary should be excluded (literal assignment): $scanSchema") } - // --------------------------------------------------------------------------- - // Connector-driven scan narrowing via requiredDataAttributes() - // --------------------------------------------------------------------------- - - // Creates a table backed by DeltaBasedColumnUpdateOperationWithReqAttrs, which overrides - // requiredDataAttributes() to return the given comma-separated column names. private def createAndInitTableWithReqAttrs( reqAttrs: String, schemaString: String, @@ -349,10 +334,6 @@ class DeltaBasedColumnUpdateTableSuite extends RowLevelOperationSuiteBase { } test("column-update: requiredDataAttributes forces connector-declared column into scan") { - // Connector declares it always needs "dep". - // SQL assigns "id" (literal) with condition on "pk". - // Connector-driven scan = {pk, dep} (dep from connector declaration; pk from condition). - // id is NOT in scan: literal assignment + not declared by connector. createAndInitTableWithReqAttrs("dep", "pk INT NOT NULL, id INT, dep STRING", """{ "pk": 1, "id": 1, "dep": "hr" } |{ "pk": 2, "id": 2, "dep": "software" } @@ -365,12 +346,10 @@ class DeltaBasedColumnUpdateTableSuite extends RowLevelOperationSuiteBase { s"dep must be in scan (connector required): $scanSchema") assert(scanSchema.fieldNames.contains("pk"), s"pk must be in scan: $scanSchema") assert(!scanSchema.fieldNames.contains("id"), - s"id should be excluded (literal assignment, not declared): $scanSchema") + s"id should be excluded (literal, not declared): $scanSchema") } test("column-update: requiredDataAttributes - data correctness") { - // Connector declares "dep,id" so it receives both the new id value and dep for routing. - // The write schema is exactly requiredDataAttributes = {dep, id} (declared order). createAndInitTableWithReqAttrs("dep,id", "pk INT NOT NULL, id INT, dep STRING", """{ "pk": 1, "id": 1, "dep": "hr" } |{ "pk": 2, "id": 2, "dep": "software" } @@ -385,13 +364,6 @@ class DeltaBasedColumnUpdateTableSuite extends RowLevelOperationSuiteBase { } test("column-update: empty requiredDataAttributes falls back to heuristic") { - // "column-update" uses DeltaBasedColumnUpdateOperation whose requiredDataAttributes() - // returns the default empty array. - // With the optimizer-driven approach for MOR, the scan is narrowed by V2ScanRelationPushDown - // which observes what columns the write plan actually references. - // SET id = -1 (literal assignment): id is not referenced from the scan, so it is pruned. - // dep is the partitioning column; since it is not declared in requiredDataAttributes() - // and is not referenced by the WHERE condition (pk = 1), it may be pruned from the scan. createAndInitTable("pk INT NOT NULL, id INT, dep STRING", """{ "pk": 1, "id": 1, "dep": "hr" } |{ "pk": 2, "id": 2, "dep": "software" } @@ -401,16 +373,21 @@ class DeltaBasedColumnUpdateTableSuite extends RowLevelOperationSuiteBase { val scanSchema = table.lastScanSchema assert(!scanSchema.fieldNames.contains("id"), - s"id must NOT be in scan (literal assignment, no scan reference): $scanSchema") + s"id must NOT be in scan (literal assignment): $scanSchema") assert(scanSchema.fieldNames.contains("pk"), s"pk must be in scan (condition): $scanSchema") } - // --------------------------------------------------------------------------- - // Connector uses RowLevelOperationInfo.updatedColumns() to derive its own - // requiredDataAttributes() dynamically. - // DeltaBasedColumnUpdateOperationFromInfo always adds "pk" (for row lookup) to - // whatever Spark reports as updated columns. - // --------------------------------------------------------------------------- + test("column-update: requiredDataAttributes throws AnalysisException for invalid column") { + createAndInitTableWithReqAttrs("nonexistent_col", "pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 1, "dep": "hr" } + |""".stripMargin) + + val ex = intercept[org.apache.spark.sql.AnalysisException] { + sql(s"UPDATE $tableNameAsString SET id = -1 WHERE pk = 1") + } + assert(ex.getMessage.contains("nonexistent_col"), + s"Expected error about unresolvable column but got: ${ex.getMessage}") + } private def createAndInitTableFromInfo(schemaString: String, jsonData: String): Unit = { val props = new java.util.HashMap[String, String]() @@ -427,16 +404,6 @@ class DeltaBasedColumnUpdateTableSuite extends RowLevelOperationSuiteBase { } test("column-update from-info: connector adds pk to updatedColumns for requiredDataAttributes") { - // Connector receives updatedColumns=[salary], adds pk for row lookup. - // requiredDataAttributes() = [pk, salary]. - // - // salary = -1 is a LITERAL assignment: the write plan references Literal(-1) not the - // scan's salary column. Since salary is in assignedAttrs, it is not a connectorExtraAttr - // pass-through either. V2ScanRelationPushDown therefore does not see salary referenced - // and prunes it from the scan. - // - // The scan contains: pk (connector pass-through), dep (partitioning + WHERE condition). - // The scan excludes: salary (literal assignment), id and bonus (not declared, not in cond). createAndInitTableFromInfo("pk INT NOT NULL, salary INT, id INT, bonus INT, dep STRING", """{ "pk": 1, "salary": 100, "id": 10, "bonus": 5, "dep": "hr" } |{ "pk": 2, "salary": 200, "id": 20, "bonus": 6, "dep": "software" } @@ -446,19 +413,13 @@ class DeltaBasedColumnUpdateTableSuite extends RowLevelOperationSuiteBase { sql(s"UPDATE $tableNameAsString SET salary = -1 WHERE dep = 'hr'") val scanSchema = table.lastScanSchema - assert(scanSchema.fieldNames.contains("pk"), - s"pk must be in scan (connector pass-through via connectorExtraAttrs): $scanSchema") - assert(scanSchema.fieldNames.contains("dep"), - s"dep must be in scan (partitioning + WHERE): $scanSchema") - assert(!scanSchema.fieldNames.contains("id"), - s"id must be excluded (not declared, not assigned, not in condition): $scanSchema") - assert(!scanSchema.fieldNames.contains("bonus"), - s"bonus must be excluded (not declared, not assigned, not in condition): $scanSchema") + assert(scanSchema.fieldNames.contains("pk"), s"pk must be in scan: $scanSchema") + assert(scanSchema.fieldNames.contains("dep"), s"dep must be in scan (WHERE): $scanSchema") + assert(!scanSchema.fieldNames.contains("id"), s"id must be excluded: $scanSchema") + assert(!scanSchema.fieldNames.contains("bonus"), s"bonus must be excluded: $scanSchema") } test("column-update from-info: write schema is updatedColumns + pk pass-through") { - // requiredDataAttributes = [pk, salary] (pk always added; salary because it's assigned). - // Write schema = requiredDataAttributes in declared order = {pk, salary}. createAndInitTableFromInfo("pk INT NOT NULL, salary INT, id INT, dep STRING", """{ "pk": 1, "salary": 100, "id": 10, "dep": "hr" } |{ "pk": 2, "salary": 200, "id": 20, "dep": "software" } @@ -467,20 +428,14 @@ class DeltaBasedColumnUpdateTableSuite extends RowLevelOperationSuiteBase { sql(s"UPDATE $tableNameAsString SET salary = -1 WHERE pk = 1") val writeSchema = table.lastWriteInfo.schema() - assert(writeSchema.fieldNames.contains("salary"), - s"salary must be in write schema (assigned): $writeSchema") - assert(writeSchema.fieldNames.contains("pk"), - s"pk must be in write schema " + - s"(connector pass-through via requiredDataAttributes): $writeSchema") - assert(!writeSchema.fieldNames.contains("id"), - s"id must not be in write schema: $writeSchema") + assert(writeSchema.fieldNames.contains("salary"), s"salary must be in write schema: $writeSchema") + assert(writeSchema.fieldNames.contains("pk"), s"pk must be in write schema: $writeSchema") + assert(!writeSchema.fieldNames.contains("id"), s"id must not be in write schema: $writeSchema") assert(!writeSchema.fieldNames.contains("dep"), - s"dep must not be in write schema (partitioning, not a data column to write): $writeSchema") + s"dep must not be in write schema: $writeSchema") } test("column-update from-info: pk already in updatedColumns is not duplicated") { - // When the user updates pk itself, updatedColumns=[pk, salary]. - // Connector sees pk already present -> requiredDataAttributes=[pk, salary] (no dup). createAndInitTableFromInfo("pk INT NOT NULL, salary INT, dep STRING", """{ "pk": 1, "salary": 100, "dep": "hr" } |{ "pk": 2, "salary": 200, "dep": "software" } @@ -502,82 +457,6 @@ class DeltaBasedColumnUpdateTableSuite extends RowLevelOperationSuiteBase { sql(s"UPDATE $tableNameAsString SET salary = -1 WHERE dep = 'hr'") - // salary updated for hr rows; id preserved (not in write schema, connector uses pk lookup) - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString ORDER BY pk"), - Row(1, -1, 10, "hr") :: - Row(2, 200, 20, "software") :: - Row(3, -1, 30, "hr") :: Nil) - } - - // --------------------------------------------------------------------------- - // CoW connector with supportsColumnUpdates() on RowLevelOperation. - // PartitionBasedColumnUpdateOperation declares requiredDataAttributes() = [pk, dep] and - // supportsColumnUpdates() = true. Spark narrows the scan to connector-declared + assigned - // columns; bonus is excluded. The connector reconstructs full rows via pk lookup. - // --------------------------------------------------------------------------- - - private def createAndInitTableCoW(schemaString: String, jsonData: String): Unit = { - val props = new java.util.HashMap[String, String]() - props.put("column-update-cow", "true") - val columns = CatalogV2Util.structTypeToV2Columns(StructType.fromDDL(schemaString)) - val transforms = Array[Transform](identity(reference(Seq("dep")))) - val tableInfo = new TableInfo.Builder() - .withColumns(columns) - .withPartitions(transforms) - .withProperties(props) - .build() - catalog.createTable(ident, tableInfo) - append(schemaString, jsonData) - } - - test("column-update CoW: scan excludes columns not declared and not assigned") { - // Connector declares [pk, dep]. SET salary = -1. - // Narrow scan = pk (declared) + dep (declared + condition + partitioning) - // + salary (assigned LHS). bonus is neither declared nor assigned -> excluded. - createAndInitTableCoW("pk INT NOT NULL, salary INT, bonus INT, dep STRING", - """{ "pk": 1, "salary": 100, "bonus": 10, "dep": "hr" } - |{ "pk": 2, "salary": 200, "bonus": 20, "dep": "software" } - |{ "pk": 3, "salary": 300, "bonus": 30, "dep": "hr" } - |""".stripMargin) - - sql(s"UPDATE $tableNameAsString SET salary = -1 WHERE dep = 'hr'") - - val scanSchema = table.lastScanSchema - assert(scanSchema.fieldNames.contains("pk"), s"pk must be in scan: $scanSchema") - assert(scanSchema.fieldNames.contains("dep"), s"dep must be in scan: $scanSchema") - assert(scanSchema.fieldNames.contains("salary"), - s"salary must be in scan (assigned LHS): $scanSchema") - assert(!scanSchema.fieldNames.contains("bonus"), s"bonus must be excluded: $scanSchema") - } - - test("column-update CoW: write schema contains only declared + assigned columns") { - createAndInitTableCoW("pk INT NOT NULL, salary INT, bonus INT, dep STRING", - """{ "pk": 1, "salary": 100, "bonus": 10, "dep": "hr" } - |{ "pk": 2, "salary": 200, "bonus": 20, "dep": "software" } - |""".stripMargin) - - sql(s"UPDATE $tableNameAsString SET salary = -1 WHERE dep = 'hr'") - - val writeSchema = table.lastWriteInfo.schema() - assert(writeSchema.fieldNames.contains("pk"), s"pk must be in write schema: $writeSchema") - assert(writeSchema.fieldNames.contains("dep"), s"dep must be in write schema: $writeSchema") - assert(writeSchema.fieldNames.contains("salary"), - s"salary must be in write schema: $writeSchema") - assert(!writeSchema.fieldNames.contains("bonus"), - s"bonus must not be in write schema: $writeSchema") - } - - test("column-update CoW: data correctness -- bonus preserved, salary updated") { - // bonus is not in the write schema; the connector must preserve it from the original row. - createAndInitTableCoW("pk INT NOT NULL, salary INT, bonus INT, dep STRING", - """{ "pk": 1, "salary": 100, "bonus": 10, "dep": "hr" } - |{ "pk": 2, "salary": 200, "bonus": 20, "dep": "software" } - |{ "pk": 3, "salary": 300, "bonus": 30, "dep": "hr" } - |""".stripMargin) - - sql(s"UPDATE $tableNameAsString SET salary = -1 WHERE dep = 'hr'") - checkAnswer( sql(s"SELECT * FROM $tableNameAsString ORDER BY pk"), Row(1, -1, 10, "hr") :: @@ -585,43 +464,6 @@ class DeltaBasedColumnUpdateTableSuite extends RowLevelOperationSuiteBase { Row(3, -1, 30, "hr") :: Nil) } - test("column-update CoW: narrow scan + subquery WHERE condition") { - // Exercises buildReplaceDataWithUnionPlan + narrow scan + the flatMap change in - // RowLevelOperationRuntimeGroupFiltering.buildTableToScanAttrMap. - // The subquery forces the UNION path (updated rows + remaining rows). - // bonus is not declared and not assigned, must be excluded from scan and write - // but the subquery-based filter must still work correctly with the narrow scan. - createAndInitTableCoW("pk INT NOT NULL, salary INT, bonus INT, dep STRING", - """{ "pk": 1, "salary": 100, "bonus": 10, "dep": "hr" } - |{ "pk": 2, "salary": 200, "bonus": 20, "dep": "software" } - |{ "pk": 3, "salary": 300, "bonus": 30, "dep": "hr" } - |""".stripMargin) - - import testImplicits._ - val subqueryDF = Seq("hr").toDF() - subqueryDF.createOrReplaceTempView("target_deps") - - sql( - s"""UPDATE $tableNameAsString - |SET salary = -1 - |WHERE dep IN (SELECT * FROM target_deps) - |""".stripMargin) - - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString ORDER BY pk"), - Row(1, -1, 10, "hr") :: - Row(2, 200, 20, "software") :: - Row(3, -1, 30, "hr") :: Nil) - } - - // --------------------------------------------------------------------------- - // Delta connector with representUpdateAsDeleteAndInsert=true AND supportsColumnUpdates=true. - // - // Point 7: The restriction that blocked column-level updates on the delete+reinsert path - // has been removed. The REINSERT leg of the Expand uses only assigned values (the narrow - // write schema from effectiveRowAttrs), and the DELETE leg uses row ID only. - // --------------------------------------------------------------------------- - private def createAndInitTableSplit(schemaString: String, jsonData: String): Unit = { val props = new java.util.HashMap[String, String]() props.put("column-update-split", "true") @@ -637,10 +479,6 @@ class DeltaBasedColumnUpdateTableSuite extends RowLevelOperationSuiteBase { } test("column-update split: write schema is narrow (assigned + pk pass-through)") { - // representUpdateAsDeleteAndInsert=true + supportsColumnUpdates=true. - // requiredDataAttributes() = [pk, id] (pk always declared; id because it's being updated). - // The write schema = requiredDataAttributes() in declared order = {pk, id}. - // dep is NOT in the write schema (not declared, not assigned). createAndInitTableSplit("pk INT NOT NULL, id INT, dep STRING", """{ "pk": 1, "id": 1, "dep": "hr" } |{ "pk": 2, "id": 2, "dep": "software" } @@ -648,7 +486,6 @@ class DeltaBasedColumnUpdateTableSuite extends RowLevelOperationSuiteBase { sql(s"UPDATE $tableNameAsString SET id = -1 WHERE pk = 1") - // Write schema is exactly requiredDataAttributes = {pk, id} in declared order. checkLastWriteInfo( expectedRowSchema = StructType(Seq( StructField("pk", IntegerType, nullable = false), @@ -659,8 +496,6 @@ class DeltaBasedColumnUpdateTableSuite extends RowLevelOperationSuiteBase { } test("column-update split: data correctness") { - // representUpdateAsDeleteAndInsert=true + supportsColumnUpdates=true. - // The connector receives narrow REINSERT rows and must reconstruct full rows. createAndInitTableSplit("pk INT NOT NULL, id INT, dep STRING", """{ "pk": 1, "id": 1, "dep": "hr" } |{ "pk": 2, "id": 2, "dep": "software" } @@ -674,4 +509,3 @@ class DeltaBasedColumnUpdateTableSuite extends RowLevelOperationSuiteBase { Row(1, -1, "hr") :: Row(2, 2, "software") :: Row(3, -1, "hr") :: Nil) } } - diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuiteBase.scala index 6c7dfc25d3a9a..13af3b0c65146 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuiteBase.scala @@ -23,11 +23,6 @@ abstract class DeltaBasedUpdateTableSuiteBase extends UpdateTableSuiteBase { override protected def deltaUpdate: Boolean = true - // --------------------------------------------------------------------------- - // RowLevelOperationInfo.updatedColumns() -- Spark informs the connector which - // columns are genuinely being updated (non-identity assignments only). - // --------------------------------------------------------------------------- - test("updatedColumns: single non-identity assignment") { createAndInitTable("pk INT NOT NULL, id INT, dep STRING", """{ "pk": 1, "id": 1, "dep": "hr" } @@ -53,7 +48,6 @@ abstract class DeltaBasedUpdateTableSuiteBase extends UpdateTableSuiteBase { """{ "pk": 1, "id": 1, "dep": "hr" } |""".stripMargin) - // dep = dep is an identity assignment and must NOT appear in updatedColumns sql(s"UPDATE $tableNameAsString SET id = -1, dep = dep WHERE pk = 1") checkLastUpdatedColumns("id") @@ -85,7 +79,7 @@ abstract class DeltaBasedUpdateTableSuiteBase extends UpdateTableSuiteBase { """{ "pk": 1, "id": 1, "dep": "hr" } |""".stripMargin) - // SET id = dep assigns a different column's value to id -- not identity + // SET id = dep assigns a different column's value -- not identity sql(s"UPDATE $tableNameAsString SET id = 0, dep = dep WHERE pk = 1") checkLastUpdatedColumns("id") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedColumnUpdateTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedColumnUpdateTableSuite.scala new file mode 100644 index 0000000000000..fe4522dfdc42e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedColumnUpdateTableSuite.scala @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +import org.apache.spark.sql.Row +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, TableInfo} +import org.apache.spark.sql.connector.expressions.LogicalExpressions.{identity, reference} +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.types.StructType + +class GroupBasedColumnUpdateTableSuite extends RowLevelOperationSuiteBase { + + private def createAndInitTableReplaceData(schemaString: String, jsonData: String): Unit = { + val props = new java.util.HashMap[String, String]() + props.put("column-update-cow", "true") + val columns = CatalogV2Util.structTypeToV2Columns(StructType.fromDDL(schemaString)) + val transforms = Array[Transform](identity(reference(Seq("dep")))) + val tableInfo = new TableInfo.Builder() + .withColumns(columns) + .withPartitions(transforms) + .withProperties(props) + .build() + catalog.createTable(ident, tableInfo) + append(schemaString, jsonData) + } + + test("column-update ReplaceData: scan excludes columns not declared and not assigned") { + createAndInitTableReplaceData("pk INT NOT NULL, salary INT, bonus INT, dep STRING", + """{ "pk": 1, "salary": 100, "bonus": 10, "dep": "hr" } + |{ "pk": 2, "salary": 200, "bonus": 20, "dep": "software" } + |{ "pk": 3, "salary": 300, "bonus": 30, "dep": "hr" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET salary = -1 WHERE dep = 'hr'") + + val scanSchema = table.lastScanSchema + assert(scanSchema.fieldNames.contains("pk"), s"pk must be in scan: $scanSchema") + assert(scanSchema.fieldNames.contains("dep"), s"dep must be in scan: $scanSchema") + assert(scanSchema.fieldNames.contains("salary"), + s"salary must be in scan (assigned): $scanSchema") + assert(!scanSchema.fieldNames.contains("bonus"), s"bonus must be excluded: $scanSchema") + } + + test("column-update ReplaceData: write schema contains only declared + assigned columns") { + createAndInitTableReplaceData("pk INT NOT NULL, salary INT, bonus INT, dep STRING", + """{ "pk": 1, "salary": 100, "bonus": 10, "dep": "hr" } + |{ "pk": 2, "salary": 200, "bonus": 20, "dep": "software" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET salary = -1 WHERE dep = 'hr'") + + val writeSchema = table.lastWriteInfo.schema() + assert(writeSchema.fieldNames.contains("pk"), s"pk must be in write schema: $writeSchema") + assert(writeSchema.fieldNames.contains("dep"), s"dep must be in write schema: $writeSchema") + assert(writeSchema.fieldNames.contains("salary"), + s"salary must be in write schema: $writeSchema") + assert(!writeSchema.fieldNames.contains("bonus"), + s"bonus must not be in write schema: $writeSchema") + } + + test("column-update ReplaceData: data correctness -- bonus preserved, salary updated") { + createAndInitTableReplaceData("pk INT NOT NULL, salary INT, bonus INT, dep STRING", + """{ "pk": 1, "salary": 100, "bonus": 10, "dep": "hr" } + |{ "pk": 2, "salary": 200, "bonus": 20, "dep": "software" } + |{ "pk": 3, "salary": 300, "bonus": 30, "dep": "hr" } + |""".stripMargin) + + sql(s"UPDATE $tableNameAsString SET salary = -1 WHERE dep = 'hr'") + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString ORDER BY pk"), + Row(1, -1, 10, "hr") :: + Row(2, 200, 20, "software") :: + Row(3, -1, 30, "hr") :: Nil) + } + + test("column-update ReplaceData: narrow scan + subquery WHERE condition") { + createAndInitTableReplaceData("pk INT NOT NULL, salary INT, bonus INT, dep STRING", + """{ "pk": 1, "salary": 100, "bonus": 10, "dep": "hr" } + |{ "pk": 2, "salary": 200, "bonus": 20, "dep": "software" } + |{ "pk": 3, "salary": 300, "bonus": 30, "dep": "hr" } + |""".stripMargin) + + import testImplicits._ + val subqueryDF = Seq("hr").toDF() + subqueryDF.createOrReplaceTempView("target_deps") + + sql( + s"""UPDATE $tableNameAsString + |SET salary = -1 + |WHERE dep IN (SELECT * FROM target_deps) + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString ORDER BY pk"), + Row(1, -1, 10, "hr") :: + Row(2, 200, 20, "software") :: + Row(3, -1, 30, "hr") :: Nil) + } +} From 4060cbf592149ab97d309aace560098fc90adf71 Mon Sep 17 00:00:00 2001 From: Anurag Mantripragada Date: Wed, 6 May 2026 21:53:45 -0700 Subject: [PATCH 5/5] Scala fmt --- .../sql/connector/DeltaBasedColumnUpdateTableSuite.scala | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedColumnUpdateTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedColumnUpdateTableSuite.scala index 15bf2bab08932..1440f481d982e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedColumnUpdateTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedColumnUpdateTableSuite.scala @@ -428,9 +428,12 @@ class DeltaBasedColumnUpdateTableSuite extends RowLevelOperationSuiteBase { sql(s"UPDATE $tableNameAsString SET salary = -1 WHERE pk = 1") val writeSchema = table.lastWriteInfo.schema() - assert(writeSchema.fieldNames.contains("salary"), s"salary must be in write schema: $writeSchema") - assert(writeSchema.fieldNames.contains("pk"), s"pk must be in write schema: $writeSchema") - assert(!writeSchema.fieldNames.contains("id"), s"id must not be in write schema: $writeSchema") + assert(writeSchema.fieldNames.contains("salary"), + s"salary must be in write schema: $writeSchema") + assert(writeSchema.fieldNames.contains("pk"), + s"pk must be in write schema: $writeSchema") + assert(!writeSchema.fieldNames.contains("id"), + s"id must not be in write schema: $writeSchema") assert(!writeSchema.fieldNames.contains("dep"), s"dep must not be in write schema: $writeSchema") }