From b0ca6c83704cb13915f088ce6b8e4ac6dfc25005 Mon Sep 17 00:00:00 2001 From: bootjp Date: Thu, 2 Jul 2026 01:29:43 +0900 Subject: [PATCH 01/14] raft: prevent snapshot disk exhaustion --- internal/raftengine/engine.go | 5 + internal/raftengine/etcd/engine.go | 40 +++++++ internal/raftengine/etcd/fsm_snapshot_file.go | 107 ++++++++++++++++++ .../raftengine/etcd/fsm_snapshot_file_test.go | 27 +++++ internal/raftengine/etcd/grpc_transport.go | 45 ++++++-- main.go | 36 ++++++ main_raft_lifecycle_test.go | 48 ++++++++ 7 files changed, 300 insertions(+), 8 deletions(-) create mode 100644 main_raft_lifecycle_test.go diff --git a/internal/raftengine/engine.go b/internal/raftengine/engine.go index 332779e1c..9db7d3783 100644 --- a/internal/raftengine/engine.go +++ b/internal/raftengine/engine.go @@ -218,6 +218,11 @@ type HealthReader interface { CheckServing(ctx context.Context) error } +type Lifecycle interface { + Done() <-chan struct{} + Err() error +} + type Admin interface { LeaderView StatusReader diff --git a/internal/raftengine/etcd/engine.go b/internal/raftengine/etcd/engine.go index 8c0aedb45..8b4e1e00a 100644 --- a/internal/raftengine/etcd/engine.go +++ b/internal/raftengine/etcd/engine.go @@ -666,6 +666,7 @@ func (e *Engine) initTransport(cfg OpenConfig) { e.dispatchStopCh = make(chan struct{}) e.transport.SetSpoolDir(cfg.DataDir) e.transport.SetFSMSnapDir(e.fsmSnapDir) + e.transport.SetFSMSnapshotPrepare(e.prepareFSMSnapshotWriteLocked) e.transport.SetFSMPayloadReader(e.readFSMPayloadLocked) e.transport.SetFSMPayloadOpener(e.openFSMPayloadLocked) e.transport.SetHandler(e.handleTransportMessage) @@ -734,6 +735,22 @@ func (e *Engine) Close() error { return nil } +func (e *Engine) Done() <-chan struct{} { + if e == nil { + done := make(chan struct{}) + close(done) + return done + } + return e.doneCh +} + +func (e *Engine) Err() error { + if e == nil { + return nil + } + return e.currentError() +} + func (e *Engine) Propose(ctx context.Context, data []byte) (*raftengine.ProposalResult, error) { return e.propose(ctx, data) } @@ -2749,6 +2766,17 @@ func (e *Engine) openFSMPayloadLocked(index uint64) (io.ReadCloser, error) { return openFSMPayloadFromFD(f) } +func (e *Engine) prepareFSMSnapshotWriteLocked(index uint64) error { + e.snapshotMu.Lock() + defer e.snapshotMu.Unlock() + return e.prepareFSMSnapshotWrite(index) +} + +func (e *Engine) prepareFSMSnapshotWrite(index uint64) error { + snapDir := filepath.Join(e.dataDir, snapDirName) + return prepareFSMSnapshotWrite(snapDir, e.fsmSnapDir, index) +} + // snapshotPayload takes a FSM snapshot for the given index, writes it to the // .fsm file on disk, and returns the 17-byte token for raftpb.Snapshot.Data. // If fsmSnapDir is not set (e.g., engines created directly in unit tests), @@ -2765,6 +2793,12 @@ func (e *Engine) snapshotPayload(index uint64) ([]byte, error) { if err != nil { return nil, errors.WithStack(err) } + if err := e.prepareFSMSnapshotWrite(index); err != nil { + slog.Warn("failed to prepare fsm snapshot write", + "index", index, + "error", err, + ) + } crc32c, writeErr := writeFSMSnapshotFile(snapshot, e.fsmSnapDir, index) closeErr := snapshot.Close() if writeErr != nil { @@ -4187,6 +4221,12 @@ func (e *Engine) persistLocalSnapshot(req snapshotRequest) error { } return e.persistLocalSnapshotPayload(req.index, payload) } + if err := e.prepareFSMSnapshotWriteLocked(req.index); err != nil { + slog.Warn("failed to prepare fsm snapshot write", + "index", req.index, + "error", err, + ) + } crc32c, writeErr := writeFSMSnapshotFile(req.snapshot, e.fsmSnapDir, req.index) closeErr := req.snapshot.Close() if writeErr != nil { diff --git a/internal/raftengine/etcd/fsm_snapshot_file.go b/internal/raftengine/etcd/fsm_snapshot_file.go index 9ed5825a4..5a3a43649 100644 --- a/internal/raftengine/etcd/fsm_snapshot_file.go +++ b/internal/raftengine/etcd/fsm_snapshot_file.go @@ -22,6 +22,7 @@ const ( snapFileExt = ".snap" snapshotTokenSize = 17 // 4 (magic) + 1 (version) + 8 (index) + 4 (crc32c) snapshotTokenVersion = byte(0x01) + prewriteSnapKeep = 1 // fsmFooterSize is the size of the CRC32C footer appended to each .fsm file. fsmFooterSize = 4 @@ -523,6 +524,112 @@ func cleanupStaleFSMSnaps(snapDir, fsmSnapDir string, disableStartupCRCCheck boo return removeStaleFSMFiles(fsmSnapDir, liveIndexes, disableStartupCRCCheck) } +// prepareFSMSnapshotWrite frees space before writing a new large .fsm payload. +// It keeps the newest prior snap/fsm pair so a failed write still leaves a +// restartable snapshot, then removes older pairs and stale pre-next-index FSM +// files. Success-path purgeOldSnapshotFiles runs after raft publishes the new +// token; this prewrite pass prevents ENOSPC before that success path can run. +func prepareFSMSnapshotWrite(snapDir, fsmSnapDir string, nextIndex uint64) error { + if fsmSnapDir == "" || nextIndex == 0 { + return nil + } + if err := os.MkdirAll(fsmSnapDir, defaultDirPerm); err != nil { + return errors.WithStack(err) + } + + var combined error + combined = errors.CombineErrors(combined, removeFSMTmpOrphans(fsmSnapDir)) + if snapDir == "" { + combined = errors.CombineErrors(combined, syncDirIfExists(fsmSnapDir)) + return errors.WithStack(combined) + } + combined = errors.CombineErrors(combined, purgeOlderSnapshotPairsBeforeWrite(snapDir, fsmSnapDir, nextIndex)) + + liveIndexes, err := collectLiveSnapIndexes(snapDir) + if err != nil { + combined = errors.CombineErrors(combined, err) + } else if liveIndexes != nil { + combined = errors.CombineErrors(combined, removeStaleFSMFilesBeforeIndex(fsmSnapDir, liveIndexes, nextIndex)) + } + combined = errors.CombineErrors(combined, syncDirIfExists(snapDir)) + combined = errors.CombineErrors(combined, syncDirIfExists(fsmSnapDir)) + return errors.WithStack(combined) +} + +type snapFileCandidate struct { + name string + index uint64 +} + +func purgeOlderSnapshotPairsBeforeWrite(snapDir, fsmSnapDir string, nextIndex uint64) error { + if snapDir == "" { + return nil + } + entries, err := os.ReadDir(snapDir) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return errors.WithStack(err) + } + + candidates := collectPrewriteSnapCandidates(entries, nextIndex) + if len(candidates) <= prewriteSnapKeep { + return nil + } + + sort.Slice(candidates, func(i, j int) bool { + if candidates[i].index == candidates[j].index { + return candidates[i].name < candidates[j].name + } + return candidates[i].index < candidates[j].index + }) + + var combined error + for _, candidate := range candidates[:len(candidates)-prewriteSnapKeep] { + if err := purgeSnapPair(snapDir, fsmSnapDir, candidate.name); err != nil { + combined = errors.CombineErrors(combined, err) + } + } + return errors.WithStack(combined) +} + +func collectPrewriteSnapCandidates(entries []os.DirEntry, nextIndex uint64) []snapFileCandidate { + candidates := make([]snapFileCandidate, 0, len(entries)) + for _, e := range entries { + if e.IsDir() || filepath.Ext(e.Name()) != snapFileExt { + continue + } + index := parseSnapFileIndex(e.Name()) + if index == 0 || index >= nextIndex { + continue + } + candidates = append(candidates, snapFileCandidate{name: e.Name(), index: index}) + } + return candidates +} + +func removeStaleFSMFilesBeforeIndex(fsmSnapDir string, liveIndexes map[uint64]bool, nextIndex uint64) error { + fsmEntries, err := os.ReadDir(fsmSnapDir) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return errors.WithStack(err) + } + for _, e := range fsmEntries { + if e.IsDir() || filepath.Ext(e.Name()) != ".fsm" { + continue + } + idx, err := strconv.ParseUint(strings.TrimSuffix(e.Name(), ".fsm"), 16, 64) + if err != nil || idx >= nextIndex || liveIndexes[idx] { + continue + } + removeWithWarn(filepath.Join(fsmSnapDir, e.Name()), "orphan fsm snapshot") + } + return nil +} + func removeFSMTmpOrphans(fsmSnapDir string) error { // Use os.ReadDir + strings.HasSuffix instead of filepath.Glob to avoid // misinterpretation of special characters (e.g. '[', ']') in fsmSnapDir diff --git a/internal/raftengine/etcd/fsm_snapshot_file_test.go b/internal/raftengine/etcd/fsm_snapshot_file_test.go index 810c2bbb8..b1c2eb639 100644 --- a/internal/raftengine/etcd/fsm_snapshot_file_test.go +++ b/internal/raftengine/etcd/fsm_snapshot_file_test.go @@ -332,6 +332,33 @@ func TestPurgeOldSnapshotFilesOrdering(t *testing.T) { require.Len(t, fsms, 3) } +func TestPrepareFSMSnapshotWritePrunesOldPairsAndOrphans(t *testing.T) { + snapDir := t.TempDir() + fsmSnapDir := t.TempDir() + payload := []byte("payload") + + for _, index := range []uint64{100, 200, 300} { + createSnapFile(t, snapDir, index) + writeFSMFileForTest(t, fsmSnapDir, index, payload) + } + writeFSMFileForTest(t, fsmSnapDir, 150, payload) + writeFSMFileForTest(t, fsmSnapDir, 500, payload) + require.NoError(t, os.WriteFile(filepath.Join(fsmSnapDir, "leftover.fsm.tmp"), []byte("tmp"), 0o600)) + + require.NoError(t, prepareFSMSnapshotWrite(snapDir, fsmSnapDir, 400)) + + require.NoFileExists(t, filepath.Join(snapDir, "0000000000000001-0000000000000064.snap")) + require.NoFileExists(t, filepath.Join(snapDir, "0000000000000001-00000000000000c8.snap")) + require.FileExists(t, filepath.Join(snapDir, "0000000000000001-000000000000012c.snap")) + + require.NoFileExists(t, fsmSnapPath(fsmSnapDir, 100)) + require.NoFileExists(t, fsmSnapPath(fsmSnapDir, 150)) + require.NoFileExists(t, fsmSnapPath(fsmSnapDir, 200)) + require.FileExists(t, fsmSnapPath(fsmSnapDir, 300)) + require.FileExists(t, fsmSnapPath(fsmSnapDir, 500)) + require.NoFileExists(t, filepath.Join(fsmSnapDir, "leftover.fsm.tmp")) +} + // --- writeFSMSnapshotFile integration --- func TestWriteFSMSnapshotFileRoundTrip(t *testing.T) { diff --git a/internal/raftengine/etcd/grpc_transport.go b/internal/raftengine/etcd/grpc_transport.go index d63a5c558..da9d6f240 100644 --- a/internal/raftengine/etcd/grpc_transport.go +++ b/internal/raftengine/etcd/grpc_transport.go @@ -51,6 +51,7 @@ type GRPCTransport struct { snapshotChunkSize int spoolDir string fsmSnapDir string + prepareFSMWrite func(index uint64) error // readFSMPayload is the fallback bridge callback that materialises the full // FSM payload into memory. Used only when openFSMPayload is not set. readFSMPayload func(index uint64) ([]byte, error) @@ -121,6 +122,15 @@ func (t *GRPCTransport) SetFSMSnapDir(dir string) { t.fsmSnapDir = dir } +func (t *GRPCTransport) SetFSMSnapshotPrepare(fn func(index uint64) error) { + if t == nil { + return + } + t.mu.Lock() + defer t.mu.Unlock() + t.prepareFSMWrite = fn +} + func (t *GRPCTransport) SetFSMPayloadReader(fn func(index uint64) ([]byte, error)) { if t == nil { return @@ -727,25 +737,26 @@ func (t *GRPCTransport) handle(ctx context.Context, msg raftpb.Message) error { return errors.WithStack(handler(ctx, msg)) } -// snapshotSpoolPlacement returns (spoolDir, fsmSnapDir) under the transport -// lock. When fsmSnapDir is wired, the spool itself is placed inside it so +// snapshotSpoolPlacement returns the snapshot receive paths/callback under the +// transport lock. When fsmSnapDir is wired, the spool itself is placed inside it so // FinalizeAsFSMFile's rename stays intra-filesystem and cannot fail with // EXDEV. Standard engine wiring puts both under cfg.DataDir, but the // receive code should not assume that. The legacy fallback path // (fsmSnapDir == "") keeps the spool in spoolDir because it never renames // — Bytes() materializes the payload in place. -func (t *GRPCTransport) snapshotSpoolPlacement() (placement, fsmSnapDir string) { +func (t *GRPCTransport) snapshotSpoolPlacement() (placement, fsmSnapDir string, prepareFn func(uint64) error) { t.mu.RLock() defer t.mu.RUnlock() fsmSnapDir = t.fsmSnapDir + prepareFn = t.prepareFSMWrite if fsmSnapDir != "" { - return fsmSnapDir, fsmSnapDir + return fsmSnapDir, fsmSnapDir, prepareFn } - return t.spoolDir, "" + return t.spoolDir, "", prepareFn } func (t *GRPCTransport) receiveSnapshotStream(stream pb.EtcdRaft_SendSnapshotServer) (raftpb.Message, error) { - spoolPlacement, fsmSnapDir := t.snapshotSpoolPlacement() + spoolPlacement, fsmSnapDir, prepareFn := t.snapshotSpoolPlacement() spool, err := newSnapshotSpool(spoolPlacement) if err != nil { return raftpb.Message{}, err @@ -764,7 +775,7 @@ func (t *GRPCTransport) receiveSnapshotStream(stream pb.EtcdRaft_SendSnapshotSer } }() - msg, payloadBytes, err := drainSnapshotChunks(stream, spool, fsmSnapDir) + msg, payloadBytes, err := drainSnapshotChunks(stream, spool, fsmSnapDir, prepareFn) if err != nil { return raftpb.Message{}, err } @@ -792,6 +803,7 @@ func drainSnapshotChunks( stream pb.EtcdRaft_SendSnapshotServer, spool *snapshotSpool, fsmSnapDir string, + prepareFn func(uint64) error, ) (raftpb.Message, int64, error) { var metadata raftpb.Message seenMetadata := false @@ -819,7 +831,7 @@ func drainSnapshotChunks( } seenMetadata = seen if chunk.Final { - msg, err := finalizeReceivedSnapshot(metadata, spool, crcWriter.Sum32(), fsmSnapDir, seenMetadata) + msg, err := finalizeReceivedSnapshot(metadata, spool, crcWriter.Sum32(), fsmSnapDir, prepareFn, seenMetadata) if err != nil { return raftpb.Message{}, 0, err } @@ -842,6 +854,7 @@ func finalizeReceivedSnapshot( spool *snapshotSpool, crc32c uint32, fsmSnapDir string, + prepareFn func(uint64) error, seenMetadata bool, ) (raftpb.Message, error) { if !seenMetadata || metadata.Snapshot == nil { @@ -849,6 +862,7 @@ func finalizeReceivedSnapshot( } index := metadata.Snapshot.Metadata.Index if fsmSnapDir != "" && index > 0 { + prepareReceivedFSMSnapshotWrite(fsmSnapDir, index, prepareFn) if err := spool.FinalizeAsFSMFile(fsmSnapDir, index, crc32c); err != nil { return raftpb.Message{}, err } @@ -861,6 +875,21 @@ func finalizeReceivedSnapshot( return buildSnapshotMessage(metadata, spool, seenMetadata) } +func prepareReceivedFSMSnapshotWrite(fsmSnapDir string, index uint64, prepareFn func(uint64) error) { + var err error + if prepareFn != nil { + err = prepareFn(index) + } else { + err = prepareFSMSnapshotWrite("", fsmSnapDir, index) + } + if err != nil { + slog.Warn("failed to prepare received fsm snapshot write", + "index", index, + "error", err, + ) + } +} + // snapshotDataFormatLabel exists purely for the structured log line on the // receiver — it lets an operator distinguish a streaming-token receive // (small heap, payload on disk) from a legacy materialization (heap holds diff --git a/main.go b/main.go index 41a405b02..414016186 100644 --- a/main.go +++ b/main.go @@ -441,6 +441,7 @@ func run() error { sqsAdvertisesHTFIFO(), slog.Default()) cleanup.Add(leadershipRefusalDeregister) eg, runCtx := errgroup.WithContext(ctx) + startRaftEngineLifecycleWatchers(runCtx, eg, runtimes) // setupDistributionCatalog + the Stage 7a process-start registration // gate are bundled so run() has a single startup-fault path: a // registry-read / behind-epoch failure fails the process @@ -511,6 +512,41 @@ func run() error { return nil } +func startRaftEngineLifecycleWatchers(ctx context.Context, eg *errgroup.Group, runtimes []*raftGroupRuntime) { + for _, runtime := range runtimes { + + if runtime == nil { + continue + } + engine := runtime.snapshotEngine() + lifecycle, ok := engine.(raftengine.Lifecycle) + if !ok { + continue + } + done := lifecycle.Done() + if done == nil { + continue + } + groupID := runtime.spec.id + eg.Go(func() error { + select { + case <-ctx.Done(): + return nil + case <-done: + select { + case <-ctx.Done(): + return nil + default: + } + if err := lifecycle.Err(); err != nil { + return errors.Wrapf(err, "raft group %d engine stopped", groupID) + } + return errors.Errorf("raft group %d engine stopped", groupID) + } + }) + } +} + func resolveRuntimeInputs() (runtimeConfig, raftEngineType, []raftengine.Server, bool, error) { if *raftId == "" { return runtimeConfig{}, "", nil, false, errors.New("flag --raftId is required") diff --git a/main_raft_lifecycle_test.go b/main_raft_lifecycle_test.go new file mode 100644 index 000000000..a7f982741 --- /dev/null +++ b/main_raft_lifecycle_test.go @@ -0,0 +1,48 @@ +package main + +import ( + "context" + "errors" + "testing" + + "github.com/bootjp/elastickv/internal/raftengine" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" +) + +type lifecycleEngineStub struct { + raftengine.Engine + done chan struct{} + err error +} + +func (e *lifecycleEngineStub) Done() <-chan struct{} { return e.done } +func (e *lifecycleEngineStub) Err() error { return e.err } + +func TestStartRaftEngineLifecycleWatchersReportsEngineFailure(t *testing.T) { + t.Parallel() + engineErr := errors.New("snapshot write failed") + engine := &lifecycleEngineStub{done: make(chan struct{}), err: engineErr} + runtimes := []*raftGroupRuntime{{spec: groupSpec{id: 7}, engine: engine}} + + eg, ctx := errgroup.WithContext(context.Background()) + startRaftEngineLifecycleWatchers(ctx, eg, runtimes) + close(engine.done) + + err := eg.Wait() + require.ErrorIs(t, err, engineErr) + require.Contains(t, err.Error(), "raft group 7 engine stopped") +} + +func TestStartRaftEngineLifecycleWatchersIgnoresContextCancellation(t *testing.T) { + t.Parallel() + engine := &lifecycleEngineStub{done: make(chan struct{})} + runtimes := []*raftGroupRuntime{{spec: groupSpec{id: 8}, engine: engine}} + ctx, cancel := context.WithCancel(context.Background()) + eg, runCtx := errgroup.WithContext(ctx) + startRaftEngineLifecycleWatchers(runCtx, eg, runtimes) + + cancel() + + require.NoError(t, eg.Wait()) +} From 9c8aaaa6156feb139a2269619eec376ab9ccc169 Mon Sep 17 00:00:00 2001 From: bootjp Date: Thu, 2 Jul 2026 22:48:37 +0900 Subject: [PATCH 02/14] raft: harden snapshot prewrite cleanup --- internal/raftengine/etcd/fsm_snapshot_file.go | 81 ++++++++++--------- .../raftengine/etcd/fsm_snapshot_file_test.go | 17 ++-- internal/raftengine/etcd/grpc_transport.go | 42 ++++++++-- .../raftengine/etcd/grpc_transport_test.go | 48 +++++++++++ main.go | 9 +-- 5 files changed, 137 insertions(+), 60 deletions(-) diff --git a/internal/raftengine/etcd/fsm_snapshot_file.go b/internal/raftengine/etcd/fsm_snapshot_file.go index 5a3a43649..d959f78f8 100644 --- a/internal/raftengine/etcd/fsm_snapshot_file.go +++ b/internal/raftengine/etcd/fsm_snapshot_file.go @@ -20,6 +20,8 @@ import ( const ( fsmSnapDirName = "fsm-snap" snapFileExt = ".snap" + fsmFileExt = ".fsm" + fsmTmpFileSuffix = ".fsm.tmp" snapshotTokenSize = 17 // 4 (magic) + 1 (version) + 8 (index) + 4 (crc32c) snapshotTokenVersion = byte(0x01) prewriteSnapKeep = 1 @@ -130,7 +132,7 @@ func decodeSnapshotToken(data []byte) (snapshotToken, error) { // fsmSnapPath returns the canonical zero-padded hex path for a .fsm snapshot file. func fsmSnapPath(fsmSnapDir string, index uint64) string { - return filepath.Join(fsmSnapDir, fmt.Sprintf("%016x.fsm", index)) + return filepath.Join(fsmSnapDir, fmt.Sprintf("%016x%s", index, fsmFileExt)) } // parseSnapFileIndex extracts the applied index from an etcd snapshotter filename. @@ -526,9 +528,9 @@ func cleanupStaleFSMSnaps(snapDir, fsmSnapDir string, disableStartupCRCCheck boo // prepareFSMSnapshotWrite frees space before writing a new large .fsm payload. // It keeps the newest prior snap/fsm pair so a failed write still leaves a -// restartable snapshot, then removes older pairs and stale pre-next-index FSM -// files. Success-path purgeOldSnapshotFiles runs after raft publishes the new -// token; this prewrite pass prevents ENOSPC before that success path can run. +// restartable snapshot, then removes older pairs. Success-path +// purgeOldSnapshotFiles runs after raft publishes the new token; this prewrite +// pass prevents ENOSPC before that success path can run. func prepareFSMSnapshotWrite(snapDir, fsmSnapDir string, nextIndex uint64) error { if fsmSnapDir == "" || nextIndex == 0 { return nil @@ -538,27 +540,20 @@ func prepareFSMSnapshotWrite(snapDir, fsmSnapDir string, nextIndex uint64) error } var combined error - combined = errors.CombineErrors(combined, removeFSMTmpOrphans(fsmSnapDir)) if snapDir == "" { combined = errors.CombineErrors(combined, syncDirIfExists(fsmSnapDir)) return errors.WithStack(combined) } combined = errors.CombineErrors(combined, purgeOlderSnapshotPairsBeforeWrite(snapDir, fsmSnapDir, nextIndex)) - - liveIndexes, err := collectLiveSnapIndexes(snapDir) - if err != nil { - combined = errors.CombineErrors(combined, err) - } else if liveIndexes != nil { - combined = errors.CombineErrors(combined, removeStaleFSMFilesBeforeIndex(fsmSnapDir, liveIndexes, nextIndex)) - } combined = errors.CombineErrors(combined, syncDirIfExists(snapDir)) combined = errors.CombineErrors(combined, syncDirIfExists(fsmSnapDir)) return errors.WithStack(combined) } type snapFileCandidate struct { - name string - index uint64 + name string + index uint64 + hasFSM bool } func purgeOlderSnapshotPairsBeforeWrite(snapDir, fsmSnapDir string, nextIndex uint64) error { @@ -573,7 +568,7 @@ func purgeOlderSnapshotPairsBeforeWrite(snapDir, fsmSnapDir string, nextIndex ui return errors.WithStack(err) } - candidates := collectPrewriteSnapCandidates(entries, nextIndex) + candidates := collectPrewriteSnapCandidates(entries, fsmSnapDir, nextIndex) if len(candidates) <= prewriteSnapKeep { return nil } @@ -585,8 +580,12 @@ func purgeOlderSnapshotPairsBeforeWrite(snapDir, fsmSnapDir string, nextIndex ui return candidates[i].index < candidates[j].index }) + keep := keepRestorablePrewriteSnapshots(candidates) var combined error - for _, candidate := range candidates[:len(candidates)-prewriteSnapKeep] { + for _, candidate := range candidates { + if keep[candidate.name] { + continue + } if err := purgeSnapPair(snapDir, fsmSnapDir, candidate.name); err != nil { combined = errors.CombineErrors(combined, err) } @@ -594,7 +593,20 @@ func purgeOlderSnapshotPairsBeforeWrite(snapDir, fsmSnapDir string, nextIndex ui return errors.WithStack(combined) } -func collectPrewriteSnapCandidates(entries []os.DirEntry, nextIndex uint64) []snapFileCandidate { +func keepRestorablePrewriteSnapshots(candidates []snapFileCandidate) map[string]bool { + keep := make(map[string]bool, prewriteSnapKeep) + for i := len(candidates) - 1; i >= 0 && len(keep) < prewriteSnapKeep; i-- { + if candidates[i].hasFSM { + keep[candidates[i].name] = true + } + } + if len(keep) == 0 && len(candidates) > 0 { + keep[candidates[len(candidates)-1].name] = true + } + return keep +} + +func collectPrewriteSnapCandidates(entries []os.DirEntry, fsmSnapDir string, nextIndex uint64) []snapFileCandidate { candidates := make([]snapFileCandidate, 0, len(entries)) for _, e := range entries { if e.IsDir() || filepath.Ext(e.Name()) != snapFileExt { @@ -604,30 +616,21 @@ func collectPrewriteSnapCandidates(entries []os.DirEntry, nextIndex uint64) []sn if index == 0 || index >= nextIndex { continue } - candidates = append(candidates, snapFileCandidate{name: e.Name(), index: index}) + candidates = append(candidates, snapFileCandidate{ + name: e.Name(), + index: index, + hasFSM: fsmSnapshotFileExists(fsmSnapDir, index), + }) } return candidates } -func removeStaleFSMFilesBeforeIndex(fsmSnapDir string, liveIndexes map[uint64]bool, nextIndex uint64) error { - fsmEntries, err := os.ReadDir(fsmSnapDir) - if err != nil { - if os.IsNotExist(err) { - return nil - } - return errors.WithStack(err) - } - for _, e := range fsmEntries { - if e.IsDir() || filepath.Ext(e.Name()) != ".fsm" { - continue - } - idx, err := strconv.ParseUint(strings.TrimSuffix(e.Name(), ".fsm"), 16, 64) - if err != nil || idx >= nextIndex || liveIndexes[idx] { - continue - } - removeWithWarn(filepath.Join(fsmSnapDir, e.Name()), "orphan fsm snapshot") +func fsmSnapshotFileExists(fsmSnapDir string, index uint64) bool { + if fsmSnapDir == "" { + return false } - return nil + info, err := os.Stat(fsmSnapPath(fsmSnapDir, index)) + return err == nil && !info.IsDir() } func removeFSMTmpOrphans(fsmSnapDir string) error { @@ -643,7 +646,7 @@ func removeFSMTmpOrphans(fsmSnapDir string) error { } var combined error for _, e := range entries { - if !e.IsDir() && strings.HasSuffix(e.Name(), ".fsm.tmp") { + if !e.IsDir() && strings.HasSuffix(e.Name(), fsmTmpFileSuffix) { if removeErr := os.Remove(filepath.Join(fsmSnapDir, e.Name())); removeErr != nil && !os.IsNotExist(removeErr) { combined = errors.CombineErrors(combined, errors.WithStack(removeErr)) } @@ -680,7 +683,7 @@ func removeStaleFSMFiles(fsmSnapDir string, liveIndexes map[uint64]bool, disable return errors.WithStack(err) } for _, e := range fsmEntries { - if e.IsDir() || filepath.Ext(e.Name()) != ".fsm" { + if e.IsDir() || filepath.Ext(e.Name()) != fsmFileExt { continue } removeStaleFSMFile(fsmSnapDir, e.Name(), liveIndexes, disableStartupCRCCheck) @@ -689,7 +692,7 @@ func removeStaleFSMFiles(fsmSnapDir string, liveIndexes map[uint64]bool, disable } func removeStaleFSMFile(fsmSnapDir, name string, liveIndexes map[uint64]bool, disableStartupCRCCheck bool) { - idx, err := strconv.ParseUint(strings.TrimSuffix(name, ".fsm"), 16, 64) + idx, err := strconv.ParseUint(strings.TrimSuffix(name, fsmFileExt), 16, 64) if err != nil { return } diff --git a/internal/raftengine/etcd/fsm_snapshot_file_test.go b/internal/raftengine/etcd/fsm_snapshot_file_test.go index b1c2eb639..b5e162513 100644 --- a/internal/raftengine/etcd/fsm_snapshot_file_test.go +++ b/internal/raftengine/etcd/fsm_snapshot_file_test.go @@ -332,15 +332,16 @@ func TestPurgeOldSnapshotFilesOrdering(t *testing.T) { require.Len(t, fsms, 3) } -func TestPrepareFSMSnapshotWritePrunesOldPairsAndOrphans(t *testing.T) { +func TestPrepareFSMSnapshotWriteKeepsNewestRestorablePair(t *testing.T) { snapDir := t.TempDir() fsmSnapDir := t.TempDir() payload := []byte("payload") - for _, index := range []uint64{100, 200, 300} { + for _, index := range []uint64{100, 200} { createSnapFile(t, snapDir, index) writeFSMFileForTest(t, fsmSnapDir, index, payload) } + createSnapFile(t, snapDir, 300) writeFSMFileForTest(t, fsmSnapDir, 150, payload) writeFSMFileForTest(t, fsmSnapDir, 500, payload) require.NoError(t, os.WriteFile(filepath.Join(fsmSnapDir, "leftover.fsm.tmp"), []byte("tmp"), 0o600)) @@ -348,15 +349,15 @@ func TestPrepareFSMSnapshotWritePrunesOldPairsAndOrphans(t *testing.T) { require.NoError(t, prepareFSMSnapshotWrite(snapDir, fsmSnapDir, 400)) require.NoFileExists(t, filepath.Join(snapDir, "0000000000000001-0000000000000064.snap")) - require.NoFileExists(t, filepath.Join(snapDir, "0000000000000001-00000000000000c8.snap")) - require.FileExists(t, filepath.Join(snapDir, "0000000000000001-000000000000012c.snap")) + require.FileExists(t, filepath.Join(snapDir, "0000000000000001-00000000000000c8.snap")) + require.NoFileExists(t, filepath.Join(snapDir, "0000000000000001-000000000000012c.snap")) require.NoFileExists(t, fsmSnapPath(fsmSnapDir, 100)) - require.NoFileExists(t, fsmSnapPath(fsmSnapDir, 150)) - require.NoFileExists(t, fsmSnapPath(fsmSnapDir, 200)) - require.FileExists(t, fsmSnapPath(fsmSnapDir, 300)) + require.FileExists(t, fsmSnapPath(fsmSnapDir, 150)) + require.FileExists(t, fsmSnapPath(fsmSnapDir, 200)) + require.NoFileExists(t, fsmSnapPath(fsmSnapDir, 300)) require.FileExists(t, fsmSnapPath(fsmSnapDir, 500)) - require.NoFileExists(t, filepath.Join(fsmSnapDir, "leftover.fsm.tmp")) + require.FileExists(t, filepath.Join(fsmSnapDir, "leftover.fsm.tmp")) } // --- writeFSMSnapshotFile integration --- diff --git a/internal/raftengine/etcd/grpc_transport.go b/internal/raftengine/etcd/grpc_transport.go index da9d6f240..a0d214f4c 100644 --- a/internal/raftengine/etcd/grpc_transport.go +++ b/internal/raftengine/etcd/grpc_transport.go @@ -816,6 +816,7 @@ func drainSnapshotChunks( crcWriter := newCRC32CWriter(spool) var payloadBytes int64 + preparedFSMWrite := false for { chunk, err := stream.Recv() if err != nil { @@ -824,14 +825,20 @@ func drainSnapshotChunks( } return raftpb.Message{}, 0, errors.WithStack(err) } - payloadBytes += int64(len(chunk.Chunk)) - seen, err := appendSnapshotChunk(&metadata, crcWriter, chunk, seenMetadata) + seen, err := appendSnapshotChunkMetadata(&metadata, chunk, seenMetadata) if err != nil { return raftpb.Message{}, 0, err } seenMetadata = seen + if !preparedFSMWrite { + preparedFSMWrite = maybePrepareReceivedFSMSnapshotWrite(metadata, fsmSnapDir, prepareFn, seenMetadata) + } + if err := writeSnapshotChunkPayload(crcWriter, chunk); err != nil { + return raftpb.Message{}, 0, err + } + payloadBytes += int64(len(chunk.Chunk)) if chunk.Final { - msg, err := finalizeReceivedSnapshot(metadata, spool, crcWriter.Sum32(), fsmSnapDir, prepareFn, seenMetadata) + msg, err := finalizeReceivedSnapshot(metadata, spool, crcWriter.Sum32(), fsmSnapDir, seenMetadata) if err != nil { return raftpb.Message{}, 0, err } @@ -854,7 +861,6 @@ func finalizeReceivedSnapshot( spool *snapshotSpool, crc32c uint32, fsmSnapDir string, - prepareFn func(uint64) error, seenMetadata bool, ) (raftpb.Message, error) { if !seenMetadata || metadata.Snapshot == nil { @@ -862,7 +868,6 @@ func finalizeReceivedSnapshot( } index := metadata.Snapshot.Metadata.Index if fsmSnapDir != "" && index > 0 { - prepareReceivedFSMSnapshotWrite(fsmSnapDir, index, prepareFn) if err := spool.FinalizeAsFSMFile(fsmSnapDir, index, crc32c); err != nil { return raftpb.Message{}, err } @@ -875,6 +880,23 @@ func finalizeReceivedSnapshot( return buildSnapshotMessage(metadata, spool, seenMetadata) } +func maybePrepareReceivedFSMSnapshotWrite( + metadata raftpb.Message, + fsmSnapDir string, + prepareFn func(uint64) error, + seenMetadata bool, +) bool { + if fsmSnapDir == "" || !seenMetadata || metadata.Snapshot == nil { + return false + } + index := metadata.Snapshot.Metadata.Index + if index == 0 { + return false + } + prepareReceivedFSMSnapshotWrite(fsmSnapDir, index, prepareFn) + return true +} + func prepareReceivedFSMSnapshotWrite(fsmSnapDir string, index uint64, prepareFn func(uint64) error) { var err error if prepareFn != nil { @@ -904,7 +926,7 @@ func snapshotDataFormatLabel(snap *raftpb.Snapshot) string { return "inline" } -func appendSnapshotChunk(metadata *raftpb.Message, payload io.Writer, chunk *pb.EtcdRaftSnapshotChunk, seenMetadata bool) (bool, error) { +func appendSnapshotChunkMetadata(metadata *raftpb.Message, chunk *pb.EtcdRaftSnapshotChunk, seenMetadata bool) (bool, error) { if len(chunk.Metadata) > 0 { if seenMetadata { return false, errors.WithStack(errSnapshotMetadataDuplicate) @@ -914,12 +936,16 @@ func appendSnapshotChunk(metadata *raftpb.Message, payload io.Writer, chunk *pb. } seenMetadata = true } + return seenMetadata, nil +} + +func writeSnapshotChunkPayload(payload io.Writer, chunk *pb.EtcdRaftSnapshotChunk) error { if len(chunk.Chunk) > 0 { if _, err := payload.Write(chunk.Chunk); err != nil { - return false, errors.WithStack(err) + return errors.WithStack(err) } } - return seenMetadata, nil + return nil } func buildSnapshotMessage(metadata raftpb.Message, spool *snapshotSpool, seenMetadata bool) (raftpb.Message, error) { diff --git a/internal/raftengine/etcd/grpc_transport_test.go b/internal/raftengine/etcd/grpc_transport_test.go index cc262417e..bd988d424 100644 --- a/internal/raftengine/etcd/grpc_transport_test.go +++ b/internal/raftengine/etcd/grpc_transport_test.go @@ -222,6 +222,54 @@ func TestReceiveSnapshotStream_StreamingTokenWhenFSMSnapDirSet(t *testing.T) { require.Equal(t, senderFSM.Applied(), receiverFSM.Applied()) } +func TestDrainSnapshotChunksPreparesBeforePayloadWrite(t *testing.T) { + const index = uint64(124) + payload := []byte("payload written after prepare") + metadata := raftpb.Message{ + Type: raftpb.MsgSnap, + From: 1, + To: 2, + Snapshot: &raftpb.Snapshot{ + Metadata: raftpb.SnapshotMetadata{Index: index, Term: 1}, + }, + } + raw, err := metadata.Marshal() + require.NoError(t, err) + + fsmSnapDir := t.TempDir() + spool, err := newSnapshotSpool(fsmSnapDir) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, spool.Close()) + }) + + prepareCalls := 0 + prepareFn := func(got uint64) error { + prepareCalls++ + require.Equal(t, index, got) + info, statErr := os.Stat(spool.path) + require.NoError(t, statErr) + require.Zero(t, info.Size(), "prepare must run before the first payload byte is spooled") + return nil + } + stream := &testSendSnapshotServer{ + chunks: []*pb.EtcdRaftSnapshotChunk{{ + Metadata: raw, + Chunk: payload, + Final: true, + }}, + } + + msg, payloadBytes, err := drainSnapshotChunks(stream, spool, fsmSnapDir, prepareFn) + require.NoError(t, err) + require.Equal(t, int64(len(payload)), payloadBytes) + require.Equal(t, 1, prepareCalls) + require.True(t, isSnapshotToken(msg.Snapshot.Data)) + got, err := readFSMSnapshotPayload(fsmSnapPath(fsmSnapDir, index)) + require.NoError(t, err) + require.Equal(t, payload, got) +} + // TestReceiveSnapshotStream_SpoolPlacedInFSMSnapDir pins the EXDEV-avoidance // fix from PR #747 round-3 (Codex P1): when fsmSnapDir is wired, the spool // file MUST be created inside fsmSnapDir (not spoolDir), so that the diff --git a/main.go b/main.go index 414016186..bb77cb01d 100644 --- a/main.go +++ b/main.go @@ -513,12 +513,11 @@ func run() error { } func startRaftEngineLifecycleWatchers(ctx context.Context, eg *errgroup.Group, runtimes []*raftGroupRuntime) { - for _, runtime := range runtimes { - - if runtime == nil { + for _, rt := range runtimes { + if rt == nil { continue } - engine := runtime.snapshotEngine() + engine := rt.snapshotEngine() lifecycle, ok := engine.(raftengine.Lifecycle) if !ok { continue @@ -527,7 +526,7 @@ func startRaftEngineLifecycleWatchers(ctx context.Context, eg *errgroup.Group, r if done == nil { continue } - groupID := runtime.spec.id + groupID := rt.spec.id eg.Go(func() error { select { case <-ctx.Done(): From 0d0b5d4eb93b4d630e4da9cfa73a1b6297cc2759 Mon Sep 17 00:00:00 2001 From: bootjp Date: Thu, 2 Jul 2026 23:07:14 +0900 Subject: [PATCH 03/14] raft: tighten snapshot prewrite safety --- internal/raftengine/etcd/fsm_snapshot_file.go | 126 +++++++++++++----- .../raftengine/etcd/fsm_snapshot_file_test.go | 6 +- internal/raftengine/etcd/grpc_transport.go | 3 + .../raftengine/etcd/grpc_transport_test.go | 29 ++++ 4 files changed, 131 insertions(+), 33 deletions(-) diff --git a/internal/raftengine/etcd/fsm_snapshot_file.go b/internal/raftengine/etcd/fsm_snapshot_file.go index d959f78f8..52685e4c3 100644 --- a/internal/raftengine/etcd/fsm_snapshot_file.go +++ b/internal/raftengine/etcd/fsm_snapshot_file.go @@ -527,10 +527,13 @@ func cleanupStaleFSMSnaps(snapDir, fsmSnapDir string, disableStartupCRCCheck boo } // prepareFSMSnapshotWrite frees space before writing a new large .fsm payload. -// It keeps the newest prior snap/fsm pair so a failed write still leaves a -// restartable snapshot, then removes older pairs. Success-path -// purgeOldSnapshotFiles runs after raft publishes the new token; this prewrite -// pass prevents ENOSPC before that success path can run. +// It keeps the newest prior verified snap/fsm pair so a failed write still +// leaves a restartable snapshot, then removes older pairs and older unpaired +// .fsm files. Runtime prewrite cleanup intentionally leaves *.fsm.tmp alone: +// local snapshot writers use that suffix for active temp files, and startup +// cleanup owns crash leftovers. Success-path purgeOldSnapshotFiles runs after +// raft publishes the new token; this prewrite pass prevents ENOSPC before that +// success path can run. func prepareFSMSnapshotWrite(snapDir, fsmSnapDir string, nextIndex uint64) error { if fsmSnapDir == "" || nextIndex == 0 { return nil @@ -551,9 +554,14 @@ func prepareFSMSnapshotWrite(snapDir, fsmSnapDir string, nextIndex uint64) error } type snapFileCandidate struct { - name string - index uint64 - hasFSM bool + name string + index uint64 + restorable bool +} + +type prewriteSnapshotRetention struct { + keep map[string]bool + restorableFloor uint64 } func purgeOlderSnapshotPairsBeforeWrite(snapDir, fsmSnapDir string, nextIndex uint64) error { @@ -569,10 +577,6 @@ func purgeOlderSnapshotPairsBeforeWrite(snapDir, fsmSnapDir string, nextIndex ui } candidates := collectPrewriteSnapCandidates(entries, fsmSnapDir, nextIndex) - if len(candidates) <= prewriteSnapKeep { - return nil - } - sort.Slice(candidates, func(i, j int) bool { if candidates[i].index == candidates[j].index { return candidates[i].name < candidates[j].name @@ -580,30 +584,65 @@ func purgeOlderSnapshotPairsBeforeWrite(snapDir, fsmSnapDir string, nextIndex ui return candidates[i].index < candidates[j].index }) - keep := keepRestorablePrewriteSnapshots(candidates) + retention := keepRestorablePrewriteSnapshots(candidates) var combined error - for _, candidate := range candidates { - if keep[candidate.name] { - continue - } - if err := purgeSnapPair(snapDir, fsmSnapDir, candidate.name); err != nil { - combined = errors.CombineErrors(combined, err) + combined = errors.CombineErrors(combined, purgeUnretainedPrewriteSnapshots(snapDir, fsmSnapDir, candidates, retention)) + combined = errors.CombineErrors(combined, removePrewriteFSMOrphansBelowRetention(snapDir, fsmSnapDir, retention)) + return errors.WithStack(combined) +} + +func purgeUnretainedPrewriteSnapshots( + snapDir string, + fsmSnapDir string, + candidates []snapFileCandidate, + retention prewriteSnapshotRetention, +) error { + var combined error + if len(candidates) > prewriteSnapKeep { + for _, candidate := range candidates { + if retention.keep[candidate.name] { + continue + } + if err := purgeSnapPair(snapDir, fsmSnapDir, candidate.name); err != nil { + combined = errors.CombineErrors(combined, err) + } } } return errors.WithStack(combined) } -func keepRestorablePrewriteSnapshots(candidates []snapFileCandidate) map[string]bool { - keep := make(map[string]bool, prewriteSnapKeep) - for i := len(candidates) - 1; i >= 0 && len(keep) < prewriteSnapKeep; i-- { - if candidates[i].hasFSM { - keep[candidates[i].name] = true +func removePrewriteFSMOrphansBelowRetention( + snapDir string, + fsmSnapDir string, + retention prewriteSnapshotRetention, +) error { + if retention.restorableFloor > 0 { + liveIndexes, err := collectLiveSnapIndexes(snapDir) + if err != nil { + return errors.WithStack(err) + } else if liveIndexes != nil { + return removeStaleFSMFilesBelowIndex(fsmSnapDir, liveIndexes, retention.restorableFloor) + } + } + return nil +} + +func keepRestorablePrewriteSnapshots(candidates []snapFileCandidate) prewriteSnapshotRetention { + retention := prewriteSnapshotRetention{ + keep: make(map[string]bool, prewriteSnapKeep), + } + for i := len(candidates) - 1; i >= 0 && len(retention.keep) < prewriteSnapKeep; i-- { + if candidates[i].restorable { + retention.keep[candidates[i].name] = true + if retention.restorableFloor == 0 || candidates[i].index < retention.restorableFloor { + retention.restorableFloor = candidates[i].index + } } } - if len(keep) == 0 && len(candidates) > 0 { - keep[candidates[len(candidates)-1].name] = true + if len(retention.keep) == 0 && len(candidates) > 0 { + retention.keep[candidates[len(candidates)-1].name] = true } - return keep + return retention } func collectPrewriteSnapCandidates(entries []os.DirEntry, fsmSnapDir string, nextIndex uint64) []snapFileCandidate { @@ -617,20 +656,43 @@ func collectPrewriteSnapCandidates(entries []os.DirEntry, fsmSnapDir string, nex continue } candidates = append(candidates, snapFileCandidate{ - name: e.Name(), - index: index, - hasFSM: fsmSnapshotFileExists(fsmSnapDir, index), + name: e.Name(), + index: index, + restorable: fsmSnapshotFileRestorable(fsmSnapDir, index), }) } return candidates } -func fsmSnapshotFileExists(fsmSnapDir string, index uint64) bool { +func fsmSnapshotFileRestorable(fsmSnapDir string, index uint64) bool { if fsmSnapDir == "" { return false } - info, err := os.Stat(fsmSnapPath(fsmSnapDir, index)) - return err == nil && !info.IsDir() + return verifyFSMSnapshotFile(fsmSnapPath(fsmSnapDir, index), 0) == nil +} + +func removeStaleFSMFilesBelowIndex(fsmSnapDir string, liveIndexes map[uint64]bool, maxIndex uint64) error { + if maxIndex == 0 { + return nil + } + fsmEntries, err := os.ReadDir(fsmSnapDir) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return errors.WithStack(err) + } + for _, e := range fsmEntries { + if e.IsDir() || filepath.Ext(e.Name()) != fsmFileExt { + continue + } + idx, err := strconv.ParseUint(strings.TrimSuffix(e.Name(), fsmFileExt), 16, 64) + if err != nil || idx >= maxIndex || liveIndexes[idx] { + continue + } + removeWithWarn(filepath.Join(fsmSnapDir, e.Name()), "old orphan fsm snapshot") + } + return nil } func removeFSMTmpOrphans(fsmSnapDir string) error { diff --git a/internal/raftengine/etcd/fsm_snapshot_file_test.go b/internal/raftengine/etcd/fsm_snapshot_file_test.go index b5e162513..7079da054 100644 --- a/internal/raftengine/etcd/fsm_snapshot_file_test.go +++ b/internal/raftengine/etcd/fsm_snapshot_file_test.go @@ -342,6 +342,8 @@ func TestPrepareFSMSnapshotWriteKeepsNewestRestorablePair(t *testing.T) { writeFSMFileForTest(t, fsmSnapDir, index, payload) } createSnapFile(t, snapDir, 300) + require.NoError(t, os.WriteFile(fsmSnapPath(fsmSnapDir, 300), []byte{0x01, 0x02}, 0o600)) + createSnapFile(t, snapDir, 350) writeFSMFileForTest(t, fsmSnapDir, 150, payload) writeFSMFileForTest(t, fsmSnapDir, 500, payload) require.NoError(t, os.WriteFile(filepath.Join(fsmSnapDir, "leftover.fsm.tmp"), []byte("tmp"), 0o600)) @@ -351,11 +353,13 @@ func TestPrepareFSMSnapshotWriteKeepsNewestRestorablePair(t *testing.T) { require.NoFileExists(t, filepath.Join(snapDir, "0000000000000001-0000000000000064.snap")) require.FileExists(t, filepath.Join(snapDir, "0000000000000001-00000000000000c8.snap")) require.NoFileExists(t, filepath.Join(snapDir, "0000000000000001-000000000000012c.snap")) + require.NoFileExists(t, filepath.Join(snapDir, "0000000000000001-000000000000015e.snap")) require.NoFileExists(t, fsmSnapPath(fsmSnapDir, 100)) - require.FileExists(t, fsmSnapPath(fsmSnapDir, 150)) + require.NoFileExists(t, fsmSnapPath(fsmSnapDir, 150)) require.FileExists(t, fsmSnapPath(fsmSnapDir, 200)) require.NoFileExists(t, fsmSnapPath(fsmSnapDir, 300)) + require.NoFileExists(t, fsmSnapPath(fsmSnapDir, 350)) require.FileExists(t, fsmSnapPath(fsmSnapDir, 500)) require.FileExists(t, filepath.Join(fsmSnapDir, "leftover.fsm.tmp")) } diff --git a/internal/raftengine/etcd/grpc_transport.go b/internal/raftengine/etcd/grpc_transport.go index a0d214f4c..8cd133951 100644 --- a/internal/raftengine/etcd/grpc_transport.go +++ b/internal/raftengine/etcd/grpc_transport.go @@ -936,6 +936,9 @@ func appendSnapshotChunkMetadata(metadata *raftpb.Message, chunk *pb.EtcdRaftSna } seenMetadata = true } + if !seenMetadata && len(chunk.Chunk) > 0 { + return false, errors.WithStack(errSnapshotMetadataNil) + } return seenMetadata, nil } diff --git a/internal/raftengine/etcd/grpc_transport_test.go b/internal/raftengine/etcd/grpc_transport_test.go index bd988d424..dd3374eab 100644 --- a/internal/raftengine/etcd/grpc_transport_test.go +++ b/internal/raftengine/etcd/grpc_transport_test.go @@ -270,6 +270,35 @@ func TestDrainSnapshotChunksPreparesBeforePayloadWrite(t *testing.T) { require.Equal(t, payload, got) } +func TestDrainSnapshotChunksRejectsPayloadBeforeMetadata(t *testing.T) { + fsmSnapDir := t.TempDir() + spool, err := newSnapshotSpool(fsmSnapDir) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, spool.Close()) + }) + + prepareCalls := 0 + prepareFn := func(uint64) error { + prepareCalls++ + return nil + } + stream := &testSendSnapshotServer{ + chunks: []*pb.EtcdRaftSnapshotChunk{{ + Chunk: []byte("payload before metadata"), + Final: true, + }}, + } + + _, payloadBytes, err := drainSnapshotChunks(stream, spool, fsmSnapDir, prepareFn) + require.ErrorIs(t, err, errSnapshotMetadataNil) + require.Zero(t, payloadBytes) + require.Zero(t, prepareCalls) + info, statErr := os.Stat(spool.path) + require.NoError(t, statErr) + require.Zero(t, info.Size()) +} + // TestReceiveSnapshotStream_SpoolPlacedInFSMSnapDir pins the EXDEV-avoidance // fix from PR #747 round-3 (Codex P1): when fsmSnapDir is wired, the spool // file MUST be created inside fsmSnapDir (not spoolDir), so that the From 2c1c5533229a56815765ca675c8f8cccd1880a8d Mon Sep 17 00:00:00 2001 From: bootjp Date: Thu, 2 Jul 2026 23:08:42 +0900 Subject: [PATCH 04/14] raft: reject malformed snapshot chunks early --- internal/raftengine/etcd/grpc_transport.go | 24 ++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/internal/raftengine/etcd/grpc_transport.go b/internal/raftengine/etcd/grpc_transport.go index 8cd133951..a7b83fdc0 100644 --- a/internal/raftengine/etcd/grpc_transport.go +++ b/internal/raftengine/etcd/grpc_transport.go @@ -818,18 +818,18 @@ func drainSnapshotChunks( var payloadBytes int64 preparedFSMWrite := false for { - chunk, err := stream.Recv() + chunk, err := recvSnapshotChunk(stream) if err != nil { - if errors.Is(err, io.EOF) { - return raftpb.Message{}, 0, errors.WithStack(errSnapshotStreamShort) - } - return raftpb.Message{}, 0, errors.WithStack(err) + return raftpb.Message{}, 0, err } seen, err := appendSnapshotChunkMetadata(&metadata, chunk, seenMetadata) if err != nil { return raftpb.Message{}, 0, err } seenMetadata = seen + if !seenMetadata && len(chunk.Chunk) > 0 { + return raftpb.Message{}, 0, errors.WithStack(errSnapshotMetadataNil) + } if !preparedFSMWrite { preparedFSMWrite = maybePrepareReceivedFSMSnapshotWrite(metadata, fsmSnapDir, prepareFn, seenMetadata) } @@ -847,6 +847,17 @@ func drainSnapshotChunks( } } +func recvSnapshotChunk(stream pb.EtcdRaft_SendSnapshotServer) (*pb.EtcdRaftSnapshotChunk, error) { + chunk, err := stream.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + return nil, errors.WithStack(errSnapshotStreamShort) + } + return nil, errors.WithStack(err) + } + return chunk, nil +} + // finalizeReceivedSnapshot picks between the streaming-token path (when an // fsmSnapDir is wired and the snapshot's metadata index is non-zero) and the // legacy in-memory path. The streaming path renames the spool file in place @@ -936,9 +947,6 @@ func appendSnapshotChunkMetadata(metadata *raftpb.Message, chunk *pb.EtcdRaftSna } seenMetadata = true } - if !seenMetadata && len(chunk.Chunk) > 0 { - return false, errors.WithStack(errSnapshotMetadataNil) - } return seenMetadata, nil } From bbf7185875965a16973b108087592605665693eb Mon Sep 17 00:00:00 2001 From: bootjp Date: Thu, 2 Jul 2026 23:39:14 +0900 Subject: [PATCH 05/14] raft: protect snapshot prewrite cleanup --- internal/raftengine/etcd/engine.go | 59 ++++++++++++--- .../etcd/engine_applied_index_test.go | 75 ++++++++++++++++++- internal/raftengine/etcd/fsm_snapshot_file.go | 9 ++- .../raftengine/etcd/fsm_snapshot_file_test.go | 2 + 4 files changed, 130 insertions(+), 15 deletions(-) diff --git a/internal/raftengine/etcd/engine.go b/internal/raftengine/etcd/engine.go index 8b4e1e00a..54bdaf5d3 100644 --- a/internal/raftengine/etcd/engine.go +++ b/internal/raftengine/etcd/engine.go @@ -1817,6 +1817,9 @@ func (e *Engine) persistReady(rd etcdraft.Ready) error { if !readyNeedsPersistence(rd) { return nil } + if !etcdraft.IsEmptySnap(rd.Snapshot) { + return e.persistReadyWithSnapshotLocked(rd) + } if err := e.applyReady(rd); err != nil { return err } @@ -1826,6 +1829,19 @@ func (e *Engine) persistReady(rd etcdraft.Ready) error { return persistReadyToWAL(e.persist, rd) } +func (e *Engine) persistReadyWithSnapshotLocked(rd etcdraft.Ready) error { + e.snapshotMu.Lock() + defer e.snapshotMu.Unlock() + + if err := e.applyReadyLocked(rd); err != nil { + return err + } + if e.persist == nil { + return nil + } + return persistReadyToWAL(e.persist, rd) +} + func (e *Engine) applyReady(rd etcdraft.Ready) error { if err := e.applyReadySnapshot(rd.Snapshot); err != nil { return err @@ -1836,6 +1852,16 @@ func (e *Engine) applyReady(rd etcdraft.Ready) error { return e.applyReadyHardState(rd.HardState) } +func (e *Engine) applyReadyLocked(rd etcdraft.Ready) error { + if err := e.applyReadySnapshotLocked(rd.Snapshot); err != nil { + return err + } + if err := e.applyReadyEntries(rd.Entries); err != nil { + return err + } + return e.applyReadyHardState(rd.HardState) +} + func (e *Engine) handleStep(msg raftpb.Message) { if e.rawNode == nil { return @@ -2031,6 +2057,15 @@ func (e *Engine) selectDispatchLane(pd *peerQueues, msgType raftpb.MessageType) } func (e *Engine) applyReadySnapshot(snapshot raftpb.Snapshot) error { + if etcdraft.IsEmptySnap(snapshot) { + return nil + } + e.snapshotMu.Lock() + defer e.snapshotMu.Unlock() + return e.applyReadySnapshotLocked(snapshot) +} + +func (e *Engine) applyReadySnapshotLocked(snapshot raftpb.Snapshot) error { if etcdraft.IsEmptySnap(snapshot) { return nil } @@ -2043,9 +2078,6 @@ func (e *Engine) applyReadySnapshot(snapshot raftpb.Snapshot) error { // Snapshot application is intentionally synchronous with the raft loop: the // local FSM must reflect the incoming raft snapshot before Ready can advance // and later committed entries can be applied safely. - e.snapshotMu.Lock() - defer e.snapshotMu.Unlock() - if isSnapshotToken(snapshot.Data) { tok, err := decodeSnapshotToken(snapshot.Data) if err != nil { @@ -2678,7 +2710,7 @@ func (e *Engine) persistConfigSnapshot(index uint64, confState raftpb.ConfState) e.snapshotMu.Lock() defer e.snapshotMu.Unlock() - payload, err := e.snapshotPayload(index) + payload, err := e.snapshotPayloadLocked(index) if err != nil { return err } @@ -2714,7 +2746,7 @@ func (e *Engine) persistConfigState(index uint64, confState raftpb.ConfState, pe return nil } - payload, err := e.snapshotPayload(index) + payload, err := e.snapshotPayloadLocked(index) if err != nil { return err } @@ -2777,11 +2809,13 @@ func (e *Engine) prepareFSMSnapshotWrite(index uint64) error { return prepareFSMSnapshotWrite(snapDir, e.fsmSnapDir, index) } -// snapshotPayload takes a FSM snapshot for the given index, writes it to the +// snapshotPayloadLocked takes a FSM snapshot for the given index, writes it to the // .fsm file on disk, and returns the 17-byte token for raftpb.Snapshot.Data. +// Caller must hold snapshotMu when fsmSnapDir is set: prewrite cleanup may +// delete orphaned .fsm files below index. // If fsmSnapDir is not set (e.g., engines created directly in unit tests), // falls back to the legacy in-memory []byte path. -func (e *Engine) snapshotPayload(index uint64) ([]byte, error) { +func (e *Engine) snapshotPayloadLocked(index uint64) ([]byte, error) { if e.fsmSnapDir == "" { snapshot, err := e.fsm.Snapshot() if err != nil { @@ -4221,7 +4255,10 @@ func (e *Engine) persistLocalSnapshot(req snapshotRequest) error { } return e.persistLocalSnapshotPayload(req.index, payload) } - if err := e.prepareFSMSnapshotWriteLocked(req.index); err != nil { + e.snapshotMu.Lock() + defer e.snapshotMu.Unlock() + + if err := e.prepareFSMSnapshotWrite(req.index); err != nil { slog.Warn("failed to prepare fsm snapshot write", "index", req.index, "error", err, @@ -4236,7 +4273,7 @@ func (e *Engine) persistLocalSnapshot(req snapshotRequest) error { return errors.WithStack(closeErr) } token := encodeSnapshotToken(req.index, crc32c) - return e.persistLocalSnapshotPayload(req.index, token) + return e.persistLocalSnapshotPayloadLocked(req.index, token) } func (e *Engine) persistLocalSnapshotPayload(index uint64, payload []byte) error { @@ -4246,6 +4283,10 @@ func (e *Engine) persistLocalSnapshotPayload(index uint64, payload []byte) error e.snapshotMu.Lock() defer e.snapshotMu.Unlock() + return e.persistLocalSnapshotPayloadLocked(index, payload) +} + +func (e *Engine) persistLocalSnapshotPayloadLocked(index uint64, payload []byte) error { current, err := e.storage.Snapshot() if err != nil { return errors.WithStack(err) diff --git a/internal/raftengine/etcd/engine_applied_index_test.go b/internal/raftengine/etcd/engine_applied_index_test.go index 437e10f77..16bf45731 100644 --- a/internal/raftengine/etcd/engine_applied_index_test.go +++ b/internal/raftengine/etcd/engine_applied_index_test.go @@ -4,6 +4,7 @@ import ( "io" "sync" "testing" + "time" "github.com/bootjp/elastickv/internal/raftengine" "github.com/coreos/go-semver/semver" @@ -67,11 +68,22 @@ func (f *recordingAppliedIndexFSM) SetDurableAppliedIndex(idx uint64) error { // that records SaveSnap calls into the shared recorder. The hook // only calls SaveSnap + Release; the rest are stubs. type recordingPersistStorage struct { - rec *applyIndexOrderRecorder + rec *applyIndexOrderRecorder + saveStarted chan struct{} + saveStartedOnce sync.Once + saveRelease <-chan struct{} } func (p *recordingPersistStorage) SaveSnap(snap raftpb.Snapshot) error { - p.rec.record("save", snap.Metadata.Index) + if p.saveStarted != nil { + p.saveStartedOnce.Do(func() { close(p.saveStarted) }) + } + if p.saveRelease != nil { + <-p.saveRelease + } + if p.rec != nil { + p.rec.record("save", snap.Metadata.Index) + } return nil } @@ -81,6 +93,65 @@ func (p *recordingPersistStorage) Sync() error func (p *recordingPersistStorage) Close() error { return nil } func (p *recordingPersistStorage) MinimalEtcdVersion() *semver.Version { return nil } +func TestPersistReadyWithSnapshotHoldsSnapshotMuThroughSaveSnap(t *testing.T) { + saveStarted := make(chan struct{}) + releaseSave := make(chan struct{}) + e := &Engine{ + storage: etcdraft.NewMemoryStorage(), + fsm: &recordingAppliedIndexFSM{}, + persist: &recordingPersistStorage{saveStarted: saveStarted, saveRelease: releaseSave}, + dataDir: t.TempDir(), + fsmSnapDir: t.TempDir(), + } + rd := etcdraft.Ready{ + Snapshot: raftpb.Snapshot{ + Data: []byte("payload"), + Metadata: raftpb.SnapshotMetadata{ + ConfState: raftpb.ConfState{Voters: []uint64{1}}, + Index: 7, + Term: 1, + }, + }, + } + + persistDone := make(chan error, 1) + go func() { + persistDone <- e.persistReady(rd) + }() + + select { + case <-saveStarted: + case <-time.After(time.Second): + t.Fatal("SaveSnap did not start") + } + + prepareDone := make(chan error, 1) + go func() { + prepareDone <- e.prepareFSMSnapshotWriteLocked(8) + }() + + select { + case err := <-prepareDone: + t.Fatalf("snapshot prepare finished before SaveSnap released snapshotMu: %v", err) + case <-time.After(100 * time.Millisecond): + } + + close(releaseSave) + + select { + case err := <-persistDone: + require.NoError(t, err) + case <-time.After(time.Second): + t.Fatal("persistReady did not finish after SaveSnap was released") + } + select { + case err := <-prepareDone: + require.NoError(t, err) + case <-time.After(time.Second): + t.Fatal("snapshot prepare did not finish after SaveSnap was released") + } +} + // TestRecordingFSM_SatisfiesAppliedIndexWriter is a compile-time- // adjacent assertion: the recording FSM MUST satisfy the writer // seam so the engine hook actually fires for it. diff --git a/internal/raftengine/etcd/fsm_snapshot_file.go b/internal/raftengine/etcd/fsm_snapshot_file.go index 52685e4c3..42b8c79af 100644 --- a/internal/raftengine/etcd/fsm_snapshot_file.go +++ b/internal/raftengine/etcd/fsm_snapshot_file.go @@ -587,7 +587,7 @@ func purgeOlderSnapshotPairsBeforeWrite(snapDir, fsmSnapDir string, nextIndex ui retention := keepRestorablePrewriteSnapshots(candidates) var combined error combined = errors.CombineErrors(combined, purgeUnretainedPrewriteSnapshots(snapDir, fsmSnapDir, candidates, retention)) - combined = errors.CombineErrors(combined, removePrewriteFSMOrphansBelowRetention(snapDir, fsmSnapDir, retention)) + combined = errors.CombineErrors(combined, removePrewriteFSMOrphansBeforeIndex(snapDir, fsmSnapDir, retention, nextIndex)) return errors.WithStack(combined) } @@ -611,17 +611,18 @@ func purgeUnretainedPrewriteSnapshots( return errors.WithStack(combined) } -func removePrewriteFSMOrphansBelowRetention( +func removePrewriteFSMOrphansBeforeIndex( snapDir string, fsmSnapDir string, retention prewriteSnapshotRetention, + nextIndex uint64, ) error { if retention.restorableFloor > 0 { liveIndexes, err := collectLiveSnapIndexes(snapDir) if err != nil { return errors.WithStack(err) } else if liveIndexes != nil { - return removeStaleFSMFilesBelowIndex(fsmSnapDir, liveIndexes, retention.restorableFloor) + return removeStaleFSMFilesBelowIndex(fsmSnapDir, liveIndexes, nextIndex) } } return nil @@ -849,7 +850,7 @@ func syncDirIfExists(dir string) error { if dir == "" { return nil } - if err := syncDir(dir); err != nil && !os.IsNotExist(err) { + if err := syncDir(dir); err != nil && !os.IsNotExist(errors.UnwrapAll(err)) { return err } return nil diff --git a/internal/raftengine/etcd/fsm_snapshot_file_test.go b/internal/raftengine/etcd/fsm_snapshot_file_test.go index 7079da054..38d1da99c 100644 --- a/internal/raftengine/etcd/fsm_snapshot_file_test.go +++ b/internal/raftengine/etcd/fsm_snapshot_file_test.go @@ -345,6 +345,7 @@ func TestPrepareFSMSnapshotWriteKeepsNewestRestorablePair(t *testing.T) { require.NoError(t, os.WriteFile(fsmSnapPath(fsmSnapDir, 300), []byte{0x01, 0x02}, 0o600)) createSnapFile(t, snapDir, 350) writeFSMFileForTest(t, fsmSnapDir, 150, payload) + writeFSMFileForTest(t, fsmSnapDir, 250, payload) writeFSMFileForTest(t, fsmSnapDir, 500, payload) require.NoError(t, os.WriteFile(filepath.Join(fsmSnapDir, "leftover.fsm.tmp"), []byte("tmp"), 0o600)) @@ -358,6 +359,7 @@ func TestPrepareFSMSnapshotWriteKeepsNewestRestorablePair(t *testing.T) { require.NoFileExists(t, fsmSnapPath(fsmSnapDir, 100)) require.NoFileExists(t, fsmSnapPath(fsmSnapDir, 150)) require.FileExists(t, fsmSnapPath(fsmSnapDir, 200)) + require.NoFileExists(t, fsmSnapPath(fsmSnapDir, 250)) require.NoFileExists(t, fsmSnapPath(fsmSnapDir, 300)) require.NoFileExists(t, fsmSnapPath(fsmSnapDir, 350)) require.FileExists(t, fsmSnapPath(fsmSnapDir, 500)) From af915e1ca0c600e37ba8da95a3790da7d3bcfe85 Mon Sep 17 00:00:00 2001 From: bootjp Date: Thu, 2 Jul 2026 23:57:34 +0900 Subject: [PATCH 06/14] raft: preserve received snapshot files --- internal/raftengine/etcd/engine.go | 75 ++++++++++++++++++- .../etcd/engine_applied_index_test.go | 2 + internal/raftengine/etcd/fsm_snapshot_file.go | 55 +++++++++++--- .../raftengine/etcd/fsm_snapshot_file_test.go | 21 ++++++ internal/raftengine/etcd/grpc_transport.go | 58 ++++++++++---- .../raftengine/etcd/grpc_transport_test.go | 68 ++++++++++++++++- 6 files changed, 245 insertions(+), 34 deletions(-) diff --git a/internal/raftengine/etcd/engine.go b/internal/raftengine/etcd/engine.go index 54bdaf5d3..f4e42ce6b 100644 --- a/internal/raftengine/etcd/engine.go +++ b/internal/raftengine/etcd/engine.go @@ -310,7 +310,8 @@ type Engine struct { // Restore swaps the underlying store state and must not race with the short // critical section that publishes a newly persisted local snapshot. - snapshotMu sync.Mutex + snapshotMu sync.Mutex + protectedReceivedFSMSnaps map[uint64]int dispatchDropCount atomic.Uint64 dispatchErrorCount atomic.Uint64 @@ -667,6 +668,7 @@ func (e *Engine) initTransport(cfg OpenConfig) { e.transport.SetSpoolDir(cfg.DataDir) e.transport.SetFSMSnapDir(e.fsmSnapDir) e.transport.SetFSMSnapshotPrepare(e.prepareFSMSnapshotWriteLocked) + e.transport.SetFSMSnapshotProtection(e.protectReceivedFSMSnapshot, e.unprotectReceivedFSMSnapshot) e.transport.SetFSMPayloadReader(e.readFSMPayloadLocked) e.transport.SetFSMPayloadOpener(e.openFSMPayloadLocked) e.transport.SetHandler(e.handleTransportMessage) @@ -1837,9 +1839,14 @@ func (e *Engine) persistReadyWithSnapshotLocked(rd etcdraft.Ready) error { return err } if e.persist == nil { + e.releaseProtectedReceivedFSMSnapshotsUpToLocked(rd.Snapshot.Metadata.Index) return nil } - return persistReadyToWAL(e.persist, rd) + if err := persistReadyToWAL(e.persist, rd); err != nil { + return err + } + e.releaseProtectedReceivedFSMSnapshotsUpToLocked(rd.Snapshot.Metadata.Index) + return nil } func (e *Engine) applyReady(rd etcdraft.Ready) error { @@ -2771,6 +2778,7 @@ func (e *Engine) persistConfigSnapshotPayloadLocked(index uint64, confState raft if err := e.persistCreatedSnapshot(snap); err != nil { return err } + e.releaseProtectedReceivedFSMSnapshotsUpToLocked(index) return nil } @@ -2806,7 +2814,62 @@ func (e *Engine) prepareFSMSnapshotWriteLocked(index uint64) error { func (e *Engine) prepareFSMSnapshotWrite(index uint64) error { snapDir := filepath.Join(e.dataDir, snapDirName) - return prepareFSMSnapshotWrite(snapDir, e.fsmSnapDir, index) + return prepareFSMSnapshotWriteProtected(snapDir, e.fsmSnapDir, index, e.protectedReceivedFSMSnapshotIndexesLocked()) +} + +func (e *Engine) protectReceivedFSMSnapshot(index uint64) { + if index == 0 || index <= e.appliedIndex.Load() { + return + } + e.snapshotMu.Lock() + defer e.snapshotMu.Unlock() + if e.protectedReceivedFSMSnaps == nil { + e.protectedReceivedFSMSnaps = make(map[uint64]int, 1) + } + e.protectedReceivedFSMSnaps[index]++ +} + +func (e *Engine) unprotectReceivedFSMSnapshot(index uint64) { + if index == 0 { + return + } + e.snapshotMu.Lock() + defer e.snapshotMu.Unlock() + e.unprotectReceivedFSMSnapshotLocked(index) +} + +func (e *Engine) unprotectReceivedFSMSnapshotLocked(index uint64) { + if e.protectedReceivedFSMSnaps == nil { + return + } + count := e.protectedReceivedFSMSnaps[index] + if count <= 1 { + delete(e.protectedReceivedFSMSnaps, index) + return + } + e.protectedReceivedFSMSnaps[index] = count - 1 +} + +func (e *Engine) releaseProtectedReceivedFSMSnapshotsUpToLocked(index uint64) { + if e.protectedReceivedFSMSnaps == nil { + return + } + for protectedIndex := range e.protectedReceivedFSMSnaps { + if protectedIndex <= index { + delete(e.protectedReceivedFSMSnaps, protectedIndex) + } + } +} + +func (e *Engine) protectedReceivedFSMSnapshotIndexesLocked() map[uint64]bool { + if len(e.protectedReceivedFSMSnaps) == 0 { + return nil + } + indexes := make(map[uint64]bool, len(e.protectedReceivedFSMSnaps)) + for index := range e.protectedReceivedFSMSnaps { + indexes[index] = true + } + return indexes } // snapshotPayloadLocked takes a FSM snapshot for the given index, writes it to the @@ -4300,7 +4363,11 @@ func (e *Engine) persistLocalSnapshotPayloadLocked(index uint64, payload []byte) } _, err = persistLocalSnapshotPayload(e.storage, e.persist, index, payload) - return e.handleLocalSnapshotPersistResult(err) + if err := e.handleLocalSnapshotPersistResult(err); err != nil { + return err + } + e.releaseProtectedReceivedFSMSnapshotsUpToLocked(index) + return nil } // handleLocalSnapshotPersistResult collapses the post-SaveSnap error diff --git a/internal/raftengine/etcd/engine_applied_index_test.go b/internal/raftengine/etcd/engine_applied_index_test.go index 16bf45731..a25ea6b52 100644 --- a/internal/raftengine/etcd/engine_applied_index_test.go +++ b/internal/raftengine/etcd/engine_applied_index_test.go @@ -103,6 +103,7 @@ func TestPersistReadyWithSnapshotHoldsSnapshotMuThroughSaveSnap(t *testing.T) { dataDir: t.TempDir(), fsmSnapDir: t.TempDir(), } + e.protectReceivedFSMSnapshot(7) rd := etcdraft.Ready{ Snapshot: raftpb.Snapshot{ Data: []byte("payload"), @@ -150,6 +151,7 @@ func TestPersistReadyWithSnapshotHoldsSnapshotMuThroughSaveSnap(t *testing.T) { case <-time.After(time.Second): t.Fatal("snapshot prepare did not finish after SaveSnap was released") } + require.Empty(t, e.protectedReceivedFSMSnaps) } // TestRecordingFSM_SatisfiesAppliedIndexWriter is a compile-time- diff --git a/internal/raftengine/etcd/fsm_snapshot_file.go b/internal/raftengine/etcd/fsm_snapshot_file.go index 42b8c79af..069315dad 100644 --- a/internal/raftengine/etcd/fsm_snapshot_file.go +++ b/internal/raftengine/etcd/fsm_snapshot_file.go @@ -535,6 +535,15 @@ func cleanupStaleFSMSnaps(snapDir, fsmSnapDir string, disableStartupCRCCheck boo // raft publishes the new token; this prewrite pass prevents ENOSPC before that // success path can run. func prepareFSMSnapshotWrite(snapDir, fsmSnapDir string, nextIndex uint64) error { + return prepareFSMSnapshotWriteProtected(snapDir, fsmSnapDir, nextIndex, nil) +} + +func prepareFSMSnapshotWriteProtected( + snapDir string, + fsmSnapDir string, + nextIndex uint64, + protectedIndexes map[uint64]bool, +) error { if fsmSnapDir == "" || nextIndex == 0 { return nil } @@ -547,7 +556,7 @@ func prepareFSMSnapshotWrite(snapDir, fsmSnapDir string, nextIndex uint64) error combined = errors.CombineErrors(combined, syncDirIfExists(fsmSnapDir)) return errors.WithStack(combined) } - combined = errors.CombineErrors(combined, purgeOlderSnapshotPairsBeforeWrite(snapDir, fsmSnapDir, nextIndex)) + combined = errors.CombineErrors(combined, purgeOlderSnapshotPairsBeforeWrite(snapDir, fsmSnapDir, nextIndex, protectedIndexes)) combined = errors.CombineErrors(combined, syncDirIfExists(snapDir)) combined = errors.CombineErrors(combined, syncDirIfExists(fsmSnapDir)) return errors.WithStack(combined) @@ -564,7 +573,12 @@ type prewriteSnapshotRetention struct { restorableFloor uint64 } -func purgeOlderSnapshotPairsBeforeWrite(snapDir, fsmSnapDir string, nextIndex uint64) error { +func purgeOlderSnapshotPairsBeforeWrite( + snapDir string, + fsmSnapDir string, + nextIndex uint64, + protectedIndexes map[uint64]bool, +) error { if snapDir == "" { return nil } @@ -587,7 +601,13 @@ func purgeOlderSnapshotPairsBeforeWrite(snapDir, fsmSnapDir string, nextIndex ui retention := keepRestorablePrewriteSnapshots(candidates) var combined error combined = errors.CombineErrors(combined, purgeUnretainedPrewriteSnapshots(snapDir, fsmSnapDir, candidates, retention)) - combined = errors.CombineErrors(combined, removePrewriteFSMOrphansBeforeIndex(snapDir, fsmSnapDir, retention, nextIndex)) + combined = errors.CombineErrors(combined, removePrewriteFSMOrphansBeforeIndex( + snapDir, + fsmSnapDir, + retention, + protectedIndexes, + nextIndex, + )) return errors.WithStack(combined) } @@ -615,6 +635,7 @@ func removePrewriteFSMOrphansBeforeIndex( snapDir string, fsmSnapDir string, retention prewriteSnapshotRetention, + protectedIndexes map[uint64]bool, nextIndex uint64, ) error { if retention.restorableFloor > 0 { @@ -622,7 +643,7 @@ func removePrewriteFSMOrphansBeforeIndex( if err != nil { return errors.WithStack(err) } else if liveIndexes != nil { - return removeStaleFSMFilesBelowIndex(fsmSnapDir, liveIndexes, nextIndex) + return removeStaleFSMFilesBelowIndex(fsmSnapDir, liveIndexes, protectedIndexes, nextIndex) } } return nil @@ -672,7 +693,12 @@ func fsmSnapshotFileRestorable(fsmSnapDir string, index uint64) bool { return verifyFSMSnapshotFile(fsmSnapPath(fsmSnapDir, index), 0) == nil } -func removeStaleFSMFilesBelowIndex(fsmSnapDir string, liveIndexes map[uint64]bool, maxIndex uint64) error { +func removeStaleFSMFilesBelowIndex( + fsmSnapDir string, + liveIndexes map[uint64]bool, + protectedIndexes map[uint64]bool, + maxIndex uint64, +) error { if maxIndex == 0 { return nil } @@ -684,11 +710,7 @@ func removeStaleFSMFilesBelowIndex(fsmSnapDir string, liveIndexes map[uint64]boo return errors.WithStack(err) } for _, e := range fsmEntries { - if e.IsDir() || filepath.Ext(e.Name()) != fsmFileExt { - continue - } - idx, err := strconv.ParseUint(strings.TrimSuffix(e.Name(), fsmFileExt), 16, 64) - if err != nil || idx >= maxIndex || liveIndexes[idx] { + if !shouldRemoveStaleFSMBelowIndex(e, liveIndexes, protectedIndexes, maxIndex) { continue } removeWithWarn(filepath.Join(fsmSnapDir, e.Name()), "old orphan fsm snapshot") @@ -696,6 +718,19 @@ func removeStaleFSMFilesBelowIndex(fsmSnapDir string, liveIndexes map[uint64]boo return nil } +func shouldRemoveStaleFSMBelowIndex( + entry os.DirEntry, + liveIndexes map[uint64]bool, + protectedIndexes map[uint64]bool, + maxIndex uint64, +) bool { + if entry.IsDir() || filepath.Ext(entry.Name()) != fsmFileExt { + return false + } + idx, err := strconv.ParseUint(strings.TrimSuffix(entry.Name(), fsmFileExt), 16, 64) + return err == nil && idx < maxIndex && !liveIndexes[idx] && !protectedIndexes[idx] +} + func removeFSMTmpOrphans(fsmSnapDir string) error { // Use os.ReadDir + strings.HasSuffix instead of filepath.Glob to avoid // misinterpretation of special characters (e.g. '[', ']') in fsmSnapDir diff --git a/internal/raftengine/etcd/fsm_snapshot_file_test.go b/internal/raftengine/etcd/fsm_snapshot_file_test.go index 38d1da99c..30b9f34d5 100644 --- a/internal/raftengine/etcd/fsm_snapshot_file_test.go +++ b/internal/raftengine/etcd/fsm_snapshot_file_test.go @@ -366,6 +366,27 @@ func TestPrepareFSMSnapshotWriteKeepsNewestRestorablePair(t *testing.T) { require.FileExists(t, filepath.Join(fsmSnapDir, "leftover.fsm.tmp")) } +func TestPrepareFSMSnapshotWritePreservesProtectedReceivedFSM(t *testing.T) { + snapDir := t.TempDir() + fsmSnapDir := t.TempDir() + payload := []byte("payload") + + createSnapFile(t, snapDir, 200) + writeFSMFileForTest(t, fsmSnapDir, 200, payload) + writeFSMFileForTest(t, fsmSnapDir, 250, payload) + writeFSMFileForTest(t, fsmSnapDir, 300, payload) + + protected := map[uint64]bool{300: true} + require.NoError(t, prepareFSMSnapshotWriteProtected(snapDir, fsmSnapDir, 400, protected)) + + require.FileExists(t, fsmSnapPath(fsmSnapDir, 200)) + require.NoFileExists(t, fsmSnapPath(fsmSnapDir, 250)) + require.FileExists(t, fsmSnapPath(fsmSnapDir, 300)) + + require.NoError(t, prepareFSMSnapshotWriteProtected(snapDir, fsmSnapDir, 400, nil)) + require.NoFileExists(t, fsmSnapPath(fsmSnapDir, 300)) +} + // --- writeFSMSnapshotFile integration --- func TestWriteFSMSnapshotFileRoundTrip(t *testing.T) { diff --git a/internal/raftengine/etcd/grpc_transport.go b/internal/raftengine/etcd/grpc_transport.go index a7b83fdc0..0b69da9b3 100644 --- a/internal/raftengine/etcd/grpc_transport.go +++ b/internal/raftengine/etcd/grpc_transport.go @@ -52,6 +52,8 @@ type GRPCTransport struct { spoolDir string fsmSnapDir string prepareFSMWrite func(index uint64) error + protectFSMWrite func(index uint64) + unprotectFSMWrite func(index uint64) // readFSMPayload is the fallback bridge callback that materialises the full // FSM payload into memory. Used only when openFSMPayload is not set. readFSMPayload func(index uint64) ([]byte, error) @@ -131,6 +133,16 @@ func (t *GRPCTransport) SetFSMSnapshotPrepare(fn func(index uint64) error) { t.prepareFSMWrite = fn } +func (t *GRPCTransport) SetFSMSnapshotProtection(protectFn, unprotectFn func(index uint64)) { + if t == nil { + return + } + t.mu.Lock() + defer t.mu.Unlock() + t.protectFSMWrite = protectFn + t.unprotectFSMWrite = unprotectFn +} + func (t *GRPCTransport) SetFSMPayloadReader(fn func(index uint64) ([]byte, error)) { if t == nil { return @@ -390,12 +402,12 @@ func (t *GRPCTransport) SendSnapshot(stream pb.EtcdRaft_SendSnapshotServer) erro } if err := t.handle(stream.Context(), msg); err != nil { // If receive finalized the snapshot as a .fsm file (token in - // Snapshot.Data), the engine refused to apply it — likely a - // transient context cancel or raft error. Remove the on-disk - // file so retries at later indexes don't leak orphan .fsm - // payloads into fsmSnapDir until the next startup runs - // cleanupStaleFSMSnaps. Same-index retries are already safe - // because os.Rename atomically replaces the prior file. + // Snapshot.Data), the engine refused to accept it into raft — + // likely a transient context cancel or closed engine. Remove the + // on-disk file so retries at later indexes don't leak orphan .fsm + // payloads into fsmSnapDir until the next startup runs cleanup. + // Same-index retries are already safe because os.Rename atomically + // replaces the prior file. t.removeOrphanedFSMSnapshot(msg) return err } @@ -404,10 +416,9 @@ func (t *GRPCTransport) SendSnapshot(stream pb.EtcdRaft_SendSnapshotServer) erro // removeOrphanedFSMSnapshot deletes the .fsm file that // receiveSnapshotStream finalized for `msg`, if any. Used by -// SendSnapshot when the engine apply (`t.handle`) fails after the -// receive succeeded — the engine has NOT applied the snapshot (apply is -// synchronous to t.handle, so a non-nil return means applied_index was -// not advanced), so the file is unreferenced and safe to remove. +// SendSnapshot when the engine handler (`t.handle`) fails after the +// receive succeeded. A non-nil return means the snapshot was not accepted into +// raft, so the file is unreferenced and safe to remove. // // Best-effort: a cleanup failure here is logged but not returned because // the original apply error is the actionable signal; orphans get swept @@ -423,7 +434,11 @@ func (t *GRPCTransport) removeOrphanedFSMSnapshot(msg raftpb.Message) { } t.mu.RLock() fsmSnapDir := t.fsmSnapDir + unprotectFn := t.unprotectFSMWrite t.mu.RUnlock() + if unprotectFn != nil { + defer unprotectFn(tok.Index) + } if fsmSnapDir == "" { return } @@ -744,19 +759,25 @@ func (t *GRPCTransport) handle(ctx context.Context, msg raftpb.Message) error { // receive code should not assume that. The legacy fallback path // (fsmSnapDir == "") keeps the spool in spoolDir because it never renames // — Bytes() materializes the payload in place. -func (t *GRPCTransport) snapshotSpoolPlacement() (placement, fsmSnapDir string, prepareFn func(uint64) error) { +func (t *GRPCTransport) snapshotSpoolPlacement() ( + placement string, + fsmSnapDir string, + prepareFn func(uint64) error, + protectFn func(uint64), +) { t.mu.RLock() defer t.mu.RUnlock() fsmSnapDir = t.fsmSnapDir prepareFn = t.prepareFSMWrite + protectFn = t.protectFSMWrite if fsmSnapDir != "" { - return fsmSnapDir, fsmSnapDir, prepareFn + return fsmSnapDir, fsmSnapDir, prepareFn, protectFn } - return t.spoolDir, "", prepareFn + return t.spoolDir, "", prepareFn, protectFn } func (t *GRPCTransport) receiveSnapshotStream(stream pb.EtcdRaft_SendSnapshotServer) (raftpb.Message, error) { - spoolPlacement, fsmSnapDir, prepareFn := t.snapshotSpoolPlacement() + spoolPlacement, fsmSnapDir, prepareFn, protectFn := t.snapshotSpoolPlacement() spool, err := newSnapshotSpool(spoolPlacement) if err != nil { return raftpb.Message{}, err @@ -775,7 +796,7 @@ func (t *GRPCTransport) receiveSnapshotStream(stream pb.EtcdRaft_SendSnapshotSer } }() - msg, payloadBytes, err := drainSnapshotChunks(stream, spool, fsmSnapDir, prepareFn) + msg, payloadBytes, err := drainSnapshotChunks(stream, spool, fsmSnapDir, prepareFn, protectFn) if err != nil { return raftpb.Message{}, err } @@ -804,6 +825,7 @@ func drainSnapshotChunks( spool *snapshotSpool, fsmSnapDir string, prepareFn func(uint64) error, + protectFn func(uint64), ) (raftpb.Message, int64, error) { var metadata raftpb.Message seenMetadata := false @@ -838,7 +860,7 @@ func drainSnapshotChunks( } payloadBytes += int64(len(chunk.Chunk)) if chunk.Final { - msg, err := finalizeReceivedSnapshot(metadata, spool, crcWriter.Sum32(), fsmSnapDir, seenMetadata) + msg, err := finalizeReceivedSnapshot(metadata, spool, crcWriter.Sum32(), fsmSnapDir, protectFn, seenMetadata) if err != nil { return raftpb.Message{}, 0, err } @@ -872,6 +894,7 @@ func finalizeReceivedSnapshot( spool *snapshotSpool, crc32c uint32, fsmSnapDir string, + protectFn func(uint64), seenMetadata bool, ) (raftpb.Message, error) { if !seenMetadata || metadata.Snapshot == nil { @@ -883,6 +906,9 @@ func finalizeReceivedSnapshot( return raftpb.Message{}, err } metadata.Snapshot.Data = encodeSnapshotToken(index, crc32c) + if protectFn != nil { + protectFn(index) + } return metadata, nil } // Legacy fallback: full materialization. Used by tests that don't wire an diff --git a/internal/raftengine/etcd/grpc_transport_test.go b/internal/raftengine/etcd/grpc_transport_test.go index dd3374eab..66b2887e0 100644 --- a/internal/raftengine/etcd/grpc_transport_test.go +++ b/internal/raftengine/etcd/grpc_transport_test.go @@ -222,6 +222,58 @@ func TestReceiveSnapshotStream_StreamingTokenWhenFSMSnapDirSet(t *testing.T) { require.Equal(t, senderFSM.Applied(), receiverFSM.Applied()) } +func TestSendSnapshotProtectsFinalizedFSMFileUntilEngineRelease(t *testing.T) { + const index = uint64(91) + + senderFSM := &testStateMachine{} + senderFSM.Apply([]byte("entry-for-protection-test")) + snap, err := senderFSM.Snapshot() + require.NoError(t, err) + var buf bytes.Buffer + _, err = snap.WriteTo(&buf) + require.NoError(t, err) + require.NoError(t, snap.Close()) + + metadata := raftpb.Message{ + Type: raftpb.MsgSnap, + From: 1, + To: 2, + Snapshot: &raftpb.Snapshot{ + Metadata: raftpb.SnapshotMetadata{Index: index, Term: 1}, + }, + } + raw, err := metadata.Marshal() + require.NoError(t, err) + + fsmSnapDir := t.TempDir() + var protected []uint64 + var unprotected []uint64 + transport := NewGRPCTransport(nil) + transport.SetSpoolDir(t.TempDir()) + transport.SetFSMSnapDir(fsmSnapDir) + transport.SetFSMSnapshotProtection( + func(index uint64) { protected = append(protected, index) }, + func(index uint64) { unprotected = append(unprotected, index) }, + ) + transport.SetHandler(func(_ context.Context, msg raftpb.Message) error { + require.NotNil(t, msg.Snapshot) + require.True(t, isSnapshotToken(msg.Snapshot.Data)) + return nil + }) + + stream := &testSendSnapshotServer{ + chunks: []*pb.EtcdRaftSnapshotChunk{ + {Metadata: raw}, + {Chunk: buf.Bytes(), Final: true}, + }, + } + + require.NoError(t, transport.SendSnapshot(stream)) + require.Equal(t, []uint64{index}, protected) + require.Empty(t, unprotected) + require.FileExists(t, fsmSnapPath(fsmSnapDir, index)) +} + func TestDrainSnapshotChunksPreparesBeforePayloadWrite(t *testing.T) { const index = uint64(124) payload := []byte("payload written after prepare") @@ -260,7 +312,7 @@ func TestDrainSnapshotChunksPreparesBeforePayloadWrite(t *testing.T) { }}, } - msg, payloadBytes, err := drainSnapshotChunks(stream, spool, fsmSnapDir, prepareFn) + msg, payloadBytes, err := drainSnapshotChunks(stream, spool, fsmSnapDir, prepareFn, nil) require.NoError(t, err) require.Equal(t, int64(len(payload)), payloadBytes) require.Equal(t, 1, prepareCalls) @@ -290,7 +342,7 @@ func TestDrainSnapshotChunksRejectsPayloadBeforeMetadata(t *testing.T) { }}, } - _, payloadBytes, err := drainSnapshotChunks(stream, spool, fsmSnapDir, prepareFn) + _, payloadBytes, err := drainSnapshotChunks(stream, spool, fsmSnapDir, prepareFn, nil) require.ErrorIs(t, err, errSnapshotMetadataNil) require.Zero(t, payloadBytes) require.Zero(t, prepareCalls) @@ -363,8 +415,8 @@ func TestReceiveSnapshotStream_SpoolPlacedInFSMSnapDir(t *testing.T) { // TestSendSnapshot_ApplyFailureRemovesFinalizedFSMFile pins the // orphan-cleanup behaviour from PR #747 round-4 (Codex P2): when the // receive path successfully finalizes the snapshot as -// fsmSnapDir/.fsm but the engine's apply (t.handle) then fails -// — transient context cancel, raft error, etc. — the finalized .fsm +// fsmSnapDir/.fsm but the engine handler (t.handle) then fails +// — transient context cancel, closed engine, etc. — the finalized .fsm // file MUST be removed. Otherwise retries at later snapshot indexes // accumulate orphan .fsm payloads in fsmSnapDir until startup runs // cleanupStaleFSMSnaps. Same-index retries are already safe via @@ -401,6 +453,12 @@ func TestSendSnapshot_ApplyFailureRemovesFinalizedFSMFile(t *testing.T) { transport := NewGRPCTransport(nil) transport.SetSpoolDir(t.TempDir()) transport.SetFSMSnapDir(fsmSnapDir) + var protected []uint64 + var unprotected []uint64 + transport.SetFSMSnapshotProtection( + func(index uint64) { protected = append(protected, index) }, + func(index uint64) { unprotected = append(unprotected, index) }, + ) // Wire a handler that always fails so SendSnapshot exercises the // orphan-cleanup branch. @@ -419,6 +477,8 @@ func TestSendSnapshot_ApplyFailureRemovesFinalizedFSMFile(t *testing.T) { err = transport.SendSnapshot(stream) require.Error(t, err) require.ErrorIs(t, err, applyErr, "SendSnapshot must surface the apply failure") + require.Equal(t, []uint64{index}, protected) + require.Equal(t, []uint64{index}, unprotected) // THE point: the .fsm file at the canonical path MUST have been // removed. Without the cleanup, leader retries at later indexes From fadd997e8d93ac4c12eb3bb0276af17586869b13 Mon Sep 17 00:00:00 2001 From: bootjp Date: Fri, 3 Jul 2026 00:19:05 +0900 Subject: [PATCH 07/14] raft: close received snapshot protection races --- internal/raftengine/etcd/engine.go | 43 ++++++ .../etcd/engine_applied_index_test.go | 55 +++++++ internal/raftengine/etcd/grpc_transport.go | 104 +++++++++++-- .../raftengine/etcd/grpc_transport_test.go | 139 +++++++++++++++++- 4 files changed, 328 insertions(+), 13 deletions(-) diff --git a/internal/raftengine/etcd/engine.go b/internal/raftengine/etcd/engine.go index f4e42ce6b..6b96c8504 100644 --- a/internal/raftengine/etcd/engine.go +++ b/internal/raftengine/etcd/engine.go @@ -1803,6 +1803,7 @@ func (e *Engine) drainReady() error { if err := e.applyCommitted(rd.CommittedEntries); err != nil { return err } + e.releaseProtectedReceivedFSMSnapshotsUpTo(e.appliedIndex.Load()) e.handleReadStates(rd.ReadStates) e.rawNode.Advance(rd) if err := e.maybePersistLocalSnapshot(); err != nil { @@ -1877,10 +1878,13 @@ func (e *Engine) handleStep(msg raftpb.Message) { e.recordQuorumAck(msg) if err := e.rawNode.Step(msg); err != nil { if errors.Is(err, etcdraft.ErrStepPeerNotFound) { + e.unprotectReceivedFSMSnapshotToken(msg) return } e.fail(errors.WithStack(err)) + return } + e.unprotectReceivedFSMSnapshotTokenIfApplied(msg) } // recordQuorumAck updates the per-peer last-response time when msg is @@ -2823,6 +2827,9 @@ func (e *Engine) protectReceivedFSMSnapshot(index uint64) { } e.snapshotMu.Lock() defer e.snapshotMu.Unlock() + if index <= e.appliedIndex.Load() { + return + } if e.protectedReceivedFSMSnaps == nil { e.protectedReceivedFSMSnaps = make(map[uint64]int, 1) } @@ -2861,6 +2868,42 @@ func (e *Engine) releaseProtectedReceivedFSMSnapshotsUpToLocked(index uint64) { } } +func (e *Engine) releaseProtectedReceivedFSMSnapshotsUpTo(index uint64) { + if index == 0 { + return + } + e.snapshotMu.Lock() + defer e.snapshotMu.Unlock() + e.releaseProtectedReceivedFSMSnapshotsUpToLocked(index) +} + +func (e *Engine) unprotectReceivedFSMSnapshotTokenIfApplied(msg raftpb.Message) { + index, ok := receivedFSMSnapshotTokenIndex(msg) + if !ok || index > e.appliedIndex.Load() { + return + } + e.unprotectReceivedFSMSnapshot(index) +} + +func (e *Engine) unprotectReceivedFSMSnapshotToken(msg raftpb.Message) { + index, ok := receivedFSMSnapshotTokenIndex(msg) + if !ok { + return + } + e.unprotectReceivedFSMSnapshot(index) +} + +func receivedFSMSnapshotTokenIndex(msg raftpb.Message) (uint64, bool) { + if msg.Type != raftpb.MsgSnap || msg.Snapshot == nil || !isSnapshotToken(msg.Snapshot.Data) { + return 0, false + } + tok, err := decodeSnapshotToken(msg.Snapshot.Data) + if err != nil || tok.Index == 0 { + return 0, false + } + return tok.Index, true +} + func (e *Engine) protectedReceivedFSMSnapshotIndexesLocked() map[uint64]bool { if len(e.protectedReceivedFSMSnaps) == 0 { return nil diff --git a/internal/raftengine/etcd/engine_applied_index_test.go b/internal/raftengine/etcd/engine_applied_index_test.go index a25ea6b52..3f30685a4 100644 --- a/internal/raftengine/etcd/engine_applied_index_test.go +++ b/internal/raftengine/etcd/engine_applied_index_test.go @@ -154,6 +154,61 @@ func TestPersistReadyWithSnapshotHoldsSnapshotMuThroughSaveSnap(t *testing.T) { require.Empty(t, e.protectedReceivedFSMSnaps) } +func TestProtectReceivedFSMSnapshotRechecksAppliedIndexUnderLock(t *testing.T) { + e := &Engine{} + e.snapshotMu.Lock() + done := make(chan struct{}) + go func() { + defer close(done) + e.protectReceivedFSMSnapshot(9) + }() + + time.Sleep(10 * time.Millisecond) + e.appliedIndex.Store(9) + e.snapshotMu.Unlock() + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("protectReceivedFSMSnapshot did not return") + } + require.Empty(t, e.protectedReceivedFSMSnaps) +} + +func TestUnprotectReceivedFSMSnapshotTokenIfApplied(t *testing.T) { + e := &Engine{ + protectedReceivedFSMSnaps: map[uint64]int{9: 1}, + } + e.appliedIndex.Store(9) + msg := raftpb.Message{ + Type: raftpb.MsgSnap, + Snapshot: &raftpb.Snapshot{ + Data: encodeSnapshotToken(9, 0), + }, + } + + e.unprotectReceivedFSMSnapshotTokenIfApplied(msg) + + require.Empty(t, e.protectedReceivedFSMSnaps) +} + +func TestUnprotectReceivedFSMSnapshotTokenIfAppliedKeepsFutureSnapshot(t *testing.T) { + e := &Engine{ + protectedReceivedFSMSnaps: map[uint64]int{10: 1}, + } + e.appliedIndex.Store(9) + msg := raftpb.Message{ + Type: raftpb.MsgSnap, + Snapshot: &raftpb.Snapshot{ + Data: encodeSnapshotToken(10, 0), + }, + } + + e.unprotectReceivedFSMSnapshotTokenIfApplied(msg) + + require.Equal(t, map[uint64]int{10: 1}, e.protectedReceivedFSMSnaps) +} + // TestRecordingFSM_SatisfiesAppliedIndexWriter is a compile-time- // adjacent assertion: the recording FSM MUST satisfy the writer // seam so the engine hook actually fires for it. diff --git a/internal/raftengine/etcd/grpc_transport.go b/internal/raftengine/etcd/grpc_transport.go index 0b69da9b3..72e428837 100644 --- a/internal/raftengine/etcd/grpc_transport.go +++ b/internal/raftengine/etcd/grpc_transport.go @@ -764,20 +764,26 @@ func (t *GRPCTransport) snapshotSpoolPlacement() ( fsmSnapDir string, prepareFn func(uint64) error, protectFn func(uint64), + unprotectFn func(uint64), ) { t.mu.RLock() defer t.mu.RUnlock() fsmSnapDir = t.fsmSnapDir prepareFn = t.prepareFSMWrite protectFn = t.protectFSMWrite + unprotectFn = t.unprotectFSMWrite if fsmSnapDir != "" { - return fsmSnapDir, fsmSnapDir, prepareFn, protectFn + return fsmSnapDir, fsmSnapDir, prepareFn, protectFn, unprotectFn } - return t.spoolDir, "", prepareFn, protectFn + return t.spoolDir, "", prepareFn, protectFn, unprotectFn } func (t *GRPCTransport) receiveSnapshotStream(stream pb.EtcdRaft_SendSnapshotServer) (raftpb.Message, error) { - spoolPlacement, fsmSnapDir, prepareFn, protectFn := t.snapshotSpoolPlacement() + spoolPlacement, fsmSnapDir, prepareFn, protectFn, unprotectFn := t.snapshotSpoolPlacement() + metadata, firstPayloadChunk, preparedFSMWrite, err := receiveSnapshotMetadata(stream, fsmSnapDir, prepareFn) + if err != nil { + return raftpb.Message{}, err + } spool, err := newSnapshotSpool(spoolPlacement) if err != nil { return raftpb.Message{}, err @@ -796,7 +802,17 @@ func (t *GRPCTransport) receiveSnapshotStream(stream pb.EtcdRaft_SendSnapshotSer } }() - msg, payloadBytes, err := drainSnapshotChunks(stream, spool, fsmSnapDir, prepareFn, protectFn) + msg, payloadBytes, err := drainSnapshotChunksFrom( + stream, + spool, + fsmSnapDir, + prepareFn, + protectFn, + unprotectFn, + metadata, + firstPayloadChunk, + preparedFSMWrite, + ) if err != nil { return raftpb.Message{}, err } @@ -826,9 +842,24 @@ func drainSnapshotChunks( fsmSnapDir string, prepareFn func(uint64) error, protectFn func(uint64), + unprotectFn func(uint64), ) (raftpb.Message, int64, error) { var metadata raftpb.Message - seenMetadata := false + return drainSnapshotChunksFrom(stream, spool, fsmSnapDir, prepareFn, protectFn, unprotectFn, metadata, nil, false) +} + +func drainSnapshotChunksFrom( + stream pb.EtcdRaft_SendSnapshotServer, + spool *snapshotSpool, + fsmSnapDir string, + prepareFn func(uint64) error, + protectFn func(uint64), + unprotectFn func(uint64), + metadata raftpb.Message, + firstPayloadChunk *pb.EtcdRaftSnapshotChunk, + preparedFSMWrite bool, +) (raftpb.Message, int64, error) { + seenMetadata := metadata.Snapshot != nil // Wrap spool with crc32CWriter so the CRC accumulates as bytes hit // disk. The CRC is only meaningful when we have an fsmSnapDir to // finalize into; the legacy fallback path discards it. Cost is @@ -838,9 +869,8 @@ func drainSnapshotChunks( crcWriter := newCRC32CWriter(spool) var payloadBytes int64 - preparedFSMWrite := false for { - chunk, err := recvSnapshotChunk(stream) + chunk, err := nextSnapshotChunk(stream, &firstPayloadChunk) if err != nil { return raftpb.Message{}, 0, err } @@ -860,7 +890,7 @@ func drainSnapshotChunks( } payloadBytes += int64(len(chunk.Chunk)) if chunk.Final { - msg, err := finalizeReceivedSnapshot(metadata, spool, crcWriter.Sum32(), fsmSnapDir, protectFn, seenMetadata) + msg, err := finalizeReceivedSnapshot(metadata, spool, crcWriter.Sum32(), fsmSnapDir, protectFn, unprotectFn, seenMetadata) if err != nil { return raftpb.Message{}, 0, err } @@ -869,6 +899,52 @@ func drainSnapshotChunks( } } +func nextSnapshotChunk( + stream pb.EtcdRaft_SendSnapshotServer, + firstPayloadChunk **pb.EtcdRaftSnapshotChunk, +) (*pb.EtcdRaftSnapshotChunk, error) { + if *firstPayloadChunk != nil { + chunk := *firstPayloadChunk + *firstPayloadChunk = nil + return chunk, nil + } + return recvSnapshotChunk(stream) +} + +func receiveSnapshotMetadata( + stream pb.EtcdRaft_SendSnapshotServer, + fsmSnapDir string, + prepareFn func(uint64) error, +) (raftpb.Message, *pb.EtcdRaftSnapshotChunk, bool, error) { + var metadata raftpb.Message + seenMetadata := false + for { + chunk, err := recvSnapshotChunk(stream) + if err != nil { + return raftpb.Message{}, nil, false, err + } + seen, err := appendSnapshotChunkMetadata(&metadata, chunk, seenMetadata) + if err != nil { + return raftpb.Message{}, nil, false, err + } + seenMetadata = seen + if !seenMetadata && len(chunk.Chunk) > 0 { + return raftpb.Message{}, nil, false, errors.WithStack(errSnapshotMetadataNil) + } + if seenMetadata { + prepared := maybePrepareReceivedFSMSnapshotWrite(metadata, fsmSnapDir, prepareFn, true) + firstPayloadChunk := &pb.EtcdRaftSnapshotChunk{ + Chunk: chunk.Chunk, + Final: chunk.Final, + } + return metadata, firstPayloadChunk, prepared, nil + } + if chunk.Final { + return raftpb.Message{}, nil, false, errors.WithStack(errSnapshotMetadataNil) + } + } +} + func recvSnapshotChunk(stream pb.EtcdRaft_SendSnapshotServer) (*pb.EtcdRaftSnapshotChunk, error) { chunk, err := stream.Recv() if err != nil { @@ -895,6 +971,7 @@ func finalizeReceivedSnapshot( crc32c uint32, fsmSnapDir string, protectFn func(uint64), + unprotectFn func(uint64), seenMetadata bool, ) (raftpb.Message, error) { if !seenMetadata || metadata.Snapshot == nil { @@ -902,13 +979,18 @@ func finalizeReceivedSnapshot( } index := metadata.Snapshot.Metadata.Index if fsmSnapDir != "" && index > 0 { + protected := false + if protectFn != nil { + protectFn(index) + protected = true + } if err := spool.FinalizeAsFSMFile(fsmSnapDir, index, crc32c); err != nil { + if protected && unprotectFn != nil { + unprotectFn(index) + } return raftpb.Message{}, err } metadata.Snapshot.Data = encodeSnapshotToken(index, crc32c) - if protectFn != nil { - protectFn(index) - } return metadata, nil } // Legacy fallback: full materialization. Used by tests that don't wire an diff --git a/internal/raftengine/etcd/grpc_transport_test.go b/internal/raftengine/etcd/grpc_transport_test.go index 66b2887e0..9c670862a 100644 --- a/internal/raftengine/etcd/grpc_transport_test.go +++ b/internal/raftengine/etcd/grpc_transport_test.go @@ -222,6 +222,49 @@ func TestReceiveSnapshotStream_StreamingTokenWhenFSMSnapDirSet(t *testing.T) { require.Equal(t, senderFSM.Applied(), receiverFSM.Applied()) } +func TestReceiveSnapshotStreamPreparesBeforeSpoolCreation(t *testing.T) { + const index = uint64(124) + payload := []byte("payload written after cleanup") + metadata := raftpb.Message{ + Type: raftpb.MsgSnap, + From: 1, + To: 2, + Snapshot: &raftpb.Snapshot{ + Metadata: raftpb.SnapshotMetadata{Index: index, Term: 1}, + }, + } + raw, err := metadata.Marshal() + require.NoError(t, err) + + fsmSnapDir := t.TempDir() + transport := NewGRPCTransport(nil) + transport.SetSpoolDir(t.TempDir()) + transport.SetFSMSnapDir(fsmSnapDir) + prepareCalls := 0 + transport.SetFSMSnapshotPrepare(func(got uint64) error { + prepareCalls++ + require.Equal(t, index, got) + matches, globErr := filepath.Glob(filepath.Join(fsmSnapDir, snapshotSpoolPattern)) + require.NoError(t, globErr) + require.Empty(t, matches, "prewrite cleanup must run before creating the receive spool") + return nil + }) + + stream := &testSendSnapshotServer{ + chunks: []*pb.EtcdRaftSnapshotChunk{{ + Metadata: raw, + Chunk: payload, + Final: true, + }}, + } + + msg, err := transport.receiveSnapshotStream(stream) + require.NoError(t, err) + require.Equal(t, 1, prepareCalls) + require.True(t, isSnapshotToken(msg.Snapshot.Data)) + require.FileExists(t, fsmSnapPath(fsmSnapDir, index)) +} + func TestSendSnapshotProtectsFinalizedFSMFileUntilEngineRelease(t *testing.T) { const index = uint64(91) @@ -274,6 +317,98 @@ func TestSendSnapshotProtectsFinalizedFSMFileUntilEngineRelease(t *testing.T) { require.FileExists(t, fsmSnapPath(fsmSnapDir, index)) } +func TestDrainSnapshotChunksProtectsBeforePublishingFSMFile(t *testing.T) { + const index = uint64(125) + payload := []byte("payload protected before final rename") + metadata := raftpb.Message{ + Type: raftpb.MsgSnap, + From: 1, + To: 2, + Snapshot: &raftpb.Snapshot{ + Metadata: raftpb.SnapshotMetadata{Index: index, Term: 1}, + }, + } + raw, err := metadata.Marshal() + require.NoError(t, err) + + fsmSnapDir := t.TempDir() + spool, err := newSnapshotSpool(fsmSnapDir) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, spool.Close()) + }) + var protected []uint64 + protectFn := func(got uint64) { + protected = append(protected, got) + _, statErr := os.Stat(fsmSnapPath(fsmSnapDir, got)) + require.True(t, os.IsNotExist(statErr), "protection must be registered before the final .fsm path is visible") + } + stream := &testSendSnapshotServer{ + chunks: []*pb.EtcdRaftSnapshotChunk{{ + Metadata: raw, + Chunk: payload, + Final: true, + }}, + } + + msg, payloadBytes, err := drainSnapshotChunks(stream, spool, fsmSnapDir, func(uint64) error { return nil }, protectFn, nil) + require.NoError(t, err) + require.Equal(t, int64(len(payload)), payloadBytes) + require.Equal(t, []uint64{index}, protected) + require.True(t, isSnapshotToken(msg.Snapshot.Data)) + require.FileExists(t, fsmSnapPath(fsmSnapDir, index)) +} + +func TestDrainSnapshotChunksUnprotectsWhenFinalizeFails(t *testing.T) { + const index = uint64(126) + payload := []byte("payload whose final rename fails") + metadata := raftpb.Message{ + Type: raftpb.MsgSnap, + From: 1, + To: 2, + Snapshot: &raftpb.Snapshot{ + Metadata: raftpb.SnapshotMetadata{Index: index, Term: 1}, + }, + } + raw, err := metadata.Marshal() + require.NoError(t, err) + + spool, err := newSnapshotSpool(t.TempDir()) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, spool.Close()) + }) + fsmSnapDir := t.TempDir() + var protected []uint64 + var unprotected []uint64 + syncErr := errors.New("simulated directory sync failure") + oldSnapshotSyncDir := snapshotSyncDir + snapshotSyncDir = func(string) error { return syncErr } + t.Cleanup(func() { + snapshotSyncDir = oldSnapshotSyncDir + }) + stream := &testSendSnapshotServer{ + chunks: []*pb.EtcdRaftSnapshotChunk{{ + Metadata: raw, + Chunk: payload, + Final: true, + }}, + } + + _, _, err = drainSnapshotChunks( + stream, + spool, + fsmSnapDir, + func(uint64) error { return nil }, + func(got uint64) { protected = append(protected, got) }, + func(got uint64) { unprotected = append(unprotected, got) }, + ) + require.Error(t, err) + require.ErrorIs(t, err, syncErr) + require.Equal(t, []uint64{index}, protected) + require.Equal(t, []uint64{index}, unprotected) +} + func TestDrainSnapshotChunksPreparesBeforePayloadWrite(t *testing.T) { const index = uint64(124) payload := []byte("payload written after prepare") @@ -312,7 +447,7 @@ func TestDrainSnapshotChunksPreparesBeforePayloadWrite(t *testing.T) { }}, } - msg, payloadBytes, err := drainSnapshotChunks(stream, spool, fsmSnapDir, prepareFn, nil) + msg, payloadBytes, err := drainSnapshotChunks(stream, spool, fsmSnapDir, prepareFn, nil, nil) require.NoError(t, err) require.Equal(t, int64(len(payload)), payloadBytes) require.Equal(t, 1, prepareCalls) @@ -342,7 +477,7 @@ func TestDrainSnapshotChunksRejectsPayloadBeforeMetadata(t *testing.T) { }}, } - _, payloadBytes, err := drainSnapshotChunks(stream, spool, fsmSnapDir, prepareFn, nil) + _, payloadBytes, err := drainSnapshotChunks(stream, spool, fsmSnapDir, prepareFn, nil, nil) require.ErrorIs(t, err, errSnapshotMetadataNil) require.Zero(t, payloadBytes) require.Zero(t, prepareCalls) From aebbbc2185bc47f6bf2ccb86a6ade59f7f03cc05 Mon Sep 17 00:00:00 2001 From: bootjp Date: Fri, 3 Jul 2026 00:25:04 +0900 Subject: [PATCH 08/14] raft: retain wal-backed snapshot fallback --- internal/raftengine/etcd/engine.go | 16 ++++++ internal/raftengine/etcd/engine_test.go | 35 +++++++++++++ internal/raftengine/etcd/fsm_snapshot_file.go | 49 +++++++++++++++++-- .../raftengine/etcd/fsm_snapshot_file_test.go | 26 ++++++++++ 4 files changed, 123 insertions(+), 3 deletions(-) diff --git a/internal/raftengine/etcd/engine.go b/internal/raftengine/etcd/engine.go index 6b96c8504..a49e9ce27 100644 --- a/internal/raftengine/etcd/engine.go +++ b/internal/raftengine/etcd/engine.go @@ -1884,6 +1884,13 @@ func (e *Engine) handleStep(msg raftpb.Message) { e.fail(errors.WithStack(err)) return } + if e.unprotectReceivedFSMSnapshotTokenIfCommitted(msg) { + return + } + if !e.rawNode.HasReady() { + e.unprotectReceivedFSMSnapshotToken(msg) + return + } e.unprotectReceivedFSMSnapshotTokenIfApplied(msg) } @@ -2885,6 +2892,15 @@ func (e *Engine) unprotectReceivedFSMSnapshotTokenIfApplied(msg raftpb.Message) e.unprotectReceivedFSMSnapshot(index) } +func (e *Engine) unprotectReceivedFSMSnapshotTokenIfCommitted(msg raftpb.Message) bool { + index, ok := receivedFSMSnapshotTokenIndex(msg) + if !ok || e.rawNode == nil || index > e.rawNode.Status().Commit { + return false + } + e.unprotectReceivedFSMSnapshot(index) + return true +} + func (e *Engine) unprotectReceivedFSMSnapshotToken(msg raftpb.Message) { index, ok := receivedFSMSnapshotTokenIndex(msg) if !ok { diff --git a/internal/raftengine/etcd/engine_test.go b/internal/raftengine/etcd/engine_test.go index 80525c608..1f9eda983 100644 --- a/internal/raftengine/etcd/engine_test.go +++ b/internal/raftengine/etcd/engine_test.go @@ -557,6 +557,41 @@ func TestHandleStepIgnoresPeerNotFoundResponses(t *testing.T) { require.NoError(t, engine.currentError()) } +func TestHandleStepUnprotectsSnapshotTokenWhenCommittedAlreadyCoversIt(t *testing.T) { + storage := etcdraft.NewMemoryStorage() + require.NoError(t, storage.ApplySnapshot(raftpb.Snapshot{ + Metadata: raftpb.SnapshotMetadata{ + ConfState: raftpb.ConfState{Voters: []uint64{1}}, + Index: 10, + Term: 1, + }, + })) + engine := &Engine{ + rawNode: mustRawNode(t, storage, 1), + protectedReceivedFSMSnaps: map[uint64]int{ + 9: 1, + }, + } + require.False(t, engine.rawNode.HasReady()) + + engine.handleStep(raftpb.Message{ + Type: raftpb.MsgSnap, + From: 2, + To: 1, + Snapshot: &raftpb.Snapshot{ + Data: encodeSnapshotToken(9, 0), + Metadata: raftpb.SnapshotMetadata{ + ConfState: raftpb.ConfState{Voters: []uint64{1}}, + Index: 9, + Term: 1, + }, + }, + }) + + require.NoError(t, engine.currentError()) + require.Empty(t, engine.protectedReceivedFSMSnaps) +} + func TestApplyReadySnapshotAdvancesAppliedIndex(t *testing.T) { engine := &Engine{ storage: etcdraft.NewMemoryStorage(), diff --git a/internal/raftengine/etcd/fsm_snapshot_file.go b/internal/raftengine/etcd/fsm_snapshot_file.go index 069315dad..6fc83d090 100644 --- a/internal/raftengine/etcd/fsm_snapshot_file.go +++ b/internal/raftengine/etcd/fsm_snapshot_file.go @@ -15,6 +15,8 @@ import ( "strings" "github.com/cockroachdb/errors" + "go.etcd.io/etcd/server/v3/storage/wal" + "go.uber.org/zap" ) const ( @@ -566,6 +568,7 @@ type snapFileCandidate struct { name string index uint64 restorable bool + walValid bool } type prewriteSnapshotRetention struct { @@ -590,7 +593,11 @@ func purgeOlderSnapshotPairsBeforeWrite( return errors.WithStack(err) } - candidates := collectPrewriteSnapCandidates(entries, fsmSnapDir, nextIndex) + walValidIndexes, err := prewriteWALSnapshotIndexes(snapDir) + if err != nil { + return errors.WithStack(err) + } + candidates := collectPrewriteSnapCandidates(entries, fsmSnapDir, nextIndex, walValidIndexes) sort.Slice(candidates, func(i, j int) bool { if candidates[i].index == candidates[j].index { return candidates[i].name < candidates[j].name @@ -653,8 +660,9 @@ func keepRestorablePrewriteSnapshots(candidates []snapFileCandidate) prewriteSna retention := prewriteSnapshotRetention{ keep: make(map[string]bool, prewriteSnapKeep), } + walFiltered := prewriteCandidatesHaveWALFilter(candidates) for i := len(candidates) - 1; i >= 0 && len(retention.keep) < prewriteSnapKeep; i-- { - if candidates[i].restorable { + if candidates[i].restorable && (!walFiltered || candidates[i].walValid) { retention.keep[candidates[i].name] = true if retention.restorableFloor == 0 || candidates[i].index < retention.restorableFloor { retention.restorableFloor = candidates[i].index @@ -667,7 +675,41 @@ func keepRestorablePrewriteSnapshots(candidates []snapFileCandidate) prewriteSna return retention } -func collectPrewriteSnapCandidates(entries []os.DirEntry, fsmSnapDir string, nextIndex uint64) []snapFileCandidate { +func prewriteCandidatesHaveWALFilter(candidates []snapFileCandidate) bool { + for _, candidate := range candidates { + if candidate.walValid { + return true + } + } + return false +} + +var prewriteWALSnapshotIndexes = loadPrewriteWALSnapshotIndexes + +func loadPrewriteWALSnapshotIndexes(snapDir string) (map[uint64]bool, error) { + walDir := filepath.Join(filepath.Dir(snapDir), walDirName) + if !wal.Exist(walDir) { + return nil, nil + } + walSnaps, err := wal.ValidSnapshotEntries(zap.NewNop(), walDir) + if err != nil { + return nil, errors.WithStack(err) + } + indexes := make(map[uint64]bool, len(walSnaps)) + for _, snap := range walSnaps { + if snap.Index > 0 { + indexes[snap.Index] = true + } + } + return indexes, nil +} + +func collectPrewriteSnapCandidates( + entries []os.DirEntry, + fsmSnapDir string, + nextIndex uint64, + walValidIndexes map[uint64]bool, +) []snapFileCandidate { candidates := make([]snapFileCandidate, 0, len(entries)) for _, e := range entries { if e.IsDir() || filepath.Ext(e.Name()) != snapFileExt { @@ -681,6 +723,7 @@ func collectPrewriteSnapCandidates(entries []os.DirEntry, fsmSnapDir string, nex name: e.Name(), index: index, restorable: fsmSnapshotFileRestorable(fsmSnapDir, index), + walValid: walValidIndexes == nil || walValidIndexes[index], }) } return candidates diff --git a/internal/raftengine/etcd/fsm_snapshot_file_test.go b/internal/raftengine/etcd/fsm_snapshot_file_test.go index 30b9f34d5..9b2211c1d 100644 --- a/internal/raftengine/etcd/fsm_snapshot_file_test.go +++ b/internal/raftengine/etcd/fsm_snapshot_file_test.go @@ -366,6 +366,32 @@ func TestPrepareFSMSnapshotWriteKeepsNewestRestorablePair(t *testing.T) { require.FileExists(t, filepath.Join(fsmSnapDir, "leftover.fsm.tmp")) } +func TestPrepareFSMSnapshotWriteKeepsWALValidFallbackPair(t *testing.T) { + snapDir := t.TempDir() + fsmSnapDir := t.TempDir() + payload := []byte("payload") + + createSnapFile(t, snapDir, 100) + writeFSMFileForTest(t, fsmSnapDir, 100, payload) + createSnapFile(t, snapDir, 200) + writeFSMFileForTest(t, fsmSnapDir, 200, payload) + + oldPrewriteWALSnapshotIndexes := prewriteWALSnapshotIndexes + prewriteWALSnapshotIndexes = func(string) (map[uint64]bool, error) { + return map[uint64]bool{100: true}, nil + } + t.Cleanup(func() { + prewriteWALSnapshotIndexes = oldPrewriteWALSnapshotIndexes + }) + + require.NoError(t, prepareFSMSnapshotWrite(snapDir, fsmSnapDir, 300)) + + require.FileExists(t, filepath.Join(snapDir, "0000000000000001-0000000000000064.snap")) + require.NoFileExists(t, filepath.Join(snapDir, "0000000000000001-00000000000000c8.snap")) + require.FileExists(t, fsmSnapPath(fsmSnapDir, 100)) + require.NoFileExists(t, fsmSnapPath(fsmSnapDir, 200)) +} + func TestPrepareFSMSnapshotWritePreservesProtectedReceivedFSM(t *testing.T) { snapDir := t.TempDir() fsmSnapDir := t.TempDir() From ea55b717c9d77c397ae2fa71471d706630d23933 Mon Sep 17 00:00:00 2001 From: bootjp Date: Fri, 3 Jul 2026 00:52:25 +0900 Subject: [PATCH 09/14] raft: keep accepted snapshots protected --- internal/raftengine/etcd/engine.go | 7 ++-- internal/raftengine/etcd/engine_test.go | 36 +++++++++++++++++++ internal/raftengine/etcd/fsm_snapshot_file.go | 17 ++++----- .../raftengine/etcd/fsm_snapshot_file_test.go | 15 ++++++++ 4 files changed, 62 insertions(+), 13 deletions(-) diff --git a/internal/raftengine/etcd/engine.go b/internal/raftengine/etcd/engine.go index a49e9ce27..a2f134b44 100644 --- a/internal/raftengine/etcd/engine.go +++ b/internal/raftengine/etcd/engine.go @@ -1876,6 +1876,7 @@ func (e *Engine) handleStep(msg raftpb.Message) { } e.recordLeaderContact(msg) e.recordQuorumAck(msg) + commitBeforeStep := e.rawNode.Status().Commit if err := e.rawNode.Step(msg); err != nil { if errors.Is(err, etcdraft.ErrStepPeerNotFound) { e.unprotectReceivedFSMSnapshotToken(msg) @@ -1884,7 +1885,7 @@ func (e *Engine) handleStep(msg raftpb.Message) { e.fail(errors.WithStack(err)) return } - if e.unprotectReceivedFSMSnapshotTokenIfCommitted(msg) { + if e.unprotectReceivedFSMSnapshotTokenIfCommitted(msg, commitBeforeStep) { return } if !e.rawNode.HasReady() { @@ -2892,9 +2893,9 @@ func (e *Engine) unprotectReceivedFSMSnapshotTokenIfApplied(msg raftpb.Message) e.unprotectReceivedFSMSnapshot(index) } -func (e *Engine) unprotectReceivedFSMSnapshotTokenIfCommitted(msg raftpb.Message) bool { +func (e *Engine) unprotectReceivedFSMSnapshotTokenIfCommitted(msg raftpb.Message, committedIndex uint64) bool { index, ok := receivedFSMSnapshotTokenIndex(msg) - if !ok || e.rawNode == nil || index > e.rawNode.Status().Commit { + if !ok || index > committedIndex { return false } e.unprotectReceivedFSMSnapshot(index) diff --git a/internal/raftengine/etcd/engine_test.go b/internal/raftengine/etcd/engine_test.go index 1f9eda983..ec26c3ee7 100644 --- a/internal/raftengine/etcd/engine_test.go +++ b/internal/raftengine/etcd/engine_test.go @@ -592,6 +592,42 @@ func TestHandleStepUnprotectsSnapshotTokenWhenCommittedAlreadyCoversIt(t *testin require.Empty(t, engine.protectedReceivedFSMSnaps) } +func TestHandleStepKeepsAcceptedSnapshotTokenProtectedUntilReady(t *testing.T) { + storage := etcdraft.NewMemoryStorage() + require.NoError(t, storage.ApplySnapshot(raftpb.Snapshot{ + Metadata: raftpb.SnapshotMetadata{ + ConfState: raftpb.ConfState{Voters: []uint64{1}}, + Index: 10, + Term: 1, + }, + })) + engine := &Engine{ + rawNode: mustRawNode(t, storage, 1), + protectedReceivedFSMSnaps: map[uint64]int{ + 11: 1, + }, + } + require.False(t, engine.rawNode.HasReady()) + + engine.handleStep(raftpb.Message{ + Type: raftpb.MsgSnap, + From: 2, + To: 1, + Snapshot: &raftpb.Snapshot{ + Data: encodeSnapshotToken(11, 0), + Metadata: raftpb.SnapshotMetadata{ + ConfState: raftpb.ConfState{Voters: []uint64{1}}, + Index: 11, + Term: 2, + }, + }, + }) + + require.NoError(t, engine.currentError()) + require.True(t, engine.rawNode.HasReady()) + require.Equal(t, map[uint64]int{11: 1}, engine.protectedReceivedFSMSnaps) +} + func TestApplyReadySnapshotAdvancesAppliedIndex(t *testing.T) { engine := &Engine{ storage: etcdraft.NewMemoryStorage(), diff --git a/internal/raftengine/etcd/fsm_snapshot_file.go b/internal/raftengine/etcd/fsm_snapshot_file.go index 6fc83d090..e3f8ff445 100644 --- a/internal/raftengine/etcd/fsm_snapshot_file.go +++ b/internal/raftengine/etcd/fsm_snapshot_file.go @@ -611,7 +611,6 @@ func purgeOlderSnapshotPairsBeforeWrite( combined = errors.CombineErrors(combined, removePrewriteFSMOrphansBeforeIndex( snapDir, fsmSnapDir, - retention, protectedIndexes, nextIndex, )) @@ -641,19 +640,17 @@ func purgeUnretainedPrewriteSnapshots( func removePrewriteFSMOrphansBeforeIndex( snapDir string, fsmSnapDir string, - retention prewriteSnapshotRetention, protectedIndexes map[uint64]bool, nextIndex uint64, ) error { - if retention.restorableFloor > 0 { - liveIndexes, err := collectLiveSnapIndexes(snapDir) - if err != nil { - return errors.WithStack(err) - } else if liveIndexes != nil { - return removeStaleFSMFilesBelowIndex(fsmSnapDir, liveIndexes, protectedIndexes, nextIndex) - } + liveIndexes, err := collectLiveSnapIndexes(snapDir) + if err != nil { + return errors.WithStack(err) } - return nil + if liveIndexes == nil { + return nil + } + return removeStaleFSMFilesBelowIndex(fsmSnapDir, liveIndexes, protectedIndexes, nextIndex) } func keepRestorablePrewriteSnapshots(candidates []snapFileCandidate) prewriteSnapshotRetention { diff --git a/internal/raftengine/etcd/fsm_snapshot_file_test.go b/internal/raftengine/etcd/fsm_snapshot_file_test.go index 9b2211c1d..89499a04d 100644 --- a/internal/raftengine/etcd/fsm_snapshot_file_test.go +++ b/internal/raftengine/etcd/fsm_snapshot_file_test.go @@ -392,6 +392,21 @@ func TestPrepareFSMSnapshotWriteKeepsWALValidFallbackPair(t *testing.T) { require.NoFileExists(t, fsmSnapPath(fsmSnapDir, 200)) } +func TestPrepareFSMSnapshotWriteRemovesOrphansWithoutRetainedSnapshot(t *testing.T) { + snapDir := t.TempDir() + fsmSnapDir := t.TempDir() + payload := []byte("payload") + + writeFSMFileForTest(t, fsmSnapDir, 100, payload) + writeFSMFileForTest(t, fsmSnapDir, 200, payload) + + protected := map[uint64]bool{200: true} + require.NoError(t, prepareFSMSnapshotWriteProtected(snapDir, fsmSnapDir, 300, protected)) + + require.NoFileExists(t, fsmSnapPath(fsmSnapDir, 100)) + require.FileExists(t, fsmSnapPath(fsmSnapDir, 200)) +} + func TestPrepareFSMSnapshotWritePreservesProtectedReceivedFSM(t *testing.T) { snapDir := t.TempDir() fsmSnapDir := t.TempDir() From 7f884f4ea8e3e5c38f973b9db0cfc00f5de81389 Mon Sep 17 00:00:00 2001 From: bootjp Date: Fri, 3 Jul 2026 01:15:44 +0900 Subject: [PATCH 10/14] raft: release ignored snapshot protections --- internal/raftengine/etcd/engine.go | 37 ++++++++++- .../etcd/engine_applied_index_test.go | 30 +++++++++ internal/raftengine/etcd/fsm_snapshot_file.go | 62 +++++++++++++++---- .../raftengine/etcd/fsm_snapshot_file_test.go | 36 ++++++++++- main.go | 2 +- main_raft_lifecycle_test.go | 12 ++++ 6 files changed, 161 insertions(+), 18 deletions(-) diff --git a/internal/raftengine/etcd/engine.go b/internal/raftengine/etcd/engine.go index a2f134b44..2bc14a027 100644 --- a/internal/raftengine/etcd/engine.go +++ b/internal/raftengine/etcd/engine.go @@ -310,8 +310,9 @@ type Engine struct { // Restore swaps the underlying store state and must not race with the short // critical section that publishes a newly persisted local snapshot. - snapshotMu sync.Mutex - protectedReceivedFSMSnaps map[uint64]int + snapshotMu sync.Mutex + protectedReceivedFSMSnaps map[uint64]int + pendingReceivedFSMSnapshotStep map[uint64]int dispatchDropCount atomic.Uint64 dispatchErrorCount atomic.Uint64 @@ -1797,6 +1798,7 @@ func (e *Engine) drainReady() error { if err := e.persistReady(rd); err != nil { return err } + e.releaseIgnoredReceivedFSMSnapshotSteps(rd) if err := e.sendMessages(rd.Messages); err != nil { return err } @@ -1892,6 +1894,7 @@ func (e *Engine) handleStep(msg raftpb.Message) { e.unprotectReceivedFSMSnapshotToken(msg) return } + e.trackReceivedFSMSnapshotStep(msg) e.unprotectReceivedFSMSnapshotTokenIfApplied(msg) } @@ -2902,6 +2905,36 @@ func (e *Engine) unprotectReceivedFSMSnapshotTokenIfCommitted(msg raftpb.Message return true } +func (e *Engine) trackReceivedFSMSnapshotStep(msg raftpb.Message) { + index, ok := receivedFSMSnapshotTokenIndex(msg) + if !ok || index <= e.appliedIndex.Load() { + return + } + if e.pendingReceivedFSMSnapshotStep == nil { + e.pendingReceivedFSMSnapshotStep = make(map[uint64]int, 1) + } + e.pendingReceivedFSMSnapshotStep[index]++ +} + +func (e *Engine) releaseIgnoredReceivedFSMSnapshotSteps(rd etcdraft.Ready) { + if len(e.pendingReceivedFSMSnapshotStep) == 0 { + return + } + snapshotIndex := uint64(0) + if !etcdraft.IsEmptySnap(rd.Snapshot) { + snapshotIndex = rd.Snapshot.Metadata.Index + } + for index, count := range e.pendingReceivedFSMSnapshotStep { + delete(e.pendingReceivedFSMSnapshotStep, index) + if index == snapshotIndex { + continue + } + for i := 0; i < count; i++ { + e.unprotectReceivedFSMSnapshot(index) + } + } +} + func (e *Engine) unprotectReceivedFSMSnapshotToken(msg raftpb.Message) { index, ok := receivedFSMSnapshotTokenIndex(msg) if !ok { diff --git a/internal/raftengine/etcd/engine_applied_index_test.go b/internal/raftengine/etcd/engine_applied_index_test.go index 3f30685a4..123425f55 100644 --- a/internal/raftengine/etcd/engine_applied_index_test.go +++ b/internal/raftengine/etcd/engine_applied_index_test.go @@ -209,6 +209,36 @@ func TestUnprotectReceivedFSMSnapshotTokenIfAppliedKeepsFutureSnapshot(t *testin require.Equal(t, map[uint64]int{10: 1}, e.protectedReceivedFSMSnaps) } +func TestReleaseIgnoredReceivedFSMSnapshotStepsUnprotectsNonSnapshotReady(t *testing.T) { + e := &Engine{ + protectedReceivedFSMSnaps: map[uint64]int{10: 1}, + pendingReceivedFSMSnapshotStep: map[uint64]int{ + 10: 1, + }, + } + + e.releaseIgnoredReceivedFSMSnapshotSteps(etcdraft.Ready{}) + + require.Empty(t, e.protectedReceivedFSMSnaps) + require.Empty(t, e.pendingReceivedFSMSnapshotStep) +} + +func TestReleaseIgnoredReceivedFSMSnapshotStepsKeepsSnapshotReadyProtected(t *testing.T) { + e := &Engine{ + protectedReceivedFSMSnaps: map[uint64]int{10: 1}, + pendingReceivedFSMSnapshotStep: map[uint64]int{ + 10: 1, + }, + } + + e.releaseIgnoredReceivedFSMSnapshotSteps(etcdraft.Ready{ + Snapshot: raftpb.Snapshot{Metadata: raftpb.SnapshotMetadata{Index: 10, Term: 1}}, + }) + + require.Equal(t, map[uint64]int{10: 1}, e.protectedReceivedFSMSnaps) + require.Empty(t, e.pendingReceivedFSMSnapshotStep) +} + // TestRecordingFSM_SatisfiesAppliedIndexWriter is a compile-time- // adjacent assertion: the recording FSM MUST satisfy the writer // seam so the engine hook actually fires for it. diff --git a/internal/raftengine/etcd/fsm_snapshot_file.go b/internal/raftengine/etcd/fsm_snapshot_file.go index e3f8ff445..49f66fb85 100644 --- a/internal/raftengine/etcd/fsm_snapshot_file.go +++ b/internal/raftengine/etcd/fsm_snapshot_file.go @@ -141,16 +141,25 @@ func fsmSnapPath(fsmSnapDir string, index uint64) string { // Snap files are named "{term:016x}-{index:016x}.snap". // Returns 0 on parse failure. func parseSnapFileIndex(name string) uint64 { + _, index := parseSnapFileTermIndex(name) + return index +} + +func parseSnapFileTermIndex(name string) (uint64, uint64) { base := strings.TrimSuffix(name, snapFileExt) idx := strings.LastIndex(base, "-") if idx < 0 { - return 0 + return 0, 0 + } + term, err := strconv.ParseUint(base[:idx], 16, 64) + if err != nil { + return 0, 0 } index, err := strconv.ParseUint(base[idx+1:], 16, 64) if err != nil { - return 0 + return 0, 0 } - return index + return term, index } // crc32CWriter wraps an io.Writer and accumulates a CRC32C checksum over all @@ -629,7 +638,13 @@ func purgeUnretainedPrewriteSnapshots( if retention.keep[candidate.name] { continue } - if err := purgeSnapPair(snapDir, fsmSnapDir, candidate.name); err != nil { + var err error + if retainedPrewriteSnapshotIndex(candidates, retention, candidate.index) { + err = purgeSnapFile(snapDir, candidate.name) + } else { + err = purgeSnapPair(snapDir, fsmSnapDir, candidate.name) + } + if err != nil { combined = errors.CombineErrors(combined, err) } } @@ -637,6 +652,15 @@ func purgeUnretainedPrewriteSnapshots( return errors.WithStack(combined) } +func retainedPrewriteSnapshotIndex(candidates []snapFileCandidate, retention prewriteSnapshotRetention, index uint64) bool { + for _, candidate := range candidates { + if candidate.index == index && retention.keep[candidate.name] { + return true + } + } + return false +} + func removePrewriteFSMOrphansBeforeIndex( snapDir string, fsmSnapDir string, @@ -683,7 +707,7 @@ func prewriteCandidatesHaveWALFilter(candidates []snapFileCandidate) bool { var prewriteWALSnapshotIndexes = loadPrewriteWALSnapshotIndexes -func loadPrewriteWALSnapshotIndexes(snapDir string) (map[uint64]bool, error) { +func loadPrewriteWALSnapshotIndexes(snapDir string) (map[walSnapshotKey]bool, error) { walDir := filepath.Join(filepath.Dir(snapDir), walDirName) if !wal.Exist(walDir) { return nil, nil @@ -692,27 +716,32 @@ func loadPrewriteWALSnapshotIndexes(snapDir string) (map[uint64]bool, error) { if err != nil { return nil, errors.WithStack(err) } - indexes := make(map[uint64]bool, len(walSnaps)) + indexes := make(map[walSnapshotKey]bool, len(walSnaps)) for _, snap := range walSnaps { if snap.Index > 0 { - indexes[snap.Index] = true + indexes[walSnapshotKey{term: snap.Term, index: snap.Index}] = true } } return indexes, nil } +type walSnapshotKey struct { + term uint64 + index uint64 +} + func collectPrewriteSnapCandidates( entries []os.DirEntry, fsmSnapDir string, nextIndex uint64, - walValidIndexes map[uint64]bool, + walValidIndexes map[walSnapshotKey]bool, ) []snapFileCandidate { candidates := make([]snapFileCandidate, 0, len(entries)) for _, e := range entries { if e.IsDir() || filepath.Ext(e.Name()) != snapFileExt { continue } - index := parseSnapFileIndex(e.Name()) + term, index := parseSnapFileTermIndex(e.Name()) if index == 0 || index >= nextIndex { continue } @@ -720,7 +749,7 @@ func collectPrewriteSnapCandidates( name: e.Name(), index: index, restorable: fsmSnapshotFileRestorable(fsmSnapDir, index), - walValid: walValidIndexes == nil || walValidIndexes[index], + walValid: walValidIndexes == nil || walValidIndexes[walSnapshotKey{term: term, index: index}], }) } return candidates @@ -904,9 +933,8 @@ func purgeSnapPair(snapDir, fsmSnapDir, snapName string) error { idx := parseSnapFileIndex(snapName) // Remove the .snap file first; skip .fsm removal if snap removal fails. - snapPath := filepath.Join(snapDir, snapName) - if removeErr := os.Remove(snapPath); removeErr != nil && !os.IsNotExist(removeErr) { - return errors.WithStack(removeErr) + if err := purgeSnapFile(snapDir, snapName); err != nil { + return err } // Remove the corresponding .fsm file second. @@ -921,6 +949,14 @@ func purgeSnapPair(snapDir, fsmSnapDir, snapName string) error { return nil } +func purgeSnapFile(snapDir, snapName string) error { + snapPath := filepath.Join(snapDir, snapName) + if removeErr := os.Remove(snapPath); removeErr != nil && !os.IsNotExist(removeErr) { + return errors.WithStack(removeErr) + } + return nil +} + func syncDirIfExists(dir string) error { if dir == "" { return nil diff --git a/internal/raftengine/etcd/fsm_snapshot_file_test.go b/internal/raftengine/etcd/fsm_snapshot_file_test.go index 89499a04d..09062f67b 100644 --- a/internal/raftengine/etcd/fsm_snapshot_file_test.go +++ b/internal/raftengine/etcd/fsm_snapshot_file_test.go @@ -3,6 +3,7 @@ package etcd import ( "bytes" "encoding/binary" + "fmt" "hash/crc32" "io" "math" @@ -377,8 +378,8 @@ func TestPrepareFSMSnapshotWriteKeepsWALValidFallbackPair(t *testing.T) { writeFSMFileForTest(t, fsmSnapDir, 200, payload) oldPrewriteWALSnapshotIndexes := prewriteWALSnapshotIndexes - prewriteWALSnapshotIndexes = func(string) (map[uint64]bool, error) { - return map[uint64]bool{100: true}, nil + prewriteWALSnapshotIndexes = func(string) (map[walSnapshotKey]bool, error) { + return map[walSnapshotKey]bool{{term: 1, index: 100}: true}, nil } t.Cleanup(func() { prewriteWALSnapshotIndexes = oldPrewriteWALSnapshotIndexes @@ -392,6 +393,30 @@ func TestPrepareFSMSnapshotWriteKeepsWALValidFallbackPair(t *testing.T) { require.NoFileExists(t, fsmSnapPath(fsmSnapDir, 200)) } +func TestPrepareFSMSnapshotWriteKeepsWALTermMatchingFallbackPair(t *testing.T) { + snapDir := t.TempDir() + fsmSnapDir := t.TempDir() + payload := []byte("payload") + + createSnapFileWithTerm(t, snapDir, 1, 100) + createSnapFileWithTerm(t, snapDir, 2, 100) + writeFSMFileForTest(t, fsmSnapDir, 100, payload) + + oldPrewriteWALSnapshotIndexes := prewriteWALSnapshotIndexes + prewriteWALSnapshotIndexes = func(string) (map[walSnapshotKey]bool, error) { + return map[walSnapshotKey]bool{{term: 1, index: 100}: true}, nil + } + t.Cleanup(func() { + prewriteWALSnapshotIndexes = oldPrewriteWALSnapshotIndexes + }) + + require.NoError(t, prepareFSMSnapshotWrite(snapDir, fsmSnapDir, 200)) + + require.FileExists(t, filepath.Join(snapDir, "0000000000000001-0000000000000064.snap")) + require.NoFileExists(t, filepath.Join(snapDir, "0000000000000002-0000000000000064.snap")) + require.FileExists(t, fsmSnapPath(fsmSnapDir, 100)) +} + func TestPrepareFSMSnapshotWriteRemovesOrphansWithoutRetainedSnapshot(t *testing.T) { snapDir := t.TempDir() fsmSnapDir := t.TempDir() @@ -407,6 +432,13 @@ func TestPrepareFSMSnapshotWriteRemovesOrphansWithoutRetainedSnapshot(t *testing require.FileExists(t, fsmSnapPath(fsmSnapDir, 200)) } +func createSnapFileWithTerm(t *testing.T, dir string, term uint64, index uint64) { + t.Helper() + name := fmt.Sprintf("%016x-%016x.snap", term, index) + path := filepath.Join(dir, name) + require.NoError(t, os.WriteFile(path, []byte("fake"), 0o600)) +} + func TestPrepareFSMSnapshotWritePreservesProtectedReceivedFSM(t *testing.T) { snapDir := t.TempDir() fsmSnapDir := t.TempDir() diff --git a/main.go b/main.go index bb77cb01d..742af112d 100644 --- a/main.go +++ b/main.go @@ -540,7 +540,7 @@ func startRaftEngineLifecycleWatchers(ctx context.Context, eg *errgroup.Group, r if err := lifecycle.Err(); err != nil { return errors.Wrapf(err, "raft group %d engine stopped", groupID) } - return errors.Errorf("raft group %d engine stopped", groupID) + return nil } }) } diff --git a/main_raft_lifecycle_test.go b/main_raft_lifecycle_test.go index a7f982741..7d1bcdb06 100644 --- a/main_raft_lifecycle_test.go +++ b/main_raft_lifecycle_test.go @@ -34,6 +34,18 @@ func TestStartRaftEngineLifecycleWatchersReportsEngineFailure(t *testing.T) { require.Contains(t, err.Error(), "raft group 7 engine stopped") } +func TestStartRaftEngineLifecycleWatchersIgnoresCleanEngineStop(t *testing.T) { + t.Parallel() + engine := &lifecycleEngineStub{done: make(chan struct{})} + runtimes := []*raftGroupRuntime{{spec: groupSpec{id: 9}, engine: engine}} + + eg, ctx := errgroup.WithContext(context.Background()) + startRaftEngineLifecycleWatchers(ctx, eg, runtimes) + close(engine.done) + + require.NoError(t, eg.Wait()) +} + func TestStartRaftEngineLifecycleWatchersIgnoresContextCancellation(t *testing.T) { t.Parallel() engine := &lifecycleEngineStub{done: make(chan struct{})} From 5d1f5f0ac3b9d2990489f607500f5b0518bcaff8 Mon Sep 17 00:00:00 2001 From: bootjp Date: Fri, 3 Jul 2026 01:30:58 +0900 Subject: [PATCH 11/14] raft: verify retained snapshot tokens --- internal/raftengine/etcd/fsm_snapshot_file.go | 35 +++++++++-- .../raftengine/etcd/fsm_snapshot_file_test.go | 63 ++++++++++++++----- 2 files changed, 77 insertions(+), 21 deletions(-) diff --git a/internal/raftengine/etcd/fsm_snapshot_file.go b/internal/raftengine/etcd/fsm_snapshot_file.go index 49f66fb85..8b6ffcb31 100644 --- a/internal/raftengine/etcd/fsm_snapshot_file.go +++ b/internal/raftengine/etcd/fsm_snapshot_file.go @@ -15,6 +15,7 @@ import ( "strings" "github.com/cockroachdb/errors" + etcdsnap "go.etcd.io/etcd/server/v3/etcdserver/api/snap" "go.etcd.io/etcd/server/v3/storage/wal" "go.uber.org/zap" ) @@ -354,6 +355,10 @@ func restoreAndComputeCRC(f *os.File, fileSize int64, fsm StateMachine) (uint32, // verifyFSMSnapshotFile performs a read-only CRC check without restoring the FSM. // Used for startup orphan detection. Pass tokenCRC=0 to skip the token comparison. func verifyFSMSnapshotFile(path string, tokenCRC uint32) error { + return verifyFSMSnapshotFileWithToken(path, tokenCRC, tokenCRC != 0) +} + +func verifyFSMSnapshotFileWithToken(path string, tokenCRC uint32, checkToken bool) error { // Open before stat to eliminate the TOCTOU window between path lookup // and file open (consistent with openAndRestoreFSMSnapshot). f, err := os.Open(path) @@ -386,7 +391,7 @@ func verifyFSMSnapshotFile(path string, tokenCRC uint32) error { return errors.Wrapf(ErrFSMSnapshotFileCRC, "path=%s footer=%08x computed=%08x", path, footer, computed) } - if tokenCRC != 0 && computed != tokenCRC { + if checkToken && computed != tokenCRC { return errors.Wrapf(ErrFSMSnapshotTokenCRC, "path=%s footer=%08x token=%08x", path, footer, tokenCRC) } @@ -606,7 +611,7 @@ func purgeOlderSnapshotPairsBeforeWrite( if err != nil { return errors.WithStack(err) } - candidates := collectPrewriteSnapCandidates(entries, fsmSnapDir, nextIndex, walValidIndexes) + candidates := collectPrewriteSnapCandidates(entries, snapDir, fsmSnapDir, nextIndex, walValidIndexes) sort.Slice(candidates, func(i, j int) bool { if candidates[i].index == candidates[j].index { return candidates[i].name < candidates[j].name @@ -732,6 +737,7 @@ type walSnapshotKey struct { func collectPrewriteSnapCandidates( entries []os.DirEntry, + snapDir string, fsmSnapDir string, nextIndex uint64, walValidIndexes map[walSnapshotKey]bool, @@ -748,18 +754,37 @@ func collectPrewriteSnapCandidates( candidates = append(candidates, snapFileCandidate{ name: e.Name(), index: index, - restorable: fsmSnapshotFileRestorable(fsmSnapDir, index), + restorable: fsmSnapshotPairRestorable(snapDir, fsmSnapDir, e.Name(), term, index), walValid: walValidIndexes == nil || walValidIndexes[walSnapshotKey{term: term, index: index}], }) } return candidates } -func fsmSnapshotFileRestorable(fsmSnapDir string, index uint64) bool { +func fsmSnapshotPairRestorable(snapDir, fsmSnapDir, snapName string, term, index uint64) bool { if fsmSnapDir == "" { return false } - return verifyFSMSnapshotFile(fsmSnapPath(fsmSnapDir, index), 0) == nil + tok, ok := snapshotTokenFromSnapFile(snapDir, snapName, term, index) + if !ok { + return false + } + return verifyFSMSnapshotFileWithToken(fsmSnapPath(fsmSnapDir, index), tok.CRC32C, true) == nil +} + +func snapshotTokenFromSnapFile(snapDir, snapName string, term, index uint64) (snapshotToken, bool) { + snapshot, err := etcdsnap.Read(zap.NewNop(), filepath.Join(snapDir, snapName)) + if err != nil { + return snapshotToken{}, false + } + if snapshot.Metadata.Term != term || snapshot.Metadata.Index != index || !isSnapshotToken(snapshot.Data) { + return snapshotToken{}, false + } + tok, err := decodeSnapshotToken(snapshot.Data) + if err != nil || tok.Index != index { + return snapshotToken{}, false + } + return tok, true } func removeStaleFSMFilesBelowIndex( diff --git a/internal/raftengine/etcd/fsm_snapshot_file_test.go b/internal/raftengine/etcd/fsm_snapshot_file_test.go index 09062f67b..2307c7ff0 100644 --- a/internal/raftengine/etcd/fsm_snapshot_file_test.go +++ b/internal/raftengine/etcd/fsm_snapshot_file_test.go @@ -3,7 +3,6 @@ package etcd import ( "bytes" "encoding/binary" - "fmt" "hash/crc32" "io" "math" @@ -12,6 +11,9 @@ import ( "testing" "github.com/stretchr/testify/require" + etcdsnap "go.etcd.io/etcd/server/v3/etcdserver/api/snap" + raftpb "go.etcd.io/raft/v3/raftpb" + "go.uber.org/zap" ) // --- Token encode/decode tests --- @@ -339,12 +341,12 @@ func TestPrepareFSMSnapshotWriteKeepsNewestRestorablePair(t *testing.T) { payload := []byte("payload") for _, index := range []uint64{100, 200} { - createSnapFile(t, snapDir, index) - writeFSMFileForTest(t, fsmSnapDir, index, payload) + crc, _ := writeFSMFileForTest(t, fsmSnapDir, index, payload) + createTokenSnapFileWithTerm(t, snapDir, 1, index, crc) } - createSnapFile(t, snapDir, 300) + createTokenSnapFileWithTerm(t, snapDir, 1, 300, 0x12345678) require.NoError(t, os.WriteFile(fsmSnapPath(fsmSnapDir, 300), []byte{0x01, 0x02}, 0o600)) - createSnapFile(t, snapDir, 350) + createTokenSnapFileWithTerm(t, snapDir, 1, 350, 0x87654321) writeFSMFileForTest(t, fsmSnapDir, 150, payload) writeFSMFileForTest(t, fsmSnapDir, 250, payload) writeFSMFileForTest(t, fsmSnapDir, 500, payload) @@ -372,10 +374,10 @@ func TestPrepareFSMSnapshotWriteKeepsWALValidFallbackPair(t *testing.T) { fsmSnapDir := t.TempDir() payload := []byte("payload") - createSnapFile(t, snapDir, 100) - writeFSMFileForTest(t, fsmSnapDir, 100, payload) - createSnapFile(t, snapDir, 200) - writeFSMFileForTest(t, fsmSnapDir, 200, payload) + crc100, _ := writeFSMFileForTest(t, fsmSnapDir, 100, payload) + createTokenSnapFileWithTerm(t, snapDir, 1, 100, crc100) + crc200, _ := writeFSMFileForTest(t, fsmSnapDir, 200, payload) + createTokenSnapFileWithTerm(t, snapDir, 1, 200, crc200) oldPrewriteWALSnapshotIndexes := prewriteWALSnapshotIndexes prewriteWALSnapshotIndexes = func(string) (map[walSnapshotKey]bool, error) { @@ -398,9 +400,9 @@ func TestPrepareFSMSnapshotWriteKeepsWALTermMatchingFallbackPair(t *testing.T) { fsmSnapDir := t.TempDir() payload := []byte("payload") - createSnapFileWithTerm(t, snapDir, 1, 100) - createSnapFileWithTerm(t, snapDir, 2, 100) - writeFSMFileForTest(t, fsmSnapDir, 100, payload) + crc, _ := writeFSMFileForTest(t, fsmSnapDir, 100, payload) + createTokenSnapFileWithTerm(t, snapDir, 1, 100, crc) + createTokenSnapFileWithTerm(t, snapDir, 2, 100, crc) oldPrewriteWALSnapshotIndexes := prewriteWALSnapshotIndexes prewriteWALSnapshotIndexes = func(string) (map[walSnapshotKey]bool, error) { @@ -417,6 +419,33 @@ func TestPrepareFSMSnapshotWriteKeepsWALTermMatchingFallbackPair(t *testing.T) { require.FileExists(t, fsmSnapPath(fsmSnapDir, 100)) } +func TestPrepareFSMSnapshotWriteKeepsTokenMatchingFallbackPair(t *testing.T) { + snapDir := t.TempDir() + fsmSnapDir := t.TempDir() + payload := []byte("payload") + + crc, _ := writeFSMFileForTest(t, fsmSnapDir, 100, payload) + createTokenSnapFileWithTerm(t, snapDir, 1, 100, crc) + createTokenSnapFileWithTerm(t, snapDir, 2, 100, crc^0xffffffff) + + oldPrewriteWALSnapshotIndexes := prewriteWALSnapshotIndexes + prewriteWALSnapshotIndexes = func(string) (map[walSnapshotKey]bool, error) { + return map[walSnapshotKey]bool{ + {term: 1, index: 100}: true, + {term: 2, index: 100}: true, + }, nil + } + t.Cleanup(func() { + prewriteWALSnapshotIndexes = oldPrewriteWALSnapshotIndexes + }) + + require.NoError(t, prepareFSMSnapshotWrite(snapDir, fsmSnapDir, 200)) + + require.FileExists(t, filepath.Join(snapDir, "0000000000000001-0000000000000064.snap")) + require.NoFileExists(t, filepath.Join(snapDir, "0000000000000002-0000000000000064.snap")) + require.FileExists(t, fsmSnapPath(fsmSnapDir, 100)) +} + func TestPrepareFSMSnapshotWriteRemovesOrphansWithoutRetainedSnapshot(t *testing.T) { snapDir := t.TempDir() fsmSnapDir := t.TempDir() @@ -432,11 +461,13 @@ func TestPrepareFSMSnapshotWriteRemovesOrphansWithoutRetainedSnapshot(t *testing require.FileExists(t, fsmSnapPath(fsmSnapDir, 200)) } -func createSnapFileWithTerm(t *testing.T, dir string, term uint64, index uint64) { +func createTokenSnapFileWithTerm(t *testing.T, dir string, term uint64, index uint64, crc32c uint32) { t.Helper() - name := fmt.Sprintf("%016x-%016x.snap", term, index) - path := filepath.Join(dir, name) - require.NoError(t, os.WriteFile(path, []byte("fake"), 0o600)) + snapshot := raftpb.Snapshot{ + Metadata: raftpb.SnapshotMetadata{Term: term, Index: index}, + Data: encodeSnapshotToken(index, crc32c), + } + require.NoError(t, etcdsnap.New(zap.NewNop(), dir).SaveSnap(snapshot)) } func TestPrepareFSMSnapshotWritePreservesProtectedReceivedFSM(t *testing.T) { From 3c5e99427426ddcd82c910e9c43478424cda7fce Mon Sep 17 00:00:00 2001 From: bootjp Date: Fri, 3 Jul 2026 01:35:14 +0900 Subject: [PATCH 12/14] raft: prefer restorable snapshot fallback --- internal/raftengine/etcd/fsm_snapshot_file.go | 30 ++++++++++++++----- .../raftengine/etcd/fsm_snapshot_file_test.go | 24 +++++++++++++++ 2 files changed, 47 insertions(+), 7 deletions(-) diff --git a/internal/raftengine/etcd/fsm_snapshot_file.go b/internal/raftengine/etcd/fsm_snapshot_file.go index 8b6ffcb31..7a12c814f 100644 --- a/internal/raftengine/etcd/fsm_snapshot_file.go +++ b/internal/raftengine/etcd/fsm_snapshot_file.go @@ -687,13 +687,13 @@ func keepRestorablePrewriteSnapshots(candidates []snapFileCandidate) prewriteSna keep: make(map[string]bool, prewriteSnapKeep), } walFiltered := prewriteCandidatesHaveWALFilter(candidates) - for i := len(candidates) - 1; i >= 0 && len(retention.keep) < prewriteSnapKeep; i-- { - if candidates[i].restorable && (!walFiltered || candidates[i].walValid) { - retention.keep[candidates[i].name] = true - if retention.restorableFloor == 0 || candidates[i].index < retention.restorableFloor { - retention.restorableFloor = candidates[i].index - } - } + keepNewestMatchingPrewriteSnapshots(candidates, &retention, func(candidate snapFileCandidate) bool { + return candidate.restorable && (!walFiltered || candidate.walValid) + }) + if len(retention.keep) == 0 { + keepNewestMatchingPrewriteSnapshots(candidates, &retention, func(candidate snapFileCandidate) bool { + return candidate.restorable + }) } if len(retention.keep) == 0 && len(candidates) > 0 { retention.keep[candidates[len(candidates)-1].name] = true @@ -701,6 +701,22 @@ func keepRestorablePrewriteSnapshots(candidates []snapFileCandidate) prewriteSna return retention } +func keepNewestMatchingPrewriteSnapshots( + candidates []snapFileCandidate, + retention *prewriteSnapshotRetention, + match func(snapFileCandidate) bool, +) { + for i := len(candidates) - 1; i >= 0 && len(retention.keep) < prewriteSnapKeep; i-- { + if !match(candidates[i]) { + continue + } + retention.keep[candidates[i].name] = true + if retention.restorableFloor == 0 || candidates[i].index < retention.restorableFloor { + retention.restorableFloor = candidates[i].index + } + } +} + func prewriteCandidatesHaveWALFilter(candidates []snapFileCandidate) bool { for _, candidate := range candidates { if candidate.walValid { diff --git a/internal/raftengine/etcd/fsm_snapshot_file_test.go b/internal/raftengine/etcd/fsm_snapshot_file_test.go index 2307c7ff0..8d0785d7d 100644 --- a/internal/raftengine/etcd/fsm_snapshot_file_test.go +++ b/internal/raftengine/etcd/fsm_snapshot_file_test.go @@ -446,6 +446,30 @@ func TestPrepareFSMSnapshotWriteKeepsTokenMatchingFallbackPair(t *testing.T) { require.FileExists(t, fsmSnapPath(fsmSnapDir, 100)) } +func TestPrepareFSMSnapshotWriteKeepsRestorablePairWhenWALValidCandidateIsBroken(t *testing.T) { + snapDir := t.TempDir() + fsmSnapDir := t.TempDir() + payload := []byte("payload") + + crc, _ := writeFSMFileForTest(t, fsmSnapDir, 100, payload) + createTokenSnapFileWithTerm(t, snapDir, 1, 100, crc) + createTokenSnapFileWithTerm(t, snapDir, 1, 200, 0x12345678) + + oldPrewriteWALSnapshotIndexes := prewriteWALSnapshotIndexes + prewriteWALSnapshotIndexes = func(string) (map[walSnapshotKey]bool, error) { + return map[walSnapshotKey]bool{{term: 1, index: 200}: true}, nil + } + t.Cleanup(func() { + prewriteWALSnapshotIndexes = oldPrewriteWALSnapshotIndexes + }) + + require.NoError(t, prepareFSMSnapshotWrite(snapDir, fsmSnapDir, 300)) + + require.FileExists(t, filepath.Join(snapDir, "0000000000000001-0000000000000064.snap")) + require.NoFileExists(t, filepath.Join(snapDir, "0000000000000001-00000000000000c8.snap")) + require.FileExists(t, fsmSnapPath(fsmSnapDir, 100)) +} + func TestPrepareFSMSnapshotWriteRemovesOrphansWithoutRetainedSnapshot(t *testing.T) { snapDir := t.TempDir() fsmSnapDir := t.TempDir() From f24c7312155a9a1071c2d4336fa4cfe7804b0c4b Mon Sep 17 00:00:00 2001 From: bootjp Date: Fri, 3 Jul 2026 01:49:50 +0900 Subject: [PATCH 13/14] raft: preserve wal loadable fallback snapshots --- internal/raftengine/etcd/fsm_snapshot_file.go | 24 +++++++++++++------ .../raftengine/etcd/fsm_snapshot_file_test.go | 4 ++-- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/internal/raftengine/etcd/fsm_snapshot_file.go b/internal/raftengine/etcd/fsm_snapshot_file.go index 7a12c814f..75f2a48a2 100644 --- a/internal/raftengine/etcd/fsm_snapshot_file.go +++ b/internal/raftengine/etcd/fsm_snapshot_file.go @@ -687,11 +687,20 @@ func keepRestorablePrewriteSnapshots(candidates []snapFileCandidate) prewriteSna keep: make(map[string]bool, prewriteSnapKeep), } walFiltered := prewriteCandidatesHaveWALFilter(candidates) - keepNewestMatchingPrewriteSnapshots(candidates, &retention, func(candidate snapFileCandidate) bool { - return candidate.restorable && (!walFiltered || candidate.walValid) - }) - if len(retention.keep) == 0 { - keepNewestMatchingPrewriteSnapshots(candidates, &retention, func(candidate snapFileCandidate) bool { + if walFiltered { + keepNewestMatchingPrewriteSnapshots(candidates, &retention, prewriteSnapKeep, func(candidate snapFileCandidate) bool { + return candidate.restorable && candidate.walValid + }) + if len(retention.keep) == 0 { + keepNewestMatchingPrewriteSnapshots(candidates, &retention, prewriteSnapKeep, func(candidate snapFileCandidate) bool { + return candidate.walValid + }) + keepNewestMatchingPrewriteSnapshots(candidates, &retention, prewriteSnapKeep+1, func(candidate snapFileCandidate) bool { + return candidate.restorable + }) + } + } else { + keepNewestMatchingPrewriteSnapshots(candidates, &retention, prewriteSnapKeep, func(candidate snapFileCandidate) bool { return candidate.restorable }) } @@ -704,10 +713,11 @@ func keepRestorablePrewriteSnapshots(candidates []snapFileCandidate) prewriteSna func keepNewestMatchingPrewriteSnapshots( candidates []snapFileCandidate, retention *prewriteSnapshotRetention, + maxKeep int, match func(snapFileCandidate) bool, ) { - for i := len(candidates) - 1; i >= 0 && len(retention.keep) < prewriteSnapKeep; i-- { - if !match(candidates[i]) { + for i := len(candidates) - 1; i >= 0 && len(retention.keep) < maxKeep; i-- { + if retention.keep[candidates[i].name] || !match(candidates[i]) { continue } retention.keep[candidates[i].name] = true diff --git a/internal/raftengine/etcd/fsm_snapshot_file_test.go b/internal/raftengine/etcd/fsm_snapshot_file_test.go index 8d0785d7d..29f92f103 100644 --- a/internal/raftengine/etcd/fsm_snapshot_file_test.go +++ b/internal/raftengine/etcd/fsm_snapshot_file_test.go @@ -446,7 +446,7 @@ func TestPrepareFSMSnapshotWriteKeepsTokenMatchingFallbackPair(t *testing.T) { require.FileExists(t, fsmSnapPath(fsmSnapDir, 100)) } -func TestPrepareFSMSnapshotWriteKeepsRestorablePairWhenWALValidCandidateIsBroken(t *testing.T) { +func TestPrepareFSMSnapshotWriteKeepsWALValidAndRestorableFallbacksWhenWALValidCandidateIsBroken(t *testing.T) { snapDir := t.TempDir() fsmSnapDir := t.TempDir() payload := []byte("payload") @@ -466,7 +466,7 @@ func TestPrepareFSMSnapshotWriteKeepsRestorablePairWhenWALValidCandidateIsBroken require.NoError(t, prepareFSMSnapshotWrite(snapDir, fsmSnapDir, 300)) require.FileExists(t, filepath.Join(snapDir, "0000000000000001-0000000000000064.snap")) - require.NoFileExists(t, filepath.Join(snapDir, "0000000000000001-00000000000000c8.snap")) + require.FileExists(t, filepath.Join(snapDir, "0000000000000001-00000000000000c8.snap")) require.FileExists(t, fsmSnapPath(fsmSnapDir, 100)) } From 282304e93fd1d9f79cc73fa1c327a8599ebd3241 Mon Sep 17 00:00:00 2001 From: bootjp Date: Fri, 3 Jul 2026 02:10:41 +0900 Subject: [PATCH 14/14] raft: reject stale received snapshots --- internal/raftengine/etcd/engine.go | 9 +-- .../etcd/engine_applied_index_test.go | 34 ++++++++-- internal/raftengine/etcd/grpc_transport.go | 47 +++++++------ .../raftengine/etcd/grpc_transport_test.go | 67 +++++++++++++++++-- 4 files changed, 123 insertions(+), 34 deletions(-) diff --git a/internal/raftengine/etcd/engine.go b/internal/raftengine/etcd/engine.go index 2bc14a027..76eb79c38 100644 --- a/internal/raftengine/etcd/engine.go +++ b/internal/raftengine/etcd/engine.go @@ -2832,19 +2832,20 @@ func (e *Engine) prepareFSMSnapshotWrite(index uint64) error { return prepareFSMSnapshotWriteProtected(snapDir, e.fsmSnapDir, index, e.protectedReceivedFSMSnapshotIndexesLocked()) } -func (e *Engine) protectReceivedFSMSnapshot(index uint64) { - if index == 0 || index <= e.appliedIndex.Load() { - return +func (e *Engine) protectReceivedFSMSnapshot(index uint64) bool { + if index == 0 { + return false } e.snapshotMu.Lock() defer e.snapshotMu.Unlock() if index <= e.appliedIndex.Load() { - return + return false } if e.protectedReceivedFSMSnaps == nil { e.protectedReceivedFSMSnaps = make(map[uint64]int, 1) } e.protectedReceivedFSMSnaps[index]++ + return true } func (e *Engine) unprotectReceivedFSMSnapshot(index uint64) { diff --git a/internal/raftengine/etcd/engine_applied_index_test.go b/internal/raftengine/etcd/engine_applied_index_test.go index 123425f55..2cd88b6c3 100644 --- a/internal/raftengine/etcd/engine_applied_index_test.go +++ b/internal/raftengine/etcd/engine_applied_index_test.go @@ -157,10 +157,9 @@ func TestPersistReadyWithSnapshotHoldsSnapshotMuThroughSaveSnap(t *testing.T) { func TestProtectReceivedFSMSnapshotRechecksAppliedIndexUnderLock(t *testing.T) { e := &Engine{} e.snapshotMu.Lock() - done := make(chan struct{}) + accepted := make(chan bool, 1) go func() { - defer close(done) - e.protectReceivedFSMSnapshot(9) + accepted <- e.protectReceivedFSMSnapshot(9) }() time.Sleep(10 * time.Millisecond) @@ -168,7 +167,34 @@ func TestProtectReceivedFSMSnapshotRechecksAppliedIndexUnderLock(t *testing.T) { e.snapshotMu.Unlock() select { - case <-done: + case got := <-accepted: + require.False(t, got) + case <-time.After(time.Second): + t.Fatal("protectReceivedFSMSnapshot did not return") + } + require.Empty(t, e.protectedReceivedFSMSnaps) +} + +func TestProtectReceivedFSMSnapshotWaitsOnSnapshotMuForAlreadyAppliedIndex(t *testing.T) { + e := &Engine{} + e.appliedIndex.Store(9) + e.snapshotMu.Lock() + accepted := make(chan bool, 1) + go func() { + accepted <- e.protectReceivedFSMSnapshot(9) + }() + + select { + case got := <-accepted: + t.Fatalf("protectReceivedFSMSnapshot returned before snapshotMu was released: %v", got) + case <-time.After(100 * time.Millisecond): + } + + e.snapshotMu.Unlock() + + select { + case got := <-accepted: + require.False(t, got) case <-time.After(time.Second): t.Fatal("protectReceivedFSMSnapshot did not return") } diff --git a/internal/raftengine/etcd/grpc_transport.go b/internal/raftengine/etcd/grpc_transport.go index 72e428837..3ae886b64 100644 --- a/internal/raftengine/etcd/grpc_transport.go +++ b/internal/raftengine/etcd/grpc_transport.go @@ -34,6 +34,7 @@ var ( errSnapshotMetadataDuplicate = errors.New("etcd raft snapshot metadata was sent more than once") errSnapshotMessageNil = errors.New("etcd raft snapshot message is required") errSnapshotStreamShort = errors.New("etcd raft snapshot stream closed before final chunk") + errReceivedFSMSnapshotStale = errors.New("etcd raft received fsm snapshot is stale") ) var grpcNewClient = grpc.NewClient @@ -52,7 +53,7 @@ type GRPCTransport struct { spoolDir string fsmSnapDir string prepareFSMWrite func(index uint64) error - protectFSMWrite func(index uint64) + protectFSMWrite func(index uint64) bool unprotectFSMWrite func(index uint64) // readFSMPayload is the fallback bridge callback that materialises the full // FSM payload into memory. Used only when openFSMPayload is not set. @@ -133,7 +134,7 @@ func (t *GRPCTransport) SetFSMSnapshotPrepare(fn func(index uint64) error) { t.prepareFSMWrite = fn } -func (t *GRPCTransport) SetFSMSnapshotProtection(protectFn, unprotectFn func(index uint64)) { +func (t *GRPCTransport) SetFSMSnapshotProtection(protectFn func(index uint64) bool, unprotectFn func(index uint64)) { if t == nil { return } @@ -763,7 +764,7 @@ func (t *GRPCTransport) snapshotSpoolPlacement() ( placement string, fsmSnapDir string, prepareFn func(uint64) error, - protectFn func(uint64), + protectFn func(uint64) bool, unprotectFn func(uint64), ) { t.mu.RLock() @@ -841,7 +842,7 @@ func drainSnapshotChunks( spool *snapshotSpool, fsmSnapDir string, prepareFn func(uint64) error, - protectFn func(uint64), + protectFn func(uint64) bool, unprotectFn func(uint64), ) (raftpb.Message, int64, error) { var metadata raftpb.Message @@ -853,7 +854,7 @@ func drainSnapshotChunksFrom( spool *snapshotSpool, fsmSnapDir string, prepareFn func(uint64) error, - protectFn func(uint64), + protectFn func(uint64) bool, unprotectFn func(uint64), metadata raftpb.Message, firstPayloadChunk *pb.EtcdRaftSnapshotChunk, @@ -970,7 +971,7 @@ func finalizeReceivedSnapshot( spool *snapshotSpool, crc32c uint32, fsmSnapDir string, - protectFn func(uint64), + protectFn func(uint64) bool, unprotectFn func(uint64), seenMetadata bool, ) (raftpb.Message, error) { @@ -978,25 +979,27 @@ func finalizeReceivedSnapshot( return raftpb.Message{}, errors.WithStack(errSnapshotMetadataNil) } index := metadata.Snapshot.Metadata.Index - if fsmSnapDir != "" && index > 0 { - protected := false - if protectFn != nil { - protectFn(index) - protected = true + if fsmSnapDir == "" || index == 0 { + // Legacy fallback: full materialization. Used by tests that don't wire an + // fsmSnapDir and by the index=0 edge case (no canonical filename to + // rename to). + return buildSnapshotMessage(metadata, spool, seenMetadata) + } + protected := false + if protectFn != nil { + if !protectFn(index) { + return raftpb.Message{}, errors.WithStack(errReceivedFSMSnapshotStale) } - if err := spool.FinalizeAsFSMFile(fsmSnapDir, index, crc32c); err != nil { - if protected && unprotectFn != nil { - unprotectFn(index) - } - return raftpb.Message{}, err + protected = true + } + if err := spool.FinalizeAsFSMFile(fsmSnapDir, index, crc32c); err != nil { + if protected && unprotectFn != nil { + unprotectFn(index) } - metadata.Snapshot.Data = encodeSnapshotToken(index, crc32c) - return metadata, nil + return raftpb.Message{}, err } - // Legacy fallback: full materialization. Used by tests that don't wire an - // fsmSnapDir and by the index=0 edge case (no canonical filename to - // rename to). - return buildSnapshotMessage(metadata, spool, seenMetadata) + metadata.Snapshot.Data = encodeSnapshotToken(index, crc32c) + return metadata, nil } func maybePrepareReceivedFSMSnapshotWrite( diff --git a/internal/raftengine/etcd/grpc_transport_test.go b/internal/raftengine/etcd/grpc_transport_test.go index 9c670862a..a2429c2d1 100644 --- a/internal/raftengine/etcd/grpc_transport_test.go +++ b/internal/raftengine/etcd/grpc_transport_test.go @@ -295,7 +295,10 @@ func TestSendSnapshotProtectsFinalizedFSMFileUntilEngineRelease(t *testing.T) { transport.SetSpoolDir(t.TempDir()) transport.SetFSMSnapDir(fsmSnapDir) transport.SetFSMSnapshotProtection( - func(index uint64) { protected = append(protected, index) }, + func(index uint64) bool { + protected = append(protected, index) + return true + }, func(index uint64) { unprotected = append(unprotected, index) }, ) transport.SetHandler(func(_ context.Context, msg raftpb.Message) error { @@ -338,10 +341,11 @@ func TestDrainSnapshotChunksProtectsBeforePublishingFSMFile(t *testing.T) { require.NoError(t, spool.Close()) }) var protected []uint64 - protectFn := func(got uint64) { + protectFn := func(got uint64) bool { protected = append(protected, got) _, statErr := os.Stat(fsmSnapPath(fsmSnapDir, got)) require.True(t, os.IsNotExist(statErr), "protection must be registered before the final .fsm path is visible") + return true } stream := &testSendSnapshotServer{ chunks: []*pb.EtcdRaftSnapshotChunk{{ @@ -359,6 +363,55 @@ func TestDrainSnapshotChunksProtectsBeforePublishingFSMFile(t *testing.T) { require.FileExists(t, fsmSnapPath(fsmSnapDir, index)) } +func TestDrainSnapshotChunksRejectsStaleFSMProtection(t *testing.T) { + const index = uint64(127) + payload := []byte("stale payload must not be published") + metadata := raftpb.Message{ + Type: raftpb.MsgSnap, + From: 1, + To: 2, + Snapshot: &raftpb.Snapshot{ + Metadata: raftpb.SnapshotMetadata{Index: index, Term: 1}, + }, + } + raw, err := metadata.Marshal() + require.NoError(t, err) + + fsmSnapDir := t.TempDir() + spool, err := newSnapshotSpool(fsmSnapDir) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, spool.Close()) + }) + var protected []uint64 + var unprotected []uint64 + stream := &testSendSnapshotServer{ + chunks: []*pb.EtcdRaftSnapshotChunk{{ + Metadata: raw, + Chunk: payload, + Final: true, + }}, + } + + _, payloadBytes, err := drainSnapshotChunks( + stream, + spool, + fsmSnapDir, + func(uint64) error { return nil }, + func(got uint64) bool { + protected = append(protected, got) + return false + }, + func(got uint64) { unprotected = append(unprotected, got) }, + ) + require.Error(t, err) + require.ErrorIs(t, err, errReceivedFSMSnapshotStale) + require.Zero(t, payloadBytes) + require.Equal(t, []uint64{index}, protected) + require.Empty(t, unprotected) + require.NoFileExists(t, fsmSnapPath(fsmSnapDir, index)) +} + func TestDrainSnapshotChunksUnprotectsWhenFinalizeFails(t *testing.T) { const index = uint64(126) payload := []byte("payload whose final rename fails") @@ -400,7 +453,10 @@ func TestDrainSnapshotChunksUnprotectsWhenFinalizeFails(t *testing.T) { spool, fsmSnapDir, func(uint64) error { return nil }, - func(got uint64) { protected = append(protected, got) }, + func(got uint64) bool { + protected = append(protected, got) + return true + }, func(got uint64) { unprotected = append(unprotected, got) }, ) require.Error(t, err) @@ -591,7 +647,10 @@ func TestSendSnapshot_ApplyFailureRemovesFinalizedFSMFile(t *testing.T) { var protected []uint64 var unprotected []uint64 transport.SetFSMSnapshotProtection( - func(index uint64) { protected = append(protected, index) }, + func(index uint64) bool { + protected = append(protected, index) + return true + }, func(index uint64) { unprotected = append(unprotected, index) }, )