From 04b95ca96842b90152052f670ce56df7e7b9df71 Mon Sep 17 00:00:00 2001 From: Pekka Heikura Date: Sun, 7 Jun 2026 16:59:38 +0300 Subject: [PATCH 1/3] perf(cuda): chunked batched prefill for Gemma 4 SWA via a real KV ring (#162) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Generalizes >4096-token chunked batched prefill to Gemma 4 sliding-window (SWA) layers. This required first fixing a pre-existing latent bug: SWA KV caches were allocated window-sized but every CUDA kernel indexed them by absolute position, so any context beyond the 512-token window read/wrote out of bounds (silently wrong output; never caught because all Gemma 4 tests used <=12-token prompts at maxContextLength 512). Fix: a real SWA KV ring. - Cache sized SwaRingSize(ctx, window) = min(ctx, window + 4096). The +chunk headroom is required because batched prefill appends a whole chunk before any of those queries attend; a bare-window ring would clobber the earliest queries' windows. - All KV-indexing kernels wrap at `pos % max_seq_len` (= allocated cache size). Identity for full/dense/global caches (pos < size); wraps for windowed SWA rings. Kernels: llm_kv_append (+bf16), llm_kv_append_batched (+bf16), llm_attention_swa, llm_attention_swa_batched, llm_flash_attn_prefill_tc / tc2 / f32. - Chunked-prefill guard relaxed from "no SWA" to "flash enabled". - EstimateMaxContext SWA per-layer cap updated to window + headroom. Also fixes long-context Gemma 4 correctness on the decode path and in CudaHybridForwardPass (both indexed the window-sized cache absolutely). Validation: - New tests: PastWindow (700-token prompt > window) and chunked-ring Theory (5040/6144, 8192/9216 — ctx > window+headroom so the cache is a true ring; chunked vs per-token overwrite the ring differently, so agreement validates sizing). Gemma4 + Qwen3 prefill/decode suites green; 26 synthetic bit-wise kernel tests (incl. bf16 append) green. - Speedup A/B (Gemma 4 E4B, 5313-token prompt, RTX 4070 Ti): chunked 129.6 t/s vs 47.7 per-token = 2.72x. Long-context greedy decode stays coherent. Closes the SWA sub-item of #162. Co-Authored-By: Claude Opus 4.8 --- src/SharpInference.Cuda/CudaTextKernels.cs | 49 ++++-- src/SharpInference.Engine/CudaForwardPass.cs | 79 ++++++--- .../CudaHybridForwardPass.cs | 10 ++ .../Gemma4CudaBatchedPrefillTests.cs | 150 ++++++++++++++++++ 4 files changed, 251 insertions(+), 37 deletions(-) diff --git a/src/SharpInference.Cuda/CudaTextKernels.cs b/src/SharpInference.Cuda/CudaTextKernels.cs index 4421642..abd884b 100644 --- a/src/SharpInference.Cuda/CudaTextKernels.cs +++ b/src/SharpInference.Cuda/CudaTextKernels.cs @@ -599,7 +599,10 @@ __device__ __forceinline__ unsigned int sharpi_uint_at(const unsigned int* __res { int i = (int)(blockIdx.x * blockDim.x + threadIdx.x); if (i >= kv_dim) return; - long offset = (long)position * (long)kv_dim + (long)i; + // Ring slot: `position % max_seq_len`. `max_seq_len` is the allocated cache size, + // so for a full-context (dense / global) cache `position < max_seq_len` makes this + // the identity; for a window-sized SWA ring it wraps the write into the ring. + long offset = (long)(position % max_seq_len) * (long)kv_dim + (long)i; k_cache[offset] = k_in[i]; v_cache[offset] = v_in[i]; } @@ -618,7 +621,11 @@ __device__ __forceinline__ unsigned int sharpi_uint_at(const unsigned int* __res { int i = (int)(blockIdx.x * blockDim.x + threadIdx.x); if (i >= kv_dim) return; - long offset = (long)position * (long)kv_dim + (long)i; + // Ring slot `position % max_seq_len` (identity for a full-context cache; wraps a + // window-sized ring). Matches the f32 llm_kv_append so the write/read indexing stays + // uniform if a windowed model ever uses the bf16 KV cache (today only the full-context + // GDN-hybrid path does, where position < max_seq_len makes this the identity). + long offset = (long)(position % max_seq_len) * (long)kv_dim + (long)i; k_cache[offset] = (unsigned short)sharpi_fp32_to_bf16(k_in[i]); v_cache[offset] = (unsigned short)sharpi_fp32_to_bf16(v_in[i]); } @@ -4066,7 +4073,9 @@ __device__ __forceinline__ void sharpi_q4k_scale_min( for (int t = (int)tid; t < eff_seq; t += 256) { int abs_t = t + window_start; float dot = 0.f; - long k_off = (long)abs_t * (long)kv_dim + (long)kv_head * (long)head_dim; + // Ring slot `abs_t % max_seq_len` (max_seq_len = allocated cache size): identity + // for a full cache, wraps a window-sized SWA ring. abs_t itself stays logical. + long k_off = (long)(abs_t % max_seq_len) * (long)kv_dim + (long)kv_head * (long)head_dim; for (int d = 0; d < head_dim; d++) dot += q[q_off + d] * k_cache[k_off + d]; float score = dot * scale; @@ -4127,7 +4136,7 @@ __device__ __forceinline__ void sharpi_q4k_scale_min( for (int t = 0; t < eff_seq; t++) { int abs_t = t + window_start; float weight = use_shared ? shared_scores[t] : head_scratch[t]; - long v_off = (long)abs_t * (long)kv_dim + (long)kv_head * (long)head_dim; + long v_off = (long)(abs_t % max_seq_len) * (long)kv_dim + (long)kv_head * (long)head_dim; acc += weight * v_cache[v_off + d]; } out[out_off + d] = acc; @@ -4246,8 +4255,10 @@ __device__ __forceinline__ float fatc_mask(float s, int qpos, int abs_k, int win for (int idx = lane; idx < FATC_KT * head_dim; idx += 32) { int kk = idx / head_dim, d = idx - kk * head_dim; int abs_k = kt0 + kk; + // Cache read at ring slot `abs_k % max_seq_len` (identity for a full cache, + // wraps a window-sized SWA ring); abs_k stays logical for the causal bound. float kv = (abs_k < key_end) - ? k_cache[(long)abs_k * kv_dim + (long)kv_head * head_dim + d] : 0.f; + ? k_cache[(long)(abs_k % max_seq_len) * kv_dim + (long)kv_head * head_dim + d] : 0.f; sKV[idx] = (unsigned short)sharpi_fp32_to_fp16(kv); } __syncthreads(); @@ -4345,7 +4356,7 @@ asm volatile( int kk = idx / head_dim, d = idx - kk * head_dim; int abs_k = kt0 + kk; float vv = (abs_k < key_end) - ? v_cache[(long)abs_k * kv_dim + (long)kv_head * head_dim + d] : 0.f; + ? v_cache[(long)(abs_k % max_seq_len) * kv_dim + (long)kv_head * head_dim + d] : 0.f; sKV[idx] = (unsigned short)sharpi_fp32_to_fp16(vv); } __syncthreads(); @@ -4454,8 +4465,10 @@ asm volatile( for (int idx = tid; idx < FATC2_KT * head_dim; idx += FATC2_W * 32) { int kk = idx / head_dim, d = idx - kk * head_dim; int abs_k = kt0 + kk; + // Cache read at ring slot `abs_k % max_seq_len` (identity for a full cache, + // wraps a window-sized SWA ring); abs_k stays logical for the causal bound. float kv = (abs_k < key_end) - ? k_cache[(long)abs_k * kv_dim + (long)kv_head * head_dim + d] : 0.f; + ? k_cache[(long)(abs_k % max_seq_len) * kv_dim + (long)kv_head * head_dim + d] : 0.f; sKV[idx] = (unsigned short)sharpi_fp32_to_fp16(kv); } __syncthreads(); @@ -4566,7 +4579,7 @@ asm volatile( int kk = idx / head_dim, d = idx - kk * head_dim; int abs_k = kt0 + kk; float vv = (abs_k < key_end) - ? v_cache[(long)abs_k * kv_dim + (long)kv_head * head_dim + d] : 0.f; + ? v_cache[(long)(abs_k % max_seq_len) * kv_dim + (long)kv_head * head_dim + d] : 0.f; sKV[idx] = (unsigned short)sharpi_fp32_to_fp16(vv); } __syncthreads(); @@ -4697,7 +4710,9 @@ asm volatile( int kk = idx / hd2, pr = idx - kk * hd2; unsigned int kh = 0u; if (kk < tile_keys) { - long off = (long)(kt0 + kk) * kv_dim + (long)kv_head * head_dim + 2 * pr; + // Ring slot `(kt0+kk) % max_seq_len`: identity for a full cache, wraps a + // window-sized SWA ring. The kt0+kk index stays logical for tile bounds. + long off = (long)((kt0 + kk) % max_seq_len) * kv_dim + (long)kv_head * head_dim + 2 * pr; kh = sharpi_f32x2_to_f16x2(k_cache[off], k_cache[off + 1]); } sKh[idx] = kh; @@ -4706,7 +4721,7 @@ asm volatile( for (int idx = tid; idx < kt_tile * head_dim; idx += (int)blockDim.x) { int kk = idx / head_dim, d = idx - kk * head_dim; sV[idx] = (kk < tile_keys) - ? v_cache[(long)(kt0 + kk) * kv_dim + (long)kv_head * head_dim + d] + ? v_cache[(long)((kt0 + kk) % max_seq_len) * kv_dim + (long)kv_head * head_dim + d] : 0.f; } __syncthreads(); @@ -4799,7 +4814,9 @@ asm volatile( for (int t = (int)tid; t < eff_seq; t += 256) { int abs_t = t + window_start; float dot = 0.f; - long k_off = (long)abs_t * (long)kv_dim + (long)kv_head * (long)head_dim; + // Ring slot `abs_t % max_seq_len` (max_seq_len = allocated cache size): identity + // for a full cache, wraps a window-sized SWA ring. abs_t itself stays logical. + long k_off = (long)(abs_t % max_seq_len) * (long)kv_dim + (long)kv_head * (long)head_dim; for (int dd = 0; dd < head_dim; dd++) dot += q[q_off + dd] * k_cache[k_off + dd]; shared_scores[t] = dot * scale; @@ -4841,7 +4858,7 @@ asm volatile( float acc = 0.f; for (int t = 0; t < eff_seq; t++) { int abs_t = t + window_start; - long v_off = (long)abs_t * (long)kv_dim + (long)kv_head * (long)head_dim; + long v_off = (long)(abs_t % max_seq_len) * (long)kv_dim + (long)kv_head * (long)head_dim; acc += shared_scores[t] * v_cache[v_off + dd]; } out[out_off + dd] = acc; @@ -5616,7 +5633,9 @@ asm volatile( int e = (int)(blockIdx.x * blockDim.x + threadIdx.x); int i = (int)blockIdx.y; if (e >= kv_dim || i >= n_tok) return; - long off = (long)(start_pos + i) * (long)kv_dim + (long)e; + // Ring slot `(start_pos+i) % max_seq_len`: identity for a full-context cache + // (position < max_seq_len), wraps into a window-sized SWA ring otherwise. + long off = (long)((start_pos + i) % max_seq_len) * (long)kv_dim + (long)e; k_cache[off] = k_all[(long)i * kv_dim + e]; v_cache[off] = v_all[(long)i * kv_dim + e]; } @@ -5630,7 +5649,9 @@ asm volatile( int e = (int)(blockIdx.x * blockDim.x + threadIdx.x); int i = (int)blockIdx.y; if (e >= kv_dim || i >= n_tok) return; - long off = (long)(start_pos + i) * (long)kv_dim + (long)e; + // Ring slot `(start_pos+i) % max_seq_len` (identity for a full-context cache; wraps a + // window-sized ring) — kept in lockstep with the f32 llm_kv_append_batched. + long off = (long)((start_pos + i) % max_seq_len) * (long)kv_dim + (long)e; k_cache[off] = (unsigned short)sharpi_fp32_to_bf16(k_all[(long)i * kv_dim + e]); v_cache[off] = (unsigned short)sharpi_fp32_to_bf16(v_all[(long)i * kv_dim + e]); } diff --git a/src/SharpInference.Engine/CudaForwardPass.cs b/src/SharpInference.Engine/CudaForwardPass.cs index 977e97d..8bbfd61 100644 --- a/src/SharpInference.Engine/CudaForwardPass.cs +++ b/src/SharpInference.Engine/CudaForwardPass.cs @@ -183,6 +183,30 @@ public sealed unsafe class CudaForwardPass : IForwardPass /// single-shot batch size. /// private const int PrefillBatchChunk = 4096; + + /// + /// Headroom (in positions) added to a Gemma-4 SWA layer's window when sizing its KV + /// ring (issue #162). A batched prefill appends a whole batch of K/V before any of + /// those queries attend, so the ring must hold the window PLUS one batched-append span + /// or the earliest queries' window would be overwritten by the latest appends. The + /// widest single append span is the larger of the chunked-prefill window + /// () and the 4096 non-flash batched-attention cap, so a + /// ring of window + SwaRingHeadroom is always large enough. Capped at the model + /// context () — a full-context cache needs no ring at all. + /// + private const int SwaRingHeadroom = PrefillBatchChunk > 4096 ? PrefillBatchChunk : 4096; + + /// + /// Allocated KV-cache size, in positions, for a Gemma-4 sliding-window layer: the + /// window plus , capped at the full context. When this + /// equals the context the cache is full (the ring modulo in the kernels degenerates to + /// the identity); when the context exceeds it the kernels wrap writes/reads modulo this + /// size. The value passed as each SWA append/attention call's maxSeqLen argument + /// MUST equal this so the kernel's pos % maxSeqLen lands in the right ring slot. + /// + private static int SwaRingSize(int maxSeqLen, int window) => + (int)Math.Min(maxSeqLen, (long)window + SwaRingHeadroom); + /// /// Issue #141: route Q8_0 trunk matmuls in the batched prefill through the /// compute-bound cuBLAS GEMM () @@ -548,8 +572,12 @@ void TraceVram(string label) int layerHd = perLayerKv ? hp.LayerHeadDim![i] : _headDim; int layerKvDim = _numKvHeads * layerHd; + // SWA layers use a window-sized ring (window + headroom for one batched + // append span, issue #162); everything else is full-context. The same + // SwaRingSize value is passed to the kernels as maxSeqLen so their + // pos % maxSeqLen wraps into this exact ring. int layerCtx = (perLayerKv && hp.IsSwaLayer is { } swa && swa[i]) - ? Math.Min(_maxSeqLen, swaWindow) + ? SwaRingSize(_maxSeqLen, swaWindow) : _maxSeqLen; _gpuKCache[i] = gpu.Allocate(TensorShape.D1((long)layerCtx * layerKvDim)); _gpuVCache[i] = gpu.Allocate(TensorShape.D1((long)layerCtx * layerKvDim)); @@ -1338,7 +1366,7 @@ private void RunGemma4DeviceRegion(int position) if (!kvShared) { int layerCtx = isSwa && _hp.SlidingWindowSize > 0 - ? Math.Min(_maxSeqLen, _hp.SlidingWindowSize) + ? SwaRingSize(_maxSeqLen, _hp.SlidingWindowSize) : _maxSeqLen; _gpu.KvAppend(kView, vView, _gpuKCache[layer], _gpuVCache[layer], kvDimL, position, layerCtx); @@ -1346,7 +1374,7 @@ private void RunGemma4DeviceRegion(int position) int effLayerCtx = (_hp.IsSwaLayer is { } swaEff && swaEff[effLayer] && _hp.SlidingWindowSize > 0) - ? Math.Min(_maxSeqLen, _hp.SlidingWindowSize) + ? SwaRingSize(_maxSeqLen, _hp.SlidingWindowSize) : _maxSeqLen; // Gemma 4 uses attention_scale = 1.0 (no 1/sqrt(head_dim) prefactor). Pass @@ -1731,13 +1759,13 @@ private ReadOnlySpan ForwardProfiledGemma4(int token, int position) if (!kvShared) { int layerCtx = isSwa && _hp.SlidingWindowSize > 0 - ? Math.Min(_maxSeqLen, _hp.SlidingWindowSize) + ? SwaRingSize(_maxSeqLen, _hp.SlidingWindowSize) : _maxSeqLen; _gpu.KvAppend(kView, vView, _gpuKCache[layer], _gpuVCache[layer], kvDimL, position, layerCtx); } int effLayerCtx = (_hp.IsSwaLayer is { } swaEff && swaEff[effLayer] && _hp.SlidingWindowSize > 0) - ? Math.Min(_maxSeqLen, _hp.SlidingWindowSize) + ? SwaRingSize(_maxSeqLen, _hp.SlidingWindowSize) : _maxSeqLen; if (isSwa) _gpu.AttentionSwa(qView, _gpuKCache[effLayer], _gpuVCache[effLayer], attnOutView, @@ -1851,15 +1879,14 @@ public ReadOnlySpan Prefill(IReadOnlyList tokens, int startPos = 0) if (BatchedPrefillEnabled && !snapKvActive && N >= 2 && startPos + N <= _maxSeqLen && IsBatchedPrefillSupported()) { - // Chunking past 4096 requires a streaming attention path (flash) AND simple - // (non-windowed) KV position semantics. SWA layers wrap KV in a window-sized - // ring whose cross-chunk behaviour past the window boundary is not yet - // validated, so any windowed model keeps the proven 4096 cap (follow-up: #162). - // Note: `IsSwaLayer` is only populated for Gemma 4 today, so the explicit - // `SlidingWindowSize <= 0` check fails closed if SWA parsing is later extended - // to a uniform-window arch that sets the window without a per-layer pattern. - bool canChunkPast4096 = PrefillFlashAttnEnabled - && _hp.IsSwaLayer is null && _hp.SlidingWindowSize <= 0; + // Chunking past 4096 requires a streaming attention path (flash) for the + // global (full-causal) layers — the non-flash AttentionBatched caps at 4096. + // SWA layers are now correct across chunk boundaries (issue #162): their KV + // ring is sized window + SwaRingHeadroom (≥ one chunk span), so appending a + // whole chunk before attending never overwrites a still-needed window, and the + // flash/append kernels wrap reads/writes modulo the ring (SwaRingSize). So flash + // alone gates chunking; windowed (Gemma 4) and dense models both qualify. + bool canChunkPast4096 = PrefillFlashAttnEnabled; int cap = canChunkPast4096 ? _maxSeqLen : 4096; if (startPos + N <= cap) { @@ -2254,12 +2281,12 @@ void ApplyRopeBatched() if (!kvShared) { - int layerCtx = isSwa && window > 0 ? Math.Min(_maxSeqLen, window) : _maxSeqLen; + int layerCtx = isSwa && window > 0 ? SwaRingSize(_maxSeqLen, window) : _maxSeqLen; _gpu.KvAppendBatched(kAll, vAll, _gpuKCache[layer], _gpuVCache[layer], kvDimL, startPos, layerCtx, N); } int effLayerCtx = (_hp.IsSwaLayer is { } swaEff && swaEff[effLayer] && window > 0) - ? Math.Min(_maxSeqLen, window) : _maxSeqLen; + ? SwaRingSize(_maxSeqLen, window) : _maxSeqLen; if (s_prefillProfile) { _gpu.Synchronize(); _profSw.Restart(); } // Gemma 4: attention_scale = 1.0, passed explicitly (kernel skips its rsqrtf). @@ -2815,6 +2842,10 @@ public static int EstimateMaxContext(GgufModel model, CudaBackend gpu, ModelHype if (hp.LayerHeadDim is { } lhd && hp.IsSwaLayer is { } swa) { int swaWindow = hp.SlidingWindowSize > 0 ? hp.SlidingWindowSize : int.MaxValue; + // SWA layers are sized as a ring of window + SwaRingHeadroom positions (issue + // #162), so the cap for the per-token byte formula is the ring size, not the + // bare window. Guard against overflow when swaWindow is "unbounded". + long swaCap = swaWindow == int.MaxValue ? long.MaxValue : (long)swaWindow + SwaRingHeadroom; long globalKvDimPerToken = 0; long swaKvDimPerToken = 0; for (int i = 0; i < hp.NumLayers; i++) @@ -2827,24 +2858,26 @@ public static int EstimateMaxContext(GgufModel model, CudaBackend gpu, ModelHype else globalKvDimPerToken += layerKvDim; } // For a given maxCtx C: bytes = globalKvDimPerToken * C - // + swaKvDimPerToken * min(C, swaWindow) + // + swaKvDimPerToken * min(C, swaCap) // Solve for the largest C ≤ hp.ContextLength that fits in `available`. - // Branch on whether C ≤ swaWindow: - // if C ≤ swaWindow: bytes = (global+swa) * C - // else: bytes = global * C + swa * swaWindow + // Branch on whether C ≤ swaCap: + // if C ≤ swaCap: bytes = (global+swa) * C + // else: bytes = global * C + swa * swaCap long globalPlusSwa = globalKvDimPerToken + swaKvDimPerToken; int candA = globalPlusSwa > 0 ? (int)(available / globalPlusSwa) : int.MaxValue; int maxCtxL; - if (candA <= swaWindow) + if (candA <= swaCap) { maxCtxL = candA; } else { - long remain = available - swaKvDimPerToken * swaWindow; + // swaCap is finite here (candA ≤ long.MaxValue always takes the branch + // above when swaCap is unbounded), so swaKvDimPerToken * swaCap is safe. + long remain = available - swaKvDimPerToken * swaCap; int candB = globalKvDimPerToken > 0 && remain > 0 ? (int)(remain / globalKvDimPerToken) : 0; - maxCtxL = Math.Max(swaWindow, candB); + maxCtxL = (int)Math.Max(swaCap, candB); } return Math.Clamp(maxCtxL, 512, hp.ContextLength); } diff --git a/src/SharpInference.Engine/CudaHybridForwardPass.cs b/src/SharpInference.Engine/CudaHybridForwardPass.cs index a624f46..524b178 100644 --- a/src/SharpInference.Engine/CudaHybridForwardPass.cs +++ b/src/SharpInference.Engine/CudaHybridForwardPass.cs @@ -446,6 +446,16 @@ void TraceVram(string label) // Gemma 4: each GPU layer sizes its KV cache by per-layer head_dim and // (for SWA layers) caps at SlidingWindowSize. Non-gemma4 stays at the // model-wide head_dim × full context. + // + // This path is DECODE-ONLY for Gemma 4 (IsBatchedPrefillSupported requires + // !_isGemma4Like), so the SWA ring only needs to hold `window` positions: + // each token appends one K/V then attends its window before the next append, + // so the bare-window ring is never overwritten early. The shared kernels' + // `pos % max_seq_len` (max_seq_len == this layerCtx) wraps writes/reads into + // it — which also fixes the pre-#162 latent OOB here (positions ≥ window used + // to index a window-sized cache absolutely). It deliberately does NOT use + // CudaForwardPass.SwaRingSize (the window+chunk headroom) because that's only + // needed by the batched-append (chunked-prefill) path, which never runs here. int layerHd = _isGemma4Like ? hp.LayerHeadDim![i] : _headDim; int layerKvDim = _numKvHeads * layerHd; int layerCtx = (_isGemma4Like && hp.IsSwaLayer is { } swa && swa[i] && hp.SlidingWindowSize > 0) diff --git a/tests/SharpInference.Tests.ForwardPass/Gemma4CudaBatchedPrefillTests.cs b/tests/SharpInference.Tests.ForwardPass/Gemma4CudaBatchedPrefillTests.cs index d9259fb..f0bae3d 100644 --- a/tests/SharpInference.Tests.ForwardPass/Gemma4CudaBatchedPrefillTests.cs +++ b/tests/SharpInference.Tests.ForwardPass/Gemma4CudaBatchedPrefillTests.cs @@ -325,6 +325,156 @@ public void Gemma4_E4B_BatchedPrefill_GemmOff_MatchesSequentialBitExact() $"(expected >95%); maxAbs={maxAbs}."); } + /// + /// Issue #162 (SWA sub-item): a Gemma-4 prompt LONGER than the 512-token sliding window + /// but still under the 4096 single-batch cap must produce correct output. This is the + /// scenario the old window-sized-cache-with-absolute-indexing got wrong (positions ≥ + /// window read/wrote out of bounds). With the SWA ring (cache sized window + + /// SwaRingHeadroom, capped at ctx) the per-token loop is itself correct again, so the + /// batched path is checked against it. ctx=1024 makes the SWA cache full (1024 < + /// window+headroom), so this isolates the windowed-attention-past-the-window fix from + /// the ring-wrap fix (covered by the chunked test below). + /// + [Fact] + public void Gemma4_E4B_BatchedPrefill_PastWindow_MatchesSequential() + { + using var gpu = TryCreate(); + if (gpu is null) return; + var path = FindModelPath(); + if (path is null) return; + + using var model = GgufModel.Open(path); + var hp = ModelHyperparams.FromGgufMetadata(model.Metadata, model); + Assert.NotNull(hp.LayerHeadDim); + Assert.True(hp.SlidingWindowSize > 0 && hp.SlidingWindowSize < 700, + $"Test assumes a sliding window < the 700-token prompt; got {hp.SlidingWindowSize}."); + + // 700 > window (512): the trailing queries' windows exclude the earliest tokens, + // exactly the regime that the absolute-into-window-sized-cache bug corrupted. + var tokens = MakeTokens(700); + + using var fwd = new CudaForwardPass(model, gpu, hp, maxContextLength: 1024); + + fwd.BatchedPrefillEnabled = true; // shipped defaults (flash TC on) + var batched = fwd.Prefill(tokens).ToArray(); + Assert.True(fwd.LastPrefillWasBatched); + + fwd.ResetCache(); + fwd.BatchedPrefillEnabled = false; + var sequential = fwd.Prefill(tokens).ToArray(); + Assert.False(fwd.LastPrefillWasBatched); + + Assert.Equal(sequential.Length, batched.Length); + Assert.Equal(Argmax(sequential), Argmax(batched)); + + // Logit-envelope check on top of argmax: the flash TC path rounds to fp16 over 42 + // layers + softcap, so a few tenths of a logit is expected; a whole-number gap would + // signal a wiring bug (e.g. a ring slot reading the wrong position) that argmax alone + // could miss when it doesn't quite flip the top token. + float maxAbs = 0f; + for (int i = 0; i < sequential.Length; i++) + maxAbs = MathF.Max(maxAbs, MathF.Abs(sequential[i] - batched[i])); + Assert.True(maxAbs < 1.5f, + $"Past-window batched vs per-token logits diverged beyond fp16 tolerance: maxAbs={maxAbs}."); + + var seqTop = TopKSet(sequential, 5); + var batTop = TopKSet(batched, 5); + int overlap = 0; + foreach (var t in batTop) if (seqTop.Contains(t)) overlap++; + Assert.True(overlap >= 4, + $"Past-window batched top-5 overlaps the per-token reference in only {overlap}/5 slots."); + } + + /// + /// Issue #162 (SWA sub-item): a Gemma-4 prompt longer than the 4096 cap must take the + /// chunked batched path (flash streaming the prior KV) and stay argmax-stable vs the + /// per-token loop, with both running through the SWA KV ring. + /// + /// This is the decisive ring oracle: the per-token loop attends right after each single + /// append (so it only ever needs ring ≥ window) while the chunked path appends a whole + /// 4096-token chunk before any of those queries attend (so it needs ring ≥ window + + /// chunk span). If the ring were undersized, the chunked path would overwrite an + /// earlier query's window and diverge from the per-token reference — so agreement + /// validates the window + SwaRingHeadroom sizing. ctx exceeds window + headroom, so the + /// SWA cache is a true ring (positions ≥ ring size wrap) rather than a full cache. + /// + /// + /// N=5040 exercises a full + partial chunk (4096 + 944); N=8192 the exact-multiple + /// boundary (two full 4096 chunks). + /// + /// + [Theory] + [InlineData(5040, 6144)] + [InlineData(8192, 9216)] + public void Gemma4_E4B_ChunkedBatchedPrefill_Over4096_MatchesSequential(int promptLen, int ctx) + { + using var gpu = TryCreate(); + if (gpu is null) return; + var path = FindModelPath(); + if (path is null) return; + + using var model = GgufModel.Open(path); + var hp = ModelHyperparams.FromGgufMetadata(model.Metadata, model); + Assert.NotNull(hp.LayerHeadDim); + Assert.True(hp.SlidingWindowSize > 0); + + var longTokens = MakeTokens(promptLen); + + // Disable SnapKV (a >budget prompt would otherwise route to the per-token SnapKV + // eviction path before the batched gate). Construct under the override, then restore. + var prevSnap = Environment.GetEnvironmentVariable("SHARPI_SNAPKV_BUDGET"); + Environment.SetEnvironmentVariable("SHARPI_SNAPKV_BUDGET", "0"); + CudaForwardPass fwd; + try { fwd = new CudaForwardPass(model, gpu, hp, maxContextLength: ctx); } + finally { Environment.SetEnvironmentVariable("SHARPI_SNAPKV_BUDGET", prevSnap); } + using var _fwd = fwd; + + fwd.BatchedPrefillEnabled = true; // shipped defaults (flash TC on) → chunked path + var batched = fwd.Prefill(longTokens).ToArray(); + Assert.True(fwd.LastPrefillWasBatched, + "Chunked batched prefill did not engage for a >4096-token Gemma 4 prompt (#162)."); + + fwd.ResetCache(); + fwd.BatchedPrefillEnabled = false; + var sequential = fwd.Prefill(longTokens).ToArray(); + Assert.False(fwd.LastPrefillWasBatched); + + Assert.Equal(sequential.Length, batched.Length); + Assert.Equal(Argmax(sequential), Argmax(batched)); + + // Envelope check: the chunked flash path reassociates the softmax over thousands of + // streamed keys, so it's looser than the ≤4096 case, but a several-unit gap would + // still flag a ring-overwrite bug (an early query reading a clobbered window slot) + // that argmax-stability alone might not surface. + float maxAbs = 0f; + for (int i = 0; i < sequential.Length; i++) + maxAbs = MathF.Max(maxAbs, MathF.Abs(sequential[i] - batched[i])); + Assert.True(maxAbs < 2.0f, + $"Chunked batched vs per-token logits diverged beyond fp tolerance: maxAbs={maxAbs}."); + + var seqTop = TopKSet(sequential, 5); + var batTop = TopKSet(batched, 5); + int overlap = 0; + foreach (var t in batTop) if (seqTop.Contains(t)) overlap++; + Assert.True(overlap >= 4, + $"Chunked batched top-5 overlaps the per-token reference in only {overlap}/5 slots."); + } + + // Deterministic spread across the vocab via a small LCG; all ids well within Gemma 4's + // vocab. Token 2 (BOS) leads so the prompt starts in-distribution. + private static int[] MakeTokens(int n) + { + var t = new int[n]; + t[0] = 2; + uint s = 0x9E3779B9u; + for (int i = 1; i < n; i++) + { + s = s * 1664525u + 1013904223u; + t[i] = (int)(s % 100000u) + 5; + } + return t; + } + private static HashSet TopKSet(ReadOnlySpan logits, int k) { var idx = new int[logits.Length]; From 87813a4a9dae0c94d746f826847473a45da0febe Mon Sep 17 00:00:00 2001 From: Pekka Heikura Date: Sun, 7 Jun 2026 17:11:35 +0300 Subject: [PATCH 2/3] docs: trim Gemma 4 CUDA prefill note + add SWA ring / >4096 chunked (#162) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The Gemma 4 `-g -1` row's Notes cell had grown into a ~400-word historical dump of #141/#146/#147/#149/#142. Condense to the essentials (the prefill optimizations, env toggles, decode path, remaining llama.cpp gap) and fold in this PR's SWA KV ring: >4096-token prompts now take the chunked batched-flash path (2.72×, 47.7 → 129.6 t/s on a 5.3K prompt) instead of the per-token fallback, and long-context correctness past the 512 window is fixed. The ~1K-ctx prefill/decode columns are unchanged — re-benchmarked the affected CUDA rows (qwen3 439 vs 432, gemma4 3594 vs 3698, gemma4-hybrid 6.5 vs 6.6 t/s, all within run-to-run noise); the chunked path only triggers >4096 tokens, so the change is modulo-identity for the ~1K column by construction. Co-Authored-By: Claude Opus 4.8 --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index b5b37d2..3409c7e 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,7 @@ coherent (`scripts/bench-all.ps1`); top-1 parity vs llama.cpp b8585 verified on | Qwen3.6-35B-A3B-MTP (GDN+MoE) | (same) | 22 GB | **CUDA** `-g -1 --no-thinking` (hybrid) | **65.0** | **22.9** | Requires `SHARPI_CPU_MOE=1`: 30 GDN + 10 attn + shared expert on GPU, routed experts CPU mmap. 100% acceptance. Fused GDN scan + batched SDPA (#114-B/#118), bit-identical, grows with ctx | | Carnice (Qwen3.6-35B-A3B-MTP finetune) | [mudler](https://huggingface.co/mudler/Carnice-Qwen3.6-MoE-35B-A3B-APEX-MTP-GGUF) | 17 GB | **CUDA** `-g -1 --no-thinking` (hybrid) | **43.6** | **25.0** | agentic finetune of 35B-A3B-MTP; 77% acceptance (`bench-carnice.ps1` — the default prompt 1-token-EOSes on this terser tune). APEX mixed-precision (Q3_K + Q8_0 experts); Q8_KS per-32 int dots auto-enable at load (#99/#101/#107), +4.6% decode at ~4× tighter parity vs plain Q8_K (`SHARPI_Q3K_Q8K=0`/`SHARPI_Q8_0_Q8K=0` to disable). Fused GDN scan + wave SDPA (#114-B/#118) bit-identical past 4096 | | Gemma 4 E4B-it Q8 | [unsloth](https://huggingface.co/unsloth/gemma-4-E4B-it-GGUF) | 8 GB | CPU | 4.9 | 5.0 | dense 42-layer gemma4: per-layer head_dim (256 SWA / 512 global), dual-RoPE, KV-share tail (18 layers), 5:1 SWA:global, logit softcap 30, PLE-256 injection (~4.2 GB mmap-resident) | -| Gemma 4 E4B-it Q8 | (same) | 8 GB | **CUDA** `-g -1 -c 2048` | **3698** | **59** | all 42 layers fit at `-c 2048`. KV-share alias + SWA/global split per layer; PLE projections (~215 MB) upload at construction. **Prefill (#141):** int8 **tensor-core MMQ** matmul (`mma.m16n8k32.s8`, each Q8_0 weight read once as int8 — beats the dequant→fp16→cuBLAS GEMM, drops its fp16 HBM temp; `SHARPI_PREFILL_MMQ=0` reverts) + a **tensor-core flash-attention** prefill (#146/#147): both QK^T and P·V on the mma cores (`mma.m16n8k16.f16`), multi-warp **d-split** so the O tile stays register-resident — replaces the scalar O(n²) per-query attention (which re-streamed each query's K/V window up to ~512×) and is **+27% at ~1K / +40% at 1.8K** over the earlier half2 flash kernel (`SHARPI_PREFILL_FLASH_TC=0` reverts to half2, `=…_FLASH=0` to scalar) + a **SoA Q8_0 weight repack** (#149): all Q8_0 readers (MMQ, dp4a, fp32 matvec, GEMM-N, dequant) read the quants 16-byte-aligned with the fp16 scales split out, killing the `qs` 2-byte-misalignment funnelshift tax — **+10-12% prefill, bit-identical**; `SHARPI_MMQ_SOA=0` reverts) + a batched Q8_0 embedding lookup. **109→3698 at ~1K ctx, →4240 at 1.8K** — profiling showed *attention*, then the matmul inner-loop efficiency, were the dominant prefill costs at realistic prompt lengths. **Decode (#142):** dp4a/Q8_1 int8 matvec (`SHARPI_Q80_DP4A=0` to bisect) + CUDA-graph capture/replay default-on (`SHARPI_CUDA_GRAPH=0` to bisect). All prefill/decode fast paths are argmax-stable vs the fp32 path, not bit-exact (the SoA repack is bit-identical). Remaining gap to llama.cpp (~8475 prefill / ~78 decode): cp.async-pipelined MMQ on the SoA layout + decode matvec work | +| Gemma 4 E4B-it Q8 | (same) | 8 GB | **CUDA** `-g -1 -c 2048` | **3698** | **59** | all 42 layers fit at `-c 2048`; KV-share alias + per-layer SWA/global split; PLE projections (~215 MB) upload at construction. **Prefill (#141/#146/#147/#149):** int8 tensor-core MMQ (`mma.m16n8k32.s8`, weight read once as int8) + tensor-core flash attention (QK^T and P·V on mma cores, register-resident O via d-split) + a SoA Q8_0 weight repack (16-byte-aligned quants) + batched embedding lookup — **109 → 3698 t/s @1K, 4240 @1.8K**. **>4096 prompts (#162/#164):** a real SWA KV ring (cache sized window + one chunk span, indexed `pos % size`) lets long prompts take the chunked batched-flash path instead of the ~8× slower per-token fallback (**2.72× — 47.7 → 129.6 t/s on a 5.3K-token prompt**), and fixes long-context correctness past the 512 window (the cache was previously window-sized but indexed absolutely → out of bounds). **Decode (#142):** dp4a/Q8_1 int8 matvec + CUDA-graph replay. Fast paths argmax-stable vs fp32 (SoA repack + ring bit-identical below the window). Bisect env: `SHARPI_PREFILL_MMQ` / `_FLASH_TC` / `SHARPI_MMQ_SOA` / `SHARPI_Q80_DP4A` / `SHARPI_CUDA_GRAPH` / `SHARPI_BATCHED_PREFILL` (`=0`). Remaining gap to llama.cpp (~8475 prefill / ~78 decode): cp.async-pipelined MMQ + decode matvec | | Gemma 4 E4B-it Q8 | (same) | 8 GB | **CUDA** `-g 22 -c 2048` (hybrid) | 6.6 | 6.8 | 22 GPU + 20 CPU layers. `-g ≤ 22` required so the CPU shared-KV tail can read its own-KV source layers; CPU dense-FFN dominates decode (bandwidth-bound). `SHARPI_CUDA_PROFILE=1` for per-phase breakdown | _Numbers re-measured across every on-disk row at ~1K ctx so the prefill column is comparable; per-issue From 44752c3e6d9d9a76f172ce1d2d88db3c6780ab03 Mon Sep 17 00:00:00 2001 From: Pekka Heikura Date: Sun, 7 Jun 2026 17:52:38 +0300 Subject: [PATCH 3/3] review: fail-closed chunk gate, modulo-by-zero guard, cross-chunk ring test (#162) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Second review-toolkit cycle findings: - Restore the fail-closed clause on the chunked-prefill gate. canChunkPast4096 is again gated on a real per-layer IsSwaLayer pattern (or no window), so a future arch that sets a model-wide SlidingWindowSize without a per-layer SWA pattern can't silently run full-causal attention past 4096 (window dropped). - Guard _maxSeqLen < 1 at construction. The KV kernels now index `pos % max_seq_len`, so a malformed zero-context GGUF reached via an explicit ctx-size would GPU-trap on divide-by-zero; fail loud instead. - Strengthen the chunked-prefill test. Prefill returns only the final token's logits; with whole-chunk prompts (8192) the final window sits entirely inside the last chunk, so a cross-chunk ring overwrite would be invisible to the assertion. New lengths 4296 (=4096+200) and 8292 (=2*4096+100) make the last chunk shorter than the 512 window, so the final token's window reaches back across the chunk boundary — the observed logit now genuinely depends on cross-chunk (and, for 8292, ring-wrapped) KV reads. 8/8 green. README: clarify the >4096 2.72x number is measured at a context that admits the 5.3K prompt, not the row's -c 2048. Co-Authored-By: Claude Opus 4.8 --- README.md | 2 +- src/SharpInference.Engine/CudaForwardPass.cs | 22 ++++++++++++++++--- .../Gemma4CudaBatchedPrefillTests.cs | 16 +++++++++----- 3 files changed, 30 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 3409c7e..d212791 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,7 @@ coherent (`scripts/bench-all.ps1`); top-1 parity vs llama.cpp b8585 verified on | Qwen3.6-35B-A3B-MTP (GDN+MoE) | (same) | 22 GB | **CUDA** `-g -1 --no-thinking` (hybrid) | **65.0** | **22.9** | Requires `SHARPI_CPU_MOE=1`: 30 GDN + 10 attn + shared expert on GPU, routed experts CPU mmap. 100% acceptance. Fused GDN scan + batched SDPA (#114-B/#118), bit-identical, grows with ctx | | Carnice (Qwen3.6-35B-A3B-MTP finetune) | [mudler](https://huggingface.co/mudler/Carnice-Qwen3.6-MoE-35B-A3B-APEX-MTP-GGUF) | 17 GB | **CUDA** `-g -1 --no-thinking` (hybrid) | **43.6** | **25.0** | agentic finetune of 35B-A3B-MTP; 77% acceptance (`bench-carnice.ps1` — the default prompt 1-token-EOSes on this terser tune). APEX mixed-precision (Q3_K + Q8_0 experts); Q8_KS per-32 int dots auto-enable at load (#99/#101/#107), +4.6% decode at ~4× tighter parity vs plain Q8_K (`SHARPI_Q3K_Q8K=0`/`SHARPI_Q8_0_Q8K=0` to disable). Fused GDN scan + wave SDPA (#114-B/#118) bit-identical past 4096 | | Gemma 4 E4B-it Q8 | [unsloth](https://huggingface.co/unsloth/gemma-4-E4B-it-GGUF) | 8 GB | CPU | 4.9 | 5.0 | dense 42-layer gemma4: per-layer head_dim (256 SWA / 512 global), dual-RoPE, KV-share tail (18 layers), 5:1 SWA:global, logit softcap 30, PLE-256 injection (~4.2 GB mmap-resident) | -| Gemma 4 E4B-it Q8 | (same) | 8 GB | **CUDA** `-g -1 -c 2048` | **3698** | **59** | all 42 layers fit at `-c 2048`; KV-share alias + per-layer SWA/global split; PLE projections (~215 MB) upload at construction. **Prefill (#141/#146/#147/#149):** int8 tensor-core MMQ (`mma.m16n8k32.s8`, weight read once as int8) + tensor-core flash attention (QK^T and P·V on mma cores, register-resident O via d-split) + a SoA Q8_0 weight repack (16-byte-aligned quants) + batched embedding lookup — **109 → 3698 t/s @1K, 4240 @1.8K**. **>4096 prompts (#162/#164):** a real SWA KV ring (cache sized window + one chunk span, indexed `pos % size`) lets long prompts take the chunked batched-flash path instead of the ~8× slower per-token fallback (**2.72× — 47.7 → 129.6 t/s on a 5.3K-token prompt**), and fixes long-context correctness past the 512 window (the cache was previously window-sized but indexed absolutely → out of bounds). **Decode (#142):** dp4a/Q8_1 int8 matvec + CUDA-graph replay. Fast paths argmax-stable vs fp32 (SoA repack + ring bit-identical below the window). Bisect env: `SHARPI_PREFILL_MMQ` / `_FLASH_TC` / `SHARPI_MMQ_SOA` / `SHARPI_Q80_DP4A` / `SHARPI_CUDA_GRAPH` / `SHARPI_BATCHED_PREFILL` (`=0`). Remaining gap to llama.cpp (~8475 prefill / ~78 decode): cp.async-pipelined MMQ + decode matvec | +| Gemma 4 E4B-it Q8 | (same) | 8 GB | **CUDA** `-g -1 -c 2048` | **3698** | **59** | all 42 layers fit at `-c 2048`; KV-share alias + per-layer SWA/global split; PLE projections (~215 MB) upload at construction. **Prefill (#141/#146/#147/#149):** int8 tensor-core MMQ (`mma.m16n8k32.s8`, weight read once as int8) + tensor-core flash attention (QK^T and P·V on mma cores, register-resident O via d-split) + a SoA Q8_0 weight repack (16-byte-aligned quants) + batched embedding lookup — **109 → 3698 t/s @1K, 4240 @1.8K**. **>4096 prompts (#162/#164):** a real SWA KV ring (cache sized window + one chunk span, indexed `pos % size`) lets long prompts take the chunked batched-flash path instead of the ~8× slower per-token fallback (**2.72× — 47.7 → 129.6 t/s on a 5.3K-token prompt**, measured at a context that admits it, not the `-c 2048` of this row), and fixes long-context correctness past the 512 window (the cache was previously window-sized but indexed absolutely → out of bounds). **Decode (#142):** dp4a/Q8_1 int8 matvec + CUDA-graph replay. Fast paths argmax-stable vs fp32 (SoA repack + ring bit-identical below the window). Bisect env: `SHARPI_PREFILL_MMQ` / `_FLASH_TC` / `SHARPI_MMQ_SOA` / `SHARPI_Q80_DP4A` / `SHARPI_CUDA_GRAPH` / `SHARPI_BATCHED_PREFILL` (`=0`). Remaining gap to llama.cpp (~8475 prefill / ~78 decode): cp.async-pipelined MMQ + decode matvec | | Gemma 4 E4B-it Q8 | (same) | 8 GB | **CUDA** `-g 22 -c 2048` (hybrid) | 6.6 | 6.8 | 22 GPU + 20 CPU layers. `-g ≤ 22` required so the CPU shared-KV tail can read its own-KV source layers; CPU dense-FFN dominates decode (bandwidth-bound). `SHARPI_CUDA_PROFILE=1` for per-phase breakdown | _Numbers re-measured across every on-disk row at ~1K ctx so the prefill column is comparable; per-issue diff --git a/src/SharpInference.Engine/CudaForwardPass.cs b/src/SharpInference.Engine/CudaForwardPass.cs index 8bbfd61..f6e41ad 100644 --- a/src/SharpInference.Engine/CudaForwardPass.cs +++ b/src/SharpInference.Engine/CudaForwardPass.cs @@ -406,6 +406,14 @@ public CudaForwardPass(GgufModel model, CudaBackend gpu, ModelHyperparams hp, _maxSeqLen = EstimateMaxContextTq(model, gpu, hp, tqFp32Window, tqBits); else _maxSeqLen = EstimateMaxContext(model, gpu, hp); + // The KV-append/attention kernels index the cache at `pos % _maxSeqLen` (the ring + // modulo, identity for full caches), so a zero context — e.g. a malformed GGUF with + // context_length=0 reached via an explicit ctx-size — would be an in-kernel + // divide-by-zero (GPU trap). Fail loud at construction instead. + if (_maxSeqLen < 1) + throw new ArgumentException( + $"Resolved max context length is {_maxSeqLen}; the model's context_length " + + "metadata is missing or zero.", nameof(maxContextLength)); if (_tqEnabled) { @@ -1884,9 +1892,17 @@ public ReadOnlySpan Prefill(IReadOnlyList tokens, int startPos = 0) // SWA layers are now correct across chunk boundaries (issue #162): their KV // ring is sized window + SwaRingHeadroom (≥ one chunk span), so appending a // whole chunk before attending never overwrites a still-needed window, and the - // flash/append kernels wrap reads/writes modulo the ring (SwaRingSize). So flash - // alone gates chunking; windowed (Gemma 4) and dense models both qualify. - bool canChunkPast4096 = PrefillFlashAttnEnabled; + // flash/append kernels wrap reads/writes modulo the ring (SwaRingSize). + // + // Fail-closed guard (kept from the pre-#162 gate): the per-layer SWA dispatch + // keys off `IsSwaLayer`, which today only Gemma 4 populates. A future arch that + // sets a model-wide `SlidingWindowSize` WITHOUT a per-layer `IsSwaLayer` pattern + // would run every layer as full-causal (window silently ignored) — harmless + // within the proven 4096 cap, but extending that past 4096 would silently drop + // the window. So only chunk past 4096 when either there's a real per-layer SWA + // pattern or the model has no window at all. + bool canChunkPast4096 = PrefillFlashAttnEnabled + && (_hp.IsSwaLayer is not null || _hp.SlidingWindowSize <= 0); int cap = canChunkPast4096 ? _maxSeqLen : 4096; if (startPos + N <= cap) { diff --git a/tests/SharpInference.Tests.ForwardPass/Gemma4CudaBatchedPrefillTests.cs b/tests/SharpInference.Tests.ForwardPass/Gemma4CudaBatchedPrefillTests.cs index f0bae3d..94e61fb 100644 --- a/tests/SharpInference.Tests.ForwardPass/Gemma4CudaBatchedPrefillTests.cs +++ b/tests/SharpInference.Tests.ForwardPass/Gemma4CudaBatchedPrefillTests.cs @@ -395,17 +395,21 @@ public void Gemma4_E4B_BatchedPrefill_PastWindow_MatchesSequential() /// 4096-token chunk before any of those queries attend (so it needs ring ≥ window + /// chunk span). If the ring were undersized, the chunked path would overwrite an /// earlier query's window and diverge from the per-token reference — so agreement - /// validates the window + SwaRingHeadroom sizing. ctx exceeds window + headroom, so the - /// SWA cache is a true ring (positions ≥ ring size wrap) rather than a full cache. + /// validates the window + SwaRingHeadroom sizing. /// /// - /// N=5040 exercises a full + partial chunk (4096 + 944); N=8192 the exact-multiple - /// boundary (two full 4096 chunks). + /// The prompt lengths make the LAST chunk SHORTER than the 512 window (4296 = 4096+200, + /// 8292 = 2·4096+100). Prefill returns only the final token's logits, and with a short + /// last chunk the final token's window reaches BACK across the preceding chunk boundary — + /// so the single observable logit genuinely depends on cross-chunk KV reads (and, for + /// 8292 at ctx 9216 > window+headroom=4608, ring-WRAPPED reads). A whole-chunk prompt + /// (e.g. 8192) would leave the final window entirely inside the last chunk, hiding any + /// cross-chunk ring overwrite from the assertion. 8292 also drives a 3-chunk loop. /// /// [Theory] - [InlineData(5040, 6144)] - [InlineData(8192, 9216)] + [InlineData(4296, 6144)] + [InlineData(8292, 9216)] public void Gemma4_E4B_ChunkedBatchedPrefill_Over4096_MatchesSequential(int promptLen, int ctx) { using var gpu = TryCreate();