Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
Original file line number Diff line number Diff line change
Expand Up @@ -1923,10 +1923,6 @@ def conf(cls):
cfg.set("spark.sql.streaming.stateStore.checkpointFormatVersion", "2")
return cfg

# TODO(SPARK-53332): Add test back when checkpoint v2 support exists for snapshotStartBatchId
def test_transform_with_value_state_metadata(self):
pass


class TransformWithStateInPySparkWithCheckpointV2TestsMixin(TransformWithStateInPySparkTestsMixin):
@classmethod
Expand All @@ -1935,10 +1931,6 @@ def conf(cls):
cfg.set("spark.sql.streaming.stateStore.checkpointFormatVersion", "2")
return cfg

# TODO(SPARK-53332): Add test back when checkpoint v2 support exists for snapshotStartBatchId
def test_transform_with_value_state_metadata(self):
pass


class TransformWithStateInPandasTests(TransformWithStateInPandasTestsMixin, ReusedSQLTestCase):
pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -610,14 +610,6 @@ object StateSourceOptions extends DataSourceOptions {
)
}

if (startOperatorStateUniqueIds.isDefined) {
if (fromSnapshotOptions.isDefined) {
throw StateDataSourceErrors.invalidOptionValue(
SNAPSHOT_START_BATCH_ID,
"Snapshot reading is currently not supported with checkpoint v2.")
}
}

StateSourceOptions(
resolvedCpLocation, batchId.get, operatorId, storeName, joinSide,
readChangeFeed, fromSnapshotOptions, readChangeFeedOptions,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,9 @@ class StatePartitionReader(
provider.asInstanceOf[SupportsFineGrainedReplay]
.replayReadStateFromSnapshot(
fromSnapshotOptions.snapshotStartBatchId + 1,
partition.sourceOptions.batchId + 1)
partition.sourceOptions.batchId + 1,
getStartStoreUniqueId,
getEndStoreUniqueId)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, Par
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues
import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
import org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOperatorStateInfo
import org.apache.spark.sql.execution.streaming.operators.stateful.join.{JoinStateManagerStoreGenerator, SymmetricHashJoinStateManager}
import org.apache.spark.sql.execution.streaming.operators.stateful.join.{JoinStateManagerStoreGenerator, SnapshotOptions, SymmetricHashJoinStateManager}
import org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinHelper.{JoinSide, LeftSide, RightSide}
import org.apache.spark.sql.execution.streaming.state.StateStoreConf
import org.apache.spark.sql.types.{BooleanType, StructType}
Expand Down Expand Up @@ -78,22 +78,40 @@ class StreamStreamJoinStatePartitionReader(

private val startStateStoreCheckpointIds =
SymmetricHashJoinStateManager.getStateStoreCheckpointIds(
partition.partition,
partition.sourceOptions.startOperatorStateUniqueIds,
usesVirtualColumnFamilies)
partition.partition,
partition.sourceOptions.startOperatorStateUniqueIds,
usesVirtualColumnFamilies)

private val keyToNumValuesStateStoreCkptId = if (joinSide == LeftSide) {
private val endStateStoreCheckpointIds =
SymmetricHashJoinStateManager.getStateStoreCheckpointIds(
partition.partition,
partition.sourceOptions.endOperatorStateUniqueIds,
usesVirtualColumnFamilies)

private val startKeyToNumValuesStateStoreCkptId = if (joinSide == LeftSide) {
startStateStoreCheckpointIds.left.keyToNumValues
} else {
startStateStoreCheckpointIds.right.keyToNumValues
}

private val keyWithIndexToValueStateStoreCkptId = if (joinSide == LeftSide) {
private val startKeyWithIndexToValueStateStoreCkptId = if (joinSide == LeftSide) {
startStateStoreCheckpointIds.left.keyWithIndexToValue
} else {
startStateStoreCheckpointIds.right.keyWithIndexToValue
}

private val endKeyToNumValuesStateStoreCkptId = if (joinSide == LeftSide) {
endStateStoreCheckpointIds.left.keyToNumValues
} else {
endStateStoreCheckpointIds.right.keyToNumValues
}

private val endKeyWithIndexToValueStateStoreCkptId = if (joinSide == LeftSide) {
endStateStoreCheckpointIds.left.keyWithIndexToValue
} else {
endStateStoreCheckpointIds.right.keyWithIndexToValue
}

/*
* This is to handle the difference of schema across state format versions. The major difference
* is whether we have added new field(s) in addition to the fields from input schema.
Expand Down Expand Up @@ -150,13 +168,19 @@ class StreamStreamJoinStatePartitionReader(
storeConf = storeConf,
hadoopConf = hadoopConf.value,
partitionId = partition.partition,
keyToNumValuesStateStoreCkptId = keyToNumValuesStateStoreCkptId,
keyWithIndexToValueStateStoreCkptId = keyWithIndexToValueStateStoreCkptId,
keyToNumValuesStateStoreCkptId = startKeyToNumValuesStateStoreCkptId,
keyWithIndexToValueStateStoreCkptId = startKeyWithIndexToValueStateStoreCkptId,
formatVersion,
skippedNullValueCount = None,
useStateStoreCoordinator = false,
snapshotStartVersion =
partition.sourceOptions.fromSnapshotOptions.map(_.snapshotStartBatchId + 1),
snapshotOptions =
partition.sourceOptions.fromSnapshotOptions.map(opts => SnapshotOptions(
snapshotVersion = opts.snapshotStartBatchId + 1,
endVersion = partition.sourceOptions.batchId + 1,
startKeyToNumValuesStateStoreCkptId = startKeyToNumValuesStateStoreCkptId,
startKeyWithIndexToValueStateStoreCkptId = startKeyWithIndexToValueStateStoreCkptId,
endKeyToNumValuesStateStoreCkptId = endKeyToNumValuesStateStoreCkptId,
endKeyWithIndexToValueStateStoreCkptId = endKeyWithIndexToValueStateStoreCkptId)),
joinStoreGenerator = new JoinStateManagerStoreGenerator()
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ import org.apache.spark.util.NextIterator
* store providers being used in this class. If true, Spark will
* take care of management for state store providers, e.g. running
* maintenance task for these providers.
* @param snapshotOptions Options controlling snapshot-based state replay for the state data
* source reader.
* @param joinStoreGenerator The generator to create state store instances, re-using the same
* instance when the join implementation uses virtual column families
* for join version 3.
Expand Down Expand Up @@ -95,15 +97,20 @@ abstract class SymmetricHashJoinStateManager(
stateFormatVersion: Int,
skippedNullValueCount: Option[SQLMetric] = None,
useStateStoreCoordinator: Boolean = true,
snapshotStartVersion: Option[Long] = None,
snapshotOptions: Option[SnapshotOptions] = None,
joinStoreGenerator: JoinStateManagerStoreGenerator) extends Logging {
import SymmetricHashJoinStateManager._

protected val keySchema = StructType(
joinKeys.zipWithIndex.map { case (k, i) => StructField(s"field$i", k.dataType, k.nullable) })
protected val keyAttributes = toAttributes(keySchema)
protected val keyToNumValues = new KeyToNumValuesStore(stateFormatVersion)
protected val keyWithIndexToValue = new KeyWithIndexToValueStore(stateFormatVersion)

protected val keyToNumValues = new KeyToNumValuesStore(
stateFormatVersion,
snapshotOptions.map(_.getKeyToNumValuesHandlerOpts()))
protected val keyWithIndexToValue = new KeyWithIndexToValueStore(
stateFormatVersion,
snapshotOptions.map(_.getKeyWithIndexToValueHandlerOpts()))

/*
=====================================================
Expand Down Expand Up @@ -456,7 +463,8 @@ abstract class SymmetricHashJoinStateManager(
/** Helper trait for invoking common functionalities of a state store. */
protected abstract class StateStoreHandler(
stateStoreType: StateStoreType,
stateStoreCkptId: Option[String]) extends Logging {
stateStoreCkptId: Option[String],
handlerSnapshotOptions: Option[HandlerSnapshotOptions] = None) extends Logging {
private var stateStoreProvider: StateStoreProvider = _

/** StateStore that the subclasses of this class is going to operate on */
Expand Down Expand Up @@ -497,7 +505,7 @@ abstract class SymmetricHashJoinStateManager(
}
val storeProviderId = StateStoreProviderId(stateInfo.get, partitionId, storeName)
val store = if (useStateStoreCoordinator) {
assert(snapshotStartVersion.isEmpty, "Should not use state store coordinator " +
assert(handlerSnapshotOptions.isEmpty, "Should not use state store coordinator " +
"when reading state as data source.")
joinStoreGenerator.getStore(
storeProviderId, keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema),
Expand All @@ -509,13 +517,19 @@ abstract class SymmetricHashJoinStateManager(
storeProviderId, keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema),
useColumnFamilies = useVirtualColumnFamilies, storeConf, hadoopConf,
useMultipleValuesPerKey = false, stateSchemaProvider = None)
if (snapshotStartVersion.isDefined) {
if (handlerSnapshotOptions.isDefined) {
if (!stateStoreProvider.isInstanceOf[SupportsFineGrainedReplay]) {
throw StateStoreErrors.stateStoreProviderDoesNotSupportFineGrainedReplay(
stateStoreProvider.getClass.toString)
}
val opts = handlerSnapshotOptions.get
stateStoreProvider.asInstanceOf[SupportsFineGrainedReplay]
.replayStateFromSnapshot(snapshotStartVersion.get, stateInfo.get.storeVersion)
.replayStateFromSnapshot(
opts.snapshotVersion,
opts.endVersion,
readOnly = true,
opts.startStateStoreCkptId,
opts.endStateStoreCkptId)
} else {
stateStoreProvider.getStore(stateInfo.get.storeVersion, stateStoreCkptId)
}
Expand All @@ -539,9 +553,12 @@ abstract class SymmetricHashJoinStateManager(


/** A wrapper around a [[StateStore]] that stores [key -> number of values]. */
protected class KeyToNumValuesStore(val stateFormatVersion: Int)
extends StateStoreHandler(KeyToNumValuesType, keyToNumValuesStateStoreCkptId) {

protected class KeyToNumValuesStore(
val stateFormatVersion: Int,
val handlerSnapshotOptions: Option[HandlerSnapshotOptions] = None)
extends StateStoreHandler(
KeyToNumValuesType, keyToNumValuesStateStoreCkptId, handlerSnapshotOptions) {
SnapshotOptions
private val useVirtualColumnFamilies = stateFormatVersion == 3
private val longValueSchema = new StructType().add("value", "long")
private val longToUnsafeRow = UnsafeProjection.create(longValueSchema)
Expand Down Expand Up @@ -707,8 +724,11 @@ abstract class SymmetricHashJoinStateManager(
* A wrapper around a [[StateStore]] that stores the mapping; the mapping depends on the
* state format version - please refer implementations of [[KeyWithIndexToValueRowConverter]].
*/
protected class KeyWithIndexToValueStore(stateFormatVersion: Int)
extends StateStoreHandler(KeyWithIndexToValueType, keyWithIndexToValueStateStoreCkptId) {
protected class KeyWithIndexToValueStore(
stateFormatVersion: Int,
handlerSnapshotOptions: Option[HandlerSnapshotOptions] = None)
extends StateStoreHandler(
KeyWithIndexToValueType, keyWithIndexToValueStateStoreCkptId, handlerSnapshotOptions) {

private val useVirtualColumnFamilies = stateFormatVersion == 3
private val keyWithIndexExprs = keyAttributes :+ Literal(1L)
Expand Down Expand Up @@ -848,11 +868,11 @@ class SymmetricHashJoinStateManagerV1(
stateFormatVersion: Int,
skippedNullValueCount: Option[SQLMetric] = None,
useStateStoreCoordinator: Boolean = true,
snapshotStartVersion: Option[Long] = None,
snapshotOptions: Option[SnapshotOptions] = None,
joinStoreGenerator: JoinStateManagerStoreGenerator) extends SymmetricHashJoinStateManager(
joinSide, inputValueAttributes, joinKeys, stateInfo, storeConf, hadoopConf,
partitionId, keyToNumValuesStateStoreCkptId, keyWithIndexToValueStateStoreCkptId,
stateFormatVersion, skippedNullValueCount, useStateStoreCoordinator, snapshotStartVersion,
stateFormatVersion, skippedNullValueCount, useStateStoreCoordinator, snapshotOptions,
joinStoreGenerator) {

/** Commit all the changes to all the state stores */
Expand Down Expand Up @@ -927,11 +947,11 @@ class SymmetricHashJoinStateManagerV2(
stateFormatVersion: Int,
skippedNullValueCount: Option[SQLMetric] = None,
useStateStoreCoordinator: Boolean = true,
snapshotStartVersion: Option[Long] = None,
snapshotOptions: Option[SnapshotOptions] = None,
joinStoreGenerator: JoinStateManagerStoreGenerator) extends SymmetricHashJoinStateManager(
joinSide, inputValueAttributes, joinKeys, stateInfo, storeConf, hadoopConf,
partitionId, keyToNumValuesStateStoreCkptId, keyWithIndexToValueStateStoreCkptId,
stateFormatVersion, skippedNullValueCount, useStateStoreCoordinator, snapshotStartVersion,
stateFormatVersion, skippedNullValueCount, useStateStoreCoordinator, snapshotOptions,
joinStoreGenerator) {

/** Commit all the changes to the state store */
Expand Down Expand Up @@ -1034,20 +1054,20 @@ object SymmetricHashJoinStateManager {
stateFormatVersion: Int,
skippedNullValueCount: Option[SQLMetric] = None,
useStateStoreCoordinator: Boolean = true,
snapshotStartVersion: Option[Long] = None,
snapshotOptions: Option[SnapshotOptions] = None,
joinStoreGenerator: JoinStateManagerStoreGenerator): SymmetricHashJoinStateManager = {
if (stateFormatVersion == 3) {
new SymmetricHashJoinStateManagerV2(
joinSide, inputValueAttributes, joinKeys, stateInfo, storeConf, hadoopConf,
partitionId, keyToNumValuesStateStoreCkptId, keyWithIndexToValueStateStoreCkptId,
stateFormatVersion, skippedNullValueCount, useStateStoreCoordinator, snapshotStartVersion,
stateFormatVersion, skippedNullValueCount, useStateStoreCoordinator, snapshotOptions,
joinStoreGenerator
)
} else {
new SymmetricHashJoinStateManagerV1(
joinSide, inputValueAttributes, joinKeys, stateInfo, storeConf, hadoopConf,
partitionId, keyToNumValuesStateStoreCkptId, keyWithIndexToValueStateStoreCkptId,
stateFormatVersion, skippedNullValueCount, useStateStoreCoordinator, snapshotStartVersion,
stateFormatVersion, skippedNullValueCount, useStateStoreCoordinator, snapshotOptions,
joinStoreGenerator
)
}
Expand Down Expand Up @@ -1280,3 +1300,36 @@ object SymmetricHashJoinStateManager {
}
}
}

/**
* Options controlling snapshot-based state replay for state data source reader.
*/
case class SnapshotOptions(
snapshotVersion: Long,
endVersion: Long,
startKeyToNumValuesStateStoreCkptId: Option[String] = None,
startKeyWithIndexToValueStateStoreCkptId: Option[String] = None,
endKeyToNumValuesStateStoreCkptId: Option[String] = None,
endKeyWithIndexToValueStateStoreCkptId: Option[String] = None) {

def getKeyToNumValuesHandlerOpts(): HandlerSnapshotOptions =
HandlerSnapshotOptions(
snapshotVersion = snapshotVersion,
endVersion = endVersion,
startStateStoreCkptId = startKeyToNumValuesStateStoreCkptId,
endStateStoreCkptId = endKeyToNumValuesStateStoreCkptId)

def getKeyWithIndexToValueHandlerOpts(): HandlerSnapshotOptions =
HandlerSnapshotOptions(
snapshotVersion = snapshotVersion,
endVersion = endVersion,
startStateStoreCkptId = startKeyWithIndexToValueStateStoreCkptId,
endStateStoreCkptId = endKeyWithIndexToValueStateStoreCkptId)
}

/** Snapshot options specialized for a single state store handler. */
private[join] case class HandlerSnapshotOptions(
snapshotVersion: Long,
endVersion: Long,
startStateStoreCkptId: Option[String],
endStateStoreCkptId: Option[String])
Original file line number Diff line number Diff line change
Expand Up @@ -972,10 +972,22 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
*
* @param snapshotVersion checkpoint version of the snapshot to start with
* @param endVersion checkpoint version to end with
* @param readOnly whether the state store should be read-only
* @param snapshotVersionStateStoreCkptId state store checkpoint ID of the snapshot version
* @param endVersionStateStoreCkptId state store checkpoint ID of the end version
* @return [[HDFSBackedStateStore]]
*/
override def replayStateFromSnapshot(
snapshotVersion: Long, endVersion: Long, readOnly: Boolean): StateStore = {
snapshotVersion: Long,
endVersion: Long,
readOnly: Boolean,
snapshotVersionStateStoreCkptId: Option[String] = None,
endVersionStateStoreCkptId: Option[String] = None): StateStore = {
if (snapshotVersionStateStoreCkptId.isDefined || endVersionStateStoreCkptId.isDefined) {
throw StateStoreErrors.stateStoreCheckpointIdsNotSupported(
"HDFSBackedStateStoreProvider does not support checkpointFormatVersion > 1 " +
"but a state store checkpointID is passed in")
}
val newMap = replayLoadedMapFromSnapshot(snapshotVersion, endVersion)
logInfo(log"Retrieved snapshot at version " +
log"${MDC(LogKeys.STATE_STORE_VERSION, snapshotVersion)} and apply delta files to version " +
Expand All @@ -990,10 +1002,21 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
*
* @param snapshotVersion checkpoint version of the snapshot to start with
* @param endVersion checkpoint version to end with
* @param snapshotVersionStateStoreCkptId state store checkpoint ID of the snapshot version
* @param endVersionStateStoreCkptId state store checkpoint ID of the end version
* @return [[HDFSBackedReadStateStore]]
*/
override def replayReadStateFromSnapshot(snapshotVersion: Long, endVersion: Long):
override def replayReadStateFromSnapshot(
snapshotVersion: Long,
endVersion: Long,
snapshotVersionStateStoreCkptId: Option[String] = None,
endVersionStateStoreCkptId: Option[String] = None):
ReadStateStore = {
if (snapshotVersionStateStoreCkptId.isDefined || endVersionStateStoreCkptId.isDefined) {
throw StateStoreErrors.stateStoreCheckpointIdsNotSupported(
"HDFSBackedStateStoreProvider does not support checkpointFormatVersion > 1 " +
"but a state store checkpointID is passed in")
}
val newMap = replayLoadedMapFromSnapshot(snapshotVersion, endVersion)
logInfo(log"Retrieved snapshot at version " +
log"${MDC(LogKeys.STATE_STORE_VERSION, snapshotVersion)} and apply delta files to version " +
Expand Down
Loading