From 6d68cd76d234d9b62a7e38e211141d20030d8ffb Mon Sep 17 00:00:00 2001 From: Pekka Heikura Date: Wed, 10 Jun 2026 22:32:12 +0300 Subject: [PATCH 1/6] perf(cuda): #207 single-user spec-decode batched verify on the dense CUDA path CudaForwardPass.BatchVerify(tokens, startPos): one packed k-token pass over the owned cache at contiguous positions [P, P+k) reusing BatchForwardMulti's trunk (every row bound to the same cache; ragged append-then-attend is exactly packed causal attention), so every trunk matmul routes through BatchDecodeMatMul and the #194 WS kernels (or opt-in #201 decode MMQ) amortize weight HBM k x. SupportsBatchVerify gates on the dense batching-capable config + uncompacted cache; rollback stays TruncateTo(P + accepted). SpeculativeDecoder: replace the `is ForwardPass` CPU type-check with the IForwardPass.SupportsBatchVerify capability check (promoted to the interface with a default-throw BatchVerify; CPU ForwardPass opts in, TQ/gemma4 excluded). Kill-switch SHARPI_SPEC_BATCH_VERIFY=0 -> sequential fallback. Adds draft/verify/commit phase timing for bench reporting. CLI: --draft-model now supported with full CUDA offload of a dense model (draft loads on its OWN CudaBackend - graph state is per backend); both spec runners generalized to IForwardPass and prefill via Prefill instead of the per-token Forward loop. download-model.ps1 gains qwen3-0.6b (draft for the Qwen3-8B bench pair). Tests: pass-level BatchVerify-vs-sequential parity at k=4/6 (argmax + top-5 + maxAbs, SnapKV pinned off), rollback TruncateTo+commit oracle, e2e 48-token greedy parity (CUDA 8B target + CPU 0.6B draft, exact match), and scripted-mock routing/kill-switch/rejection coverage. CudaBatchForwardMultiTests + ContinuousBatchingTests stay green. Co-Authored-By: Claude Fable 5 --- scripts/download-model.ps1 | 11 +- src/SharpInference.Cli/RunCommand.cs | 72 +++-- src/SharpInference.Core/IForwardPass.cs | 27 +- src/SharpInference.Engine/CudaForwardPass.cs | 81 ++++++ src/SharpInference.Engine/ForwardPass.cs | 9 + .../SpeculativeDecoder.cs | 41 ++- .../CudaSpecBatchVerifyTests.cs | 266 ++++++++++++++++++ .../SpeculativeDecoderTests.cs | 138 +++++++++ 8 files changed, 605 insertions(+), 40 deletions(-) create mode 100644 tests/SharpInference.Tests.ForwardPass/CudaSpecBatchVerifyTests.cs diff --git a/scripts/download-model.ps1 b/scripts/download-model.ps1 index 16dcd5d..b20d7a5 100644 --- a/scripts/download-model.ps1 +++ b/scripts/download-model.ps1 @@ -29,7 +29,7 @@ .\download-model.ps1 -Model realesrgan-x4 # Real-ESRGAN x4plus upscaler (67 MB) #> param( - [ValidateSet("smollm2", "qwen3-8b", "olmoe-1b-7b", "llama31-70b", "qwen3-coder-30b-a3b", "qwen36-35b-a3b", + [ValidateSet("smollm2", "qwen3-8b", "qwen3-0.6b", "olmoe-1b-7b", "llama31-70b", "qwen3-coder-30b-a3b", "qwen36-35b-a3b", "qwen36-27b-mtp", "qwen36-27b-mtp-q5", "qwen36-35b-a3b-mtp", "carnice-35b-a3b-mtp", "gemma4-12b-qat", "gemma4-12b-q4km", "llama4-scout", "z-image-turbo", "z-image-turbo-q8", "realesrgan-x4")] @@ -50,6 +50,15 @@ $Models = @{ Size = "4.9 GB" Phase = "2b-3" } + # Qwen3-0.6B Q8_0 — speculative-decoding draft for Qwen3-8B (issue #207). Same + # tokenizer/vocab (151936) as Qwen3-8B; Q8_0 keeps draft quality high so the + # acceptance rate (and thus the spec-decode speedup) stays in the alpha 0.7-0.8 band. + "qwen3-0.6b" = @{ + Files = @("Qwen3-0.6B-Q8_0.gguf") + Urls = @("https://huggingface.co/Qwen/Qwen3-0.6B-GGUF/resolve/main/Qwen3-0.6B-Q8_0.gguf") + Size = "~0.6 GB" + Phase = "spec-decode draft (issue #207)" + } # Smallest MoE model that fits in 12 GB VRAM for full-offload kernel validation. # OLMoE arch (allenai) — 7B total params, 1B active, 64 experts × 8 active, softmax routing. # ModelGraph maps "olmoe" → NEOX RoPE, GQA, no shared expert. Used to validate diff --git a/src/SharpInference.Cli/RunCommand.cs b/src/SharpInference.Cli/RunCommand.cs index 02f90d6..8580c79 100644 --- a/src/SharpInference.Cli/RunCommand.cs +++ b/src/SharpInference.Cli/RunCommand.cs @@ -633,16 +633,21 @@ protected override int Execute(CommandContext context, Settings settings, Cancel }; var rng = settings.Seed >= 0 ? new Random(settings.Seed) : new Random(); - // Speculative decoding path (requires --draft-model and --temp 0) + // Speculative decoding path (requires --draft-model and --temp 0). Supported + // targets: pure CPU (-g 0) and full CUDA offload of a dense model (issue #207 — + // packed k-token verify via CudaForwardPass.BatchVerify). Vulkan and the partial- + // offload hybrids fall back to normal generation: without a batched verify, + // speculation costs k sequential target forwards per step and is never a win. if (settings.DraftModelPath is not null) { + bool cudaSpecTarget = gpuFwd is CudaForwardPass { SupportsBatchVerify: true }; if (settings.Temperature > 0f) { AnsiConsole.MarkupLine("[yellow]Warning:[/] Speculative decoding requires greedy sampling (--temp 0). Falling back to normal generation."); } - else if (nGpuLayers != 0) + else if (nGpuLayers != 0 && !cudaSpecTarget) { - AnsiConsole.MarkupLine("[yellow]Warning:[/] Speculative decoding is only supported for CPU (--n-gpu-layers 0). Falling back to normal generation."); + AnsiConsole.MarkupLine("[yellow]Warning:[/] Speculative decoding requires pure CPU (-g 0) or full CUDA offload of a dense model. Falling back to normal generation."); } else if (!File.Exists(settings.DraftModelPath)) { @@ -656,13 +661,28 @@ protected override int Execute(CommandContext context, Settings settings, Cancel AnsiConsole.MarkupLine($"[dim]Loading draft model:[/] {settings.DraftModelPath}"); using var draftModel = GgufModel.Open(settings.DraftModelPath); var draftHp = ModelHyperparams.FromGgufMetadata(draftModel.Metadata, draftModel); - using var draftCpuBackend = new CpuBackend(); - using var draftFwd = new ForwardPass(draftModel, draftCpuBackend, draftHp); - AnsiConsole.MarkupLine($"[dim]Draft model: {draftHp.NumLayers}L, {draftHp.EmbeddingDim}d | Lookahead k={settings.SpecLookahead}[/]"); - - if (settings.Prompt is not null) - return RunSpeculativeSinglePrompt(settings, fwd!, draftFwd, tokenizer, sp); - return RunSpeculativeInteractive(settings, fwd!, draftFwd, tokenizer, sp); + if (cudaSpecTarget) + { + // The draft gets its OWN CudaBackend: graph capture state is one + // exec graph per backend instance, so sharing the target's backend + // would have the draft's decode graph clobber the target's. + using var draftCuda = CudaBackend.Create(); + using var draftFwd = new CudaForwardPass(draftModel, draftCuda, draftHp, ctxSize); + AnsiConsole.MarkupLine($"[dim]Draft model: {draftHp.NumLayers}L, {draftHp.EmbeddingDim}d ([green]CUDA[/]) | Lookahead k={settings.SpecLookahead}[/]"); + var target = (CudaForwardPass)gpuFwd!; + if (settings.Prompt is not null) + return RunSpeculativeSinglePrompt(settings, target, draftFwd, tokenizer, sp); + return RunSpeculativeInteractive(settings, target, draftFwd, tokenizer, sp); + } + else + { + using var draftCpuBackend = new CpuBackend(); + using var draftFwd = new ForwardPass(draftModel, draftCpuBackend, draftHp); + AnsiConsole.MarkupLine($"[dim]Draft model: {draftHp.NumLayers}L, {draftHp.EmbeddingDim}d ([blue]CPU[/]) | Lookahead k={settings.SpecLookahead}[/]"); + if (settings.Prompt is not null) + return RunSpeculativeSinglePrompt(settings, fwd!, draftFwd, tokenizer, sp); + return RunSpeculativeInteractive(settings, fwd!, draftFwd, tokenizer, sp); + } } catch (Exception ex) { @@ -700,7 +720,7 @@ protected override int Execute(CommandContext context, Settings settings, Cancel } private static int RunSpeculativeSinglePrompt(Settings s, - ForwardPass target, ForwardPass draft, + IForwardPass target, IForwardPass draft, GgufTokenizer tok, SamplingParams sp) { var prompt = FormatPrompt(s.Prompt!, s.SystemPrompt, enableThinking: !s_noThinking); @@ -710,14 +730,10 @@ private static int RunSpeculativeSinglePrompt(Settings s, Console.Write(s.Prompt); var sw = Stopwatch.StartNew(); - // Prefill both models with the same prompt - ReadOnlySpan targetLogits = default; - ReadOnlySpan draftLogits = default; - for (int i = 0; i < tokens.Count; i++) - { - targetLogits = target.Forward(tokens[i], i); - draftLogits = draft.Forward(tokens[i], i); - } + // Prefill both models with the same prompt (batched-trunk path on both — the + // per-token Forward loop this replaces was ~30× slower on the CUDA target). + ReadOnlySpan targetLogits = target.Prefill(tokens); + ReadOnlySpan draftLogits = draft.Prefill(tokens); var prefillMs = sw.Elapsed.TotalMilliseconds; var spec = new SpeculativeDecoder(target, draft, s.SpecLookahead); @@ -743,12 +759,13 @@ private static int RunSpeculativeSinglePrompt(Settings s, AnsiConsole.MarkupLine($"\n[dim]Prefill: {tokens.Count} tokens, {tokens.Count / (prefillMs / 1000):F1} t/s | " + $"Decode: {totalDecoded} tokens, {totalDecoded / (decodeMs / 1000):F1} t/s" + (totalDecoded > generated ? $" ({generated} visible, {totalDecoded - generated} thinking)" : "") + - $" | Acceptance rate: {spec.AcceptanceRate:P0}[/]"); + $" | Acceptance rate: {spec.AcceptanceRate:P0} | " + + $"draft {spec.DraftMs:F0}ms / verify {spec.VerifyMs:F0}ms / commit {spec.CommitMs:F0}ms[/]"); return 0; } private static int RunSpeculativeInteractive(Settings s, - ForwardPass target, ForwardPass draft, + IForwardPass target, IForwardPass draft, GgufTokenizer tok, SamplingParams sp) { AnsiConsole.MarkupLine("[green]Interactive chat (speculative decoding).[/] Type a message, or [yellow]/exit[/] to quit.\n"); @@ -764,17 +781,12 @@ private static int RunSpeculativeInteractive(Settings s, var prompt = FormatPrompt(input, s.SystemPrompt, enableThinking: !s_noThinking); var tokens = tok.Encode(prompt); - target.Cache.Reset(); - draft.Cache.Reset(); + target.ResetCache(); + draft.ResetCache(); - ReadOnlySpan targetLogits = default; - ReadOnlySpan draftLogits = default; var sw = Stopwatch.StartNew(); - for (int i = 0; i < tokens.Count; i++) - { - targetLogits = target.Forward(tokens[i], i); - draftLogits = draft.Forward(tokens[i], i); - } + ReadOnlySpan targetLogits = target.Prefill(tokens); + ReadOnlySpan draftLogits = draft.Prefill(tokens); spec.Initialize(tokens.Count, targetLogits, draftLogits); diff --git a/src/SharpInference.Core/IForwardPass.cs b/src/SharpInference.Core/IForwardPass.cs index 1d20275..778bb92 100644 --- a/src/SharpInference.Core/IForwardPass.cs +++ b/src/SharpInference.Core/IForwardPass.cs @@ -152,13 +152,32 @@ ReadOnlySpan MtpForward(int token, int position, ReadOnlySpan prev void PrefillMtp(IReadOnlyList tokens, int startPos = 0) { } /// - /// True when this pass implements a batched two-token verify path (issue #30). - /// Callers () dispatch to on - /// the hybrid GDN passes where it pays off; everything else stays on the - /// sequential N=1 algorithm. + /// True when this pass implements a batched verify path. Two consumers, two methods: + /// the MTP decoder dispatches to (two-token self-speculative + /// verify, issue #30 — implemented by the hybrid GDN passes), and the speculative + /// decoder dispatches to (k-token draft verification, + /// issue #207 — implemented by the rewindable dense passes). The consumers' own gates + /// keep the two method sets disjoint: the MTP decoder requires + /// (GDN hybrids only), the speculative decoder requires + /// (which GDN hybrids never report). A pass that + /// returns true here must implement whichever method its reachable consumer calls. /// bool SupportsBatchVerify => false; + /// + /// Batched verification for speculative decoding (issue #207): process + /// as one packed pass over the current sequence starting at + /// (the cache must hold exactly + /// positions), returning result[i] = logits after tokens[i]. All k K/V + /// entries are appended to the cache; the caller rewinds rejected tokens via + /// . Amortizes the weight reads k× vs sequential + /// calls on memory-bound decode paths. + /// + float[][] BatchVerify(int[] tokens, int startPos) => + throw new NotSupportedException( + $"{GetType().Name} does not implement BatchVerify. " + + "Check SupportsBatchVerify before calling."); + /// /// Last completed 's token-1 pre-output-norm hidden. /// Used by the MTP commit step on the batched verify path. Empty when no batched diff --git a/src/SharpInference.Engine/CudaForwardPass.cs b/src/SharpInference.Engine/CudaForwardPass.cs index 3995b1c..9599803 100644 --- a/src/SharpInference.Engine/CudaForwardPass.cs +++ b/src/SharpInference.Engine/CudaForwardPass.cs @@ -396,6 +396,12 @@ private void AttentionSwaKv(Tensor q, Tensor kCache, Tensor vCache, Tensor outpu private Tensor[][]? _raggedKLayers; private Tensor[][]? _raggedVLayers; + // Issue #207: non-owning per-sequence-cache view over the OWNED K/V tensors, so the + // single-user speculative-decode BatchVerify can drive BatchForwardMulti's packed trunk + // against the owned cache. Never disposed (the owned tensors are freed by Dispose); + // every layer is marked aliased so even an accidental Dispose can't free them. + private CudaSequenceKvCache? _ownedCacheView; + // Ragged attention spill scratch [N × numHeads × maxSeqLen] — the ragged kernel // spills per-(sequence, head) score rows when a sequence's length exceeds the // 4096-slot shared-memory fast path. Lazily allocated only when such a length @@ -3437,6 +3443,81 @@ internal float[][] BatchForwardMulti(int[] tokens, int[] positions, CudaSequence } } + // ── Speculative-decode batched verify (issue #207) ────────────────────────────────── + + /// + /// Whether can run: the dense continuous-batching-capable + /// configuration ( — non-MoE, non-Gemma-4, no + /// TurboQuant, no active SnapKV budget, no final-logit softcap, GEMM-N-batchable weights) + /// with an uncompacted cache. The eviction term is redundant today (eviction needs an + /// active budget, which is already excluded) but mirrors the GDN passes' gate so a future + /// relaxation of the budget exclusion can't silently re-enable verify on a compacted + /// cache (physical slot != position). + /// + public bool SupportsBatchVerify => _kvEvictedCount == 0 && SupportsContinuousBatching; + + /// + /// Batched k-token verify for single-user speculative decoding (issue #207): one packed + /// pass over the OWNED cache at contiguous positions [, + /// + k), returning result[i] = logits after + /// tokens[i]. Reuses 's trunk with every row bound + /// to the same cache: the ragged kernels append all k K/V rows before any row attends, + /// and row i attends over [0, startPos+i] — i.e. packed causal attention (the legacy + /// per-sequence fallback loop appends-then-attends in ascending row order, equally + /// causal). Every matmul routes through , so the #194 + /// weight-stationary kernels (or the opt-in #201 decode MMQ) amortize the weight HBM + /// reads k×. Each row keeps the per-token kernels' reduction chains (#194/#197), so the + /// default WS path is expected bit-identical to k sequential calls; + /// the opt-in compute-bound/MMQ toggles are argmax-stable only. All k K/V entries land in + /// the cache; the caller rewinds rejected tokens via . Issues + /// direct launches only — the per-token decode CUDA graph (owned-cache pointers) stays + /// valid for the surrounding Forward steps. + /// + public float[][] BatchVerify(int[] tokens, int startPos) + { + ArgumentNullException.ThrowIfNull(tokens); + if (!SupportsBatchVerify) + throw new NotSupportedException( + "BatchVerify requires the dense batching-capable configuration (no MoE / " + + "Gemma-4 / TurboQuant / SnapKV / softcap, GEMM-N-batchable weights) and an " + + "uncompacted cache. Check SupportsBatchVerify before calling."); + int k = tokens.Length; + if (k == 0) return Array.Empty(); + if (startPos < 0 || startPos + k > _maxSeqLen) + throw new ArgumentOutOfRangeException(nameof(startPos), + $"BatchVerify range [{startPos}, {startPos + k}) exceeds the context window (maxSeqLen={_maxSeqLen})."); + + if (k == 1) + { + // A single token amortizes nothing — the per-token Forward (CUDA-graph + // replayable) is strictly better. Mirrors the CPU BatchVerify fallback. + var logits = Forward(tokens[0], startPos); + var seq = new float[1][]; + seq[0] = new float[_hp.VocabSize]; + logits.CopyTo(seq[0]); + return seq; + } + + if (_ownedCacheView is null) + { + var all = new HashSet(); + for (int l = 0; l < _hp.NumLayers; l++) all.Add(l); + _ownedCacheView = new CudaSequenceKvCache(_gpu, _ownedKCache, _ownedVCache, all); + } + _ownedCacheView.Length = startPos; + + var positions = new int[k]; + for (int i = 0; i < k; i++) positions[i] = startPos + i; + var caches = new CudaSequenceKvCache[k]; + Array.Fill(caches, _ownedCacheView); + + var result = BatchForwardMulti(tokens, positions, caches); + // Mirror what k sequential Forward calls would leave behind; the speculative + // decoder's TruncateTo(startPos + accepted) then rewinds the rejected tail. + _kvLength = Math.Max(_kvLength, startPos + k); + return result; + } + /// (Re)allocate the batched-decode logits buffer [ × vocab] /// and its host download buffer when the decode batch size changes. private void EnsureDecodeLogits(int n) diff --git a/src/SharpInference.Engine/ForwardPass.cs b/src/SharpInference.Engine/ForwardPass.cs index 8975ed5..5d7a6c0 100644 --- a/src/SharpInference.Engine/ForwardPass.cs +++ b/src/SharpInference.Engine/ForwardPass.cs @@ -1161,6 +1161,15 @@ private ReadOnlySpan PrefillCoreTq(IReadOnlyList tokens, int startPo } } + /// + /// Whether can run (issue #207): everything except the two + /// configurations it throws for — the TurboQuant KV cache (compressed ring can't take + /// the batched appends) and gemma4-style per-layer head_dim (not wired into the batched + /// trunk). MoE stays true: itself falls back to + /// sequential calls for MoE, which is still correct. + /// + public bool SupportsBatchVerify => _tqKvCache is null && _layerHeadDim is null; + /// /// Batched verification for speculative decoding: processes starting /// at using the existing KV cache as context. diff --git a/src/SharpInference.Engine/SpeculativeDecoder.cs b/src/SharpInference.Engine/SpeculativeDecoder.cs index 9dc64dd..30cd992 100644 --- a/src/SharpInference.Engine/SpeculativeDecoder.cs +++ b/src/SharpInference.Engine/SpeculativeDecoder.cs @@ -19,6 +19,7 @@ public sealed class SpeculativeDecoder { private readonly IForwardPass _target; private readonly IForwardPass _draft; + private readonly bool _batchVerify; private int _lookahead; // Generation state @@ -30,6 +31,10 @@ public sealed class SpeculativeDecoder private long _totalAccepted; private long _totalEmitted; + // Phase timing (issue #207 bench reporting): cumulative wall time spent drafting, + // batch-verifying, and committing (truncate + correction forwards) across all steps. + private readonly System.Diagnostics.Stopwatch _phaseSw = new(); + public SpeculativeDecoder(IForwardPass target, IForwardPass draft, int lookahead = 4) { if (target.VocabSize != draft.VocabSize) @@ -54,6 +59,12 @@ public SpeculativeDecoder(IForwardPass target, IForwardPass draft, int lookahead nameof(draft)); _target = target; _draft = draft; + // Kill-switch (issue #207): SHARPI_SPEC_BATCH_VERIFY=0 forces the sequential + // verify fallback even when the target implements BatchVerify. Read once at + // construction (same pattern as the forward passes' decode toggles); the + // capability itself is re-checked per step — it can flip after construction + // (e.g. ForwardPass.EnableTurboQuant). + _batchVerify = Environment.GetEnvironmentVariable("SHARPI_SPEC_BATCH_VERIFY") != "0"; _lookahead = Math.Max(1, lookahead); _savedTargetLogits = new float[target.VocabSize]; _savedDraftLogits = new float[draft.VocabSize]; @@ -69,6 +80,15 @@ public int Lookahead /// Running acceptance rate (accepted tokens / total emitted tokens). public float AcceptanceRate => _totalEmitted > 0 ? (float)_totalAccepted / _totalEmitted : 0f; + /// Cumulative milliseconds spent in the draft phase (k−1 draft forwards per step). + public double DraftMs { get; private set; } + + /// Cumulative milliseconds spent batch-verifying with the target. + public double VerifyMs { get; private set; } + + /// Cumulative milliseconds spent in cache truncation + correction-token forwards. + public double CommitMs { get; private set; } + /// /// Initialize the decoder after both models have processed the prompt. /// Call after prefilling both target and draft with the same prompt tokens. @@ -83,6 +103,9 @@ public void Initialize(int prefillLength, ReadOnlySpan targetLogits, Read draftLogits.CopyTo(_savedDraftLogits); _totalAccepted = 0; _totalEmitted = 0; + DraftMs = 0; + VerifyMs = 0; + CommitMs = 0; } /// @@ -123,6 +146,7 @@ private int[] Step(int k) // ── Draft phase ────────────────────────────────────────────────────────── // d[0] is free: argmax of saved draft logits (no forward pass needed). // d[1..k-1] require k-1 draft Forward calls, appending d[0..k-2] to draft cache. + _phaseSw.Restart(); var draftTokens = new int[k]; var draftLogitsPerPos = new float[k][]; @@ -137,12 +161,15 @@ private int[] Step(int k) draftTokens[i] = ArgMax(draftLogitsPerPos[i]); } // Draft cache is now at P + k - 1 (appended d[0..k-2]). + DraftMs += _phaseSw.Elapsed.TotalMilliseconds; // ── Target batch-verify ────────────────────────────────────────────────── // Process d[0..k-1] in one batched forward pass. // targetLogitsFromBatch[i] = P_target(·|ctx + d[0..i]) (logits AFTER d[i]). // After this call, target cache is at P + k. + _phaseSw.Restart(); float[][] targetLogitsFromBatch = BatchVerifyTarget(draftTokens, P); + VerifyMs += _phaseSw.Elapsed.TotalMilliseconds; // targetLogits[0] = saved (before d[0]) // targetLogits[i+1] = after d[i] (from batch) @@ -172,6 +199,7 @@ private int[] Step(int k) // ── Truncate caches to accepted position ───────────────────────────────── // Target is at P+k; truncate to P+accepted. + _phaseSw.Restart(); _target.TruncateTo(P + accepted); // Draft is at P+k-1; truncate to P+accepted. @@ -191,6 +219,7 @@ private int[] Step(int k) int commitPos = accepted == k ? P + k : P + accepted; var newTargetLogits = _target.Forward(correction, commitPos); var newDraftLogits = _draft.Forward(correction, commitPos); + CommitMs += _phaseSw.Elapsed.TotalMilliseconds; // ── Update state ───────────────────────────────────────────────────────── _nextPos = commitPos + 1; @@ -205,14 +234,16 @@ private int[] Step(int k) } /// - /// Batch-verify draft tokens with the target model. - /// If target is a ForwardPass, uses its batched BatchVerify for efficiency. - /// Falls back to sequential Forward calls otherwise. + /// Batch-verify draft tokens with the target model. Targets that report + /// (CPU , dense + /// CudaForwardPass — issue #207) take the packed k-token pass, which amortizes + /// the weight reads k× on memory-bound decode paths. Everything else (and + /// SHARPI_SPEC_BATCH_VERIFY=0) falls back to k sequential Forward calls. /// private float[][] BatchVerifyTarget(int[] draftTokens, int startPos) { - if (_target is ForwardPass cpuTarget) - return cpuTarget.BatchVerify(draftTokens, startPos); + if (_batchVerify && _target.SupportsBatchVerify) + return _target.BatchVerify(draftTokens, startPos); // Generic fallback: sequential Forward calls var result = new float[draftTokens.Length][]; diff --git a/tests/SharpInference.Tests.ForwardPass/CudaSpecBatchVerifyTests.cs b/tests/SharpInference.Tests.ForwardPass/CudaSpecBatchVerifyTests.cs new file mode 100644 index 0000000..b4e8b93 --- /dev/null +++ b/tests/SharpInference.Tests.ForwardPass/CudaSpecBatchVerifyTests.cs @@ -0,0 +1,266 @@ +using SharpInference.Core; +using SharpInference.Cpu; +using SharpInference.Engine; +using SharpInference.Cuda; + +namespace SharpInference.Tests.ForwardPass; + +/// +/// Issue #207: single-user speculative decoding on the dense CUDA path. +/// runs one packed k-token pass over the OWNED +/// cache at contiguous positions [P, P+k) — 's +/// trunk with every row bound to the same cache — so the weight HBM reads are amortized k× +/// vs the k sequential calls it replaces. +/// +/// Correctness contract (chunked-prefill class): argmax-stable vs sequential Forward at every +/// verified position, asserted with the maxAbs/top-5 tolerances of +/// . The default WS path keeps the per-token kernels' +/// reduction chains (#194/#197), so the e2e greedy spec output is expected to EXACTLY match +/// the non-spec greedy baseline. +/// +/// One ~5 GB Qwen3-8B instance per test; the sequential reference runs first and BatchVerify +/// follows after a soft rewind — deliberately the +/// production flow (verify overwrites the stale rewound K/V slots). Silent-skips when CUDA +/// or the GGUF is absent — mirrors . +/// +public sealed class CudaSpecBatchVerifyTests +{ + private const string TargetModelFile = "Qwen3-8B-Q4_K_M.gguf"; + private const string DraftModelFile = "Qwen3-0.6B-Q8_0.gguf"; + + private static readonly int[] Prompt = { 9707, 11, 1879, 0, 358, 1079, 264, 4108, 1614, 13 }; + + private static CudaBackend? TryCreate() + { + if (!CudaBackend.IsAvailable()) return null; + try { return CudaBackend.Create(); } + catch { return null; } + } + + // SnapKV pinned off: BatchVerify is unsupported under an active SnapKV budget, and + // VRAM-scaled auto-SnapKV could otherwise engage on a smaller GPU and flip + // SupportsBatchVerify to false (same pinning as CudaBatchForwardMultiTests.NewFwd). + private static CudaForwardPass NewFwd(GgufModel model, CudaBackend gpu, ModelHyperparams hp, int ctx = 512) + { + var prev = Environment.GetEnvironmentVariable("SHARPI_SNAPKV_BUDGET"); + Environment.SetEnvironmentVariable("SHARPI_SNAPKV_BUDGET", "0"); + try { return new CudaForwardPass(model, gpu, hp, maxContextLength: ctx); } + finally { Environment.SetEnvironmentVariable("SHARPI_SNAPKV_BUDGET", prev); } + } + + private static string? FindModelPath(string file) + { + string[] absolute = { $@"E:\models\{file}", $@"C:\p\sharpi\models\{file}" }; + foreach (var p in absolute) + if (File.Exists(p)) return p; + var dir = Directory.GetCurrentDirectory(); + for (int i = 0; i < 8; i++) + { + var p = Path.Combine(dir, "models", file); + if (File.Exists(p)) return p; + var parent = Directory.GetParent(dir); + if (parent is null) break; + dir = parent.FullName; + } + return null; + } + + private static int Argmax(ReadOnlySpan logits) + { + int best = 0; + float bestVal = logits[0]; + for (int i = 1; i < logits.Length; i++) + if (logits[i] > bestVal) { bestVal = logits[i]; best = i; } + return best; + } + + private static HashSet TopKSet(ReadOnlySpan logits, int k) + { + var idx = new int[logits.Length]; + for (int i = 0; i < idx.Length; i++) idx[i] = i; + var arr = logits.ToArray(); + Array.Sort(idx, (a, b) => arr[b].CompareTo(arr[a])); + var set = new HashSet(); + for (int i = 0; i < k && i < idx.Length; i++) set.Add(idx[i]); + return set; + } + + private static (float maxAbs, int overlap) Compare(float[] reference, float[] candidate) + { + Assert.Equal(reference.Length, candidate.Length); + float maxAbs = 0f; + for (int i = 0; i < reference.Length; i++) + maxAbs = MathF.Max(maxAbs, MathF.Abs(reference[i] - candidate[i])); + var refTop = TopKSet(reference, 5); + var candTop = TopKSet(candidate, 5); + int overlap = 0; + foreach (var t in candTop) if (refTop.Contains(t)) overlap++; + return (maxAbs, overlap); + } + + /// + /// Headline pass-level oracle: BatchVerify's per-position logits for k packed tokens + /// must reproduce k sequential Forward calls at every position (argmax equal + + /// maxAbs/top-5 within the cross-path tolerance). Run at k=4 and k=6 — 6 is not a + /// capacity-stamped WS kernel size (2/4/8/16), so it also exercises the pad-to-capacity + /// dispatch. + /// + [Theory] + [InlineData(4)] + [InlineData(6)] + public void Qwen3_8B_BatchVerify_MatchesSequentialForward(int k) + { + using var gpu = TryCreate(); + if (gpu is null) return; + var path = FindModelPath(TargetModelFile); + if (path is null) return; + + using var model = GgufModel.Open(path); + var hp = ModelHyperparams.FromGgufMetadata(model.Metadata, model); + Assert.Null(hp.LayerHeadDim); + Assert.False(hp.IsMoE); + + using var fwd = NewFwd(model, gpu, hp); + Assert.True(fwd.SupportsBatchVerify, + "Dense Qwen3-8B Q4_K_M must report SupportsBatchVerify on the CUDA path."); + + fwd.ResetCache(); + var prefillLogits = fwd.Prefill(Prompt); + int P = Prompt.Length; + + // Greedy-chain k tokens so the verified positions carry realistic activations. + var tokens = new int[k]; + tokens[0] = Argmax(prefillLogits); + + // Sequential reference: k Forward calls, capturing logits at every position. + var reference = new float[k][]; + for (int i = 0; i < k; i++) + { + var logits = fwd.Forward(tokens[i], P + i); + reference[i] = logits.ToArray(); + if (i + 1 < k) tokens[i + 1] = Argmax(logits); + } + + // Rewind (soft — stale K/V stays and must be overwritten) and batch-verify. + fwd.TruncateTo(P); + float[][] batch = fwd.BatchVerify(tokens, P); + + Assert.Equal(k, batch.Length); + for (int i = 0; i < k; i++) + { + var (maxAbs, overlap) = Compare(reference[i], batch[i]); + Assert.Equal(Argmax(reference[i]), Argmax(batch[i])); + Assert.True(overlap >= 4, + $"Position {i}: batched top-5 overlaps the sequential reference in only {overlap}/5 slots (maxAbs={maxAbs})."); + Assert.True(maxAbs < 1.0f, + $"Position {i}: batched vs sequential logits diverged beyond tolerance: maxAbs={maxAbs}."); + } + } + + /// + /// Rollback oracle — the full speculative step shape: BatchVerify k tokens (some + /// deliberately wrong), TruncateTo(P+accepted), then Forward the correction at + /// P+accepted. The post-rollback logits must match the sequential trajectory that + /// never saw the rejected tokens — catches stale-KV leaks past the truncation point + /// (the rejected rows' K/V stays in the cache and must be masked by seqLen and + /// overwritten by the commit). + /// + [Fact] + public void Qwen3_8B_BatchVerify_TruncateAndCommit_MatchesSequential() + { + using var gpu = TryCreate(); + if (gpu is null) return; + var path = FindModelPath(TargetModelFile); + if (path is null) return; + + using var model = GgufModel.Open(path); + var hp = ModelHyperparams.FromGgufMetadata(model.Metadata, model); + + using var fwd = NewFwd(model, gpu, hp); + Assert.True(fwd.SupportsBatchVerify); + + fwd.ResetCache(); + var prefillLogits = fwd.Prefill(Prompt); + int P = Prompt.Length; + int t0 = Argmax(prefillLogits); + + // Sequential reference trajectory: accept t0, then the correction t1. + int t1 = Argmax(fwd.Forward(t0, P)); + float[] reference = fwd.Forward(t1, P + 1).ToArray(); + + // Spec-step shape: rewind to P, verify [t0, junk, junk, junk] (junk = off-chain + // tokens that will be rejected), accept only t0, commit t1. + fwd.TruncateTo(P); + int junk = (t0 + 7919) % hp.VocabSize; + float[][] batch = fwd.BatchVerify([t0, junk, junk, junk], P); + Assert.Equal(t1, Argmax(batch[0])); // verify logits after t0 must still pick t1 + + fwd.TruncateTo(P + 1); + float[] committed = fwd.Forward(t1, P + 1).ToArray(); + + var (maxAbs, overlap) = Compare(reference, committed); + Assert.Equal(Argmax(reference), Argmax(committed)); + Assert.True(overlap >= 4, + $"Post-rollback commit top-5 overlap {overlap}/5 (maxAbs={maxAbs})."); + Assert.True(maxAbs < 1.0f, + $"Post-rollback commit diverged from the sequential trajectory: maxAbs={maxAbs}."); + } + + /// + /// E2E greedy parity: SpeculativeDecoder with a CUDA Qwen3-8B target and a CPU + /// Qwen3-0.6B draft must emit EXACTLY the target's own non-spec greedy continuation — + /// the spec-decode invariant (the draft only proposes; every emitted token is argmax of + /// target logits). The default WS verify keeps per-token reduction chains, so unlike the + /// pass-level tolerance oracles this asserts exact token equality over 48 tokens; a + /// mismatch means a real verify/rollback bug (or an FP-borderline token — investigate + /// before weakening, per the SnapKV-parity precedent). + /// + [Fact] + public void Qwen3_8B_SpecDecode_GreedyParity_E2E() + { + using var gpu = TryCreate(); + if (gpu is null) return; + var targetPath = FindModelPath(TargetModelFile); + var draftPath = FindModelPath(DraftModelFile); + if (targetPath is null || draftPath is null) return; + + const int DecodeTokens = 48; + + using var targetModel = GgufModel.Open(targetPath); + var targetHp = ModelHyperparams.FromGgufMetadata(targetModel.Metadata, targetModel); + using var target = NewFwd(targetModel, gpu, targetHp); + Assert.True(target.SupportsBatchVerify); + + // Non-spec greedy baseline on the target alone. + target.ResetCache(); + var logits = target.Prefill(Prompt); + int P = Prompt.Length; + var baseline = new List(); + int tok = Argmax(logits); + for (int i = 0; i < DecodeTokens; i++) + { + baseline.Add(tok); + logits = target.Forward(tok, P + i); + tok = Argmax(logits); + } + + // Spec decode with the 0.6B CPU draft (same Qwen3 tokenizer/vocab). + using var draftModel = GgufModel.Open(draftPath); + var draftHp = ModelHyperparams.FromGgufMetadata(draftModel.Metadata, draftModel); + Assert.Equal(targetHp.VocabSize, draftHp.VocabSize); + using var cpu = new CpuBackend(); + using var draft = new SharpInference.Engine.ForwardPass(draftModel, cpu, draftHp); + + target.ResetCache(); + var targetLogits = target.Prefill(Prompt).ToArray(); + var draftLogits = draft.Prefill(Prompt).ToArray(); + + var spec = new SpeculativeDecoder(target, draft, lookahead: 4); + spec.Initialize(P, targetLogits, draftLogits); + + var emitted = new List(); + spec.Decode(DecodeTokens, [], emitted.Add); + + Assert.Equal(baseline, emitted); + } +} diff --git a/tests/SharpInference.Tests.ForwardPass/SpeculativeDecoderTests.cs b/tests/SharpInference.Tests.ForwardPass/SpeculativeDecoderTests.cs index 16a6616..c8f0d47 100644 --- a/tests/SharpInference.Tests.ForwardPass/SpeculativeDecoderTests.cs +++ b/tests/SharpInference.Tests.ForwardPass/SpeculativeDecoderTests.cs @@ -55,6 +55,144 @@ public void Ctor_VocabSizeMismatch_ThrowsArgumentExceptionRegardlessOfRewindSupp Assert.Contains("vocab", ex.Message, StringComparison.OrdinalIgnoreCase); } + // ── BatchVerify routing (issue #207) ──────────────────────────────────────────── + // + // The decoder dispatches verification through the IForwardPass.SupportsBatchVerify + // capability (replacing the old `is ForwardPass` CPU type-check), with + // SHARPI_SPEC_BATCH_VERIFY=0 as the kill-switch back to sequential Forward calls. + // Scripted chain models (next = token+1 mod vocab) make the greedy accept/reject + // logic fully deterministic so the call counts can be asserted exactly. + + [Fact] + public void Decode_BatchVerifyCapableTarget_RoutesThroughBatchVerify() + { + var target = new ChainForwardPass(vocab: 16, supportsBatchVerify: true); + var draft = new ChainForwardPass(vocab: 16, supportsBatchVerify: false); + + var spec = new SpeculativeDecoder(target, draft, lookahead: 3); + spec.Initialize(prefillLength: 1, ChainForwardPass.Logits(16, next: 2), ChainForwardPass.Logits(16, next: 2)); + + var emitted = new List(); + spec.Decode(6, [], emitted.Add); + + // Draft and target agree everywhere (same chain), so both steps accept fully: + // step 1 (k=3) emits 3+1 tokens, step 2 (k=min(3, remaining=2)=2) emits 2+1. + Assert.Equal(new[] { 2, 3, 4, 5, 6, 7 }, emitted); + Assert.Equal(2, target.BatchVerifyCalls); + // Target Forward runs only for the per-step correction commit, never for verify. + Assert.Equal(2, target.ForwardCalls); + // Full acceptance: 5 accepted out of 7 emitted (corrections always count against + // the rate — k/(k+1) per step is the ceiling). + Assert.Equal(5f / 7f, spec.AcceptanceRate, 3); + } + + [Fact] + public void Decode_KillSwitch_FallsBackToSequentialForward() + { + var prev = Environment.GetEnvironmentVariable("SHARPI_SPEC_BATCH_VERIFY"); + Environment.SetEnvironmentVariable("SHARPI_SPEC_BATCH_VERIFY", "0"); + SpeculativeDecoder spec; + var target = new ChainForwardPass(vocab: 16, supportsBatchVerify: true); + var draft = new ChainForwardPass(vocab: 16, supportsBatchVerify: false); + try + { + // The kill-switch is read once at construction. + spec = new SpeculativeDecoder(target, draft, lookahead: 3); + } + finally + { + Environment.SetEnvironmentVariable("SHARPI_SPEC_BATCH_VERIFY", prev); + } + + spec.Initialize(prefillLength: 1, ChainForwardPass.Logits(16, next: 2), ChainForwardPass.Logits(16, next: 2)); + var emitted = new List(); + spec.Decode(6, [], emitted.Add); + + // Identical emitted sequence, but verification ran as k sequential Forwards plus + // the per-step correction commit: step 1 (k=3) 3+1, step 2 (k=2) 2+1 → 7 total. + Assert.Equal(new[] { 2, 3, 4, 5, 6, 7 }, emitted); + Assert.Equal(0, target.BatchVerifyCalls); + Assert.Equal(7, target.ForwardCalls); + } + + [Fact] + public void Decode_DraftDiverges_RejectionEmitsCorrectionFromBatchLogits() + { + var target = new ChainForwardPass(vocab: 16, supportsBatchVerify: true); + // Draft diverges from the chain at its second proposal of each step. + var draft = new ChainForwardPass(vocab: 16, supportsBatchVerify: false, divergeEvery: 2); + + var spec = new SpeculativeDecoder(target, draft, lookahead: 3); + spec.Initialize(prefillLength: 1, ChainForwardPass.Logits(16, next: 2), ChainForwardPass.Logits(16, next: 2)); + + var emitted = new List(); + spec.Decode(4, [], emitted.Add); + + // The emitted sequence must still be the target's greedy chain regardless of + // where the draft diverged — corrections come from the verify logits. + Assert.Equal(new[] { 2, 3, 4, 5 }, emitted); + Assert.True(spec.AcceptanceRate < 1f); + Assert.True(target.BatchVerifyCalls > 0); + } + + /// + /// Deterministic "chain" model: greedy next token is always (token+1) mod vocab. + /// Tracks Forward/BatchVerify call counts; divergeEvery > 0 makes every + /// Nth Forward propose (token+2) instead, simulating a draft that goes off-chain. + /// + private sealed class ChainForwardPass : IForwardPass + { + private readonly bool _supportsBatchVerify; + private readonly int _divergeEvery; + + public int ForwardCalls; + public int BatchVerifyCalls; + + public ChainForwardPass(int vocab, bool supportsBatchVerify, int divergeEvery = 0) + { + VocabSize = vocab; + _supportsBatchVerify = supportsBatchVerify; + _divergeEvery = divergeEvery; + } + + public int VocabSize { get; } + public int MaxSeqLen => 4096; + public bool SupportsPartialRewind => true; + public bool SupportsBatchVerify => _supportsBatchVerify; + + public static float[] Logits(int vocab, int next) + { + var l = new float[vocab]; + l[next] = 1f; + return l; + } + + public ReadOnlySpan Forward(int token, int position) + { + ForwardCalls++; + int next = (token + 1) % VocabSize; + if (_divergeEvery > 0 && ForwardCalls % _divergeEvery == 0) + next = (token + 2) % VocabSize; + return Logits(VocabSize, next); + } + + public float[][] BatchVerify(int[] tokens, int startPos) + { + if (!_supportsBatchVerify) + throw new NotSupportedException("BatchVerify called on a non-capable mock."); + BatchVerifyCalls++; + var result = new float[tokens.Length][]; + for (int i = 0; i < tokens.Length; i++) + result[i] = Logits(VocabSize, (tokens[i] + 1) % VocabSize); + return result; + } + + public ReadOnlySpan Prefill(IReadOnlyList tokens, int startPos = 0) => new float[VocabSize]; + public void TruncateTo(int length) { } + public void ResetCache() { } + public void Dispose() { } + } + private sealed class MockForwardPass : IForwardPass { public MockForwardPass(int vocabSize, bool supportsPartialRewind) From cdbceb484039d6cf3cd9785f04c5a3f08a1d5e02 Mon Sep 17 00:00:00 2001 From: Pekka Heikura Date: Wed, 10 Jun 2026 22:50:22 +0300 Subject: [PATCH 2/6] perf(engine): #207 fold the commit forward into the verify batch; fix spec VRAM sizing SpeculativeDecoder.Step now packs the CERTAIN next token (argmax of saved target logits) with k-1 draft proposals into ONE batched target pass (the llama.cpp formulation): the batch yields both the verification logits and the next step's saved logits, so the separate 13.2 ms correction-commit Forward disappears - the target runs exactly one batched pass per step. Saved draft logits are no longer decoder state (Initialize keeps the parameter for call-site compat). CudaForwardPass: SupportsBatchVerify gets its own gate - a CONFIGURED SnapKV budget no longer disables verify, only an actual prefill-time eviction does (_kvEvictedCount > 0, the #130 GDN pattern); ThrowIfBatchingUnsupported gains a decodeOnly mode for BatchForwardMulti (decode never evicts; per-seq caches are only obtainable via CreateCache, which still rejects budgets). CLI: cap the CUDA draft's KV ring at min(target.MaxSeqLen, 4096) unless -c is explicit - sizing it from post-target free VRAM oversubscribed the 12 GB card (34K-ctx/7GB ring: decode 75->13 t/s; even a target-matched 12K ring paged the draft weights every step: 34 t/s). Generation is bounded by BOTH windows. Qwen3-8B Q4_K_M + Qwen3-0.6B Q8_0 draft, 4070 Ti, 200-token greedy decode (baseline 74.8 t/s): k=3 96.1, k=4 99.3 (1.33x), k=6 86.1, k=8 77.2; SHARPI_BATCH_DECODE_MMQ=1: k=5 90.4, k=6 96.2, k=8 92.1. Verify cost at k=4 is ~19.5 ms/step (matches the #201 N=4 packed-step reference). Spec output text identical to non-spec baseline. Co-Authored-By: Claude Fable 5 --- src/SharpInference.Cli/RunCommand.cs | 33 +++- src/SharpInference.Engine/CudaForwardPass.cs | 73 +++++---- .../SpeculativeDecoder.cs | 145 +++++++++--------- .../SpeculativeDecoderTests.cs | 19 +-- 4 files changed, 151 insertions(+), 119 deletions(-) diff --git a/src/SharpInference.Cli/RunCommand.cs b/src/SharpInference.Cli/RunCommand.cs index 8580c79..8c28d30 100644 --- a/src/SharpInference.Cli/RunCommand.cs +++ b/src/SharpInference.Cli/RunCommand.cs @@ -98,7 +98,7 @@ public sealed class Settings : CommandSettings [Description("Path to a smaller draft model for speculative decoding (greedy only, requires --temp 0)")] public string? DraftModelPath { get; init; } - [CommandOption("--spec-lookahead")] + [CommandOption("--spec-lookahead|--draft-tokens")] [Description("Number of draft tokens per speculative step with --draft-model (default: 4)")] [DefaultValue(4)] public int SpecLookahead { get; init; } @@ -663,13 +663,26 @@ protected override int Execute(CommandContext context, Settings settings, Cancel var draftHp = ModelHyperparams.FromGgufMetadata(draftModel.Metadata, draftModel); if (cudaSpecTarget) { + var target = (CudaForwardPass)gpuFwd!; // The draft gets its OWN CudaBackend: graph capture state is one // exec graph per backend instance, so sharing the target's backend // would have the draft's decode graph clobber the target's. + // + // Clamp the draft's context: the decoder advances both passes in + // lockstep, so the draft never sees a position past the target's + // window — and unless the user pinned -c explicitly, cap it at 4096 + // (the decode runners bound generation by BOTH windows, so a smaller + // draft ring only caps session length, never indexes out of range). + // Passing 0 would size the draft's KV from the VRAM left AFTER the + // target loaded — measured on the 12 GB 4070 Ti: the 0.6B draft + // grabbed a 34K-ctx / ~7 GB ring next to the 8B target (decode + // 75 → 13 t/s, WDDM paging); even a target-matched 12K fp32 ring + // (~2.8 GB) left so little headroom that the draft's weights paged + // in and out every step (draft forward 2.9 → ~15 ms, decode 34 t/s). + int draftCtx = ctxSize > 0 ? target.MaxSeqLen : Math.Min(target.MaxSeqLen, 4096); using var draftCuda = CudaBackend.Create(); - using var draftFwd = new CudaForwardPass(draftModel, draftCuda, draftHp, ctxSize); + using var draftFwd = new CudaForwardPass(draftModel, draftCuda, draftHp, draftCtx); AnsiConsole.MarkupLine($"[dim]Draft model: {draftHp.NumLayers}L, {draftHp.EmbeddingDim}d ([green]CUDA[/]) | Lookahead k={settings.SpecLookahead}[/]"); - var target = (CudaForwardPass)gpuFwd!; if (settings.Prompt is not null) return RunSpeculativeSinglePrompt(settings, target, draftFwd, tokenizer, sp); return RunSpeculativeInteractive(settings, target, draftFwd, tokenizer, sp); @@ -739,13 +752,20 @@ private static int RunSpeculativeSinglePrompt(Settings s, var spec = new SpeculativeDecoder(target, draft, s.SpecLookahead); spec.Initialize(tokens.Count, targetLogits, draftLogits); + // Bound generation by BOTH context windows (the draft's may be smaller — the CUDA + // spec path caps its KV ring), leaving lookahead headroom for the last spec step. + int maxNew = Math.Min(sp.MaxNewTokens, + Math.Min(target.MaxSeqLen, draft.MaxSeqLen) - tokens.Count - s.SpecLookahead - 1); + if (maxNew < sp.MaxNewTokens) + AnsiConsole.MarkupLine($"[yellow]Note:[/] generation capped at {maxNew} tokens by the context window."); + sw.Restart(); int generated = 0; int totalDecoded = 0; bool inThinking = false; var streamDec = new Utf8StreamDecoder(); bool hideThinking = s.HideThinking; - spec.Decode(sp.MaxNewTokens, sp.StopTokenIds ?? [], token => + spec.Decode(maxNew, sp.StopTokenIds ?? [], token => { if (EmitToken(token, tok, streamDec, ref inThinking, hideThinking)) generated++; totalDecoded++; @@ -790,13 +810,16 @@ private static int RunSpeculativeInteractive(Settings s, spec.Initialize(tokens.Count, targetLogits, draftLogits); + int maxNew = Math.Min(sp.MaxNewTokens, + Math.Min(target.MaxSeqLen, draft.MaxSeqLen) - tokens.Count - s.SpecLookahead - 1); + sw.Restart(); int generated = 0; int totalDecoded = 0; bool inThinking = false; var streamDec = new Utf8StreamDecoder(); bool hideThinking = s.HideThinking; - spec.Decode(sp.MaxNewTokens, sp.StopTokenIds ?? [], token => + spec.Decode(maxNew, sp.StopTokenIds ?? [], token => { if (EmitToken(token, tok, streamDec, ref inThinking, hideThinking)) generated++; totalDecoded++; diff --git a/src/SharpInference.Engine/CudaForwardPass.cs b/src/SharpInference.Engine/CudaForwardPass.cs index 9599803..d7967c5 100644 --- a/src/SharpInference.Engine/CudaForwardPass.cs +++ b/src/SharpInference.Engine/CudaForwardPass.cs @@ -3009,24 +3009,32 @@ private bool DecodeBatchable(Tensor t) => /// decode — keep them out until a batched softcap is wired), and every trunk + output weight /// in a GEMM-N-batchable dtype (excludes Q4_0). /// - public bool SupportsContinuousBatching + public bool SupportsContinuousBatching => _snapKvEffectiveBudget == 0 && DenseBatchedDecodeSupported(); + + /// + /// The arch/dtype gate shared by and + /// : dense (non-MoE, non-Gemma-4), no TurboQuant, no + /// final-logit softcap, every trunk + output weight GEMM-N-batchable. SnapKV terms are + /// applied by the callers — continuous batching excludes any configured budget (prefill + /// into a per-sequence cache could evict), while spec-decode verify only excludes an + /// actually-compacted owned cache (decode never evicts; an unevicted budget keeps + /// slot == position). + /// + private bool DenseBatchedDecodeSupported() { - get + if (_isMoE || _isGemma4Like || _tqEnabled) return false; + if (_hp.FinalLogitSoftcap > 0f) return false; + for (int i = 0; i < _hp.NumLayers; i++) { - if (_isMoE || _isGemma4Like || _tqEnabled || _snapKvEffectiveBudget > 0) return false; - if (_hp.FinalLogitSoftcap > 0f) return false; - for (int i = 0; i < _hp.NumLayers; i++) - { - if (!DecodeBatchable(_wq[i]) || !DecodeBatchable(_wk[i]) || - !DecodeBatchable(_wo[i]) || !DecodeBatchable(_wGate[i]) || - !DecodeBatchable(_wUp[i]) || !DecodeBatchable(_wDown[i])) - return false; - // Dense layers must own a separate V projection (k_eq_v is Gemma-4-only); - // a null _wv would NRE in BatchForwardMulti, so disable batching defensively. - if (_wv[i] is not { } wv || !DecodeBatchable(wv)) return false; - } - return DecodeBatchable(_wOutput); + if (!DecodeBatchable(_wq[i]) || !DecodeBatchable(_wk[i]) || + !DecodeBatchable(_wo[i]) || !DecodeBatchable(_wGate[i]) || + !DecodeBatchable(_wUp[i]) || !DecodeBatchable(_wDown[i])) + return false; + // Dense layers must own a separate V projection (k_eq_v is Gemma-4-only); + // a null _wv would NRE in BatchForwardMulti, so disable batching defensively. + if (_wv[i] is not { } wv || !DecodeBatchable(wv)) return false; } + return DecodeBatchable(_wOutput); } /// @@ -3035,7 +3043,7 @@ public bool SupportsContinuousBatching /// machine (TQ ring, SnapKV eviction), per-layer geometry (Gemma 4), a logit softcap, or a /// weight dtype the GEMM-N matvec can't drive (Q4_0) is out of scope. /// - private void ThrowIfBatchingUnsupported() + private void ThrowIfBatchingUnsupported(bool decodeOnly = false) { if (_isMoE) throw new NotSupportedException( @@ -3046,14 +3054,20 @@ private void ThrowIfBatchingUnsupported() if (_tqEnabled) throw new NotSupportedException( "CUDA continuous batching is not supported with the TurboQuant KV cache."); - if (_snapKvEffectiveBudget > 0) - throw new NotSupportedException( - "CUDA continuous batching is not supported with an active SnapKV budget (eviction runs only on a whole-prompt prefill of the owned cache)."); + // Prefill-capable entry points (CreateCache / PrefillWithCache / PrefillPackedMulti) + // reject any configured SnapKV budget — a whole-prompt prefill into a bound cache + // could evict. Decode-only batching (BatchForwardMulti, incl. the issue-#207 + // BatchVerify wrapper) never evicts, so it only rejects an ALREADY-compacted owned + // cache, where logical position != physical slot. + if (decodeOnly ? _kvEvictedCount > 0 : _snapKvEffectiveBudget > 0) + throw new NotSupportedException(decodeOnly + ? "CUDA batched decode is not supported on a SnapKV-compacted cache (physical slot != logical position)." + : "CUDA continuous batching is not supported with an active SnapKV budget (eviction runs only on a whole-prompt prefill of the owned cache)."); if (_hp.FinalLogitSoftcap > 0f) throw new NotSupportedException( "CUDA continuous batching is not supported with a final-logit softcap (the batched decode finisher does not apply it)."); // Reaching here, only the weight-dtype loop can make it unsupported. - if (!SupportsContinuousBatching) + if (!DenseBatchedDecodeSupported()) throw new NotSupportedException( "CUDA continuous batching requires every trunk + output weight in a GEMM-N-batchable " + "dtype (Q4_K/Q5_K/Q6_K/Q8_0/F32); a Q4_0 weight has no batched-decode matvec kernel."); @@ -3227,7 +3241,7 @@ internal float[][] BatchForwardMulti(int[] tokens, int[] positions, CudaSequence ArgumentNullException.ThrowIfNull(tokens); ArgumentNullException.ThrowIfNull(positions); ArgumentNullException.ThrowIfNull(caches); - ThrowIfBatchingUnsupported(); + ThrowIfBatchingUnsupported(decodeOnly: true); int N = tokens.Length; if (N == 0) return Array.Empty(); if (positions.Length != N || caches.Length != N) @@ -3446,15 +3460,16 @@ internal float[][] BatchForwardMulti(int[] tokens, int[] positions, CudaSequence // ── Speculative-decode batched verify (issue #207) ────────────────────────────────── /// - /// Whether can run: the dense continuous-batching-capable - /// configuration ( — non-MoE, non-Gemma-4, no - /// TurboQuant, no active SnapKV budget, no final-logit softcap, GEMM-N-batchable weights) - /// with an uncompacted cache. The eviction term is redundant today (eviction needs an - /// active budget, which is already excluded) but mirrors the GDN passes' gate so a future - /// relaxation of the budget exclusion can't silently re-enable verify on a compacted - /// cache (physical slot != position). + /// Whether can run: the dense batched-decode configuration + /// ( — non-MoE, non-Gemma-4, no TurboQuant, no + /// final-logit softcap, GEMM-N-batchable weights) with an uncompacted cache. Unlike + /// , a CONFIGURED SnapKV budget does not disable + /// verify — only an actual prefill-time eviction does (then physical slot != logical + /// position and the batched kernels would mis-index). Dynamic: flips false after such a + /// prefill, so the speculative decoder (which re-checks per step) degrades to sequential + /// verify — the same once-evicted gating the GDN passes use (#130). /// - public bool SupportsBatchVerify => _kvEvictedCount == 0 && SupportsContinuousBatching; + public bool SupportsBatchVerify => _kvEvictedCount == 0 && DenseBatchedDecodeSupported(); /// /// Batched k-token verify for single-user speculative decoding (issue #207): one packed diff --git a/src/SharpInference.Engine/SpeculativeDecoder.cs b/src/SharpInference.Engine/SpeculativeDecoder.cs index 30cd992..1aac4a9 100644 --- a/src/SharpInference.Engine/SpeculativeDecoder.cs +++ b/src/SharpInference.Engine/SpeculativeDecoder.cs @@ -3,14 +3,17 @@ namespace SharpInference.Engine; /// -/// Speculative decoding (greedy): a small draft model generates K tokens which the target model -/// verifies via a batched forward pass, accepting each where they agree and generating a -/// correction token where they first diverge. +/// Speculative decoding (greedy): a small draft model proposes tokens which the target model +/// verifies via a batched forward pass, accepting each where they agree and correcting at the +/// first divergence. /// -/// Expected speedup: E[tokens/step] / E[target-forwards/step] where both equal -/// (1-α^(k+1))/(1-α) for acceptance rate α, but target uses batched matmuls (k tokens in one -/// Prefill-style call) reducing memory bandwidth from k×1 to approximately 1×batch. -/// Typical speedup 1.3–2× depending on model size ratio and acceptance rate. +/// Each step packs the CERTAIN next token (argmax of the saved target logits) together with +/// k−1 draft proposals into ONE batched target pass (the llama.cpp formulation): the batch +/// yields both the verification logits and the next step's saved logits, so the target runs +/// exactly one batched pass per step — no separate correction-commit forward. On memory-bound +/// decode paths the batched pass costs ~1–2× a single forward (issue #194/#207 weight +/// amortization), so the speedup is ≈ E[tokens/step] / (cost_batch/cost_forward + draft +/// overhead), with E[tokens/step] = 1 + E[accepted of k−1] for per-token acceptance α. /// /// Both target and draft must share the same tokenizer (same vocab size). /// Note: does NOT take ownership of the forward pass instances. @@ -22,10 +25,13 @@ public sealed class SpeculativeDecoder private readonly bool _batchVerify; private int _lookahead; - // Generation state + // Generation state. Invariant at step boundaries: both caches hold exactly _nextPos + // positions and _savedTargetLogits are the target's logits after the token at + // _nextPos−1 (so argmax(_savedTargetLogits) is the next emitted token, by greedy + // construction). The draft's own last logits are not part of the state — each step's + // first proposal requires forwarding the certain token through the draft anyway. private int _nextPos; private float[] _savedTargetLogits; - private float[] _savedDraftLogits; // Acceptance statistics private long _totalAccepted; @@ -67,7 +73,6 @@ public SpeculativeDecoder(IForwardPass target, IForwardPass draft, int lookahead _batchVerify = Environment.GetEnvironmentVariable("SHARPI_SPEC_BATCH_VERIFY") != "0"; _lookahead = Math.Max(1, lookahead); _savedTargetLogits = new float[target.VocabSize]; - _savedDraftLogits = new float[draft.VocabSize]; } /// Adaptive lookahead: increase/decrease based on recent acceptance rate. @@ -86,7 +91,7 @@ public int Lookahead /// Cumulative milliseconds spent batch-verifying with the target. public double VerifyMs { get; private set; } - /// Cumulative milliseconds spent in cache truncation + correction-token forwards. + /// Cumulative milliseconds spent in cache truncation + draft-sync forwards. public double CommitMs { get; private set; } /// @@ -95,12 +100,14 @@ public int Lookahead /// /// Number of prompt tokens (= new KV cache length). /// Logits from the target's last prefill step (vocab-size span). - /// Logits from the draft's last prefill step (vocab-size span). - public void Initialize(int prefillLength, ReadOnlySpan targetLogits, ReadOnlySpan draftLogits) + /// Accepted for call-site compatibility but no longer consulted: + /// each step's first draft proposal requires forwarding the (certain) next token through + /// the draft anyway, which produces fresher logits than the prefill tail. The draft must + /// still be PREFILLED before calling this — its cache has to hold the prompt. + public void Initialize(int prefillLength, ReadOnlySpan targetLogits, ReadOnlySpan draftLogits = default) { _nextPos = prefillLength; targetLogits.CopyTo(_savedTargetLogits); - draftLogits.CopyTo(_savedDraftLogits); _totalAccepted = 0; _totalEmitted = 0; DraftMs = 0; @@ -134,102 +141,77 @@ public void Decode(int maxTokens, ReadOnlySpan stopTokenIds, Action em } /// - /// Run one speculative step: draft k tokens, batch-verify with target, accept greedily. - /// Returns the emitted token array (accepted_count + 1 tokens, including the correction). - /// Updates internal state (_nextPos, _savedTargetLogits, _savedDraftLogits). + /// Run one speculative step (llama.cpp formulation): pack the certain next token with + /// k−1 draft proposals, batch-verify all k in ONE target pass, accept greedily. + /// Returns the emitted token array (1 + accepted tokens). Updates internal state + /// (_nextPos, _savedTargetLogits). /// private int[] Step(int k) { int P = _nextPos; - int vocabSize = _target.VocabSize; - // ── Draft phase ────────────────────────────────────────────────────────── - // d[0] is free: argmax of saved draft logits (no forward pass needed). - // d[1..k-1] require k-1 draft Forward calls, appending d[0..k-2] to draft cache. - _phaseSw.Restart(); - var draftTokens = new int[k]; - var draftLogitsPerPos = new float[k][]; - - draftLogitsPerPos[0] = _savedDraftLogits; - draftTokens[0] = ArgMax(_savedDraftLogits); + // tokens[0] is CERTAIN (greedy argmax of the saved target logits — it would be + // emitted by plain decode too); tokens[1..k-1] are draft proposals chained from it. + var tokens = new int[k]; + tokens[0] = ArgMax(_savedTargetLogits); + // ── Draft phase: k−1 draft forwards propose tokens[1..k-1] ─────────────── + _phaseSw.Restart(); for (int i = 1; i < k; i++) { - var logits = _draft.Forward(draftTokens[i - 1], P + i - 1); - draftLogitsPerPos[i] = new float[vocabSize]; - logits.CopyTo(draftLogitsPerPos[i]); - draftTokens[i] = ArgMax(draftLogitsPerPos[i]); + var logits = _draft.Forward(tokens[i - 1], P + i - 1); + tokens[i] = ArgMax(logits); } - // Draft cache is now at P + k - 1 (appended d[0..k-2]). + // Draft cache is now at P + k - 1 (appended tokens[0..k-2]). DraftMs += _phaseSw.Elapsed.TotalMilliseconds; // ── Target batch-verify ────────────────────────────────────────────────── - // Process d[0..k-1] in one batched forward pass. - // targetLogitsFromBatch[i] = P_target(·|ctx + d[0..i]) (logits AFTER d[i]). - // After this call, target cache is at P + k. + // One packed pass over tokens[0..k-1] at positions P..P+k-1. + // batch[i] = logits AFTER tokens[i]. Target cache advances to P + k. _phaseSw.Restart(); - float[][] targetLogitsFromBatch = BatchVerifyTarget(draftTokens, P); + float[][] batch = BatchVerifyTarget(tokens, P); VerifyMs += _phaseSw.Elapsed.TotalMilliseconds; - // targetLogits[0] = saved (before d[0]) - // targetLogits[i+1] = after d[i] (from batch) - // We use targetLogits[i] to verify d[i]: accept if argmax == d[i]. - - // ── Greedy accept/reject ───────────────────────────────────────────────── + // ── Greedy accept ──────────────────────────────────────────────────────── + // tokens[i] (i ≥ 1) is accepted iff the target's logits after tokens[i-1] + // pick it. tokens[0] needs no check (it IS the target's pick). int accepted = 0; - for (int i = 0; i < k; i++) + for (int i = 1; i < k; i++) { - float[] tLogits = i == 0 ? _savedTargetLogits : targetLogitsFromBatch[i - 1]; - if (ArgMax(tLogits) == draftTokens[i]) - accepted++; - else - break; + if (ArgMax(batch[i - 1]) == tokens[i]) accepted++; + else break; } - // targetLogits at position `accepted` (logits for deciding correction token): - float[] correctionSourceLogits = accepted == k - ? targetLogitsFromBatch[k - 1] // all accepted: logits after d[k-1] - : (accepted == 0 ? _savedTargetLogits : targetLogitsFromBatch[accepted - 1]); - - int correction = ArgMax(correctionSourceLogits); - - // Update acceptance stats _totalAccepted += accepted; _totalEmitted += accepted + 1; - // ── Truncate caches to accepted position ───────────────────────────────── - // Target is at P+k; truncate to P+accepted. + // ── Roll both caches back to the last emitted token ────────────────────── + // Emitted: tokens[0..accepted] → new length newPos. The batch already holds the + // logits after the last emitted token (batch[accepted]) — they seed the next step, + // so NO separate correction forward is needed on the target. _phaseSw.Restart(); - _target.TruncateTo(P + accepted); + int newPos = P + 1 + accepted; + _target.TruncateTo(newPos); - // Draft is at P+k-1; truncate to P+accepted. - // For all-accepted (accepted == k): need to sync d[k-1] into draft first. - if (accepted == k) + if (accepted == k - 1) { - // Draft phase only appended d[0..k-2]. Sync d[k-1] now. - _draft.Forward(draftTokens[k - 1], P + k - 1); - // Draft cache is now at P+k. No truncation needed before commit. + // Fully accepted: the draft never processed tokens[k-1] (its cache is at + // P+k-1 = newPos-1). Sync it so the next step's draft chain starts at newPos. + _draft.Forward(tokens[k - 1], P + k - 1); } else { - _draft.TruncateTo(P + accepted); + _draft.TruncateTo(newPos); } - - // ── Commit correction to both caches ───────────────────────────────────── - int commitPos = accepted == k ? P + k : P + accepted; - var newTargetLogits = _target.Forward(correction, commitPos); - var newDraftLogits = _draft.Forward(correction, commitPos); CommitMs += _phaseSw.Elapsed.TotalMilliseconds; // ── Update state ───────────────────────────────────────────────────────── - _nextPos = commitPos + 1; - newTargetLogits.CopyTo(_savedTargetLogits); - newDraftLogits.CopyTo(_savedDraftLogits); + _nextPos = newPos; + batch[accepted].CopyTo(_savedTargetLogits, 0); - // ── Build emitted token list: d[0..accepted-1] + correction ────────────── + // ── Emitted token list: the certain token + accepted proposals ─────────── var emitted = new int[accepted + 1]; - for (int i = 0; i < accepted; i++) emitted[i] = draftTokens[i]; - emitted[accepted] = correction; + for (int i = 0; i <= accepted; i++) emitted[i] = tokens[i]; return emitted; } @@ -256,6 +238,17 @@ private float[][] BatchVerifyTarget(int[] draftTokens, int startPos) return result; } + private static int ArgMax(ReadOnlySpan logits) + { + int best = 0; + float bestVal = logits[0]; + for (int i = 1; i < logits.Length; i++) + { + if (logits[i] > bestVal) { bestVal = logits[i]; best = i; } + } + return best; + } + private static int ArgMax(float[] logits) { int best = 0; diff --git a/tests/SharpInference.Tests.ForwardPass/SpeculativeDecoderTests.cs b/tests/SharpInference.Tests.ForwardPass/SpeculativeDecoderTests.cs index c8f0d47..c265aa9 100644 --- a/tests/SharpInference.Tests.ForwardPass/SpeculativeDecoderTests.cs +++ b/tests/SharpInference.Tests.ForwardPass/SpeculativeDecoderTests.cs @@ -76,14 +76,15 @@ public void Decode_BatchVerifyCapableTarget_RoutesThroughBatchVerify() spec.Decode(6, [], emitted.Add); // Draft and target agree everywhere (same chain), so both steps accept fully: - // step 1 (k=3) emits 3+1 tokens, step 2 (k=min(3, remaining=2)=2) emits 2+1. + // each k=3 step packs [certain, d1, d2] into one verify and emits all 3. Assert.Equal(new[] { 2, 3, 4, 5, 6, 7 }, emitted); Assert.Equal(2, target.BatchVerifyCalls); - // Target Forward runs only for the per-step correction commit, never for verify. - Assert.Equal(2, target.ForwardCalls); - // Full acceptance: 5 accepted out of 7 emitted (corrections always count against - // the rate — k/(k+1) per step is the ceiling). - Assert.Equal(5f / 7f, spec.AcceptanceRate, 3); + // The certain token rides in the verify batch, so the target NEVER runs a + // single-token Forward — one batched pass per step is the whole target cost. + Assert.Equal(0, target.ForwardCalls); + // 2 accepted proposals out of 3 emitted per step (the certain token never + // counts as accepted): 4/6. + Assert.Equal(4f / 6f, spec.AcceptanceRate, 3); } [Fact] @@ -108,11 +109,11 @@ public void Decode_KillSwitch_FallsBackToSequentialForward() var emitted = new List(); spec.Decode(6, [], emitted.Add); - // Identical emitted sequence, but verification ran as k sequential Forwards plus - // the per-step correction commit: step 1 (k=3) 3+1, step 2 (k=2) 2+1 → 7 total. + // Identical emitted sequence, but verification ran as k sequential Forwards: + // 2 steps × 3 = 6 total (no batched pass, no separate commit forward either). Assert.Equal(new[] { 2, 3, 4, 5, 6, 7 }, emitted); Assert.Equal(0, target.BatchVerifyCalls); - Assert.Equal(7, target.ForwardCalls); + Assert.Equal(6, target.ForwardCalls); } [Fact] From 22d11c27df8b3d142924f3fd9ad36e8ffc0bb274 Mon Sep 17 00:00:00 2001 From: Pekka Heikura Date: Thu, 11 Jun 2026 09:06:14 +0300 Subject: [PATCH 3/6] feat(engine): #207 prompt-lookup (n-gram) drafting for speculative decoding PromptLookupDraft: model-free draft source (llama.cpp lookup-decoding analog) that proposes continuation tokens by matching the tail n-gram (max..min, most recent occurrence wins, min=2 to suppress junk 1-gram fires) of prompt + generated history. Zero forward-pass cost; no match -> no proposals -> the step degrades to a plain single-token decode, so the floor is ~baseline. SpeculativeDecoder gains a lookup mode (second ctor + prompt-token Initialize overload): the draft phase consults the matcher instead of running k-1 draft forwards, and there is no draft cache to truncate/sync. Verify, accept, and rollback are unchanged - emitted tokens are the target greedy chain regardless of proposal quality. CLI: --draft-lookup (mutually exclusive with --draft-model), same CPU / full- CUDA-offload targets as model-draft speculation. Qwen3-8B Q4_K_M, 4070 Ti, greedy (baseline 74.9 t/s): echo-heavy prompt ("repeat this text three times") k=4/8/16 -> 170.5 / 182.2 / 148.7 t/s (2.4x at k=8, draft 0 ms); adversarial no-copy prompt k=4/8 -> 76.9 / 71.8 t/s (8-9% acceptance, floor holds within ~4% of baseline). Tests: PromptLookupDraft matcher units (longest-n-gram preference, recency, caps, reset), scripted lookup-mode decoder trace (proposal accept + floor path, no target Forwards), CUDA e2e 48-token greedy parity with lookup drafting on Qwen3-8B. Co-Authored-By: Claude Fable 5 --- src/SharpInference.Cli/RunCommand.cs | 79 +++++++++++++---- .../PromptLookupDraft.cs | 83 +++++++++++++++++ .../SpeculativeDecoder.cs | 88 ++++++++++++++++--- .../CudaSpecBatchVerifyTests.cs | 47 ++++++++++ .../PromptLookupDraftTests.cs | 78 ++++++++++++++++ .../SpeculativeDecoderTests.cs | 37 ++++++++ 6 files changed, 384 insertions(+), 28 deletions(-) create mode 100644 src/SharpInference.Engine/PromptLookupDraft.cs create mode 100644 tests/SharpInference.Tests.ForwardPass/PromptLookupDraftTests.cs diff --git a/src/SharpInference.Cli/RunCommand.cs b/src/SharpInference.Cli/RunCommand.cs index 8c28d30..7077f7e 100644 --- a/src/SharpInference.Cli/RunCommand.cs +++ b/src/SharpInference.Cli/RunCommand.cs @@ -103,6 +103,11 @@ public sealed class Settings : CommandSettings [DefaultValue(4)] public int SpecLookahead { get; init; } + [CommandOption("--draft-lookup")] + [Description("Speculative decoding via prompt-lookup (n-gram) drafting — proposes tokens by matching the generated tail against prompt+history; no draft model needed (greedy only, requires --temp 0)")] + [DefaultValue(false)] + public bool DraftLookup { get; init; } + [CommandOption("--spec-type")] [Description("Speculative decoding type: auto (default; enables MTP when supported), none, mtp (alias: draft-mtp). Mirrors llama.cpp.")] [DefaultValue("auto")] @@ -274,7 +279,7 @@ protected override int Execute(CommandContext context, Settings settings, Cancel AnsiConsole.MarkupLine("[red]Error:[/] TurboQuant is not supported for hybrid GDN models (no KV cache on GDN layers)."); return 1; } - if (hp.IsHybridSsm && settings.DraftModelPath is not null) + if (hp.IsHybridSsm && (settings.DraftModelPath is not null || settings.DraftLookup)) { AnsiConsole.MarkupLine("[red]Error:[/] Speculative decoding is not supported for hybrid GDN models (GDN state is destructively updated and cannot be rewound)."); return 1; @@ -638,9 +643,14 @@ protected override int Execute(CommandContext context, Settings settings, Cancel // packed k-token verify via CudaForwardPass.BatchVerify). Vulkan and the partial- // offload hybrids fall back to normal generation: without a batched verify, // speculation costs k sequential target forwards per step and is never a win. - if (settings.DraftModelPath is not null) + if (settings.DraftModelPath is not null || settings.DraftLookup) { bool cudaSpecTarget = gpuFwd is CudaForwardPass { SupportsBatchVerify: true }; + if (settings.DraftModelPath is not null && settings.DraftLookup) + { + AnsiConsole.MarkupLine("[red]Error:[/] --draft-model and --draft-lookup are mutually exclusive."); + return 1; + } if (settings.Temperature > 0f) { AnsiConsole.MarkupLine("[yellow]Warning:[/] Speculative decoding requires greedy sampling (--temp 0). Falling back to normal generation."); @@ -649,6 +659,32 @@ protected override int Execute(CommandContext context, Settings settings, Cancel { AnsiConsole.MarkupLine("[yellow]Warning:[/] Speculative decoding requires pure CPU (-g 0) or full CUDA offload of a dense model. Falling back to normal generation."); } + else if (settings.DraftLookup) + { + // Prompt-lookup drafting (issue #207): no draft model — proposals come from + // n-gram matches against prompt + generated history, verified by the same + // batched-verify step. Floor is ~baseline (no match → plain decode step). + try + { + IForwardPass lookupTarget = cudaSpecTarget ? (CudaForwardPass)gpuFwd! : fwd!; + AnsiConsole.MarkupLine($"[dim]Speculative decoding: prompt-lookup (n-gram) drafting | Lookahead k={settings.SpecLookahead}[/]"); + if (settings.Prompt is not null) + return RunSpeculativeSinglePrompt(settings, lookupTarget, null, tokenizer, sp); + return RunSpeculativeInteractive(settings, lookupTarget, null, tokenizer, sp); + } + catch (Exception ex) + { + Console.Error.WriteLine(ex); + return 1; + } + finally + { + gpuFwd?.Dispose(); + gpuBackend?.Dispose(); + fwd?.Dispose(); + hybridFwd?.Dispose(); + } + } else if (!File.Exists(settings.DraftModelPath)) { AnsiConsole.MarkupLine($"[red]Error:[/] Draft model not found: {settings.DraftModelPath}"); @@ -733,7 +769,7 @@ protected override int Execute(CommandContext context, Settings settings, Cancel } private static int RunSpeculativeSinglePrompt(Settings s, - IForwardPass target, IForwardPass draft, + IForwardPass target, IForwardPass? draft, GgufTokenizer tok, SamplingParams sp) { var prompt = FormatPrompt(s.Prompt!, s.SystemPrompt, enableThinking: !s_noThinking); @@ -743,19 +779,28 @@ private static int RunSpeculativeSinglePrompt(Settings s, Console.Write(s.Prompt); var sw = Stopwatch.StartNew(); - // Prefill both models with the same prompt (batched-trunk path on both — the - // per-token Forward loop this replaces was ~30× slower on the CUDA target). + // Prefill (batched-trunk path — the per-token Forward loop this replaces was + // ~30× slower on the CUDA target). A null draft means prompt-lookup mode. ReadOnlySpan targetLogits = target.Prefill(tokens); - ReadOnlySpan draftLogits = draft.Prefill(tokens); + ReadOnlySpan draftLogits = draft is not null ? draft.Prefill(tokens) : default; var prefillMs = sw.Elapsed.TotalMilliseconds; - var spec = new SpeculativeDecoder(target, draft, s.SpecLookahead); - spec.Initialize(tokens.Count, targetLogits, draftLogits); + SpeculativeDecoder spec; + if (draft is not null) + { + spec = new SpeculativeDecoder(target, draft, s.SpecLookahead); + spec.Initialize(tokens.Count, targetLogits, draftLogits); + } + else + { + spec = new SpeculativeDecoder(target, new PromptLookupDraft(), s.SpecLookahead); + spec.Initialize(tokens, targetLogits); + } // Bound generation by BOTH context windows (the draft's may be smaller — the CUDA // spec path caps its KV ring), leaving lookahead headroom for the last spec step. int maxNew = Math.Min(sp.MaxNewTokens, - Math.Min(target.MaxSeqLen, draft.MaxSeqLen) - tokens.Count - s.SpecLookahead - 1); + Math.Min(target.MaxSeqLen, draft?.MaxSeqLen ?? int.MaxValue) - tokens.Count - s.SpecLookahead - 1); if (maxNew < sp.MaxNewTokens) AnsiConsole.MarkupLine($"[yellow]Note:[/] generation capped at {maxNew} tokens by the context window."); @@ -785,11 +830,13 @@ private static int RunSpeculativeSinglePrompt(Settings s, } private static int RunSpeculativeInteractive(Settings s, - IForwardPass target, IForwardPass draft, + IForwardPass target, IForwardPass? draft, GgufTokenizer tok, SamplingParams sp) { AnsiConsole.MarkupLine("[green]Interactive chat (speculative decoding).[/] Type a message, or [yellow]/exit[/] to quit.\n"); - var spec = new SpeculativeDecoder(target, draft, s.SpecLookahead); + var spec = draft is not null + ? new SpeculativeDecoder(target, draft, s.SpecLookahead) + : new SpeculativeDecoder(target, new PromptLookupDraft(), s.SpecLookahead); while (true) { @@ -802,16 +849,18 @@ private static int RunSpeculativeInteractive(Settings s, var tokens = tok.Encode(prompt); target.ResetCache(); - draft.ResetCache(); + draft?.ResetCache(); var sw = Stopwatch.StartNew(); ReadOnlySpan targetLogits = target.Prefill(tokens); - ReadOnlySpan draftLogits = draft.Prefill(tokens); - spec.Initialize(tokens.Count, targetLogits, draftLogits); + if (draft is not null) + spec.Initialize(tokens.Count, targetLogits, draft.Prefill(tokens)); + else + spec.Initialize(tokens, targetLogits); int maxNew = Math.Min(sp.MaxNewTokens, - Math.Min(target.MaxSeqLen, draft.MaxSeqLen) - tokens.Count - s.SpecLookahead - 1); + Math.Min(target.MaxSeqLen, draft?.MaxSeqLen ?? int.MaxValue) - tokens.Count - s.SpecLookahead - 1); sw.Restart(); int generated = 0; diff --git a/src/SharpInference.Engine/PromptLookupDraft.cs b/src/SharpInference.Engine/PromptLookupDraft.cs new file mode 100644 index 0000000..4bdb1c6 --- /dev/null +++ b/src/SharpInference.Engine/PromptLookupDraft.cs @@ -0,0 +1,83 @@ +namespace SharpInference.Engine; + +/// +/// Model-free draft source for speculative decoding (issue #207, llama.cpp lookup-decoding +/// analog): proposes continuation tokens by matching the tail n-gram of the generated +/// context against the prompt + everything generated so far. Proposals cost no forward +/// passes, so on copy-heavy workloads (RAG quotation, summarization, code edits, +/// self-repetitive output) the speculative step gets its draft for free; when nothing +/// matches it proposes nothing and the step degrades to a plain single-token decode. +/// +/// Matching: longest tail n-gram first ( down to +/// ), most recent occurrence wins (generated text repeats locally). +/// defaults to 2 — 1-gram matches fire constantly and mostly +/// propose junk, which makes every step pay a wider verify batch for nothing. +/// +public sealed class PromptLookupDraft +{ + private readonly List _history = new(); + + /// Largest tail n-gram length tried for a match. + public int NgramMax { get; } + + /// Smallest tail n-gram length tried before giving up (no proposal). + public int NgramMin { get; } + + public PromptLookupDraft(int ngramMax = 3, int ngramMin = 2) + { + if (ngramMin < 1) throw new ArgumentOutOfRangeException(nameof(ngramMin)); + if (ngramMax < ngramMin) throw new ArgumentOutOfRangeException(nameof(ngramMax)); + NgramMax = ngramMax; + NgramMin = ngramMin; + } + + /// Tokens observed so far (prompt + emitted). + public int Count => _history.Count; + + /// Reset the history to a new prompt (start of a generation). + public void Reset(IReadOnlyList promptTokens) + { + _history.Clear(); + if (promptTokens is not null) _history.AddRange(promptTokens); + } + + /// Record an emitted token so future proposals can match it. + public void Append(int token) => _history.Add(token); + + /// + /// Propose up to continuation tokens for the current + /// history. Returns an empty array when no tail n-gram of length + /// [, ] recurs earlier in the history. + /// + public int[] Propose(int maxTokens) + { + var h = _history; + int len = h.Count; + if (maxTokens <= 0) return Array.Empty(); + + for (int n = NgramMax; n >= NgramMin; n--) + { + if (len < n + 1) continue; + + // Most recent earlier occurrence of the tail h[len-n .. len): candidate start + // positions i walk backward; the matched occurrence must end before the tail + // ends (i + n < len) so there is at least one continuation token to copy. + for (int i = len - n - 1; i >= 0; i--) + { + bool match = true; + for (int j = 0; j < n; j++) + { + if (h[i + j] != h[len - n + j]) { match = false; break; } + } + if (!match) continue; + + int start = i + n; + int count = Math.Min(maxTokens, len - start); + var proposal = new int[count]; + for (int t = 0; t < count; t++) proposal[t] = h[start + t]; + return proposal; + } + } + return Array.Empty(); + } +} diff --git a/src/SharpInference.Engine/SpeculativeDecoder.cs b/src/SharpInference.Engine/SpeculativeDecoder.cs index 1aac4a9..7525371 100644 --- a/src/SharpInference.Engine/SpeculativeDecoder.cs +++ b/src/SharpInference.Engine/SpeculativeDecoder.cs @@ -21,7 +21,8 @@ namespace SharpInference.Engine; public sealed class SpeculativeDecoder { private readonly IForwardPass _target; - private readonly IForwardPass _draft; + private readonly IForwardPass? _draft; // model-draft mode + private readonly PromptLookupDraft? _lookup; // prompt-lookup mode (issue #207) private readonly bool _batchVerify; private int _lookahead; @@ -41,6 +42,27 @@ public sealed class SpeculativeDecoder // batch-verifying, and committing (truncate + correction forwards) across all steps. private readonly System.Diagnostics.Stopwatch _phaseSw = new(); + /// + /// Prompt-lookup mode (issue #207): proposals come from + /// (n-gram matches against prompt + generated history) instead of a draft model — + /// zero draft-forward cost, and a step with no match degrades to plain decode. + /// Initialize with the prompt-token overload so the lookup sees the prompt. + /// + public SpeculativeDecoder(IForwardPass target, PromptLookupDraft lookup, int lookahead = 4) + { + ArgumentNullException.ThrowIfNull(lookup); + if (!target.SupportsPartialRewind) + throw new ArgumentException( + $"Speculative decoding requires the target forward pass to support partial rewind; " + + $"{target.GetType().Name} does not.", + nameof(target)); + _target = target; + _lookup = lookup; + _batchVerify = Environment.GetEnvironmentVariable("SHARPI_SPEC_BATCH_VERIFY") != "0"; + _lookahead = Math.Max(1, lookahead); + _savedTargetLogits = new float[target.VocabSize]; + } + public SpeculativeDecoder(IForwardPass target, IForwardPass draft, int lookahead = 4) { if (target.VocabSize != draft.VocabSize) @@ -115,6 +137,23 @@ public void Initialize(int prefillLength, ReadOnlySpan targetLogits, Read CommitMs = 0; } + /// + /// Prompt-lookup-mode initialization: seeds the lookup history with the prompt tokens + /// (proposals match against prompt + generated text). Call after prefilling the target. + /// + /// The prompt as fed to the target's prefill. + /// Logits from the target's last prefill step. + public void Initialize(IReadOnlyList promptTokens, ReadOnlySpan targetLogits) + { + if (_lookup is null) + throw new InvalidOperationException( + "Prompt-token initialization is only valid in prompt-lookup mode; " + + "use Initialize(prefillLength, targetLogits, draftLogits) with a draft model."); + ArgumentNullException.ThrowIfNull(promptTokens); + _lookup.Reset(promptTokens); + Initialize(promptTokens.Count, targetLogits); + } + /// /// Decode up to tokens using greedy speculative decoding, /// invoking for each accepted or correction token. @@ -151,18 +190,35 @@ private int[] Step(int k) int P = _nextPos; // tokens[0] is CERTAIN (greedy argmax of the saved target logits — it would be - // emitted by plain decode too); tokens[1..k-1] are draft proposals chained from it. - var tokens = new int[k]; - tokens[0] = ArgMax(_savedTargetLogits); + // emitted by plain decode too); tokens[1..] are draft proposals chained from it. + int n0 = ArgMax(_savedTargetLogits); - // ── Draft phase: k−1 draft forwards propose tokens[1..k-1] ─────────────── + // ── Draft phase ────────────────────────────────────────────────────────── _phaseSw.Restart(); - for (int i = 1; i < k; i++) + int[] tokens; + if (_lookup is not null) + { + // Prompt-lookup: tokens[0] is certain, so it joins the history before + // matching; proposals are whatever the tail n-gram match yields (possibly + // none — then the step is a plain single-token decode). + _lookup.Append(n0); + int[] proposals = _lookup.Propose(k - 1); + tokens = new int[1 + proposals.Length]; + tokens[0] = n0; + proposals.CopyTo(tokens, 1); + } + else { - var logits = _draft.Forward(tokens[i - 1], P + i - 1); - tokens[i] = ArgMax(logits); + // Model draft: k−1 draft forwards propose tokens[1..k-1]; the draft cache + // advances to P + k - 1 (appended tokens[0..k-2]). + tokens = new int[k]; + tokens[0] = n0; + for (int i = 1; i < k; i++) + { + var logits = _draft!.Forward(tokens[i - 1], P + i - 1); + tokens[i] = ArgMax(logits); + } } - // Draft cache is now at P + k - 1 (appended tokens[0..k-2]). DraftMs += _phaseSw.Elapsed.TotalMilliseconds; // ── Target batch-verify ────────────────────────────────────────────────── @@ -176,7 +232,7 @@ private int[] Step(int k) // tokens[i] (i ≥ 1) is accepted iff the target's logits after tokens[i-1] // pick it. tokens[0] needs no check (it IS the target's pick). int accepted = 0; - for (int i = 1; i < k; i++) + for (int i = 1; i < tokens.Length; i++) { if (ArgMax(batch[i - 1]) == tokens[i]) accepted++; else break; @@ -193,15 +249,21 @@ private int[] Step(int k) int newPos = P + 1 + accepted; _target.TruncateTo(newPos); - if (accepted == k - 1) + if (_lookup is not null) + { + // No draft cache to manage; record the accepted proposals (tokens[0] was + // appended before matching) so future tail n-grams can match them. + for (int i = 1; i <= accepted; i++) _lookup.Append(tokens[i]); + } + else if (accepted == tokens.Length - 1) { // Fully accepted: the draft never processed tokens[k-1] (its cache is at // P+k-1 = newPos-1). Sync it so the next step's draft chain starts at newPos. - _draft.Forward(tokens[k - 1], P + k - 1); + _draft!.Forward(tokens[^1], P + tokens.Length - 1); } else { - _draft.TruncateTo(newPos); + _draft!.TruncateTo(newPos); } CommitMs += _phaseSw.Elapsed.TotalMilliseconds; diff --git a/tests/SharpInference.Tests.ForwardPass/CudaSpecBatchVerifyTests.cs b/tests/SharpInference.Tests.ForwardPass/CudaSpecBatchVerifyTests.cs index b4e8b93..b34e9d2 100644 --- a/tests/SharpInference.Tests.ForwardPass/CudaSpecBatchVerifyTests.cs +++ b/tests/SharpInference.Tests.ForwardPass/CudaSpecBatchVerifyTests.cs @@ -263,4 +263,51 @@ public void Qwen3_8B_SpecDecode_GreedyParity_E2E() Assert.Equal(baseline, emitted); } + + /// + /// E2E greedy parity for prompt-lookup mode (issue #207 goal 3): with the n-gram + /// lookup draft (zero draft forwards, proposal quality irrelevant to correctness), + /// the emitted stream must still EXACTLY equal the target's non-spec greedy + /// continuation — accepted proposals only ever shortcut tokens the target would + /// have picked anyway. + /// + [Fact] + public void Qwen3_8B_SpecDecode_PromptLookup_GreedyParity_E2E() + { + using var gpu = TryCreate(); + if (gpu is null) return; + var targetPath = FindModelPath(TargetModelFile); + if (targetPath is null) return; + + const int DecodeTokens = 48; + + using var targetModel = GgufModel.Open(targetPath); + var targetHp = ModelHyperparams.FromGgufMetadata(targetModel.Metadata, targetModel); + using var target = NewFwd(targetModel, gpu, targetHp); + Assert.True(target.SupportsBatchVerify); + + // Non-spec greedy baseline. + target.ResetCache(); + var logits = target.Prefill(Prompt); + int P = Prompt.Length; + var baseline = new List(); + int tok = Argmax(logits); + for (int i = 0; i < DecodeTokens; i++) + { + baseline.Add(tok); + logits = target.Forward(tok, P + i); + tok = Argmax(logits); + } + + // Prompt-lookup spec decode. + target.ResetCache(); + var targetLogits = target.Prefill(Prompt).ToArray(); + var spec = new SpeculativeDecoder(target, new PromptLookupDraft(), lookahead: 4); + spec.Initialize(Prompt, targetLogits); + + var emitted = new List(); + spec.Decode(DecodeTokens, [], emitted.Add); + + Assert.Equal(baseline, emitted); + } } diff --git a/tests/SharpInference.Tests.ForwardPass/PromptLookupDraftTests.cs b/tests/SharpInference.Tests.ForwardPass/PromptLookupDraftTests.cs new file mode 100644 index 0000000..13eaad6 --- /dev/null +++ b/tests/SharpInference.Tests.ForwardPass/PromptLookupDraftTests.cs @@ -0,0 +1,78 @@ +using SharpInference.Engine; + +namespace SharpInference.Tests.ForwardPass; + +/// +/// Issue #207: prompt-lookup (n-gram) drafting — the model-free draft source for +/// speculative decoding. Pure token-matching logic, no models involved. +/// +public sealed class PromptLookupDraftTests +{ + [Fact] + public void Propose_TailBigramRecurs_ReturnsContinuationOfMostRecentMatch() + { + var d = new PromptLookupDraft(ngramMax: 3, ngramMin: 2); + // Tail [10, 11] matches index 0; continuation is [12, 10, 11]. + d.Reset(new[] { 10, 11, 12, 10, 11 }); + + Assert.Equal(new[] { 12, 10, 11 }, d.Propose(3)); + // maxTokens caps the proposal length. + Assert.Equal(new[] { 12 }, d.Propose(1)); + } + + [Fact] + public void Propose_LongerNgramPreferred() + { + var d = new PromptLookupDraft(ngramMax: 3, ngramMin: 2); + // Tail trigram [7, 8, 9] matches at index 0 (continuation 1); the bigram [8, 9] + // ALSO matches at index 4 (continuation 2). The trigram match must win. + d.Reset(new[] { 7, 8, 9, 1, 8, 9, 2, 7, 8, 9 }); + + Assert.Equal(new[] { 1 }, d.Propose(1)); + } + + [Fact] + public void Propose_MostRecentOccurrenceWins() + { + var d = new PromptLookupDraft(ngramMax: 2, ngramMin: 2); + // [5, 6] occurs twice; the later (index 3) continuation 9 must win over 7. + d.Reset(new[] { 5, 6, 7, 5, 6, 9, 5, 6 }); + + Assert.Equal(new[] { 9, 5, 6 }, d.Propose(5)); + } + + [Fact] + public void Propose_NoMatch_ReturnsEmpty() + { + var d = new PromptLookupDraft(ngramMax: 3, ngramMin: 2); + d.Reset(new[] { 1, 2, 3, 4, 5, 6 }); + + Assert.Empty(d.Propose(4)); + Assert.Empty(d.Propose(0)); + } + + [Fact] + public void Append_ExtendsMatchableHistory() + { + var d = new PromptLookupDraft(ngramMax: 3, ngramMin: 2); + d.Reset(new[] { 1, 2, 3 }); + Assert.Empty(d.Propose(2)); + + // Generated tokens repeat the prompt opening: tail [1, 2] now matches index 0. + d.Append(1); + d.Append(2); + Assert.Equal(new[] { 3, 1, 2 }, d.Propose(3)); + } + + [Fact] + public void Reset_ClearsPreviousSession() + { + var d = new PromptLookupDraft(ngramMax: 2, ngramMin: 2); + d.Reset(new[] { 1, 2, 1, 2 }); + Assert.NotEmpty(d.Propose(1)); + + d.Reset(new[] { 9, 8, 7 }); + Assert.Empty(d.Propose(1)); + Assert.Equal(3, d.Count); + } +} diff --git a/tests/SharpInference.Tests.ForwardPass/SpeculativeDecoderTests.cs b/tests/SharpInference.Tests.ForwardPass/SpeculativeDecoderTests.cs index c265aa9..cf57d16 100644 --- a/tests/SharpInference.Tests.ForwardPass/SpeculativeDecoderTests.cs +++ b/tests/SharpInference.Tests.ForwardPass/SpeculativeDecoderTests.cs @@ -136,6 +136,43 @@ public void Decode_DraftDiverges_RejectionEmitsCorrectionFromBatchLogits() Assert.True(target.BatchVerifyCalls > 0); } + [Fact] + public void Decode_PromptLookupMode_EmitsTargetChainAndUsesLookupProposals() + { + var target = new ChainForwardPass(vocab: 16, supportsBatchVerify: true); + var spec = new SpeculativeDecoder(target, new PromptLookupDraft(ngramMax: 3, ngramMin: 2), lookahead: 4); + + // Prompt [10,11,12,10]; the target's saved logits continue the chain with 11. + // Step 1: certain 11 joins the history → tail [10,11] matches index 0 → proposals + // [12,10,11]. The chain target accepts 12 (the chain's true next) and rejects 10, + // so the step emits [11,12]. Later steps find no matching tail and degrade to + // plain single-token decode steps — the floor behavior. + spec.Initialize(new[] { 10, 11, 12, 10 }, ChainForwardPass.Logits(16, next: 11)); + + var emitted = new List(); + spec.Decode(4, [], emitted.Add); + + // The emitted sequence is the target's greedy chain regardless of proposal quality. + Assert.Equal(new[] { 11, 12, 13, 14 }, emitted); + // Step 1 verified [11,12,10,11] (one batch), steps 2 and 3 verified the lone + // certain token; the certain token rides in the verify, so no target Forward runs. + Assert.Equal(3, target.BatchVerifyCalls); + Assert.Equal(0, target.ForwardCalls); + // Exactly one proposal (the 12) was accepted across the run. + Assert.True(spec.AcceptanceRate > 0f); + } + + [Fact] + public void Initialize_PromptOverloadWithoutLookup_Throws() + { + var target = new ChainForwardPass(vocab: 16, supportsBatchVerify: true); + var draft = new ChainForwardPass(vocab: 16, supportsBatchVerify: false); + var spec = new SpeculativeDecoder(target, draft); + + Assert.Throws( + () => spec.Initialize(new[] { 1, 2, 3 }, ChainForwardPass.Logits(16, next: 4))); + } + /// /// Deterministic "chain" model: greedy next token is always (token+1) mod vocab. /// Tracks Forward/BatchVerify call counts; divergeEvery > 0 makes every From 2e6f423eb107eb99e05fc04d59571d7555c6baf6 Mon Sep 17 00:00:00 2001 From: Pekka Heikura Date: Thu, 11 Jun 2026 09:30:01 +0300 Subject: [PATCH 4/6] chore: gitignore Nsight profiling artifacts (prof/) Multi-GB nsys-rep/sqlite dumps from kernel profiling sessions; far over GitHub's 100 MB file limit and regenerable from the bench scripts. Co-Authored-By: Claude Fable 5 --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index 2f2d6ac..d272f2d 100644 --- a/.gitignore +++ b/.gitignore @@ -439,6 +439,8 @@ codebooks/*.bin models/ tools/ tmp/ +# Nsight Systems/Compute profiling artifacts (multi-GB sqlite/nsys-rep dumps) +prof/ # Per-developer host overrides (model paths, etc.) — not committed **/appsettings.Local.json From 0911cf3877cc643d7f948e1dee84a915169c7e4e Mon Sep 17 00:00:00 2001 From: Pekka Heikura Date: Thu, 11 Jun 2026 11:03:25 +0300 Subject: [PATCH 5/6] perf(engine): #207 goal 4 - MTP k-token batched verify on the GDN-hybrid path (#30) Realizes issue #25's >=1.3x criterion: Qwen3.6-27B-MTP Q4_K_M CUDA-hybrid decode 6.2 -> 10.4 t/s (1.68x) at default settings (k=2 verify batch, 90% draft acceptance). Sweep: draftN=1/2/3 -> 9.7/8.5/9.2 t/s with a 3-slot ring, 10.4 at the shipped 1-slot default. 35B-A3B-MTP (MoE) stays at parity (23.9 vs 24.4 warm) - per-token MoE routing can't share verify weight reads; grouped-by-expert verify batching is the follow-up there. Mechanics (the dense #207 BatchVerify seam, generalized to GDN): - MtpDecoder.DecodeBatched is now the folded k-token loop: the certain token plus a chained MTP draft sequence (new MtpLastHidden self-chaining) run through ONE IForwardPass.BatchVerify per step; on rejection the correction rides into the NEXT step's batch - no per-step commit forward (the #207 lesson). MTP KV entries for accepted drafts are rewritten with trunk hiddens (new HiddenAt) to keep the legacy trunk-quality contract. Per-step SupportsBatchVerify re-check preserves the #130 SnapKV gate; SHARPI_DISABLE_BATCH_VERIFY still forces the sequential N=1 path. Draft/Verify/Commit phase ms exposed for WDDM-paging diagnosis. - CudaHybridGdnForwardPass.BatchVerify reuses the #111/#114-B batched-prefill trunk (GEMM-batched projections, batched contiguous-position attention) with a per-position delta-net recurrence so a DEVICE-side GDN snapshot ring captures every token boundary; [k x vocab] all-position lm_head; pair-batched MatVec2In on the CPU mmap FFN layers (odd-k tail runs as a duplicated-input pair so per-token bits are k-parity-independent). - The ring (slots = SHARPI_MTP_BATCH_MAX-1, default 2 -> 1 slot x 149 MiB on 27B) is reserved BEFORE TryUploadDenseFfnLayers fills VRAM and skipped under SHARPI_DISABLE_MTP=1; each extra slot displaces ~2 GPU FFN layers (~0.35 t/s), which is why k=2 ships as the default. - Verify-path batched matmuls are latched onto the temp-free matvec re-stream: the #162 MMQ/dequant-GEMM prefill machinery allocates per-call fp16 temps (71 MB per Q6_K FFN layer, ~600 MB for the lm_head) that only amortize at prefill N and land in WDDM-paged VRAM at decode-k (measured 6.0 -> 9.2 t/s from this latch alone). - Fixes a latent bug: BatchForward2's GPU-GDN reject path snapshotted the STALE host _gdnStateCache (live state is the device tensors), so a rejected draft never actually rewound GPU GDN state. The device ring now serves slot 0 for BatchForward2 too; the new CUDA rollback oracle test fails without the fix. - HybridGdnForwardPass (CPU) gets the same BatchVerify with a lazily grown host ring; the MtpDecoder_GreedyParity_LlamaCpp byte oracle passes over the new path (also exercised at k=4 during bring-up). - Engine/CLI: the SpecDraftNMax>1 rejection is lifted; --spec-draft-n-max / SHARPI_MTP_DRAFT_N select the chain depth, clamped per step to the ring capacity (MaxBatchVerifyTokens) with an actionable CLI message. Next unlock (measured headroom): the GPU verify cost scales linearly with k on the matvec re-stream while the CPU FFN only pair-amortizes - a 4-input dense MatVec (est. ~12+ t/s at k=4) or #194-style weight-stationary batched decode kernels for the hybrid trunk. Tests: MtpDecoderBatchVerifyTests (scripted folded-loop contracts), CudaMtpBatchVerifyTests (per-position parity vs sequential, device-ring rollback oracle, chained-acceptance e2e); HybridGdnForwardPassTests, CudaHybridGdnForwardPassTests, CudaHybridGdnSnapKvTests, CudaSpecBatchVerifyTests, SpeculativeDecoderTests, GdnStateCache suites all green. Co-Authored-By: Claude Fable 5 --- src/SharpInference.Cli/RunCommand.cs | 26 +- src/SharpInference.Core/IForwardPass.cs | 36 +- .../CudaHybridGdnForwardPass.cs | 489 +++++++++++++++++- .../HybridGdnForwardPass.cs | 316 ++++++++++- src/SharpInference.Engine/InferenceEngine.cs | 21 +- src/SharpInference.Engine/MtpDecoder.cs | 312 +++++++---- .../CudaMtpBatchVerifyTests.cs | 243 +++++++++ .../MtpDecoderBatchVerifyTests.cs | 289 +++++++++++ 8 files changed, 1568 insertions(+), 164 deletions(-) create mode 100644 tests/SharpInference.Tests.ForwardPass/CudaMtpBatchVerifyTests.cs create mode 100644 tests/SharpInference.Tests.ForwardPass/MtpDecoderBatchVerifyTests.cs diff --git a/src/SharpInference.Cli/RunCommand.cs b/src/SharpInference.Cli/RunCommand.cs index 7077f7e..0c20587 100644 --- a/src/SharpInference.Cli/RunCommand.cs +++ b/src/SharpInference.Cli/RunCommand.cs @@ -114,7 +114,7 @@ public sealed class Settings : CommandSettings public string SpecTypeStr { get; init; } = "auto"; [CommandOption("--spec-draft-n-max")] - [Description("Max draft tokens per MTP step (default: 1). Currently only N=1 is supported on the MTP path; issue #30 will lift this. Mirrors llama.cpp.")] + [Description("Max draft tokens per MTP step (issue #30 batched verify). Unset resolves via SHARPI_MTP_DRAFT_N, then defaults to 1 (a 2-token verify batch — the measured optimum). Values > 1 also need snapshot-ring slots: set SHARPI_MTP_BATCH_MAX >= drafts+1 (default 2; each extra slot costs ~150 MiB VRAM on 27B). Mirrors llama.cpp.")] [DefaultValue(0)] public int SpecDraftNMax { get; init; } @@ -990,7 +990,12 @@ private static bool ResolveCliMtp(IForwardPass? mtpFwd, SamplingParams sp, bool } } - int maxBatchN = (mtpFwd is not null && mtpFwd.SupportsBatchVerify) ? 2 : 1; + // Max drafts per step = batch capacity − 1 (the certain token rides in the + // batch). The pass's snapshot-ring capacity bounds the batch (issue #30); + // without batched verify the sequential path drafts exactly 1 per step. + int maxDraftN = (mtpFwd is not null && mtpFwd.SupportsBatchVerify) + ? Math.Max(1, mtpFwd.MaxBatchVerifyTokens - 1) + : 1; switch (sp.SpecType) { @@ -1001,16 +1006,18 @@ private static bool ResolveCliMtp(IForwardPass? mtpFwd, SamplingParams sp, bool if (mtpFwd is null || !mtpFwd.HasMtpHead) { rejectReason = "--spec-type mtp requires a model with an MTP head (nextn tensors)."; return false; } if (sp.Temperature > 0f) { rejectReason = "--spec-type mtp requires greedy sampling (--temp 0)."; return false; } if (!noThinking) { rejectReason = "--spec-type mtp requires --no-thinking (chat template must render with enable_thinking=false)."; return false; } - if (sp.SpecDraftNMax > maxBatchN) + if (sp.SpecDraftNMax > maxDraftN) { - rejectReason = $"--spec-draft-n-max={sp.SpecDraftNMax} exceeds the supported N=2 batched verify ceiling (issue #30); N>2 is still TODO."; + rejectReason = $"--spec-draft-n-max={sp.SpecDraftNMax} exceeds this pass's batched-verify capacity " + + $"({maxDraftN} drafts/step); raise SHARPI_MTP_BATCH_MAX (snapshot-ring slots) to go deeper."; return false; } return true; default: // Auto - if (eligible && sp.SpecDraftNMax > maxBatchN) + if (eligible && sp.SpecDraftNMax > maxDraftN) { - rejectReason = $"--spec-draft-n-max={sp.SpecDraftNMax} exceeds the supported N=2 batched verify ceiling (issue #30); N>2 is still TODO."; + rejectReason = $"--spec-draft-n-max={sp.SpecDraftNMax} exceeds this pass's batched-verify capacity " + + $"({maxDraftN} drafts/step); raise SHARPI_MTP_BATCH_MAX (snapshot-ring slots) to go deeper."; return false; } return eligible; @@ -1044,7 +1051,12 @@ private static (int generated, int totalDecoded, float acceptanceRate, long acce Console.Error.WriteLine($"[DBG] tok={totalDecoded} next={next}('{tok.Decode([next])}')"); totalDecoded++; if (EmitToken(next, tok, streamDec, ref inThinking, hideThinking)) generated++; - }, pMin: sp.SpecDraftPMin, ct: CancellationToken.None); + }, pMin: sp.SpecDraftPMin, draftN: MtpDecoder.ResolveDraftN(sp.SpecDraftNMax), + ct: CancellationToken.None); + + if (Environment.GetEnvironmentVariable("SHARPI_TRACE_MTP") == "1" && mtpDec.TotalDraftsEmitted > 0) + Console.Error.WriteLine( + $"[mtp] phase ms: draft={mtpDec.DraftMs:F0} verify={mtpDec.VerifyMs:F0} commit={mtpDec.CommitMs:F0}"); // Flush the UTF-8 decoder tail, applying the same hide-thinking gate as DecodeLoop. var tail = streamDec.Flush(); diff --git a/src/SharpInference.Core/IForwardPass.cs b/src/SharpInference.Core/IForwardPass.cs index 778bb92..f294ef5 100644 --- a/src/SharpInference.Core/IForwardPass.cs +++ b/src/SharpInference.Core/IForwardPass.cs @@ -170,7 +170,8 @@ void PrefillMtp(IReadOnlyList tokens, int startPos = 0) { } /// (the cache must hold exactly /// positions), returning result[i] = logits after tokens[i]. All k K/V /// entries are appended to the cache; the caller rewinds rejected tokens via - /// . Amortizes the weight reads k× vs sequential + /// (rewindable passes) or + /// (GDN hybrids, issue #30). Amortizes the weight reads k× vs sequential /// calls on memory-bound decode paths. /// float[][] BatchVerify(int[] tokens, int startPos) => @@ -178,6 +179,17 @@ float[][] BatchVerify(int[] tokens, int startPos) => $"{GetType().Name} does not implement BatchVerify. " + "Check SupportsBatchVerify before calling."); + /// + /// Maximum token count accepted by a single call. + /// Rewindable dense passes have no structural limit (rollback is a cache + /// truncate), so the default is unbounded. GDN-hybrid passes (issue #30) + /// override this with their snapshot-ring capacity: rolling back to + /// position startPos + j requires the post-token-j-1 recurrent + /// state to have been captured, and the ring holds a fixed number of slots + /// sized at construction (SHARPI_MTP_BATCH_MAX). + /// + int MaxBatchVerifyTokens => int.MaxValue; + /// /// Last completed 's token-1 pre-output-norm hidden. /// Used by the MTP commit step on the batched verify path. Empty when no batched @@ -210,6 +222,28 @@ void RestoreBatchSnapshot(int lengthAfter) => $"{GetType().Name} does not implement RestoreBatchSnapshot. " + "Check SupportsBatchVerify before calling."); + /// + /// Post-trunk pre-final-norm hidden of an absolute position, read from the + /// hidden-history buffer that / / + /// / maintain on MTP-capable + /// passes (the contract, issue #33). Used by the batched + /// MTP decoder (issue #30) to fetch h_{p} for draft-chain commits and the + /// next step's first draft after a multi-token verify. Returns an empty span when + /// the position has not been populated (or on passes without an MTP head). + /// + ReadOnlySpan HiddenAt(int position) => default; + + /// + /// The MTP block's own residual-stream output from the most recent + /// call, captured before the shared-head norm. + /// Multi-token MTP drafting (issue #30) chains the head on itself: draft i+1's + /// prevHidden is draft i's block output, exactly the role the trunk's + /// plays for the first draft (the standard NEXTN/EAGLE + /// chaining; verification corrects any quality loss). Empty when no + /// has run or the pass has no MTP head. + /// + ReadOnlySpan MtpLastHidden => default; + /// Reset the MTP attention KV cache. No-op when is false. void MtpResetCache() { } diff --git a/src/SharpInference.Engine/CudaHybridGdnForwardPass.cs b/src/SharpInference.Engine/CudaHybridGdnForwardPass.cs index a52230f..eb80737 100644 --- a/src/SharpInference.Engine/CudaHybridGdnForwardPass.cs +++ b/src/SharpInference.Engine/CudaHybridGdnForwardPass.cs @@ -391,6 +391,49 @@ public sealed unsafe class CudaHybridGdnForwardPass : IForwardPass private byte* _batchSnapshotBuf; private long _batchSnapshotCap; private bool _batchSnapshotValid; + private int _batchStartPos; // startPos of the most recent batched verify + private int _batchK; // token count of the most recent batched verify + + // ── Device-side GDN snapshot ring (issues #30/#207 goal 4) ────────── + // On the GPU-GDN trunk (!_cpuGdn) the live recurrent state is the per-layer + // _gpuGdnScanState/_gpuGdnConvState device tensors, so rollback snapshots must + // be captured on-device: ring slot j holds every GDN layer's (scan, conv) state + // AFTER batch token j, packed per layer at offset gdnIdx × per-layer-floats. + // Allocated in the constructor BEFORE TryUploadDenseFfnLayers fills VRAM + // (~100 MB/slot for 27B; landing it in WDDM-paged memory would 5-10× the + // verify). _gdnRingSlots is the achieved slot count (alloc stops on OOM); + // SupportsBatchVerify requires ≥ 1 slot on this trunk. The host + // _batchSnapshotBuf above serves the SHARPI_CPU_GDN=1 debug trunk only. + private readonly Tensor?[]? _gpuGdnRingScan; + private readonly Tensor?[]? _gpuGdnRingConv; + private readonly int _gdnRingSlots; + + // Batched-verify scratch (exact-k; reallocated when the batch size changes the + // same way EnsureBatchedFfnScratch is — GEMM-N derives rows from ElementCount/k). + private Tensor? _gpuBvLogitsAll; // [k × vocab] all-position logits + private Tensor? _gpuBvFfnAll; // [k × embDim] CPU-FFN upload staging / combine + private float* _bvNormHost; // pinned [k × embDim] — moeNorm download for CPU FFN + private float* _bvFfnHost; // pinned [k × embDim] — CPU FFN outputs for upload + private float[]? _bvLogitsHost; // managed [k × vocab] download target + private int _bvCap; + + // MTP block-out hidden of the most recent MtpForward (pre-shared-head-norm), + // pinned so the capture rides the queued D2H stream (issue #30 draft chaining). + private float* _mtpSelfHidden; + + // Max tokens per BatchVerify call = ring slots + 1. SHARPI_MTP_BATCH_MAX in + // [2, 8] bounds the ring reservation. Default 2 (one slot): each slot costs + // ~149 MiB of VRAM that TryUploadDenseFfnLayers would otherwise fill with + // ~2 dense FFN layers, and the k=2 verify is the measured 27B optimum — + // deeper chains only pay once the CPU FFN amortizes more than pairwise + // (4-input MatVec follow-up). Instance-resolved so tests can override per + // construction. + private readonly int _mtpBatchMax = ResolveMtpBatchMax(); + private static int ResolveMtpBatchMax() + { + var s = Environment.GetEnvironmentVariable("SHARPI_MTP_BATCH_MAX"); + return s is not null && int.TryParse(s, out var v) ? Math.Clamp(v, 2, 8) : 2; + } // Token-2 host FFN scratch (intermediate gate/up post-MatVec2In, pre-SiLuMul). private readonly float* _cpuFfnGateBuf2; private readonly float* _cpuFfnUpBuf2; @@ -1103,6 +1146,53 @@ void TraceVram(string label) // We're conservative — most of the remaining budget is reserved for the // attention KV cache (10 layers × maxSeqLen × kvDim × 4 B × 2) and various // scratch. Use the GpuKvBytes from placement when the planner sized it. + // ── MTP detection + GDN snapshot ring reservation (issues #30/#207) ── + // Decided HERE (not at the head-upload block below) because the ring must + // be carved out of VRAM BEFORE TryUploadDenseFfnLayers greedily fills it + // to a 64 MiB margin — a later allocation would land in WDDM-paged memory + // and 5-10× every verify phase. SHARPI_DISABLE_MTP=1 skips the ring so + // MTP-off baseline runs keep the VRAM for FFN layers. + _hasMtp = hp.NumMtpLayers > 0 + && model.FindTensor($"blk.{hp.NumLayers}.nextn.eh_proj.weight") is not null; + if (_hasMtp && !_cpuGdn && _gdnStateCache.NumGdnLayers > 0 + && Environment.GetEnvironmentVariable("SHARPI_DISABLE_MTP") != "1") + { + int numGdn = _gdnStateCache.NumGdnLayers; + int scanF = _gdnStateCache.ScanStateFloatsPerLayer; + int convF = _gdnStateCache.ConvStateFloatsPerLayer; + int want = _mtpBatchMax - 1; + var ringScan = new Tensor?[want]; + var ringConv = new Tensor?[want]; + int got = 0; + for (int s = 0; s < want; s++) + { + try + { + ringScan[s] = gpu.Allocate(TensorShape.D1((long)numGdn * scanF)); + if (convF > 0) + ringConv[s] = gpu.Allocate(TensorShape.D1((long)numGdn * convF)); + got = s + 1; + } + catch (Exception ex) + { + // VRAM exhausted (or backend fault) — keep the slots we already + // have; MaxBatchVerifyTokens shrinks to match. + if (ringScan[s] is { } partial) { gpu.Free(partial); ringScan[s] = null; } + Console.Error.WriteLine( + $"[CudaHybridGdnForwardPass] GDN ring slot {s} allocation failed ({ex.GetType().Name}); " + + $"continuing with {got} slot(s)."); + break; + } + } + _gpuGdnRingScan = ringScan; + _gpuGdnRingConv = ringConv; + _gdnRingSlots = got; + long slotBytes = (long)numGdn * (scanF + convF) * sizeof(float); + Console.Error.WriteLine( + $"[CudaHybridGdnForwardPass] MTP batched-verify GDN ring: {got} slot(s) × " + + $"{slotBytes / (1024 * 1024)} MiB → max verify batch {got + 1} tokens."); + } + if (!hp.IsMoE) { // Dense FFN — no expert slot manager. Per-layer FFN-on-GPU upload runs @@ -1132,9 +1222,9 @@ void TraceVram(string label) // ── MTP / NEXTN head on GPU (issue #29; mirror of HybridGdnForwardPass) ── // Loaded when the GGUF reports nextn_predict_layers > 0 AND the expected // tensors exist at blk.{NumLayers}. Multi-head MTP (NumMtpLayers > 1) is - // out of scope for v1; only the first head is loaded. - _hasMtp = hp.NumMtpLayers > 0 - && model.FindTensor($"blk.{hp.NumLayers}.nextn.eh_proj.weight") is not null; + // out of scope for v1; only the first head is loaded. (_hasMtp itself is + // decided earlier, before the dense-FFN VRAM fill, so the batched-verify + // GDN ring could be reserved first.) if (_hasMtp) { int mtpLayerIdx = hp.NumLayers; @@ -1215,6 +1305,11 @@ void TraceVram(string label) // the queued PCIe transfer (issue #49). _lastHidden = AllocPinned(_embDim); + // MTP self-chaining hidden (issue #30): the MTP block's residual + // output, captured in MtpForward before the in-place shared-head + // norm. Pinned for the same queued-D2H reason as _lastHidden. + _mtpSelfHidden = AllocPinned(_embDim); + // GPU dense FFN scratch is allocated by TryUploadDenseFfnLayers only // when at least one trunk FFN layer lands on GPU. For MTP we need it // regardless — the MTP block's dense FFN runs on GPU even if all @@ -1716,7 +1811,7 @@ private void TrunkLayerSequential(int layer, int N, int startPos, bool isAttn, /// its single-token counterpart. /// private void TrunkBlockBatched(int layer, int N, int startPos, bool isAttn, - bool snapKvActive, int wStart) + bool snapKvActive, int wStart, bool gdnSnapRing = false) { int embDim = _embDim; EnsureBatchedTrunkScratch(N); @@ -1731,7 +1826,7 @@ private void TrunkBlockBatched(int layer, int N, int startPos, bool isAttn, if (isAttn) AttnBlockBatched(layer, N, startPos, snapKvActive, wStart, norm, blockOut); else - GdnBlockBatched(layer, N, norm, blockOut); + GdnBlockBatched(layer, N, norm, blockOut, gdnSnapRing); // postBlock = blockOut + stream (the pre-block residual). blockOut now holds // the MoE/FFN residual. @@ -1768,8 +1863,14 @@ private void TrunkLayerBatched(int layer, int N, int startPos, bool isAttn, /// Batched GDN block: projections over N tokens; fused sequential-scan /// recurrence + batched conv1d/L2norm/tile by default (issue #114-B), or the - /// per-position View loop under SHARPI_BATCHED_GDN_SCAN=0. - private void GdnBlockBatched(int layer, int N, Tensor norm, Tensor blockOut) + /// per-position View loop under SHARPI_BATCHED_GDN_SCAN=0. + /// (issue #30 batched verify) forces the + /// per-position loop — the fused scan only materialises the final state, but + /// rollback needs the state after EVERY non-final token — and captures each + /// boundary into device ring slot i via . + /// The per-position loop is the documented bit-identical fallback, so the + /// verify trunk stays in the same precision class as prefill/Forward. + private void GdnBlockBatched(int layer, int N, Tensor norm, Tensor blockOut, bool snapRing = false) { int convCh = _gdnConvChannels, valDim = _gdnValueDim, nVH = _gdnNumVHeads; int kDim = _gdnKeyDim, hd = _gdnHeadDim; @@ -1789,7 +1890,7 @@ private void GdnBlockBatched(int layer, int N, Tensor norm, Tensor blockOut) // one batched launch per stage + a single sequential-scan kernel. Output is // bit-identical to the per-position View loop below (same per-position math, // same reduction order; only the host launch overhead is removed). - if (BatchedGdnScanEnabled) + if (BatchedGdnScanEnabled && !snapRing) { var qkvConvAll = _gpuBtQkvConv!; var qHeadAll = _gpuBtQHead!; @@ -1847,6 +1948,11 @@ private void GdnBlockBatched(int layer, int N, Tensor norm, Tensor blockOut) scanState, _gpuGdnQHead, _gpuGdnKHead, _gpuGdnVHead, aIn, bIn, _gpuSsmA[layer], _gpuSsmDtBias[layer], _gpuSsmNormW[layer], zIn, outV, nVH, hd, normEps: 1e-6f); + // Batched verify (issue #30): capture the post-token-i state into + // device ring slot i so a rejection at position startPos+i+1 can + // restore it. Stream-ordered D2D, ~2 MB scan + conv per layer. + if (snapRing && i < N - 1) + CaptureGdnRingSlot(slot: i, layer); } finally { @@ -3092,9 +3198,28 @@ public ReadOnlySpan Forward(int token, int position) /// batched-verify coexist with eviction is the #130 follow-up. public bool SupportsBatchVerify => _hasMtp && (!_hp.IsMoE || _cpuMoe) + && (_cpuGdn || _gdnRingSlots >= 1) && !KvCacheCompacted && Environment.GetEnvironmentVariable("SHARPI_DISABLE_BATCH_VERIFY") != "1"; + /// + /// On the GPU-GDN trunk the ceiling is the device snapshot ring's capacity + /// (slots + 1, reserved at construction — SHARPI_MTP_BATCH_MAX). The + /// SHARPI_CPU_GDN=1 debug trunk keeps the legacy 2-token path. + public int MaxBatchVerifyTokens => _cpuGdn ? 2 : _gdnRingSlots + 1; + + /// + public ReadOnlySpan HiddenAt(int position) + { + if (!_hasMtp || position < 0 || position >= _mtpHiddenHistoryLength) + return default; + return new ReadOnlySpan(_mtpPrefillHiddens + (long)position * _embDim, _embDim); + } + + /// + public ReadOnlySpan MtpLastHidden => + _mtpSelfHidden != null ? new ReadOnlySpan(_mtpSelfHidden, _embDim) : default; + /// public ReadOnlySpan LastHiddenT1 => _lastHiddenT1 != null ? new ReadOnlySpan(_lastHiddenT1, _embDim) : default; @@ -3168,10 +3293,12 @@ public void BatchForward2(int t1, int t2, int startPos, else { GpuGdnBlockAt(layer, position: startPos, normIn: _gpuNormBuf, hiddenOut: _gpuHidden); - int gdnIdx = _gdnStateCache.GdnLayerOf(layer); - _gdnStateCache.SnapshotLayerInto(gdnIdx, - _batchSnapshotBuf + (long)gdnIdx * layerSnapBytes, - layerSnapBytes); + // Snapshot t1's state into DEVICE ring slot 0 — the live state on + // this trunk is the _gpuGdnScanState/_gpuGdnConvState tensors, not + // the host _gdnStateCache (which is stale outside CaptureSnapshot; + // the pre-#207 host SnapshotLayerInto here silently captured stale + // bytes, so a rejected draft never actually rewound the GPU state). + CaptureGdnRingSlot(slot: 0, layer); GpuGdnBlockAt(layer, position: startPos + 1, normIn: _gpuNormBuf2, hiddenOut: _gpuHidden2); } @@ -3261,6 +3388,8 @@ public void BatchForward2(int t1, int t2, int startPos, _gdnStateCache.IncrementPosition(); _kvCache.IncrementPosition(); _gdnStateCache.IncrementPosition(); + _batchStartPos = startPos; + _batchK = 2; // 5. Snapshot pre-output-norm hiddens for MTP commit + next iter draft. // Issue #49: queue both snapshots via DownloadAsync (pinned host @@ -3304,22 +3433,42 @@ public void BatchForward2(int t1, int t2, int startPos, } /// + /// + /// Roll the caches back to an intermediate point of the most recent batched + /// verify using the GDN snapshot ring: slot lengthAfter - startPos - 1 + /// holds the state after the batch token at position lengthAfter - 1. + /// GPU-GDN trunk: device-to-device ring → live state tensors (the host + /// _gdnStateCache stays stale, as in normal GPU operation — only its length + /// is bookkept). SHARPI_CPU_GDN=1 trunk: host ring → host state. + /// public void RestoreBatchSnapshot(int lengthAfter) { if (!_batchSnapshotValid) throw new InvalidOperationException( "RestoreBatchSnapshot: no batched-verify snapshot is held. " + - "Call BatchForward2 first."); - if (lengthAfter < 0) + "Call BatchForward2 or BatchVerify first."); + int slot = lengthAfter - _batchStartPos - 1; + if (slot < 0 || slot >= _batchK - 1) throw new ArgumentOutOfRangeException(nameof(lengthAfter), lengthAfter, - "lengthAfter must be >= 0."); + $"RestoreBatchSnapshot: lengthAfter must be in [{_batchStartPos + 1}, " + + $"{_batchStartPos + _batchK - 1}] — the most recent batched verify " + + $"covered positions [{_batchStartPos}, {_batchStartPos + _batchK})."); - long layerSnapBytes = _gdnStateCache.LayerSnapshotBytes; - for (int gdnIdx = 0; gdnIdx < _gdnStateCache.NumGdnLayers; gdnIdx++) + if (_cpuGdn) + { + long layerSnapBytes = _gdnStateCache.LayerSnapshotBytes; + long slotBytes = layerSnapBytes * _gdnStateCache.NumGdnLayers; + for (int gdnIdx = 0; gdnIdx < _gdnStateCache.NumGdnLayers; gdnIdx++) + { + _gdnStateCache.RestoreLayerFrom(gdnIdx, + _batchSnapshotBuf + slot * slotBytes + (long)gdnIdx * layerSnapBytes, + layerSnapBytes); + } + } + else { - _gdnStateCache.RestoreLayerFrom(gdnIdx, - _batchSnapshotBuf + (long)gdnIdx * layerSnapBytes, - layerSnapBytes); + for (int layer = 0; layer < _hp.NumLayers; layer++) + RestoreGdnRingSlot(slot, layer); } _gdnStateCache.SetLength(lengthAfter); _kvCache.TruncateTo(lengthAfter); @@ -3330,6 +3479,276 @@ public void RestoreBatchSnapshot(int lengthAfter) _batchSnapshotValid = false; } + /// + /// Copy one GDN layer's live device state (scan + conv) into ring slot + /// at the layer's packed offset. No-op for attention + /// layers. Device-to-device, stream-ordered — callers issue it right after the + /// layer's token-slot recurrence update. + /// + private void CaptureGdnRingSlot(int slot, int layer) + { + int gdnIdx = _gdnStateCache.GdnLayerOf(layer); + if (gdnIdx < 0) return; + if (_gpuGdnRingScan is null || slot >= _gdnRingSlots) + throw new InvalidOperationException( + $"CaptureGdnRingSlot({slot}): the GDN snapshot ring has {_gdnRingSlots} slot(s). " + + "Callers must clamp the batch to MaxBatchVerifyTokens."); + long scanBytes = (long)_gdnStateCache.ScanStateFloatsPerLayer * sizeof(float); + long convBytes = (long)_gdnStateCache.ConvStateFloatsPerLayer * sizeof(float); + if (_gpuGdnScanState[layer] is { } scanT && scanBytes > 0) + _gpu.CopyDeviceRegion(_gpuGdnRingScan[slot]!, gdnIdx * scanBytes, scanT, 0, scanBytes); + if (_gpuGdnConvState[layer] is { } convT && convBytes > 0 && _gpuGdnRingConv?[slot] is { } convRing) + _gpu.CopyDeviceRegion(convRing, gdnIdx * convBytes, convT, 0, convBytes); + } + + /// Inverse of : ring slot → live device state. + private void RestoreGdnRingSlot(int slot, int layer) + { + int gdnIdx = _gdnStateCache.GdnLayerOf(layer); + if (gdnIdx < 0) return; + if (_gpuGdnRingScan is null || slot >= _gdnRingSlots) + throw new InvalidOperationException( + $"RestoreGdnRingSlot({slot}): the GDN snapshot ring has {_gdnRingSlots} slot(s)."); + long scanBytes = (long)_gdnStateCache.ScanStateFloatsPerLayer * sizeof(float); + long convBytes = (long)_gdnStateCache.ConvStateFloatsPerLayer * sizeof(float); + if (_gpuGdnScanState[layer] is { } scanT && scanBytes > 0) + _gpu.CopyDeviceRegion(scanT, 0, _gpuGdnRingScan[slot]!, gdnIdx * scanBytes, scanBytes); + if (_gpuGdnConvState[layer] is { } convT && convBytes > 0 && _gpuGdnRingConv?[slot] is { } convRing) + _gpu.CopyDeviceRegion(convT, 0, convRing, gdnIdx * convBytes, convBytes); + } + + /// + /// k-token batched verify for the MTP folded decode loop (issues #30 / + /// #207 goal 4). The trunk runs as the #111/#114-B batched-prefill launches — + /// GEMM-batched projections, batched attention at contiguous positions + /// [startPos, startPos+k), per-position delta-net recurrence with the + /// device GDN snapshot ring captured at every token boundary — followed by the + /// FFN stage (GEMM-N on GPU dense layers; pair-batched MatVec2In on the + /// CPU mmap layers that dominate 27B decode; per-token CPU MoE) and an + /// all-position [k × vocab] lm_head. Returns result[i] = logits after + /// tokens[i]; rollback is . + /// + public float[][] BatchVerify(int[] tokens, int startPos) + { + ThrowIfFaulted(); + ArgumentNullException.ThrowIfNull(tokens); + if (!SupportsBatchVerify) + throw new InvalidOperationException( + "BatchVerify requires an MTP pass with an uncompacted cache and an " + + "available GDN snapshot ring. Check SupportsBatchVerify before calling."); + int k = tokens.Length; + if (k == 0) return Array.Empty(); + if (startPos < 0 || startPos + k > _maxSeqLen) + throw new ArgumentOutOfRangeException(nameof(startPos), + $"BatchVerify range [{startPos}, {startPos + k}) exceeds the context window (maxSeqLen={_maxSeqLen})."); + if (k > MaxBatchVerifyTokens) + throw new ArgumentOutOfRangeException(nameof(tokens), k, + $"BatchVerify token count exceeds MaxBatchVerifyTokens ({MaxBatchVerifyTokens}); " + + "raise SHARPI_MTP_BATCH_MAX (ring slots are reserved at construction)."); + if (k == 1) + { + // A single token amortizes nothing — plain Forward is strictly better. + var l = Forward(tokens[0], startPos); + return [l.ToArray()]; + } + if (_cpuGdn) + { + // SHARPI_CPU_GDN=1 debug trunk: the host-side GDN state is live, so the + // legacy 2-token path with its host snapshot is correct as-is. + // MaxBatchVerifyTokens caps k at 2 in this config. + BatchForward2(tokens[0], tokens[1], startPos, out var l1, out var l2); + return [l1.ToArray(), l2.ToArray()]; + } + if (_kvCache.Length != startPos) + throw new InvalidOperationException( + $"BatchVerify: _kvCache.Length={_kvCache.Length} != startPos={startPos}. " + + "Caches must sit exactly at startPos (a SnapKV-compacted cache is gated " + + "off via SupportsBatchVerify)."); + if (_gdnStateCache.Length != startPos) + throw new InvalidOperationException( + $"BatchVerify: _gdnStateCache.Length={_gdnStateCache.Length} != startPos={startPos}."); + + int embDim = _embDim; + long embBytes = (long)embDim * sizeof(float); + bool isMoe = _hp.IsMoE; + + EnsureStreamAll(k); + EnsureBatchedTrunkScratch(k); + if (BatchedFfnEnabled && !isMoe && _denseFfnGpuLayers > 0) + EnsureBatchedFfnScratch(k); + EnsureBatchVerifyScratch(k); + + // Pessimistic fault latch — same contract as the batched prefills: a + // mid-pass throw leaves the recurrent state partially advanced while the + // length counters still read startPos; fatal for this pass. + _faulted = true; + _batchSnapshotValid = false; + // Decode-sized batch: keep every batched matmul on the temp-free matvec + // re-stream (see GpuMatMulBatched). Cleared before returning; a mid-pass + // throw leaves it set, but _faulted already makes the pass unusable then. + _matVecBatchedOnly = true; + + var stream = _gpuStreamAll!; + + // 1. Embed every token into the residual-stream buffer + reserve KV blocks. + for (int i = 0; i < k; i++) + { + EmbedToken(_gpuHidden, tokens[i]); + _gpu.CopyDeviceRegion(stream, i * embBytes, _gpuHidden, 0, embBytes); + _kvCache.ReserveBlockAt(startPos + i); + } + + // 2. Trunk (batched, ring-capturing) + FFN stage, layer by layer. + for (int layer = 0; layer < _hp.NumLayers; layer++) + { + bool isAttn = _hp.LayerTypes![layer] == LayerType.Attention; + TrunkBlockBatched(layer, k, startPos, isAttn, snapKvActive: false, wStart: 0, + gdnSnapRing: true); + var blockOut = _gpuBtBlockOut!; + var moeNorm = _gpuBtMoeNorm!; + + bool denseGpuLayer = !isMoe && _gpuWFfnGate is not null && _gpuWFfnGate[layer] is not null; + bool batchLayer = BatchedFfnEnabled && denseGpuLayer + && BatchedMatMulSupported(_gpuWFfnGate![layer]!) + && BatchedMatMulSupported(_gpuWFfnUp![layer]!) + && BatchedMatMulSupported(_gpuWFfnDown![layer]!); + if (batchLayer) + { + // GEMM-N gate/up/down over all k tokens (issue #121 machinery). + BatchedGpuDenseFfn(layer, k, moeNorm, _gpuBfHiddenAll!); + _gpu.AddInPlace(_gpuBfHiddenAll!, blockOut); + _gpu.CopyDeviceRegion(stream, 0, _gpuBfHiddenAll!, 0, k * embBytes); + } + else if (!isMoe && !denseGpuLayer) + { + // CPU mmap dense FFN — the 27B/12GB decode cost center (~8.6 GB + // weight reads per token). Pair-batched MatVec2In reads each weight + // row once per pair; the odd tail re-runs as a duplicated-input + // pair (second output → sink) so every token's bits match the pair + // kernel regardless of k parity. + _gpu.Download(moeNorm, (nint)_bvNormHost, k * embDim); + for (int i = 0; i < k; i += 2) + { + bool tail = i + 1 >= k; + int j = tail ? i : i + 1; + CpuDenseFfn2(layer, + _bvNormHost + (long)i * embDim, _bvNormHost + (long)j * embDim, + _bvFfnHost + (long)i * embDim, + tail ? _cpuMoeHidden2 : _bvFfnHost + (long)j * embDim); + } + _gpu.UploadInto(_gpuBvFfnAll!, (nint)_bvFfnHost, k * embDim); + _gpu.AddInPlace(_gpuBvFfnAll!, blockOut); + _gpu.CopyDeviceRegion(stream, 0, _gpuBvFfnAll!, 0, k * embBytes); + } + else + { + // Per-token fallbacks: GPU dense layer with a non-GEMM-N weight + // dtype, or CPU MoE (per-token routing, issue #45 — the wins come + // from the batched trunk + lm_head, not the routed FFN itself). + // Full-GPU MoE never reaches here (SupportsBatchVerify gate). + for (int i = 0; i < k; i++) + { + _gpu.CopyDeviceRegion(_gpuNormBuf, 0, moeNorm, i * embBytes, embBytes); + if (!isMoe) + { + GpuDenseFfn(layer); + } + else + { + _gpu.Download(_gpuNormBuf, (nint)_cpuNormBuf, embDim); + CpuMoeFfnCore( + _gpuWGateShexp[layer], _gpuWUpShexp[layer], _gpuWDownShexp[layer], + _cpuFfnGateInp![layer], _cpuFfnGateInpShexp![layer], + _cpuFfnGateExps![layer], _cpuFfnUpExps![layer], _cpuFfnDownExps![layer], + gpuNormIn: _gpuNormBuf, gpuSharedOut: _gpuSharedOut, + cpuNormIn: _cpuNormBuf, cpuMoeOut: _cpuMoeHidden); + _gpu.UploadInto(_gpuHidden, (nint)_cpuMoeHidden, embDim); + } + _gpu.CopyDeviceRegion(_gpuResidual, 0, blockOut, i * embBytes, embBytes); + _gpu.AddInPlace(_gpuHidden, _gpuResidual); + _gpu.CopyDeviceRegion(stream, i * embBytes, _gpuHidden, 0, embBytes); + } + } + } + + // 3. Advance the position counters by k. + for (int i = 0; i < k; i++) + { + _kvCache.IncrementPosition(); + _gdnStateCache.IncrementPosition(); + } + _batchStartPos = startPos; + _batchK = k; + _faulted = false; + + // 4. MTP hidden history (issues #33/#106): stream[i] holds the + // pre-output-norm hidden for token startPos+i. + EnsureMtpHiddenHistoryCap(startPos + k); + _gpu.Download(stream, (nint)(_mtpPrefillHiddens + (long)startPos * embDim), k * embDim); + if (_mtpHiddenHistoryLength < startPos + k) + _mtpHiddenHistoryLength = startPos + k; + new ReadOnlySpan(_mtpPrefillHiddens + (long)(startPos + k - 1) * embDim, embDim) + .CopyTo(new Span(_lastHidden, embDim)); + + // 5. All-position logits: batched output norm + GEMM-N lm_head when the + // weight dtype supports it, else a per-token MatMul loop. + var normAll = _gpuBtNorm!; // free to reuse after the last trunk layer + _gpu.RmsNormBatched(normAll, stream, _gpuOutputNorm!, k, embDim, _hp.RmsNormEps); + var result = new float[k][]; + if (BatchedMatMulSupported(_gpuOutputWeight!)) + { + GpuMatMulBatched(_gpuBvLogitsAll!, _gpuOutputWeight!, normAll, k); + _gpu.Download(_gpuBvLogitsAll!, _bvLogitsHost.AsSpan(0, k * _hp.VocabSize)); + for (int i = 0; i < k; i++) + { + result[i] = new float[_hp.VocabSize]; + Array.Copy(_bvLogitsHost!, (long)i * _hp.VocabSize, result[i], 0, _hp.VocabSize); + } + } + else + { + var outDt = _gpuWeightDTypes.TryGetValue(_gpuOutputWeight!.Handle, out var dt) + ? dt : DType.Float32; + for (int i = 0; i < k; i++) + { + _gpu.CopyDeviceRegion(_gpuHidden, 0, normAll, i * embBytes, embBytes); + _gpu.MatMul(_gpuLogits, _gpuOutputWeight!, _gpuHidden, outDt); + _gpu.Download(_gpuLogits, _logitsBuf); + result[i] = (float[])_logitsBuf.Clone(); + } + } + + _matVecBatchedOnly = false; + _batchSnapshotValid = true; + return result; + } + + /// + /// (Re)allocate the batched-verify scratch for an exact batch size of + /// tokens: the [k × vocab] all-position logits tensor + + /// managed download buffer, the [k × embDim] CPU-FFN staging tensor, and the + /// pinned host norm/FFN roundtrip buffers. Exact-size (not grow-only) because + /// the GEMM-N kernels derive their row count from ElementCount / k. + /// + private void EnsureBatchVerifyScratch(int k) + { + if (_bvCap == k) return; + if (_gpuBvLogitsAll is { } l) { _gpu.Free(l); _gpuBvLogitsAll = null; } + if (_gpuBvFfnAll is { } f) { _gpu.Free(f); _gpuBvFfnAll = null; } + if (_bvNormHost != null) { CudaBackend.FreePinnedHost((nint)_bvNormHost); _bvNormHost = null; } + if (_bvFfnHost != null) { CudaBackend.FreePinnedHost((nint)_bvFfnHost); _bvFfnHost = null; } + long logitsTotal = (long)k * _hp.VocabSize; + if (logitsTotal > int.MaxValue) + throw new NotSupportedException( + $"Batched verify logits buffer ({k}×{_hp.VocabSize}) exceeds int.MaxValue."); + _gpuBvLogitsAll = _gpu.Allocate(TensorShape.D1(logitsTotal)); + _gpuBvFfnAll = _gpu.Allocate(TensorShape.D1((long)k * _embDim)); + _bvNormHost = AllocPinnedL((long)k * _embDim); + _bvFfnHost = AllocPinnedL((long)k * _embDim); + _bvLogitsHost = new float[(int)logitsTotal]; + _bvCap = k; + } + // ================================================================= // GPU attention block — GLU-gated Q, partial NEOX RoPE on first 64 dims // ================================================================= @@ -3522,6 +3941,12 @@ public ReadOnlySpan MtpForward(int token, int position, ReadOnlySpan 1) + // Issue #30/#207 batched verify: the MMQ/dequant-GEMM compute path only + // amortizes its fixed per-call costs (whole-weight dequant to an fp16 temp + // for Q6_K/Q5_K — 71 MB per 27B FFN layer, ~600 MB for the lm_head) at + // prefill-scale N. At decode-sized k those temps land in WDDM-paged VRAM + // behind the post-fill 64 MiB margin and 5-10× every step. The verify path + // latches _matVecBatchedOnly so every batched matmul takes the temp-free + // matvec re-stream (the same decode kernels the sequential Forward uses). + if (GdnPrefillComputeEnabled && nTok > 1 && !_matVecBatchedOnly) { int cols = (int)(inputAll.ElementCount / nTok); switch (dt) @@ -4663,6 +5095,10 @@ private void GpuMatMulBatched(Tensor outputAll, Tensor matrix, Tensor inputAll, _gpu.MatMulBatched(outputAll, matrix, inputAll, nTok, dt); } + // True while BatchVerify drives the batched trunk: forces GpuMatMulBatched onto + // the matvec re-stream path (see the comment there). Never set on prefill. + private bool _matVecBatchedOnly; + /// Issue #121: true when 's dtype is one of the dtypes /// implements a GEMM-N kernel for. Gates the /// batched FFN/MoE path so an unsupported dtype falls back to the per-token loop @@ -5367,6 +5803,17 @@ public void Dispose() _batchSnapshotBuf = null; } } + // Issue #30/#207-goal-4 k-token batched verify: device GDN snapshot + // ring + exact-k verify scratch + the MTP self-chaining hidden. + if (_gpuGdnRingScan is not null) + foreach (var t in _gpuGdnRingScan) if (t is { } rs) _gpu.Free(rs); + if (_gpuGdnRingConv is not null) + foreach (var t in _gpuGdnRingConv) if (t is { } rc) _gpu.Free(rc); + if (_gpuBvLogitsAll is { } bvl) _gpu.Free(bvl); + if (_gpuBvFfnAll is { } bvf) _gpu.Free(bvf); + if (_bvNormHost != null) CudaBackend.FreePinnedHost((nint)_bvNormHost); + if (_bvFfnHost != null) CudaBackend.FreePinnedHost((nint)_bvFfnHost); + if (_mtpSelfHidden != null) CudaBackend.FreePinnedHost((nint)_mtpSelfHidden); _mtpKvCache?.Dispose(); } diff --git a/src/SharpInference.Engine/HybridGdnForwardPass.cs b/src/SharpInference.Engine/HybridGdnForwardPass.cs index f12ab2e..5ec1b91 100644 --- a/src/SharpInference.Engine/HybridGdnForwardPass.cs +++ b/src/SharpInference.Engine/HybridGdnForwardPass.cs @@ -118,12 +118,44 @@ public sealed unsafe class HybridGdnForwardPass : IForwardPass private readonly float* _logits2; // [vocabSize] private readonly float* _lastHiddenT1; // [embDim] — t1's pre-output-norm hidden after BatchForward2 - // Per-layer GDN snapshot buffer holding the "between t1 and t2" state during - // BatchForward2. Each gdn-layer slot is _gdnStateCache.LayerSnapshotBytes wide, - // contiguous in this buffer. Sized in the constructor when MTP is loaded. + // Per-token-boundary GDN snapshot ring used by the batched verify paths + // (issues #30 / #207-goal-4). Slot j holds every GDN layer's state AFTER the + // batch's token j (j ∈ [0, k-2]), so a rejection at draft position j+1 can + // restore via RestoreBatchSnapshot(startPos + j + 1). Slot layout: + // offset(slot, gdnIdx) = slot × NumGdnLayers × LayerSnapshotBytes + // + gdnIdx × LayerSnapshotBytes + // BatchForward2 (the legacy 2-token path) writes slot 0 only. The buffer + // starts at 1 slot (constructor) and grows lazily in BatchVerify. private byte* _batchSnapshotBuf; private long _batchSnapshotCap; private bool _batchSnapshotValid; + private int _batchSnapshotSlots; // ring slots currently allocated + private int _batchStartPos; // startPos of the most recent batched verify + private int _batchK; // token count of the most recent batched verify + + // Batched-verify residual streams [k × embDim] (lazily grown; issue #30 k-token + // generalization). BatchForward2 keeps its dedicated _hidden2/... pair. + private float* _bvHiddenAll; + private float* _bvResidAll; + private float* _bvNormAll; + private int _bvCap; + + // MTP block-out hidden of the most recent MtpForward (pre-shared-head-norm), + // used as the next chained draft's prevHidden (issue #30 multi-token drafting). + private float* _mtpSelfHidden; + + // Max tokens per BatchVerify call (= 1 + max MTP draft chain length). The host + // snapshot ring grows lazily to k-1 slots of NumGdnLayers × LayerSnapshotBytes + // each (~149 MB/slot for 27B), so keep a sane ceiling. SHARPI_MTP_BATCH_MAX in + // [2, 8]; default matches the CUDA pass (k=2 is the measured optimum until the + // CPU FFN amortizes more than pairwise). Instance-resolved so tests can + // override per construction. + private readonly int _mtpBatchMax = ResolveMtpBatchMax(); + private static int ResolveMtpBatchMax() + { + var s = Environment.GetEnvironmentVariable("SHARPI_MTP_BATCH_MAX"); + return s is not null && int.TryParse(s, out var v) ? Math.Clamp(v, 2, 8) : 2; + } // ── Dimensions (cached) ──────────────────────────────────────────── private readonly int _embDim; @@ -602,6 +634,10 @@ public HybridGdnForwardPass(GgufModel model, IComputeBackend backend, ModelHyper // is needed as MTP input. Sized at embDim, refreshed each Forward. _lastHidden = Alloc(_embDim); + // MTP self-chaining hidden (issue #30): the MTP block's own residual + // output, captured before the shared-head norm in MtpForward. + _mtpSelfHidden = Alloc(_embDim); + // Issue #30 / #45: batched verify scratch. _hidden2/_residual2/_normBuf2/ // _logits2/_lastHiddenT1 are needed for any MTP-bearing model. The dense // FFN intermediate buffers (_ffnGate2/_ffnUp2) are only used on the dense @@ -622,8 +658,11 @@ public HybridGdnForwardPass(GgufModel model, IComputeBackend backend, ModelHyper long totalSnapBytes = perLayerBytes * _gdnStateCache.NumGdnLayers; if (totalSnapBytes > 0) { + // One ring slot up front (covers BatchForward2); BatchVerify grows + // the ring lazily via EnsureBatchSnapshotSlots(k - 1). _batchSnapshotBuf = (byte*)NativeMemory.Alloc((nuint)totalSnapBytes); _batchSnapshotCap = totalSnapBytes; + _batchSnapshotSlots = 1; } } @@ -1056,6 +1095,8 @@ public void BatchForward2(int t1, int t2, int startPos, _gdnStateCache.IncrementPosition(); _kvCache.IncrementPosition(); _gdnStateCache.IncrementPosition(); + _batchStartPos = startPos; + _batchK = 2; // 5. Snapshot the pre-output-norm hiddens before final norm overwrites them. var h1Span = new ReadOnlySpan(_hidden, _embDim); @@ -1092,29 +1133,36 @@ public void BatchForward2(int t1, int t2, int startPos, } /// - /// Roll the caches back to startPos + 1 (i.e. token-1 has been processed, - /// token-2 has not) using the snapshot taken inside the most recent - /// . Used by on a rejected - /// draft so a single follow-up can replay the corrected - /// token at position startPos + 1. + /// Roll the caches back to an intermediate point of the most recent batched + /// verify ( or ) using the + /// per-token-boundary GDN snapshot ring. selects + /// ring slot lengthAfter - startPos - 1: the state captured after the + /// batch's token at position lengthAfter - 1. Used by + /// on a rejected draft; the correction token then either replays via + /// (legacy N=2 path) or rides in the next verify batch + /// (folded k-token path). /// - /// Cache length to restore to (always - /// startPos + 1 for the issue-#30 N=2 path). + /// Cache length to restore to; must lie in + /// [startPos + 1, startPos + k - 1] of the most recent batched verify. public void RestoreBatchSnapshot(int lengthAfter) { if (!_batchSnapshotValid) throw new InvalidOperationException( "RestoreBatchSnapshot: no batched-verify snapshot is held. " + - "Call BatchForward2 first."); - if (lengthAfter < 0) + "Call BatchForward2 or BatchVerify first."); + int slot = lengthAfter - _batchStartPos - 1; + if (slot < 0 || slot >= _batchK - 1) throw new ArgumentOutOfRangeException(nameof(lengthAfter), lengthAfter, - "lengthAfter must be >= 0."); + $"RestoreBatchSnapshot: lengthAfter must be in [{_batchStartPos + 1}, " + + $"{_batchStartPos + _batchK - 1}] — the most recent batched verify " + + $"covered positions [{_batchStartPos}, {_batchStartPos + _batchK})."); long layerSnapBytes = _gdnStateCache.LayerSnapshotBytes; + long slotBytes = layerSnapBytes * _gdnStateCache.NumGdnLayers; for (int gdnIdx = 0; gdnIdx < _gdnStateCache.NumGdnLayers; gdnIdx++) { _gdnStateCache.RestoreLayerFrom(gdnIdx, - _batchSnapshotBuf + (long)gdnIdx * layerSnapBytes, + _batchSnapshotBuf + slot * slotBytes + (long)gdnIdx * layerSnapBytes, layerSnapBytes); } _gdnStateCache.SetLength(lengthAfter); @@ -1129,6 +1177,236 @@ public void RestoreBatchSnapshot(int lengthAfter) _batchSnapshotValid = false; } + /// + public int MaxBatchVerifyTokens => _mtpBatchMax; + + /// + public ReadOnlySpan HiddenAt(int position) + { + if (!_hasMtp || position < 0 || position >= _mtpHiddenHistoryLength) + return default; + return new ReadOnlySpan(_mtpPrefillHiddens + (long)position * _embDim, _embDim); + } + + /// + public ReadOnlySpan MtpLastHidden => + _mtpSelfHidden != null ? new ReadOnlySpan(_mtpSelfHidden, _embDim) : default; + + /// + /// k-token batched verify for the MTP folded decode loop (issue #30 / + /// #207 goal 4). Generalizes : processes + /// at positions [startPos, startPos + k) with + /// per-token sequential attn/GDN sublayers (causal order; GDN snapshot ring + /// captured after every non-final token) and pair-batched dense FFN / lm_head + /// via — each weight row read once per pair. + /// Returns result[i] = logits after tokens[i]. + /// Per-position outputs are bit-identical regardless of k: every token's + /// math runs through the same kernels with the same inputs (the odd-k tail is + /// processed as a duplicated-input MatVec2In pair so no kernel switch occurs). + /// This matches the BatchForward2 precision class — argmax-stable vs the + /// sequential path, whose dense FFN uses MatVecDual. + /// + public float[][] BatchVerify(int[] tokens, int startPos) + { + ArgumentNullException.ThrowIfNull(tokens); + if (!SupportsBatchVerify) + throw new InvalidOperationException( + "BatchVerify is only supported on MTP passes with an uncompacted cache. " + + "Check SupportsBatchVerify before calling."); + int k = tokens.Length; + if (k == 0) return Array.Empty(); + if (startPos < 0 || startPos + k > MaxSeqLen) + throw new ArgumentOutOfRangeException(nameof(startPos), + $"BatchVerify range [{startPos}, {startPos + k}) exceeds the context window (maxSeqLen={MaxSeqLen})."); + if (k > MaxBatchVerifyTokens) + throw new ArgumentOutOfRangeException(nameof(tokens), k, + $"BatchVerify token count exceeds MaxBatchVerifyTokens ({MaxBatchVerifyTokens}); " + + "raise SHARPI_MTP_BATCH_MAX or shorten the draft chain."); + if (k == 1) + { + // A single token amortizes nothing — same fallback as the dense passes. + var l = Forward(tokens[0], startPos); + return [l.ToArray()]; + } + if (_kvCache.Length != startPos) + throw new InvalidOperationException( + $"BatchVerify: _kvCache.Length={_kvCache.Length} != startPos={startPos}. " + + "Caches must sit exactly at startPos (a SnapKV-compacted cache is gated off " + + "via SupportsBatchVerify)."); + if (_gdnStateCache.Length != startPos) + throw new InvalidOperationException( + $"BatchVerify: _gdnStateCache.Length={_gdnStateCache.Length} != startPos={startPos}."); + + int embDim = _embDim; + EnsureBatchVerifyScratch(k); + EnsureBatchSnapshotSlots(k - 1); + + // 1. Embed all tokens into independent residual streams + reserve KV blocks. + for (int i = 0; i < k; i++) + { + EmbedTokenInto(tokens[i], _bvHiddenAll + (long)i * embDim); + _kvCache.ReserveBlockAt(startPos + i); + } + + _batchSnapshotValid = false; + long layerSnapBytes = _gdnStateCache.LayerSnapshotBytes; + long slotBytes = layerSnapBytes * _gdnStateCache.NumGdnLayers; + + // 2. Trunk layers — per-token attn/GDN (t_i before t_{i+1} so each token's + // attention reads its predecessors' K/V and GDN state), batched FFN. + for (int layer = 0; layer < _hp.NumLayers; layer++) + { + var attnNormW = GetNormWeight(_attnNorm[layer]); + for (int i = 0; i < k; i++) + { + float* h = _bvHiddenAll + (long)i * embDim; + Copy(_bvResidAll + (long)i * embDim, h, embDim); + SimdKernels.RmsNorm(_bvNormAll + (long)i * embDim, h, attnNormW, embDim, _hp.RmsNormEps); + } + + bool isAttn = _hp.LayerTypes![layer] == LayerType.Attention; + if (isAttn) + { + for (int i = 0; i < k; i++) + AttnBlockAt(layer, position: startPos + i, kvPosition: startPos + i, + normIn: _bvNormAll + (long)i * embDim, + hiddenOut: _bvHiddenAll + (long)i * embDim); + } + else + { + int gdnIdx = _gdnStateCache.GdnLayerOf(layer); + for (int i = 0; i < k; i++) + { + GdnBlockAt(layer, position: startPos + i, + normIn: _bvNormAll + (long)i * embDim, + hiddenOut: _bvHiddenAll + (long)i * embDim); + // Ring slot i = state after token i (rollback-to-startPos+i+1). + if (i < k - 1) + _gdnStateCache.SnapshotLayerInto(gdnIdx, + _batchSnapshotBuf + i * slotBytes + (long)gdnIdx * layerSnapBytes, + layerSnapBytes); + } + } + + for (int i = 0; i < k; i++) + SimdKernels.AddInPlace(_bvHiddenAll + (long)i * embDim, _bvResidAll + (long)i * embDim, embDim); + + var postNormW = GetNormWeight(_postAttnNorm[layer]); + for (int i = 0; i < k; i++) + { + float* h = _bvHiddenAll + (long)i * embDim; + Copy(_bvResidAll + (long)i * embDim, h, embDim); + SimdKernels.RmsNorm(_bvNormAll + (long)i * embDim, h, postNormW, embDim, _hp.RmsNormEps); + } + + if (_hp.IsMoE) + { + // Per-token MoE (issue #45): routed top-K differs per token, so no + // shared expert weight reads — same as BatchForward2. + for (int i = 0; i < k; i++) + MoeFfnCore( + _wGateInp[layer], + _wGateShexp[layer], _wUpShexp[layer], _wDownShexp[layer], + _wGateExps[layer], _wUpExps[layer], _wDownExps[layer], + _wGateInpShexp[layer], + normInExt: _bvNormAll + (long)i * embDim, + hiddenOutExt: _bvHiddenAll + (long)i * embDim); + } + else + { + // MatVec2In pairs; the odd tail duplicates its input into both pair + // slots (second output → _hidden2 sink) so EVERY token goes through + // the identical kernel — per-position bits don't depend on k parity. + for (int i = 0; i < k; i += 2) + { + bool tail = i + 1 >= k; + int j = tail ? i : i + 1; + DenseFfn2(layer, + _bvNormAll + (long)i * embDim, _bvNormAll + (long)j * embDim, + _bvHiddenAll + (long)i * embDim, + tail ? _hidden2 : _bvHiddenAll + (long)j * embDim); + } + } + + for (int i = 0; i < k; i++) + SimdKernels.AddInPlace(_bvHiddenAll + (long)i * embDim, _bvResidAll + (long)i * embDim, embDim); + } + + // 3. Advance both caches by k. + for (int i = 0; i < k; i++) + { + _kvCache.IncrementPosition(); + _gdnStateCache.IncrementPosition(); + } + _batchStartPos = startPos; + _batchK = k; + + // 4. Hidden history (issue #33/#106) + LastHidden before the final norm + // overwrites the streams in place. + if (_hasMtp) + { + EnsureMtpHiddenHistoryCap(startPos + k); + for (int i = 0; i < k; i++) + new ReadOnlySpan(_bvHiddenAll + (long)i * embDim, embDim).CopyTo( + new Span(_mtpPrefillHiddens + (long)(startPos + i) * embDim, embDim)); + if (_mtpHiddenHistoryLength < startPos + k) + _mtpHiddenHistoryLength = startPos + k; + Copy(_lastHidden, _bvHiddenAll + (long)(k - 1) * embDim, embDim); + } + + // 5. Final norm + lm_head, MatVec2In pairs (vocab-sized weight read once per + // pair). Odd tail uses the duplicated-input pair with _logits2 as sink. + var outNormW = GetNormWeight(_outputNorm); + for (int i = 0; i < k; i++) + { + float* h = _bvHiddenAll + (long)i * embDim; + SimdKernels.RmsNorm(h, h, outNormW, embDim, _hp.RmsNormEps); + } + + var result = new float[k][]; + for (int i = 0; i < k; i += 2) + { + bool tail = i + 1 >= k; + int j = tail ? i : i + 1; + SimdKernels.MatVec2In(_logits, _logits2, _outputWeight.DataPtr, + _bvHiddenAll + (long)i * embDim, _bvHiddenAll + (long)j * embDim, + _hp.VocabSize, embDim, _outputWeight.DType); + result[i] = new ReadOnlySpan(_logits, _hp.VocabSize).ToArray(); + if (!tail) + result[i + 1] = new ReadOnlySpan(_logits2, _hp.VocabSize).ToArray(); + } + + _batchSnapshotValid = true; + return result; + } + + /// Grow the [k × embDim] batched-verify residual streams (grow-only). + private void EnsureBatchVerifyScratch(int k) + { + if (_bvCap >= k) return; + nuint bytes = (nuint)((long)k * _embDim * sizeof(float)); + if (_bvHiddenAll != null) NativeMemory.Free(_bvHiddenAll); + if (_bvResidAll != null) NativeMemory.Free(_bvResidAll); + if (_bvNormAll != null) NativeMemory.Free(_bvNormAll); + _bvHiddenAll = (float*)NativeMemory.AllocZeroed(bytes); + _bvResidAll = (float*)NativeMemory.AllocZeroed(bytes); + _bvNormAll = (float*)NativeMemory.AllocZeroed(bytes); + _bvCap = k; + } + + /// Grow the GDN snapshot ring to at least slots + /// (grow-only; contents need not survive — the ring is rewritten every batch). + private void EnsureBatchSnapshotSlots(int slots) + { + if (_batchSnapshotSlots >= slots) return; + long slotBytes = _gdnStateCache.LayerSnapshotBytes * _gdnStateCache.NumGdnLayers; + if (slotBytes <= 0) { _batchSnapshotSlots = slots; return; } + if (_batchSnapshotBuf != null) NativeMemory.Free(_batchSnapshotBuf); + _batchSnapshotBuf = (byte*)NativeMemory.Alloc((nuint)(slotBytes * slots)); + _batchSnapshotCap = slotBytes * slots; + _batchSnapshotSlots = slots; + } + /// /// Batched gate × up → down dense FFN for two tokens sharing the same weight /// matrices. Each weight row is touched once per row iteration and dotted @@ -1511,6 +1789,12 @@ public ReadOnlySpan MtpForward(int token, int position, ReadOnlySpan GenerateChunksAsync( break; } - // --spec-draft-n-max parity with llama.cpp. Sharpi's MTP decoder - // is sequential N=1 today; issue #30 (Phase-7 batched verify + per-token - // GDN snapshot ring) is what unlocks N>1. Reject larger values up front - // so users don't silently get the same throughput they'd get without - // the flag. - if (useMtp && sp.SpecDraftNMax > 1) - throw new InvalidOperationException( - $"SamplingParams.SpecDraftNMax={sp.SpecDraftNMax} is not yet supported. " + - "Sharpi's MTP path is sequential N=1; issue #30 (Phase-7 batched verify) " + - "lifts this. Pass --spec-draft-n-max 1 (or omit) until then."); + // --spec-draft-n-max parity with llama.cpp (issue #30): the MTP + // draft-chain length per step. Unset (0) resolves via + // SHARPI_MTP_DRAFT_N → built-in default; MtpDecoder clamps per + // step against the pass's snapshot-ring capacity + // (MaxBatchVerifyTokens), so over-asking degrades gracefully. + int mtpDraftN = MtpDecoder.ResolveDraftN(sp.SpecDraftNMax); // Prefix cache decision: two branches. // (a) Rewindable attention pass — existing FindCacheablePrefix path, @@ -495,7 +491,7 @@ public async IAsyncEnumerable GenerateChunksAsync( if (chunk.Length > 0) channel.Writer.TryWrite( new GenerateChunk(GenerateChunkKind.Text, chunk)); - }, pMin: sp.SpecDraftPMin, ct: ct); + }, pMin: sp.SpecDraftPMin, draftN: mtpDraftN, ct: ct); var textFlushMtp = textDecMtp.Flush(); if (textFlushMtp.Length > 0) @@ -507,7 +503,8 @@ public async IAsyncEnumerable GenerateChunksAsync( { Console.Error.WriteLine( $"[InferenceEngine] MTP: {mtpDec.TotalDraftsAccepted}/{mtpDec.TotalDraftsEmitted} " + - $"drafts accepted ({mtpDec.AcceptanceRate:P1})"); + $"drafts accepted ({mtpDec.AcceptanceRate:P1}); " + + $"phase ms draft={mtpDec.DraftMs:F0} verify={mtpDec.VerifyMs:F0} commit={mtpDec.CommitMs:F0}"); } // End-of-decode snapshot — see the non-MTP twin below for the diff --git a/src/SharpInference.Engine/MtpDecoder.cs b/src/SharpInference.Engine/MtpDecoder.cs index b21b860..36b9409 100644 --- a/src/SharpInference.Engine/MtpDecoder.cs +++ b/src/SharpInference.Engine/MtpDecoder.cs @@ -8,10 +8,19 @@ namespace SharpInference.Engine; /// is the MTP head of the same network — not a separate weights file — so /// the vocab, tokenizer, and trunk state are all shared. /// -/// Algorithm (v1, sequential N=1). +/// Primary algorithm — folded k-token batched verify (issues #30/#207). +/// Per step, the certain next token (argmax of the saved main logits) plus a +/// chain of draftN MTP proposals run through ONE +/// pass, which amortizes the trunk's +/// weight reads k× on memory-bound decode. Leading drafts that match the +/// verifier's argmax are emitted; on a rejection the trunk rolls back via the +/// per-token GDN snapshot ring () +/// and the correction token rides into the NEXT step's batch — the target never +/// runs a separate correction forward. See . /// -/// The main pass and the MTP head each maintain their own KV/state caches -/// that advance in lockstep, one position per emitted token. Per iteration: +/// Fallback algorithm — sequential N=1 (), +/// used when the pass doesn't support batched verify (e.g. SnapKV-compacted +/// cache, issue #130) or under SHARPI_DISABLE_BATCH_VERIFY=1. Per iter: /// /// /// t1 = argmax(saved_main_logits) — already correct by greedy @@ -28,13 +37,9 @@ namespace SharpInference.Engine; /// for the next iter. /// /// -/// Speedup envelope. -/// Sequential N=1 emits 2 tokens per 2 main forwards + 2 MTP forwards. With -/// MTP ~1/64 the cost of a main forward, the per-iteration wall time is -/// roughly the same as 2 baseline forwards — i.e. v1 is near-baseline, not -/// 1.3×. Hitting the issue #25 acceptance criterion (≥1.3×) requires a -/// batched main verify pass (Phase 7 optimization) plus a per-token GDN -/// snapshot ring (Phase 11.7 / Risk #6). Both are tracked as follow-ups. +/// The sequential form emits 2 tokens per 2 main forwards — near-baseline +/// throughput (the pre-#30 state). The batched form is what realizes the issue +/// #25 ≥1.3× criterion. /// /// State invariants this class relies on. /// @@ -58,6 +63,21 @@ public sealed class MtpDecoder private long _totalDraftsEmitted; private long _totalDraftsAccepted; + // Phase timing (issue #207 bench reporting, mirrors SpeculativeDecoder): + // cumulative wall time in the MTP draft chain + KV refresh forwards, the + // batched main verify, and rollback/state sync. Uniform slowdown across all + // three phases under VRAM pressure indicates WDDM paging, not slow kernels. + private readonly System.Diagnostics.Stopwatch _phaseSw = new(); + + /// Cumulative ms in MTP head forwards (draft chain + KV refresh). + public double DraftMs { get; private set; } + + /// Cumulative ms in the batched main verify passes. + public double VerifyMs { get; private set; } + + /// Cumulative ms in snapshot rollback + saved-state sync. + public double CommitMs { get; private set; } + public MtpDecoder(IForwardPass fwd) { ArgumentNullException.ThrowIfNull(fwd); @@ -112,15 +132,38 @@ public void Initialize(int nextPosition, ReadOnlySpan lastMainLogits) h.CopyTo(_savedHidden); _totalDraftsEmitted = 0; _totalDraftsAccepted = 0; + DraftMs = 0; + VerifyMs = 0; + CommitMs = 0; + } + + /// + /// Resolve the effective MTP draft-chain length (proposed tokens per step) from the + /// llama.cpp-parity --spec-draft-n-max value. < 1 (unset) falls back + /// to SHARPI_MTP_DRAFT_N, then to the built-in default of 1 (a 2-token verify + /// batch — the measured 27B optimum: the GPU trunk's matvec re-stream scales linearly + /// with k while the CPU FFN only pair-amortizes, so deeper chains don't pay until a + /// 4-input CPU FFN kernel lands; see the issue #30 bench). Deeper chains also need + /// ring slots: raise SHARPI_MTP_BATCH_MAX alongside. The result is further + /// clamped per step against . + /// Shared by and the CLI so both resolve identically. + /// + public static int ResolveDraftN(int specDraftNMax) + { + if (specDraftNMax >= 1) return specDraftNMax; + var s = Environment.GetEnvironmentVariable("SHARPI_MTP_DRAFT_N"); + if (s is not null && int.TryParse(s, out var v) && v >= 1) return v; + return 1; } /// /// Decode up to tokens. Calls /// for every accepted or correction token. Stops when a token in /// is generated (and does NOT emit the stop token). - /// Dispatches to the batched N=2 verify path () when the - /// underlying forward pass implements ; - /// otherwise falls back to the sequential N=1 algorithm. + /// Dispatches to the folded k-token batched verify path (, + /// issues #30/#207) when the underlying forward pass implements + /// ; otherwise falls back to the sequential + /// N=1 algorithm. /// /// Min draft probability for probabilistic accept under MTP /// verification (llama.cpp's --spec-draft-p-min, issue #38). 1.0 @@ -128,26 +171,30 @@ public void Initialize(int nextPosition, ReadOnlySpan lastMainLogits) /// p ∈ (0, 1) also accepts when softmax(target)[draft] >= p; the /// emitted token then equals the draft, which can differ from baseline greedy. /// 0.0 or negative is treated as 1.0. + /// MTP draft-chain length: proposed tokens per step (the verify + /// batch is 1 + draftN wide — the certain token rides in the batch). Clamped + /// per step to , the remaining token + /// budget, and the context window. 1 reproduces the legacy N=2 behaviour. public void Decode(int maxTokens, ReadOnlySpan stopTokenIds, Action emitToken, float pMin = 1f, + int draftN = 1, CancellationToken ct = default) { // Treat 0 and negative as the default (argmax-match) for back-compat with // callers that left SpecDraftPMin at the previous default of 0f. if (pMin <= 0f) pMin = 1f; - // SupportsBatchVerify is evaluated once here and we commit to one path for the - // whole loop. Issue #130: it returns false when the KV cache is SnapKV-compacted, - // so an evicted cache routes to the eviction-safe sequential path. This relies on - // eviction being prefill-only (all ApplySnapKvEviction call sites are in Prefill) — - // if decode-time/rolling eviction is ever added, BatchForward2's precondition would - // need to handle a mid-loop compaction rather than this once-per-Decode gate. - bool batched = _fwd.SupportsBatchVerify; + draftN = Math.Max(1, draftN); + // Initial dispatch; DecodeBatched additionally re-checks the capability per + // step (#130: it flips false when the KV cache is SnapKV-compacted) and + // hands off to the sequential loop if it ever turns off mid-decode. + bool batched = _fwd.SupportsBatchVerify + && Environment.GetEnvironmentVariable("SHARPI_DISABLE_BATCH_VERIFY") != "1"; if (Environment.GetEnvironmentVariable("SHARPI_TRACE_MTP") == "1") Console.Error.WriteLine( - $"[mtp] batched-verify {(batched ? "ON" : "OFF (cache compacted / unsupported config)")}"); + $"[mtp] batched-verify {(batched ? $"ON (draftN={draftN}, maxBatch={_fwd.MaxBatchVerifyTokens})" : "OFF (cache compacted / unsupported config / disabled)")}"); if (batched) { - DecodeBatched(maxTokens, stopTokenIds, emitToken, pMin, ct); + DecodeBatched(maxTokens, stopTokenIds, emitToken, pMin, draftN, ct); return; } DecodeSequential(maxTokens, stopTokenIds, emitToken, pMin, ct); @@ -237,120 +284,167 @@ private void DecodeSequential(int maxTokens, ReadOnlySpan stopTokenIds, Act } /// - /// Batched N=2 verify (issue #30). Per iter: - /// t1 = argmax(saved_main_logits); emit t1 - /// t2_draft = argmax(MtpForward(t1, P, h_prev)) - /// BatchForward2(t1, t2_draft, P) → (l@P+1, l@P+2), LastHiddenT1 = h@P - /// if argmax(l@P+1) == t2_draft: accept; saved_main_logits = l@P+2; emit t2_draft - /// else: reject; RestoreBatchSnapshot(P+1); MtpTruncateTo(P+1); - /// Forward(t2_target, P+1) → l@P+2; saved_main_logits = l@P+2; emit t2_target - /// MtpForward(t2_emitted, P+1, LastHiddenT1) # commit MTP at P+1 - /// _savedHidden = LastHidden # h@P+1 ready for next iter + /// Folded k-token batched verify (issues #30 / #207 goal 4 — the llama.cpp / + /// SpeculativeDecoder formulation). Per step, with k = 1 + min(draftN, budget): + /// + /// t1 = argmax(saved_main_logits); emit t1 // CERTAIN token + /// tokens = [t1, d1..d_{k-1}] d_i chained via MtpForward // d1 from h@P-1, + /// // d_{i>1} from MtpLastHidden + /// batch = BatchVerify(tokens, P) // ONE pass; batch[i] = logits after tokens[i] + /// a = count of leading d_i with Accept(d_i, batch[i-1]) + /// if a < k-1: RestoreBatchSnapshot(P + 1 + a) // GDN ring + KV rewind + /// MtpTruncateTo(P+1); re-run MtpForward for d_1..d_a with trunk hiddens (HiddenAt) + /// emit d_1..d_a + /// saved_main_logits = batch[a] // the correction (or next certain token) + /// _savedHidden = HiddenAt(P + a); _nextPos = P + 1 + a + /// + /// The rejected-path correction token is argmax(batch[a]) — it rides into the NEXT + /// step's batch as t1, so the trunk runs exactly one batched pass per step and no + /// separate correction forward exists (the #207 lesson: a per-step commit forward + /// erases the batching win). At draftN=1 this emits exactly the legacy N=2 + /// sequence: every emitted token is argmax of main logits when pMin == 1. /// private void DecodeBatched(int maxTokens, ReadOnlySpan stopTokenIds, Action emitToken, - float pMin, CancellationToken ct) + float pMin, int draftN, CancellationToken ct) { bool trace = Environment.GetEnvironmentVariable("SHARPI_TRACE_MTP") == "1"; - // Per-iter copy of h@P (LastHiddenT1) so MtpForward's scratch can't - // disturb the slice between batched verify and MTP commit. Sized to the - // embedding dim (length of _savedHidden); allocated once outside the loop. - Span hAtPCopy = new float[_savedHidden.Length]; + // Draft-chain hidden scratch: MtpLastHidden points into pass scratch that the + // next MtpForward overwrites, so the chain copies it out between calls. + var chainHidden = new float[_savedHidden.Length]; int generated = 0; while (generated < maxTokens) { ct.ThrowIfCancellationRequested(); + int P = _nextPos; + // Step sizing: the batch holds t1 + the draft chain, bounded by the + // remaining token budget (a step never verifies tokens it can't emit), + // the pass's snapshot-ring capacity, and the context window. + int kEff = Math.Min(1 + draftN, maxTokens - generated); + kEff = Math.Min(kEff, _fwd.MaxBatchVerifyTokens); + kEff = Math.Min(kEff, _fwd.MaxSeqLen - P); + + // Per-step capability re-check (#130: SnapKV compaction turns it off). + // kEff < 2 also routes here: a 1-token "batch" is just a plain decode + // step, which the sequential loop already implements. + if (kEff < 2 || !_fwd.SupportsBatchVerify) + { + DecodeSequential(maxTokens - generated, stopTokenIds, emitToken, pMin, ct); + return; + } + // ── Token 1: argmax of last main logits (greedy correctness) ── int t1 = ArgMax(_savedMainLogits); if (IsStop(t1, stopTokenIds)) return; emitToken(t1); generated++; - if (generated >= maxTokens) return; - - // ── MTP draft for position P+1 ──────────────────────────── - int P = _nextPos; - ReadOnlySpan mtpLogits = _fwd.MtpForward(t1, P, _savedHidden); - int t2Draft = ArgMax(mtpLogits); - float t2DraftLogit = mtpLogits[t2Draft]; - _totalDraftsEmitted++; - // ── Batched main verify (advances main caches through t1 + t2_draft) ─ - _fwd.BatchForward2(t1, t2Draft, P, - out ReadOnlySpan l_atPplus1, - out ReadOnlySpan l_atPplus2); - int t2Target = ArgMax(l_atPplus1); - - // Snapshot LastHiddenT1 — h@P (t1's pre-output-norm hidden). Needed for - // the MTP commit at P+1 regardless of accept/reject. We copy out now - // because a subsequent Forward (reject path) doesn't overwrite it but - // the value's tied to the batched forward's scratch. - var hAtP = _fwd.LastHiddenT1; - if (hAtP.Length != _savedHidden.Length) - throw new InvalidOperationException( - $"LastHiddenT1 length {hAtP.Length} != EmbeddingDim {_savedHidden.Length}."); - // Use a local copy so MtpForward (which writes its own scratch) can't - // disturb the slice between now and the commit call. - hAtP.CopyTo(hAtPCopy); - - ReadOnlySpan mainLogitsAfter; - int t2; - bool accept = AcceptDraft(t2Draft, t2Target, l_atPplus1, pMin, - out float draftProbInMain); - if (accept) + // ── MTP draft chain: kEff−1 proposals ───────────────────── + // d1 is conditioned on the trunk's h@P-1; deeper drafts self-chain on + // the MTP block's own output hidden (NEXTN/EAGLE-style — verification + // corrects any quality loss, so this only affects acceptance rate). + _phaseSw.Restart(); + var tokens = new int[kEff]; + tokens[0] = t1; + ReadOnlySpan prevH = _savedHidden; + for (int i = 1; i < kEff; i++) { - _totalDraftsAccepted++; - // Emit the draft (which equals argmax on argmax-match, or differs - // from argmax on a prob-only accept under pMin < 1.0). The batched - // forward has already advanced both caches through t2_draft, so - // l_at_P+2 is the next iter's saved_main_logits. - t2 = t2Draft; - mainLogitsAfter = l_atPplus2; + ReadOnlySpan mtpLogits = _fwd.MtpForward(tokens[i - 1], P + i - 1, prevH); + tokens[i] = ArgMax(mtpLogits); + if (i < kEff - 1) + { + var selfH = _fwd.MtpLastHidden; + if (selfH.Length != chainHidden.Length) + throw new InvalidOperationException( + $"MtpLastHidden length {selfH.Length} != EmbeddingDim {chainHidden.Length}; " + + "the forward pass must capture the MTP block output for chained drafting."); + selfH.CopyTo(chainHidden); + prevH = chainHidden; + } } - else + _totalDraftsEmitted += kEff - 1; + DraftMs += _phaseSw.Elapsed.TotalMilliseconds; + + // ── ONE batched main verify over tokens[0..kEff) ────────── + _phaseSw.Restart(); + float[][] batch = _fwd.BatchVerify(tokens, P); + VerifyMs += _phaseSw.Elapsed.TotalMilliseconds; + + // ── Greedy accept: count leading agreeing drafts ────────── + int a = 0; + for (int i = 1; i < kEff; i++) { - t2 = t2Target; - _fwd.RestoreBatchSnapshot(P + 1); - _fwd.MtpTruncateTo(P + 1); - mainLogitsAfter = _fwd.Forward(t2Target, P + 1); + int target = ArgMax(batch[i - 1]); + if (AcceptDraft(tokens[i], target, batch[i - 1], pMin, out _)) a++; + else break; } + _totalDraftsAccepted += a; + int newPos = P + 1 + a; if (trace) - { - float draftLogitInMain = l_atPplus1[t2Draft]; - float mainTopLogit = l_atPplus1[t2Target]; Console.Error.WriteLine( - $"[mtp-batch] P={P} t1={t1} t2_draft={t2Draft}(draft_logit={t2DraftLogit:F3}, main_logit_at_draft={draftLogitInMain:F3}, p={draftProbInMain:F3}) " + - $"t2_target={t2Target}(main_logit={mainTopLogit:F3}) " + - $"{(accept ? "ACCEPT" : "reject")}"); - } - - if (IsStop(t2, stopTokenIds)) - { - // Keep state consistent for a follow-up call. - _fwd.LastHidden.CopyTo(_savedHidden); - mainLogitsAfter.CopyTo(_savedMainLogits); - _nextPos = P + 2; - return; - } - emitToken(t2); generated++; - if (generated >= maxTokens) + $"[mtp-batch] P={P} k={kEff} t1={t1} " + + $"drafts=[{string.Join(",", tokens[1..])}] accepted={a}/{kEff - 1}" + + (a < kEff - 1 ? $" correction={ArgMax(batch[a])}" : "")); + + _phaseSw.Restart(); + // ── Roll back the rejected tail (no-op when fully accepted) ── + if (newPos < P + kEff) + _fwd.RestoreBatchSnapshot(newPos); + CommitMs += _phaseSw.Elapsed.TotalMilliseconds; + + // ── MTP KV refresh ──────────────────────────────────────── + // The chain wrote MTP K/V at P (trunk-quality) and P+1..P+kEff-2 + // (self-hidden quality). Rewind to P+1 and rewrite the ACCEPTED + // positions with the trunk hiddens the verify produced, so future + // drafts attend over trunk-quality K/V — the same per-position + // contract the legacy N=2 commit kept. The rejected-path correction's + // MTP entry is written by the NEXT step's first chain call. + _phaseSw.Restart(); + _fwd.MtpTruncateTo(P + 1); + for (int i = 1; i <= a; i++) + _ = _fwd.MtpForward(tokens[i], P + i, HiddenAtChecked(P + i - 1)); + DraftMs += _phaseSw.Elapsed.TotalMilliseconds; + + // ── Emit accepted drafts ────────────────────────────────── + for (int i = 1; i <= a; i++) { - _fwd.LastHidden.CopyTo(_savedHidden); - mainLogitsAfter.CopyTo(_savedMainLogits); - _nextPos = P + 2; - return; + if (IsStop(tokens[i], stopTokenIds)) + { + // Don't emit the stop; sync state to "after tokens[i-1]" for a + // consistent follow-up. (The trunk cache may sit past _nextPos — + // same benign overshoot class as the legacy stop-at-t2 path.) + batch[i - 1].CopyTo(_savedMainLogits, 0); + HiddenAtChecked(P + i - 1).CopyTo(_savedHidden); + _nextPos = P + i; + return; + } + emitToken(tokens[i]); generated++; } - // ── MTP commit at P+1 ──────────────────────────────────── - // prevHidden = h@P (the hidden that came out of the trunk for t1). - _ = _fwd.MtpForward(t2, P + 1, hAtPCopy); - - // Update saved state for next iter. _fwd.LastHidden = h@P+1. - _fwd.LastHidden.CopyTo(_savedHidden); - mainLogitsAfter.CopyTo(_savedMainLogits); - _nextPos = P + 2; + // ── Saved state for the next step ───────────────────────── + // batch[a] predicts position newPos: it is the next certain token — + // the correction on a reject, the chain continuation on full accept. + batch[a].CopyTo(_savedMainLogits, 0); + HiddenAtChecked(newPos - 1).CopyTo(_savedHidden); + _nextPos = newPos; } } + /// + /// with an invariant check: the batched verify + /// must have populated the hidden-history slot for every batch position. + /// + private ReadOnlySpan HiddenAtChecked(int position) + { + var h = _fwd.HiddenAt(position); + if (h.Length != _savedHidden.Length) + throw new InvalidOperationException( + $"HiddenAt({position}) returned {h.Length} floats, expected {_savedHidden.Length}. " + + "BatchVerify must write the per-position trunk hiddens into the MTP hidden " + + "history (the PrefillMtp contract, issue #33)."); + return h; + } + /// /// MTP draft acceptance check (issue #38). Always accepts when the draft is the /// verifier's argmax. Under < 1.0, ALSO accepts when diff --git a/tests/SharpInference.Tests.ForwardPass/CudaMtpBatchVerifyTests.cs b/tests/SharpInference.Tests.ForwardPass/CudaMtpBatchVerifyTests.cs new file mode 100644 index 0000000..e576a6f --- /dev/null +++ b/tests/SharpInference.Tests.ForwardPass/CudaMtpBatchVerifyTests.cs @@ -0,0 +1,243 @@ +using SharpInference.Core; +using SharpInference.Cuda; +using SharpInference.Engine; + +namespace SharpInference.Tests.ForwardPass; + +/// +/// CUDA-hybrid coverage for the k-token MTP batched verify (issues #30 / #207 +/// goal 4) on the qwen35 27B-MTP model: +/// +/// pass-level: per-position +/// logits vs k sequential Forward calls (argmax + maxAbs — the +/// BatchForward2 precision class: the CPU mmap FFN layers run MatVec2In in +/// the batch vs MatVecDual sequentially, so bit-equality is not expected); +/// rollback: verify junk drafts, +/// to an intermediate position, and confirm the continued trajectory matches +/// the pure-sequential one — this exercises the DEVICE GDN snapshot ring +/// (pre-#207 the GPU-GDN reject path restored stale host state and the +/// rejected draft's rank-1 update stayed baked into the recurrence); +/// e2e: batched greedy decode is coherent and the +/// chained drafts actually get accepted. +/// +/// Skipped silently when CUDA is unavailable or the 27B-MTP GGUF isn't on disk. +/// +public sealed class CudaMtpBatchVerifyTests +{ + private static CudaBackend? TryCreate() + { + if (!CudaBackend.IsAvailable()) return null; + try { return CudaBackend.Create(); } catch { return null; } + } + + private static string? FindMtpModelPath() + { + string[] absoluteCandidates = + { + @"C:\p\sharpi\models\Qwen3.6-27B-MTP-Q4_K_M.gguf", + @"E:\models\Qwen3.6-27B-MTP-Q4_K_M.gguf", + }; + foreach (var p in absoluteCandidates) + if (File.Exists(p)) return p; + + var dir = Directory.GetCurrentDirectory(); + for (int i = 0; i < 8; i++) + { + var p = Path.Combine(dir, "models", "Qwen3.6-27B-MTP-Q4_K_M.gguf"); + if (File.Exists(p)) return p; + var parent = Directory.GetParent(dir); + if (parent is null) break; + dir = parent.FullName; + } + return null; + } + + /// + /// Constructs the pass with a 4-token snapshot ring (the production default is + /// 2 — the measured k=2 optimum — but these tests exercise k=4 batches). + /// SHARPI_MTP_BATCH_MAX is instance-resolved at construction, so the env scope + /// only needs to cover the ctor. + /// + private static CudaHybridGdnForwardPass CreatePass(GgufModel model, CudaBackend gpu, + ModelHyperparams hp) + { + var placement = new LayerPlacement( + GpuLayers: hp.NumLayers, + CpuLayers: 0, + GpuWeightBytes: 0, + GpuKvBytes: 0, + RecommendedCtxSize: Math.Min(hp.ContextLength, 2048)); + var prev = Environment.GetEnvironmentVariable("SHARPI_MTP_BATCH_MAX"); + Environment.SetEnvironmentVariable("SHARPI_MTP_BATCH_MAX", "4"); + try + { + return new CudaHybridGdnForwardPass(model, gpu, hp, placement); + } + finally + { + Environment.SetEnvironmentVariable("SHARPI_MTP_BATCH_MAX", prev); + } + } + + private static int ArgMax(float[] logits) + { + int best = 0; + for (int i = 1; i < logits.Length; i++) + if (logits[i] > logits[best]) best = i; + return best; + } + + private static float MaxAbsDiff(float[] a, float[] b) + { + float m = 0; + for (int i = 0; i < a.Length; i++) + { + float d = MathF.Abs(a[i] - b[i]); + if (d > m) m = d; + } + return m; + } + + [Fact] + public void BatchVerify_MatchesSequentialForward_PerPosition() + { + using var gpu = TryCreate(); + if (gpu is null) return; + var path = FindMtpModelPath(); + if (path is null) return; + + using var model = GgufModel.Open(path); + var hp = ModelHyperparams.FromGgufMetadata(model.Metadata, model); + Assert.True(hp.NumMtpLayers > 0); + var tokenizer = GgufTokenizer.FromGgufModel(model); + using var fwd = CreatePass(model, gpu, hp); + Assert.True(fwd.HasMtpHead); + Assert.True(fwd.SupportsBatchVerify, + "27B-MTP without SnapKV must support batched verify (GDN ring must have allocated)."); + Assert.True(fwd.MaxBatchVerifyTokens >= 4, + $"Default ring should allow ≥4-token batches; got {fwd.MaxBatchVerifyTokens}."); + + var prompt = tokenizer.Encode("The quick brown fox jumps over the lazy dog and then").ToArray(); + int P = prompt.Length; + + // Reference: greedy continuation via sequential Forward (k = 4 tokens). + var prefillLogits = fwd.Prefill(prompt).ToArray(); + const int K = 4; + var contTokens = new int[K]; + var seqLogits = new float[K][]; + contTokens[0] = ArgMax(prefillLogits); + for (int i = 0; i < K; i++) + { + seqLogits[i] = fwd.Forward(contTokens[i], P + i).ToArray(); + if (i + 1 < K) contTokens[i + 1] = ArgMax(seqLogits[i]); + } + + // Same tokens through one packed BatchVerify on a freshly prefilled state. + fwd.ResetCache(); + _ = fwd.Prefill(prompt); + var batch = fwd.BatchVerify(contTokens, P); + + Assert.Equal(K, batch.Length); + for (int i = 0; i < K; i++) + { + Assert.Equal(ArgMax(seqLogits[i]), ArgMax(batch[i])); + float maxAbs = MaxAbsDiff(seqLogits[i], batch[i]); + Assert.True(maxAbs < 0.25f, + $"BatchVerify logits at position {P + i} diverge from sequential Forward " + + $"(maxAbs={maxAbs:F4}) — beyond the MatVec2In-vs-MatVecDual noise envelope; " + + "suspect a position/state mismatch in the batched trunk."); + } + } + + [Fact] + public void BatchVerify_Rollback_RestoresDeviceGdnState() + { + using var gpu = TryCreate(); + if (gpu is null) return; + var path = FindMtpModelPath(); + if (path is null) return; + + using var model = GgufModel.Open(path); + var hp = ModelHyperparams.FromGgufMetadata(model.Metadata, model); + var tokenizer = GgufTokenizer.FromGgufModel(model); + using var fwd = CreatePass(model, gpu, hp); + if (!fwd.SupportsBatchVerify || fwd.MaxBatchVerifyTokens < 4) return; + + var prompt = tokenizer.Encode("Water boils at one hundred degrees and freezes at").ToArray(); + int P = prompt.Length; + + // Reference trajectory: g0 then two more greedy tokens, fully sequential. + var prefillLogits = fwd.Prefill(prompt).ToArray(); + int g0 = ArgMax(prefillLogits); + var l1 = fwd.Forward(g0, P).ToArray(); + int g1 = ArgMax(l1); + var l2 = fwd.Forward(g1, P + 1).ToArray(); + int g2 = ArgMax(l2); + var l3 = fwd.Forward(g2, P + 2).ToArray(); + + // Fresh state → verify g0 + three JUNK drafts, roll back to P+1 (only g0 + // kept), then replay the true continuation sequentially. If the device GDN + // ring restore is broken, the junk tokens' rank-1 recurrence updates stay + // baked in and the replayed logits drift far beyond kernel noise. + fwd.ResetCache(); + _ = fwd.Prefill(prompt); + int junk = (g1 + 7) % hp.VocabSize; + var batch = fwd.BatchVerify([g0, junk, junk, junk], P); + Assert.Equal(ArgMax(l1), ArgMax(batch[0])); + + fwd.RestoreBatchSnapshot(P + 1); + var r2 = fwd.Forward(g1, P + 1).ToArray(); + Assert.Equal(ArgMax(l2), ArgMax(r2)); + float d2 = MaxAbsDiff(l2, r2); + Assert.True(d2 < 0.25f, + $"Post-rollback Forward at P+1 diverges from the sequential trajectory " + + $"(maxAbs={d2:F4}) — the GDN snapshot ring did not restore the device state."); + + var r3 = fwd.Forward(g2, P + 2).ToArray(); + Assert.Equal(ArgMax(l3), ArgMax(r3)); + float d3 = MaxAbsDiff(l3, r3); + Assert.True(d3 < 0.25f, + $"Second post-rollback Forward diverges (maxAbs={d3:F4}); residual junk-draft " + + "contamination in the GDN recurrence."); + } + + [Fact] + public void MtpDecoder_BatchedGreedy_CoherentWithAcceptedDrafts() + { + using var gpu = TryCreate(); + if (gpu is null) return; + var path = FindMtpModelPath(); + if (path is null) return; + + using var model = GgufModel.Open(path); + var hp = ModelHyperparams.FromGgufMetadata(model.Metadata, model); + var tokenizer = GgufTokenizer.FromGgufModel(model); + using var fwd = CreatePass(model, gpu, hp); + if (!fwd.SupportsBatchVerify || fwd.MaxBatchVerifyTokens < 4) return; + + var prompt = tokenizer.Encode( + "Write a Python function that sorts a list using the quicksort algorithm:").ToArray(); + var logits = fwd.Prefill(prompt); + + var decoder = new MtpDecoder(fwd); + decoder.Initialize(prompt.Length, logits); + fwd.PrefillMtp(prompt); + + var produced = new List(24); + int[] stops = tokenizer.EogTokenIds.ToArray(); + decoder.Decode(24, stops, produced.Add, pMin: 1f, draftN: 3); + + Assert.True(produced.Count >= 8, + $"Batched MTP decode stopped after {produced.Count} tokens — unexpectedly early EOS."); + Assert.True(produced.Distinct().Count() >= 2, + $"Degenerate decode: [{string.Join(",", produced)}]"); + // Chained drafting must actually land accepts (the 27B head accepts 95-100% + // at depth 1; depth-3 chains compound but anything below ~30% means the + // chain/self-hidden wiring is broken even if output stays correct). + Assert.True(decoder.TotalDraftsEmitted > 0); + Assert.True(decoder.AcceptanceRate >= 0.3f, + $"Chained-draft acceptance {decoder.AcceptanceRate:P0} " + + $"({decoder.TotalDraftsAccepted}/{decoder.TotalDraftsEmitted}) is far below the " + + "depth-1 reference (95-100%); MtpLastHidden chaining or the MTP KV refresh is off."); + } +} diff --git a/tests/SharpInference.Tests.ForwardPass/MtpDecoderBatchVerifyTests.cs b/tests/SharpInference.Tests.ForwardPass/MtpDecoderBatchVerifyTests.cs new file mode 100644 index 0000000..8bd62f4 --- /dev/null +++ b/tests/SharpInference.Tests.ForwardPass/MtpDecoderBatchVerifyTests.cs @@ -0,0 +1,289 @@ +using SharpInference.Core; +using SharpInference.Engine; + +namespace SharpInference.Tests.ForwardPass; + +/// +/// Scripted-pass coverage for the folded k-token MTP batched verify loop +/// (, issues #30 / #207 goal 4). A deterministic fake +/// scripts the trunk's greedy chain and the MTP head's +/// drafts (with injectable disagreement positions), and asserts the decoder's +/// cache-length contracts on every call — so accept/reject bookkeeping bugs fail +/// loudly here without any model file. +/// +/// Invariant under pMin = 1 (argmax-match accept): the emitted sequence equals +/// the trunk's greedy chain REGARDLESS of what the MTP head drafts — rejections +/// only cost speed. Every test asserts that first. +/// +public sealed class MtpDecoderBatchVerifyTests +{ + private const int Vocab = 16; + private const int EmbDim = 4; + + /// Deterministic trunk chain: next token after t is (t + 1) % Vocab. + private static int NextTarget(int t) => (t + 1) % Vocab; + + private static float[] Logits(int next) + { + var l = new float[Vocab]; + l[next] = 1f; + return l; + } + + /// + /// Scripted MTP-capable pass. Trunk chain is ; the MTP + /// head drafts the same chain except at positions listed in + /// (where it proposes t+2, which the verify + /// rejects). Asserts the MtpDecoder's length contracts on every call. + /// + private sealed class ScriptedMtpPass : IForwardPass + { + public bool BatchVerifySupported = true; + public int MaxBatch = 8; + public HashSet RejectAtPositions = new(); + + public int MainLen; // trunk cache length + public int MtpLen; // MTP KV cache length + public int HistLen; // hidden-history length + + public int ForwardCalls; + public int BatchVerifyCalls; + public readonly List BatchVerifyKs = new(); + public readonly List RestoreCalls = new(); + // (position, prevHiddenMarker) for every MtpForward — markers: trunk hidden + // at pos p encodes p; MTP self-hidden at draft pos p encodes -(p + 1000). + public readonly List<(int Pos, float PrevMarker)> MtpCalls = new(); + + private readonly float[] _lastHidden = new float[EmbDim]; + private readonly float[] _mtpSelfHidden = new float[EmbDim]; + private readonly Dictionary _hist = new(); + + public ScriptedMtpPass(int prefillLen) + { + MainLen = prefillLen; + MtpLen = prefillLen; + HistLen = prefillLen; + _lastHidden[0] = prefillLen - 1; // trunk hidden marker for h@prefillLen-1 + } + + public int VocabSize => Vocab; + public int MaxSeqLen => 4096; + public bool HasMtpHead => true; + public bool SupportsBatchVerify => BatchVerifySupported; + public int MaxBatchVerifyTokens => MaxBatch; + public ReadOnlySpan LastHidden => _lastHidden; + public ReadOnlySpan MtpLastHidden => _mtpSelfHidden; + + public ReadOnlySpan HiddenAt(int position) + { + if (position < 0 || position >= HistLen) return default; + return _hist.TryGetValue(position, out var h) ? h : default; + } + + private void WriteHist(int position) + { + var h = new float[EmbDim]; + h[0] = position; + _hist[position] = h; + if (HistLen < position + 1) HistLen = position + 1; + } + + public ReadOnlySpan Forward(int token, int position) + { + Assert.Equal(MainLen, position); + ForwardCalls++; + MainLen++; + WriteHist(position); + _lastHidden[0] = position; + return Logits(NextTarget(token)); + } + + public float[][] BatchVerify(int[] tokens, int startPos) + { + Assert.True(BatchVerifySupported, "BatchVerify called while unsupported"); + Assert.True(tokens.Length <= MaxBatch, + $"BatchVerify k={tokens.Length} exceeds MaxBatchVerifyTokens={MaxBatch}"); + Assert.Equal(MainLen, startPos); + BatchVerifyCalls++; + BatchVerifyKs.Add(tokens.Length); + var result = new float[tokens.Length][]; + for (int i = 0; i < tokens.Length; i++) + { + result[i] = Logits(NextTarget(tokens[i])); + WriteHist(startPos + i); + } + MainLen += tokens.Length; + _lastHidden[0] = startPos + tokens.Length - 1; + return result; + } + + public void RestoreBatchSnapshot(int lengthAfter) + { + RestoreCalls.Add(lengthAfter); + Assert.True(lengthAfter < MainLen, + "RestoreBatchSnapshot must rewind, not extend, the trunk cache."); + MainLen = lengthAfter; + if (MtpLen > lengthAfter) MtpLen = lengthAfter; + if (HistLen > lengthAfter) HistLen = lengthAfter; + } + + public ReadOnlySpan MtpForward(int token, int position, ReadOnlySpan prevHidden) + { + Assert.Equal(MtpLen, position); + MtpCalls.Add((position, prevHidden[0])); + MtpLen++; + _mtpSelfHidden[0] = -(position + 1000); + int draft = RejectAtPositions.Contains(position + 1) + ? (token + 2) % Vocab + : NextTarget(token); + return Logits(draft); + } + + public void MtpTruncateTo(int length) + { + if (MtpLen > length) MtpLen = length; + } + + public ReadOnlySpan Prefill(IReadOnlyList tokens, int startPos = 0) => new float[Vocab]; + public void TruncateTo(int length) { } + public void ResetCache() { } + public void Dispose() { } + } + + private static List Decode(ScriptedMtpPass pass, int prefillLen, int firstToken, + int maxTokens, int draftN, int[]? stops = null) + { + var dec = new MtpDecoder(pass); + dec.Initialize(prefillLen, Logits(firstToken)); + var emitted = new List(); + dec.Decode(maxTokens, stops ?? [], emitted.Add, pMin: 1f, draftN: draftN); + return emitted; + } + + private static List TargetChain(int firstToken, int count) + { + var chain = new List(count) { firstToken }; + while (chain.Count < count) chain.Add(NextTarget(chain[^1])); + return chain; + } + + [Fact] + public void AllAccept_EmitsTargetChain_OneBatchPerStep_NoForward() + { + var pass = new ScriptedMtpPass(prefillLen: 10); + var emitted = Decode(pass, 10, firstToken: 3, maxTokens: 12, draftN: 3); + + Assert.Equal(TargetChain(3, 12), emitted); + // 12 tokens at 4 tokens/step (all drafts accepted) = 3 batched passes. + Assert.Equal(3, pass.BatchVerifyCalls); + Assert.All(pass.BatchVerifyKs, k => Assert.Equal(4, k)); + Assert.Empty(pass.RestoreCalls); + Assert.Equal(0, pass.ForwardCalls); // the fold: zero per-token trunk forwards + } + + [Fact] + public void RejectEveryDraft_StillEmitsTargetChain_WithoutCorrectionForwards() + { + var pass = new ScriptedMtpPass(prefillLen: 10); + // Reject the FIRST draft of every step: drafts predict positions 11, 12, ... + for (int p = 0; p < 64; p++) pass.RejectAtPositions.Add(p); + + var emitted = Decode(pass, 10, firstToken: 3, maxTokens: 8, draftN: 3); + + // pMin=1 invariant: corrections ride into the next batch, output unchanged. + Assert.Equal(TargetChain(3, 8), emitted); + // Every step emits only t1 → one batch per emitted token, EXCEPT the final + // token: with 1 budget left the step degrades to the sequential tail, which + // emits the pending argmax without any trunk pass. Every batched step rolls + // back to P+1 (zero accepted drafts); the target never runs Forward. + Assert.Equal(7, pass.BatchVerifyCalls); + Assert.Equal(7, pass.RestoreCalls.Count); + Assert.Equal(0, pass.ForwardCalls); + } + + [Fact] + public void MidChainReject_RollsBackToAcceptedBoundary() + { + var pass = new ScriptedMtpPass(prefillLen: 10); + // Step 1 verifies positions [10, 14); drafts predict 11, 12, 13. + // Reject the draft predicting 13 → 2 drafts accepted, rollback to 13. + pass.RejectAtPositions.Add(13); + + var emitted = Decode(pass, 10, firstToken: 3, maxTokens: 6, draftN: 3); + + Assert.Equal(TargetChain(3, 6), emitted); + Assert.Equal(13, Assert.Single(pass.RestoreCalls)); + } + + [Fact] + public void StopToken_NotEmitted_DecodeEnds() + { + var pass = new ScriptedMtpPass(prefillLen: 10); + // Chain from 3: 3, 4, 5, 6, ... — stop at 6. + var emitted = Decode(pass, 10, firstToken: 3, maxTokens: 12, draftN: 3, stops: [6]); + Assert.Equal(new List { 3, 4, 5 }, emitted); + } + + [Fact] + public void CapabilityOff_FallsBackToSequential() + { + var pass = new ScriptedMtpPass(prefillLen: 10) { BatchVerifySupported = false }; + var emitted = Decode(pass, 10, firstToken: 3, maxTokens: 6, draftN: 3); + + Assert.Equal(TargetChain(3, 6), emitted); + Assert.Equal(0, pass.BatchVerifyCalls); + Assert.True(pass.ForwardCalls > 0, "sequential fallback must drive Forward"); + } + + [Fact] + public void MaxBatchVerifyTokens_ClampsTheChain() + { + var pass = new ScriptedMtpPass(prefillLen: 10) { MaxBatch = 2 }; + var emitted = Decode(pass, 10, firstToken: 3, maxTokens: 8, draftN: 5); + + Assert.Equal(TargetChain(3, 8), emitted); + // Every batch clamped to the ring capacity (the fake also hard-asserts ≤ MaxBatch). + Assert.All(pass.BatchVerifyKs, k => Assert.Equal(2, k)); + } + + [Fact] + public void DraftN1_MatchesLegacyTwoTokenShape() + { + var pass = new ScriptedMtpPass(prefillLen: 10); + var emitted = Decode(pass, 10, firstToken: 3, maxTokens: 8, draftN: 1); + + Assert.Equal(TargetChain(3, 8), emitted); + Assert.All(pass.BatchVerifyKs, k => Assert.Equal(2, k)); + } + + [Fact] + public void DraftChain_SelfChains_AndRefreshUsesTrunkHiddens() + { + var pass = new ScriptedMtpPass(prefillLen: 10); + _ = Decode(pass, 10, firstToken: 3, maxTokens: 4, draftN: 3); + + // One step: chain at positions 10, 11, 12 then (all-accept) refresh at 11, 12, 13. + Assert.Equal(6, pass.MtpCalls.Count); + + // Chain call 1 (pos 10): prevHidden = trunk h@9 (marker 9). + Assert.Equal((10, 9f), pass.MtpCalls[0]); + // Chain calls 2-3 self-chain on the MTP block hidden (marker -(pos-1 + 1000)). + Assert.Equal((11, -1010f), pass.MtpCalls[1]); + Assert.Equal((12, -1011f), pass.MtpCalls[2]); + // Refresh rewrites accepted positions with TRUNK hiddens h@10..h@12. + Assert.Equal((11, 10f), pass.MtpCalls[3]); + Assert.Equal((12, 11f), pass.MtpCalls[4]); + Assert.Equal((13, 12f), pass.MtpCalls[5]); + } + + [Fact] + public void RemainingBudget_ShrinksTheLastBatch() + { + var pass = new ScriptedMtpPass(prefillLen: 10); + // maxTokens=6 with k=4 steps: step 1 emits 4, step 2 may batch at most 2. + var emitted = Decode(pass, 10, firstToken: 3, maxTokens: 6, draftN: 3); + + Assert.Equal(TargetChain(3, 6), emitted); + Assert.Equal(new List { 4, 2 }, pass.BatchVerifyKs); + } +} From 112b6af1119c560f946e9504e63c10b0fff12ed0 Mon Sep 17 00:00:00 2001 From: Pekka Heikura Date: Thu, 11 Jun 2026 14:39:57 +0300 Subject: [PATCH 6/6] fix(engine): #208 review findings - accepted-stop state consistency, dense SnapKV verify gate, matvec small-N threshold, CLI spec window guards - MtpDecoder: an accepted STOP draft now clamps acceptance before rollback so trunk KV / GDN state / MTP KV / hidden history all end exactly at _nextPos (previously stranded at P+1+a past the stop); one-time notice when draftN exceeds the snapshot-ring capacity (the silent-clamp server case); greedy selection delegates to Sampler.Greedy (3 copies -> 1). - ForwardPass (CPU dense): SupportsBatchVerify gates on Length==LogicalLength - a SnapKV-compacted cache routed BatchVerify's TruncateTo(startPos) past the compacted slots (the #130 gate the CUDA/GDN passes already had). - CudaHybridGdnForwardPass: the _matVecBatchedOnly latch (leak window between _faulted=false and latch-clear) is replaced by a MatMulComputeBatchMinN=8 threshold inside GpuMatMulBatched - the dequant-GEMM/MMQ temps only amortize at prefill N; host _batchSnapshotBuf now allocated only for the SHARPI_CPU_GDN=1 trunk (~150 MB host saved per MTP load); _bvCap reset before reallocs. - HybridGdnForwardPass: null-after-free in the scratch/ring growers (OOM-path double-free); SHARPI_MTP_BATCH_MAX semantics live in one shared GdnStateCache.ResolveMtpBatchMax. - CLI: prompts that exceed the (possibly 4096-capped) draft KV ring now fail fast with an actionable error before prefill writes out of range; a --spec-draft-n-max beyond ring capacity warns and clamps instead of disabling MTP outright; maxNew can no longer go negative silently. - PromptLookupDraft: List.CopyTo for the proposal copy (PR #208 review). - IForwardPass: SupportsBatchVerify doc no longer routes readers to the retired BatchForward2 dispatch. - README: 27B-MTP rows updated to the #30/#207 results (10.4 t/s, 1.68x). Tests: 45 scripted + 26 model-gated green incl. MtpDecoder_GreedyParity_LlamaCpp. Co-Authored-By: Claude Fable 5 --- README.md | 16 +++-- src/SharpInference.Cli/RunCommand.cs | 66 +++++++++++++++---- src/SharpInference.Core/IForwardPass.cs | 19 +++--- .../CudaHybridGdnForwardPass.cs | 58 ++++++++-------- src/SharpInference.Engine/ForwardPass.cs | 15 ++++- src/SharpInference.Engine/GdnStateCache.cs | 14 ++++ .../HybridGdnForwardPass.cs | 34 +++++----- src/SharpInference.Engine/MtpDecoder.cs | 63 ++++++++---------- .../PromptLookupDraft.cs | 2 +- .../SpeculativeDecoder.cs | 25 ++----- .../MtpDecoderBatchVerifyTests.cs | 16 ++++- 11 files changed, 191 insertions(+), 137 deletions(-) diff --git a/README.md b/README.md index de1c6c7..43d4e1b 100644 --- a/README.md +++ b/README.md @@ -43,8 +43,8 @@ coherent (`scripts/bench-all.ps1`); top-1 parity vs llama.cpp b8585 verified on | Llama-4 Scout 17B-16E (MoE) | (same) | 61 GB | CUDA `-g -1` (hybrid) | 1.2 | 2.6 | 7 GPU + 41 CPU layers — model dwarfs the 12 GB card so CPU-only wins; per-expert SLRU streaming (#72/#77) still lifts both (not on bench machine) | | Qwen3.6-35B-A3B (GDN+MoE) | [unsloth](https://huggingface.co/unsloth/Qwen3.6-35B-A3B-GGUF) | 22 GB | CPU | 9.0 | 9.3 | hybrid GDN/attn, 256 experts / 8 active | | Qwen3.6-35B-A3B (GDN+MoE) | (same) | 22 GB | **CUDA** `-g -1` (hybrid) | **55.1** | **23.7** | 10 attn + 30 GDN on GPU; MoE auto-routed to CPU, shared expert on GPU overlapped with the routed loop. Fused GDN scan + batched-query SDPA (#114-B/#118), bit-identical, win grows with ctx. `SHARPI_CPU_MOE=0` forces on-GPU experts (#129 fused MoE-reduce kernel, +20% prefill) | -| Qwen3.6-27B-MTP (GDN) | [unsloth](https://huggingface.co/unsloth/Qwen3.6-27B-MTP-GGUF) | 16 GB | CPU `--no-thinking` | 3.2 | **3.8** | dense 27B GDN/attn + native MTP head; auto MTP self-spec (#25) at greedy + `--no-thinking`. 95% draft acceptance; batched N=2 verify (#30) over MTP-off | -| Qwen3.6-27B-MTP (GDN) | (same) | 16 GB | **CUDA** `-g -1 --no-thinking` (hybrid) | **9.3** | **7.2** | 20/64 dense FFN on GPU + GDN/attn KV resident, 44/64 FFN CPU mmap. 95% acceptance; batched trunk + on-GPU dense-FFN (#119/#121), bit-identical | +| Qwen3.6-27B-MTP (GDN) | [unsloth](https://huggingface.co/unsloth/Qwen3.6-27B-MTP-GGUF) | 16 GB | CPU `--no-thinking` | 3.0 | **3.6** | dense 27B GDN/attn + native MTP head; auto MTP self-spec (#25) at greedy + `--no-thinking`. 90% draft acceptance; folded k-token batched verify (#30/#207) — 1.2× over MTP-off (3.0) | +| Qwen3.6-27B-MTP (GDN) | (same) | 16 GB | **CUDA** `-g -1 --no-thinking` (hybrid) | **7.3** | **10.4** | 22/64 dense FFN on GPU + GDN/attn KV resident, 42/64 FFN CPU mmap. 90% acceptance; folded k-token batched verify + GDN snapshot ring (#30/#207) — **1.68× over MTP-off (6.4)**. Deeper chains: `--spec-draft-n-max N` + `SHARPI_MTP_BATCH_MAX=N+1` (~150 MiB VRAM/slot) | | Qwen3.6-27B-MTP (GDN) | (same) | 19 GB | CPU `--no-thinking` `Q5_K_M` | 2.8 | **3.5** | ~10% slower than Q4_K_M; 100% acceptance | | Qwen3.6-27B-MTP (GDN) | (same) | 19 GB | **CUDA** `-g -1 --no-thinking` `Q5_K_M` (hybrid) | 5.9 | **5.5** | 13/64 FFN on GPU, 51/64 CPU mmap. 98% acceptance; batched trunk (#119) bit-identical | | Qwen3.6-35B-A3B-MTP (GDN+MoE) | [unsloth](https://huggingface.co/unsloth/Qwen3.6-35B-A3B-MTP-GGUF) | 22 GB | CPU `--no-thinking` | 9.1 | **8.5** | GDN/attn + 256-expert MoE + MTP head (#44). 100% acceptance; MoE-MTP batched verify (#45) — routed experts sequential per token, so ~MTP-off parity | @@ -116,10 +116,14 @@ to ~5 t/s at 6K, so FastScan is ~1.9× decode there. Models with native MTP heads (Qwen3.6-27B-MTP, Qwen3.5/3.6 A3B-MTP, DeepSeek V3/R1) get self-speculative decoding with no separate draft model. It engages automatically when the pass reports `HasMtpHead`, sampling -is greedy (`--temp 0`), and thinking is off (`--no-thinking`); the CLI prints `MTP accept: N%`. Batched N=2 -verify (#30) is the default for dense MTP; MoE MTP also batches the trunk while routed experts run per token -(#45). CLI mirrors llama.cpp: `--spec-type`, `--spec-draft-n-max <1|2>`, `--spec-draft-p-min <0..1>` -(lossy probabilistic accept). `SHARPI_DISABLE_MTP=1` / `SHARPI_DISABLE_BATCH_VERIFY=1` are the off-switches. +is greedy (`--temp 0`), and thinking is off (`--no-thinking`); the CLI prints `MTP accept: N%`. The default is +a folded k-token batched verify (#30/#207): the certain token plus a chained draft sequence run through ONE +batched trunk pass per step, with rejections rolled back via a per-token GDN snapshot ring; a rejected draft's +correction rides into the next step's batch, so no per-step commit forward exists. MoE MTP batches the trunk +while routed experts run per token (#45). CLI mirrors llama.cpp: `--spec-type`, `--spec-draft-n-max ` +(drafts/step; default 1 = the measured optimum — deeper chains also need `SHARPI_MTP_BATCH_MAX>=N+1` ring +slots at ~150 MiB VRAM each), `--spec-draft-p-min <0..1>` (lossy probabilistic accept). +`SHARPI_DISABLE_MTP=1` / `SHARPI_DISABLE_BATCH_VERIFY=1` are the off-switches. ### Chat-continuation cache diff --git a/src/SharpInference.Cli/RunCommand.cs b/src/SharpInference.Cli/RunCommand.cs index 0c20587..616e598 100644 --- a/src/SharpInference.Cli/RunCommand.cs +++ b/src/SharpInference.Cli/RunCommand.cs @@ -768,6 +768,28 @@ protected override int Execute(CommandContext context, Settings settings, Cancel } } + /// + /// True when a prompt of tokens leaves no room to + /// speculate inside BOTH context windows (prompt + lookahead + 1 correction token). + /// Prints an actionable error: the typical trigger is the CUDA draft's 4096-token + /// KV ring cap when -c isn't pinned, where prefilling past the ring would + /// write K/V out of range and a tail prompt would silently emit zero tokens. + /// + private static bool SpecWindowExhausted(int promptTokens, + IForwardPass target, IForwardPass? draft, int lookahead) + { + int window = Math.Min(target.MaxSeqLen, draft?.MaxSeqLen ?? int.MaxValue); + if (promptTokens + lookahead + 1 < window) return false; + AnsiConsole.MarkupLine( + $"[red]Error:[/] prompt ({promptTokens} tokens) + lookahead ({lookahead}) does not fit the " + + $"speculative context window ({window} tokens" + + (draft is not null && draft.MaxSeqLen < target.MaxSeqLen + ? $", limited by the draft model's KV ring — pass -c to size it explicitly" + : "") + + "). Shorten the prompt, raise -c, or drop --draft-model/--draft-lookup."); + return true; + } + private static int RunSpeculativeSinglePrompt(Settings s, IForwardPass target, IForwardPass? draft, GgufTokenizer tok, SamplingParams sp) @@ -775,6 +797,13 @@ private static int RunSpeculativeSinglePrompt(Settings s, var prompt = FormatPrompt(s.Prompt!, s.SystemPrompt, enableThinking: !s_noThinking); var tokens = tok.Encode(prompt); + // The prompt must fit BOTH context windows BEFORE any prefill runs — the + // draft's ring may be much smaller than the target's (the CUDA spec path + // caps it at 4096 when -c isn't pinned), and a too-long prompt would write + // K/V past the ring's end during draft.Prefill, not merely cap generation. + if (SpecWindowExhausted(tokens.Count, target, draft, s.SpecLookahead)) + return 1; + if (!s.NoDisplayPrompt) Console.Write(s.Prompt); @@ -799,6 +828,7 @@ private static int RunSpeculativeSinglePrompt(Settings s, // Bound generation by BOTH context windows (the draft's may be smaller — the CUDA // spec path caps its KV ring), leaving lookahead headroom for the last spec step. + // The guard above ensures maxNew >= 1 here. int maxNew = Math.Min(sp.MaxNewTokens, Math.Min(target.MaxSeqLen, draft?.MaxSeqLen ?? int.MaxValue) - tokens.Count - s.SpecLookahead - 1); if (maxNew < sp.MaxNewTokens) @@ -848,6 +878,12 @@ private static int RunSpeculativeInteractive(Settings s, var prompt = FormatPrompt(input, s.SystemPrompt, enableThinking: !s_noThinking); var tokens = tok.Encode(prompt); + // Same pre-prefill window guard as the single-prompt runner: the draft + // ring may be smaller than the target's window, and prefilling past it + // writes K/V out of range rather than just capping generation. + if (SpecWindowExhausted(tokens.Count, target, draft, s.SpecLookahead)) + continue; + target.ResetCache(); draft?.ResetCache(); @@ -1006,24 +1042,30 @@ private static bool ResolveCliMtp(IForwardPass? mtpFwd, SamplingParams sp, bool if (mtpFwd is null || !mtpFwd.HasMtpHead) { rejectReason = "--spec-type mtp requires a model with an MTP head (nextn tensors)."; return false; } if (sp.Temperature > 0f) { rejectReason = "--spec-type mtp requires greedy sampling (--temp 0)."; return false; } if (!noThinking) { rejectReason = "--spec-type mtp requires --no-thinking (chat template must render with enable_thinking=false)."; return false; } - if (sp.SpecDraftNMax > maxDraftN) - { - rejectReason = $"--spec-draft-n-max={sp.SpecDraftNMax} exceeds this pass's batched-verify capacity " + - $"({maxDraftN} drafts/step); raise SHARPI_MTP_BATCH_MAX (snapshot-ring slots) to go deeper."; - return false; - } + WarnIfDraftNClamped(sp.SpecDraftNMax, maxDraftN); return true; default: // Auto - if (eligible && sp.SpecDraftNMax > maxDraftN) - { - rejectReason = $"--spec-draft-n-max={sp.SpecDraftNMax} exceeds this pass's batched-verify capacity " + - $"({maxDraftN} drafts/step); raise SHARPI_MTP_BATCH_MAX (snapshot-ring slots) to go deeper."; - return false; - } + if (eligible) + WarnIfDraftNClamped(sp.SpecDraftNMax, maxDraftN); return eligible; } } + /// + /// A draft chain deeper than the snapshot ring's capacity is CLAMPED, not rejected + /// (rejecting would disable MTP entirely and run SLOWER — the silent-baseline trap + /// the old SpecDraftNMax>1 throw existed to prevent). Warn so the user knows the + /// effective depth and the knob that raises it; MtpDecoder clamps per step. + /// + private static void WarnIfDraftNClamped(int requested, int maxDraftN) + { + if (requested > maxDraftN) + AnsiConsole.MarkupLine( + $"[yellow]Note:[/] --spec-draft-n-max={requested} exceeds the snapshot-ring capacity; " + + $"running {maxDraftN} draft(s)/step. Set SHARPI_MTP_BATCH_MAX={requested + 1} to go deeper " + + "(each ring slot costs ~150 MiB VRAM on 27B-class models)."); + } + // MTP self-speculative decode path. Reuses the same UTF-8 streaming + EmitToken // logic as the baseline DecodeLoop but drives token emission via MtpDecoder. // Requires --no-thinking, so no thinking-mode bookkeeping here. diff --git a/src/SharpInference.Core/IForwardPass.cs b/src/SharpInference.Core/IForwardPass.cs index f294ef5..0ba7885 100644 --- a/src/SharpInference.Core/IForwardPass.cs +++ b/src/SharpInference.Core/IForwardPass.cs @@ -152,15 +152,16 @@ ReadOnlySpan MtpForward(int token, int position, ReadOnlySpan prev void PrefillMtp(IReadOnlyList tokens, int startPos = 0) { } /// - /// True when this pass implements a batched verify path. Two consumers, two methods: - /// the MTP decoder dispatches to (two-token self-speculative - /// verify, issue #30 — implemented by the hybrid GDN passes), and the speculative - /// decoder dispatches to (k-token draft verification, - /// issue #207 — implemented by the rewindable dense passes). The consumers' own gates - /// keep the two method sets disjoint: the MTP decoder requires - /// (GDN hybrids only), the speculative decoder requires - /// (which GDN hybrids never report). A pass that - /// returns true here must implement whichever method its reachable consumer calls. + /// True when this pass implements . Two consumers share the + /// method: the speculative decoder (k-token draft verification, issue #207 — rewindable + /// dense passes; rollback via ) and the MTP decoder (k-token + /// self-speculative verify, issue #30 — hybrid GDN passes; rollback via + /// and the per-token GDN snapshot ring, batch size + /// capped by ). The consumers' own gates pick the + /// rollback mechanism: the MTP decoder requires (GDN hybrids + /// only), the speculative decoder requires (which + /// GDN hybrids never report). GDN passes flip this false while the KV cache is + /// SnapKV-compacted (issue #130) — consumers re-check per step. /// bool SupportsBatchVerify => false; diff --git a/src/SharpInference.Engine/CudaHybridGdnForwardPass.cs b/src/SharpInference.Engine/CudaHybridGdnForwardPass.cs index eb80737..7f5116c 100644 --- a/src/SharpInference.Engine/CudaHybridGdnForwardPass.cs +++ b/src/SharpInference.Engine/CudaHybridGdnForwardPass.cs @@ -421,19 +421,13 @@ public sealed unsafe class CudaHybridGdnForwardPass : IForwardPass // pinned so the capture rides the queued D2H stream (issue #30 draft chaining). private float* _mtpSelfHidden; - // Max tokens per BatchVerify call = ring slots + 1. SHARPI_MTP_BATCH_MAX in - // [2, 8] bounds the ring reservation. Default 2 (one slot): each slot costs - // ~149 MiB of VRAM that TryUploadDenseFfnLayers would otherwise fill with - // ~2 dense FFN layers, and the k=2 verify is the measured 27B optimum — - // deeper chains only pay once the CPU FFN amortizes more than pairwise - // (4-input MatVec follow-up). Instance-resolved so tests can override per - // construction. - private readonly int _mtpBatchMax = ResolveMtpBatchMax(); - private static int ResolveMtpBatchMax() - { - var s = Environment.GetEnvironmentVariable("SHARPI_MTP_BATCH_MAX"); - return s is not null && int.TryParse(s, out var v) ? Math.Clamp(v, 2, 8) : 2; - } + // Max tokens per BatchVerify call = ring slots + 1. Each slot costs ~149 MiB + // of VRAM that TryUploadDenseFfnLayers would otherwise fill with ~2 dense FFN + // layers, hence the conservative default; deeper chains only pay once the CPU + // FFN amortizes more than pairwise (4-input MatVec follow-up). Instance-resolved + // at construction so tests can override per instance; the knob semantics live + // in one place (GdnStateCache.ResolveMtpBatchMax) shared with the CPU pass. + private readonly int _mtpBatchMax = GdnStateCache.ResolveMtpBatchMax(); // Token-2 host FFN scratch (intermediate gate/up post-MatVec2In, pre-SiLuMul). private readonly float* _cpuFfnGateBuf2; private readonly float* _cpuFfnUpBuf2; @@ -1349,9 +1343,13 @@ void TraceVram(string label) _cpuFfnUpBuf2 = Alloc(_intermDim); } + // Host snapshot buffer for BatchForward2's between-token capture — only + // the SHARPI_CPU_GDN=1 debug trunk uses it (the default GPU trunk's + // batched-verify snapshots live in the device ring); skip the ~150 MB + // host allocation otherwise. long perLayerBytes = _gdnStateCache.LayerSnapshotBytes; long totalSnapBytes = perLayerBytes * _gdnStateCache.NumGdnLayers; - if (totalSnapBytes > 0) + if (totalSnapBytes > 0 && _cpuGdn) { _batchSnapshotBuf = (byte*)NativeMemory.Alloc((nuint)totalSnapBytes); _batchSnapshotCap = totalSnapBytes; @@ -3583,10 +3581,6 @@ public float[][] BatchVerify(int[] tokens, int startPos) // length counters still read startPos; fatal for this pass. _faulted = true; _batchSnapshotValid = false; - // Decode-sized batch: keep every batched matmul on the temp-free matvec - // re-stream (see GpuMatMulBatched). Cleared before returning; a mid-pass - // throw leaves it set, but _faulted already makes the pass unusable then. - _matVecBatchedOnly = true; var stream = _gpuStreamAll!; @@ -3718,7 +3712,6 @@ public float[][] BatchVerify(int[] tokens, int startPos) } } - _matVecBatchedOnly = false; _batchSnapshotValid = true; return result; } @@ -3737,6 +3730,8 @@ private void EnsureBatchVerifyScratch(int k) if (_gpuBvFfnAll is { } f) { _gpu.Free(f); _gpuBvFfnAll = null; } if (_bvNormHost != null) { CudaBackend.FreePinnedHost((nint)_bvNormHost); _bvNormHost = null; } if (_bvFfnHost != null) { CudaBackend.FreePinnedHost((nint)_bvFfnHost); _bvFfnHost = null; } + _bvCap = -1; // a mid-sequence alloc failure must not leave a stale cap + // matching a future k (early return on half-built scratch) long logitsTotal = (long)k * _hp.VocabSize; if (logitsTotal > int.MaxValue) throw new NotSupportedException( @@ -5071,14 +5066,16 @@ private void GpuMatMulBatched(Tensor outputAll, Tensor matrix, Tensor inputAll, // bit-parity oracle keeps the matvec path. Q4_K/Q6_K/Q5_K need 256-aligned cols; // Q8_0 needs 32 — true for every projection dim, but guarded so we fall back // (never throw) on an odd shape. - // Issue #30/#207 batched verify: the MMQ/dequant-GEMM compute path only - // amortizes its fixed per-call costs (whole-weight dequant to an fp16 temp - // for Q6_K/Q5_K — 71 MB per 27B FFN layer, ~600 MB for the lm_head) at - // prefill-scale N. At decode-sized k those temps land in WDDM-paged VRAM - // behind the post-fill 64 MiB margin and 5-10× every step. The verify path - // latches _matVecBatchedOnly so every batched matmul takes the temp-free - // matvec re-stream (the same decode kernels the sequential Forward uses). - if (GdnPrefillComputeEnabled && nTok > 1 && !_matVecBatchedOnly) + // The MMQ/dequant-GEMM compute path only amortizes its fixed per-call costs + // (whole-weight dequant to an fp16 temp for Q6_K/Q5_K — 71 MB per 27B FFN + // layer, ~600 MB for the lm_head; MMQ's activation re-quant) at prefill-scale + // N. At decode-sized N — the #30 batched verify (k ≤ 8 by SHARPI_MTP_BATCH_MAX) + // or a tiny prefill tail chunk — those temps land in WDDM-paged VRAM behind + // the post-fill 64 MiB margin and 5-10× every step (measured on the verify: + // 6.0 → 9.2 t/s from this threshold alone). Small N takes the temp-free + // matvec re-stream below — the same decode kernels sequential Forward uses, + // and the bit-exact reference path the compute kernels are validated against. + if (GdnPrefillComputeEnabled && nTok > MatMulComputeBatchMinN) { int cols = (int)(inputAll.ElementCount / nTok); switch (dt) @@ -5095,9 +5092,10 @@ private void GpuMatMulBatched(Tensor outputAll, Tensor matrix, Tensor inputAll, _gpu.MatMulBatched(outputAll, matrix, inputAll, nTok, dt); } - // True while BatchVerify drives the batched trunk: forces GpuMatMulBatched onto - // the matvec re-stream path (see the comment there). Never set on prefill. - private bool _matVecBatchedOnly; + // Crossover below which the MMQ/dequant-GEMM compute kernels' fixed per-call + // costs exceed the matvec re-stream's k× weight reads. 8 = the verify-batch + // ceiling; prefill chunks run at hundreds, so the regimes are well separated. + private const int MatMulComputeBatchMinN = 8; /// Issue #121: true when 's dtype is one of the dtypes /// implements a GEMM-N kernel for. Gates the diff --git a/src/SharpInference.Engine/ForwardPass.cs b/src/SharpInference.Engine/ForwardPass.cs index 5d7a6c0..543c9b2 100644 --- a/src/SharpInference.Engine/ForwardPass.cs +++ b/src/SharpInference.Engine/ForwardPass.cs @@ -1165,10 +1165,19 @@ private ReadOnlySpan PrefillCoreTq(IReadOnlyList tokens, int startPo /// Whether can run (issue #207): everything except the two /// configurations it throws for — the TurboQuant KV cache (compressed ring can't take /// the batched appends) and gemma4-style per-layer head_dim (not wired into the batched - /// trunk). MoE stays true: itself falls back to - /// sequential calls for MoE, which is still correct. + /// trunk) — and a SnapKV-compacted cache. After Compact the physical slot count + /// () sits below the logical RoPE position + /// (), but appends at + /// the LOGICAL position via TruncateTo(startPos), which would declare slots + /// past the compacted length valid and read garbage K/V — same #130 gate the CUDA and + /// GDN passes already have; the sequential fallback handles the + /// compacted frame correctly. MoE stays true: itself + /// falls back to sequential calls for MoE, which is still correct. /// - public bool SupportsBatchVerify => _tqKvCache is null && _layerHeadDim is null; + public bool SupportsBatchVerify => + _tqKvCache is null + && _layerHeadDim is null + && _kvCache.Length == _kvCache.LogicalLength; /// /// Batched verification for speculative decoding: processes starting diff --git a/src/SharpInference.Engine/GdnStateCache.cs b/src/SharpInference.Engine/GdnStateCache.cs index 8baad90..ef89e3e 100644 --- a/src/SharpInference.Engine/GdnStateCache.cs +++ b/src/SharpInference.Engine/GdnStateCache.cs @@ -383,6 +383,20 @@ public void RestoreLayerFrom(int gdnLayerIndex, byte* src, long srcBytes) } } + /// + /// Resolve the SHARPI_MTP_BATCH_MAX knob: max tokens per batched-verify call + /// (= 1 + max MTP draft-chain length), which sizes the per-token-boundary GDN + /// snapshot ring at value − 1 slots. Clamped to [2, 8]; default 2 — one + /// ring slot (~149 MB for 27B on either side of the PCIe bus), the measured + /// k=2 optimum until the CPU FFN amortizes more than pairwise. Shared by both + /// hybrid GDN passes so the knob means the same thing on every backend. + /// + public static int ResolveMtpBatchMax() + { + var s = Environment.GetEnvironmentVariable("SHARPI_MTP_BATCH_MAX"); + return s is not null && int.TryParse(s, out var v) ? Math.Clamp(v, 2, 8) : 2; + } + /// /// Set explicitly. Used by the batched verify path (issue #30) /// to rewind length after a per-layer state restore — the per-layer copy does not diff --git a/src/SharpInference.Engine/HybridGdnForwardPass.cs b/src/SharpInference.Engine/HybridGdnForwardPass.cs index 5ec1b91..4da21b1 100644 --- a/src/SharpInference.Engine/HybridGdnForwardPass.cs +++ b/src/SharpInference.Engine/HybridGdnForwardPass.cs @@ -144,18 +144,11 @@ public sealed unsafe class HybridGdnForwardPass : IForwardPass // used as the next chained draft's prevHidden (issue #30 multi-token drafting). private float* _mtpSelfHidden; - // Max tokens per BatchVerify call (= 1 + max MTP draft chain length). The host - // snapshot ring grows lazily to k-1 slots of NumGdnLayers × LayerSnapshotBytes - // each (~149 MB/slot for 27B), so keep a sane ceiling. SHARPI_MTP_BATCH_MAX in - // [2, 8]; default matches the CUDA pass (k=2 is the measured optimum until the - // CPU FFN amortizes more than pairwise). Instance-resolved so tests can - // override per construction. - private readonly int _mtpBatchMax = ResolveMtpBatchMax(); - private static int ResolveMtpBatchMax() - { - var s = Environment.GetEnvironmentVariable("SHARPI_MTP_BATCH_MAX"); - return s is not null && int.TryParse(s, out var v) ? Math.Clamp(v, 2, 8) : 2; - } + // Max tokens per BatchVerify call (= 1 + max MTP draft chain length); the host + // snapshot ring grows lazily to k-1 slots. Instance-resolved at construction so + // tests can override per instance; the knob semantics live in one place + // (GdnStateCache.ResolveMtpBatchMax) shared with the CUDA pass. + private readonly int _mtpBatchMax = GdnStateCache.ResolveMtpBatchMax(); // ── Dimensions (cached) ──────────────────────────────────────────── private readonly int _embDim; @@ -1380,14 +1373,17 @@ public float[][] BatchVerify(int[] tokens, int startPos) return result; } - /// Grow the [k × embDim] batched-verify residual streams (grow-only). + /// Grow the [k × embDim] batched-verify residual streams (grow-only). + /// Fields are nulled before each re-allocation so a mid-sequence OOM leaves + /// null pointers (clean re-entry / Dispose) instead of dangling ones. private void EnsureBatchVerifyScratch(int k) { if (_bvCap >= k) return; nuint bytes = (nuint)((long)k * _embDim * sizeof(float)); - if (_bvHiddenAll != null) NativeMemory.Free(_bvHiddenAll); - if (_bvResidAll != null) NativeMemory.Free(_bvResidAll); - if (_bvNormAll != null) NativeMemory.Free(_bvNormAll); + if (_bvHiddenAll != null) { NativeMemory.Free(_bvHiddenAll); _bvHiddenAll = null; } + if (_bvResidAll != null) { NativeMemory.Free(_bvResidAll); _bvResidAll = null; } + if (_bvNormAll != null) { NativeMemory.Free(_bvNormAll); _bvNormAll = null; } + _bvCap = 0; _bvHiddenAll = (float*)NativeMemory.AllocZeroed(bytes); _bvResidAll = (float*)NativeMemory.AllocZeroed(bytes); _bvNormAll = (float*)NativeMemory.AllocZeroed(bytes); @@ -1395,13 +1391,15 @@ private void EnsureBatchVerifyScratch(int k) } /// Grow the GDN snapshot ring to at least slots - /// (grow-only; contents need not survive — the ring is rewritten every batch). + /// (grow-only; contents need not survive — the ring is rewritten every batch). + /// Same null-before-realloc discipline as . private void EnsureBatchSnapshotSlots(int slots) { if (_batchSnapshotSlots >= slots) return; long slotBytes = _gdnStateCache.LayerSnapshotBytes * _gdnStateCache.NumGdnLayers; if (slotBytes <= 0) { _batchSnapshotSlots = slots; return; } - if (_batchSnapshotBuf != null) NativeMemory.Free(_batchSnapshotBuf); + if (_batchSnapshotBuf != null) { NativeMemory.Free(_batchSnapshotBuf); _batchSnapshotBuf = null; } + _batchSnapshotSlots = 0; _batchSnapshotBuf = (byte*)NativeMemory.Alloc((nuint)(slotBytes * slots)); _batchSnapshotCap = slotBytes * slots; _batchSnapshotSlots = slots; diff --git a/src/SharpInference.Engine/MtpDecoder.cs b/src/SharpInference.Engine/MtpDecoder.cs index 36b9409..c176340 100644 --- a/src/SharpInference.Engine/MtpDecoder.cs +++ b/src/SharpInference.Engine/MtpDecoder.cs @@ -192,6 +192,14 @@ public void Decode(int maxTokens, ReadOnlySpan stopTokenIds, Action em if (Environment.GetEnvironmentVariable("SHARPI_TRACE_MTP") == "1") Console.Error.WriteLine( $"[mtp] batched-verify {(batched ? $"ON (draftN={draftN}, maxBatch={_fwd.MaxBatchVerifyTokens})" : "OFF (cache compacted / unsupported config / disabled)")}"); + // Surface a silent capability clamp (the pre-#30 code threw for draftN > 1; + // a server operator asking for a deeper chain than the snapshot ring allows + // should see WHY throughput matches a shallower setting). + if (batched && draftN > _fwd.MaxBatchVerifyTokens - 1) + Console.Error.WriteLine( + $"[mtp] requested draft chain {draftN} exceeds the snapshot-ring capacity; " + + $"clamping to {_fwd.MaxBatchVerifyTokens - 1} draft(s)/step " + + "(raise SHARPI_MTP_BATCH_MAX to reserve more ring slots)."); if (batched) { DecodeBatched(maxTokens, stopTokenIds, emitToken, pMin, draftN, ct); @@ -370,12 +378,20 @@ private void DecodeBatched(int maxTokens, ReadOnlySpan stopTokenIds, Action VerifyMs += _phaseSw.Elapsed.TotalMilliseconds; // ── Greedy accept: count leading agreeing drafts ────────── + // An accepted STOP draft also ends the chain here, EXCLUDED from `a`: + // like the sequential path, the stop token is neither emitted nor + // committed — clamping acceptance before the rollback below keeps + // every piece of state (trunk KV, GDN ring restore, MTP KV, hidden + // history, _nextPos) consistent at newPos, where the pre-#208-review + // emit-loop stop check returned with the caches stranded past _nextPos. int a = 0; + bool stopHit = false; for (int i = 1; i < kEff; i++) { int target = ArgMax(batch[i - 1]); - if (AcceptDraft(tokens[i], target, batch[i - 1], pMin, out _)) a++; - else break; + if (!AcceptDraft(tokens[i], target, batch[i - 1], pMin, out _)) break; + if (IsStop(tokens[i], stopTokenIds)) { stopHit = true; break; } + a++; } _totalDraftsAccepted += a; int newPos = P + 1 + a; @@ -405,28 +421,22 @@ private void DecodeBatched(int maxTokens, ReadOnlySpan stopTokenIds, Action _ = _fwd.MtpForward(tokens[i], P + i, HiddenAtChecked(P + i - 1)); DraftMs += _phaseSw.Elapsed.TotalMilliseconds; - // ── Emit accepted drafts ────────────────────────────────── + // ── Emit accepted drafts (stop drafts never reach here — the accept + // loop clamps `a` before the first accepted stop) ──────── for (int i = 1; i <= a; i++) { - if (IsStop(tokens[i], stopTokenIds)) - { - // Don't emit the stop; sync state to "after tokens[i-1]" for a - // consistent follow-up. (The trunk cache may sit past _nextPos — - // same benign overshoot class as the legacy stop-at-t2 path.) - batch[i - 1].CopyTo(_savedMainLogits, 0); - HiddenAtChecked(P + i - 1).CopyTo(_savedHidden); - _nextPos = P + i; - return; - } emitToken(tokens[i]); generated++; } // ── Saved state for the next step ───────────────────────── // batch[a] predicts position newPos: it is the next certain token — - // the correction on a reject, the chain continuation on full accept. + // the correction on a reject, the chain continuation on full accept, + // or the (un-emitted, un-committed) stop on stopHit. All caches sit + // exactly at newPos, so a follow-up call resumes consistently. batch[a].CopyTo(_savedMainLogits, 0); HiddenAtChecked(newPos - 1).CopyTo(_savedHidden); _nextPos = newPos; + if (stopHit) return; } } @@ -495,27 +505,10 @@ private static float SoftmaxProbAt(ReadOnlySpan logits, int idx) return (float)(Math.Exp(logits[idx] - max) / sumExp); } - private static int ArgMax(ReadOnlySpan logits) - { - int best = 0; - float bestVal = logits[0]; - for (int i = 1; i < logits.Length; i++) - { - if (logits[i] > bestVal) { bestVal = logits[i]; best = i; } - } - return best; - } - - private static int ArgMax(float[] logits) - { - int best = 0; - float bestVal = logits[0]; - for (int i = 1; i < logits.Length; i++) - { - if (logits[i] > bestVal) { bestVal = logits[i]; best = i; } - } - return best; - } + // Greedy selection delegates to the engine's single argmax implementation — + // spec-decode parity contracts hinge on draft selection, verification, and + // production sampling all breaking ties identically (first max wins). + private static int ArgMax(ReadOnlySpan logits) => Sampler.Greedy(logits); private static bool IsStop(int token, ReadOnlySpan stopTokenIds) { diff --git a/src/SharpInference.Engine/PromptLookupDraft.cs b/src/SharpInference.Engine/PromptLookupDraft.cs index 4bdb1c6..93b3e4a 100644 --- a/src/SharpInference.Engine/PromptLookupDraft.cs +++ b/src/SharpInference.Engine/PromptLookupDraft.cs @@ -74,7 +74,7 @@ public int[] Propose(int maxTokens) int start = i + n; int count = Math.Min(maxTokens, len - start); var proposal = new int[count]; - for (int t = 0; t < count; t++) proposal[t] = h[start + t]; + h.CopyTo(start, proposal, 0, count); return proposal; } } diff --git a/src/SharpInference.Engine/SpeculativeDecoder.cs b/src/SharpInference.Engine/SpeculativeDecoder.cs index 7525371..da622fe 100644 --- a/src/SharpInference.Engine/SpeculativeDecoder.cs +++ b/src/SharpInference.Engine/SpeculativeDecoder.cs @@ -300,27 +300,10 @@ private float[][] BatchVerifyTarget(int[] draftTokens, int startPos) return result; } - private static int ArgMax(ReadOnlySpan logits) - { - int best = 0; - float bestVal = logits[0]; - for (int i = 1; i < logits.Length; i++) - { - if (logits[i] > bestVal) { bestVal = logits[i]; best = i; } - } - return best; - } - - private static int ArgMax(float[] logits) - { - int best = 0; - float bestVal = logits[0]; - for (int i = 1; i < logits.Length; i++) - { - if (logits[i] > bestVal) { bestVal = logits[i]; best = i; } - } - return best; - } + // Greedy selection delegates to the engine's single argmax implementation — + // spec-decode parity contracts hinge on draft selection, verification, and + // production sampling all breaking ties identically (first max wins). + private static int ArgMax(ReadOnlySpan logits) => Sampler.Greedy(logits); private static bool IsStop(int token, ReadOnlySpan stopTokenIds) { diff --git a/tests/SharpInference.Tests.ForwardPass/MtpDecoderBatchVerifyTests.cs b/tests/SharpInference.Tests.ForwardPass/MtpDecoderBatchVerifyTests.cs index 8bd62f4..989bdf7 100644 --- a/tests/SharpInference.Tests.ForwardPass/MtpDecoderBatchVerifyTests.cs +++ b/tests/SharpInference.Tests.ForwardPass/MtpDecoderBatchVerifyTests.cs @@ -216,12 +216,24 @@ public void MidChainReject_RollsBackToAcceptedBoundary() } [Fact] - public void StopToken_NotEmitted_DecodeEnds() + public void StopToken_NotEmitted_DecodeEnds_StateConsistent() { var pass = new ScriptedMtpPass(prefillLen: 10); - // Chain from 3: 3, 4, 5, 6, ... — stop at 6. + // Chain from 3: 3, 4, 5, 6, ... — stop at 6. Step 1 verifies [10, 14): + // t1=3@10, drafts 4@11, 5@12, 6@13; the accepted stop (6) clamps + // acceptance at a=2 → rollback to 13, emit 4 and 5, end decode. var emitted = Decode(pass, 10, firstToken: 3, maxTokens: 12, draftN: 3, stops: [6]); Assert.Equal(new List { 3, 4, 5 }, emitted); + + // The accepted-stop boundary must leave EVERY cache exactly at the last + // emitted position + 1 (13) — the stop is neither emitted nor committed. + // Pre-#208-review the trunk/MTP caches were stranded at P+1+a past + // _nextPos, poisoning the GDN recurrence for any follow-up use. + Assert.Equal(13, Assert.Single(pass.RestoreCalls)); + Assert.Equal(13, pass.MainLen); + Assert.True(pass.MtpLen <= 13, + $"MTP KV must not hold positions past the stop boundary (len={pass.MtpLen})."); + Assert.Equal(13, pass.HistLen); } [Fact]