perf(cuda): chunked batched prefill for Gemma 4 SWA via a real KV ring (#162)#164
Conversation
#162) Generalizes >4096-token chunked batched prefill to Gemma 4 sliding-window (SWA) layers. This required first fixing a pre-existing latent bug: SWA KV caches were allocated window-sized but every CUDA kernel indexed them by absolute position, so any context beyond the 512-token window read/wrote out of bounds (silently wrong output; never caught because all Gemma 4 tests used <=12-token prompts at maxContextLength 512). Fix: a real SWA KV ring. - Cache sized SwaRingSize(ctx, window) = min(ctx, window + 4096). The +chunk headroom is required because batched prefill appends a whole chunk before any of those queries attend; a bare-window ring would clobber the earliest queries' windows. - All KV-indexing kernels wrap at `pos % max_seq_len` (= allocated cache size). Identity for full/dense/global caches (pos < size); wraps for windowed SWA rings. Kernels: llm_kv_append (+bf16), llm_kv_append_batched (+bf16), llm_attention_swa, llm_attention_swa_batched, llm_flash_attn_prefill_tc / tc2 / f32. - Chunked-prefill guard relaxed from "no SWA" to "flash enabled". - EstimateMaxContext SWA per-layer cap updated to window + headroom. Also fixes long-context Gemma 4 correctness on the decode path and in CudaHybridForwardPass (both indexed the window-sized cache absolutely). Validation: - New tests: PastWindow (700-token prompt > window) and chunked-ring Theory (5040/6144, 8192/9216 — ctx > window+headroom so the cache is a true ring; chunked vs per-token overwrite the ring differently, so agreement validates sizing). Gemma4 + Qwen3 prefill/decode suites green; 26 synthetic bit-wise kernel tests (incl. bf16 append) green. - Speedup A/B (Gemma 4 E4B, 5313-token prompt, RTX 4070 Ti): chunked 129.6 t/s vs 47.7 per-token = 2.72x. Long-context greedy decode stays coherent. Closes the SWA sub-item of #162. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request implements a sliding window attention (SWA) ring buffer for Gemma-4 models to support prompts longer than the sliding window and enable chunked batched prefill past 4096 tokens. It introduces modulo indexing in CUDA kernels and adds headroom to the SWA ring size calculation to prevent overwriting needed cache slots. Feedback highlights a potential thread-safety issue in the newly added tests due to process-wide environment variable modifications, suggesting sequential execution for these tests.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| var prevSnap = Environment.GetEnvironmentVariable("SHARPI_SNAPKV_BUDGET"); | ||
| Environment.SetEnvironmentVariable("SHARPI_SNAPKV_BUDGET", "0"); |
There was a problem hiding this comment.
Modifying environment variables via Environment.SetEnvironmentVariable is a process-wide operation and is not thread-safe. If tests are executed in parallel, this can lead to race conditions and flaky test results if other tests concurrently instantiate CudaForwardPass and read SHARPI_SNAPKV_BUDGET.
To prevent this, consider marking this test class with the [Collection("Sequential")] attribute (or a shared collection) to ensure it runs sequentially with respect to other tests that access or modify environment variables.
…162) The Gemma 4 `-g -1` row's Notes cell had grown into a ~400-word historical dump of #141/#146/#147/#149/#142. Condense to the essentials (the prefill optimizations, env toggles, decode path, remaining llama.cpp gap) and fold in this PR's SWA KV ring: >4096-token prompts now take the chunked batched-flash path (2.72×, 47.7 → 129.6 t/s on a 5.3K prompt) instead of the per-token fallback, and long-context correctness past the 512 window is fixed. The ~1K-ctx prefill/decode columns are unchanged — re-benchmarked the affected CUDA rows (qwen3 439 vs 432, gemma4 3594 vs 3698, gemma4-hybrid 6.5 vs 6.6 t/s, all within run-to-run noise); the chunked path only triggers >4096 tokens, so the change is modulo-identity for the ~1K column by construction. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
|
Thanks. Re: the Second review-toolkit cycle did surface three real items, now fixed in follow-up commits:
|
…g test (#162) Second review-toolkit cycle findings: - Restore the fail-closed clause on the chunked-prefill gate. canChunkPast4096 is again gated on a real per-layer IsSwaLayer pattern (or no window), so a future arch that sets a model-wide SlidingWindowSize without a per-layer SWA pattern can't silently run full-causal attention past 4096 (window dropped). - Guard _maxSeqLen < 1 at construction. The KV kernels now index `pos % max_seq_len`, so a malformed zero-context GGUF reached via an explicit ctx-size would GPU-trap on divide-by-zero; fail loud instead. - Strengthen the chunked-prefill test. Prefill returns only the final token's logits; with whole-chunk prompts (8192) the final window sits entirely inside the last chunk, so a cross-chunk ring overwrite would be invisible to the assertion. New lengths 4296 (=4096+200) and 8292 (=2*4096+100) make the last chunk shorter than the 512 window, so the final token's window reaches back across the chunk boundary — the observed logit now genuinely depends on cross-chunk (and, for 8292, ring-wrapped) KV reads. 8/8 green. README: clarify the >4096 2.72x number is measured at a context that admits the 5.3K prompt, not the row's -c 2048. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
perf(cuda): chunked batched prefill for Gemma 4 SWA via a real KV ring (#162)
What & why
Generalizes the >4096-token chunked batched prefill (#162/#163, dense models) to Gemma 4
sliding-window-attention (SWA) layers. Getting there required first fixing a pre-existing
latent correctness bug: SWA KV caches were allocated window-sized (
min(ctx, 512)forE4B) but every CUDA kernel indexed them by absolute position with no modulo. So for any
context beyond the 512-token window, the SWA layers read/wrote out of bounds —
producing silently wrong output (OOB on pooled VRAM doesn't crash). It was never caught
because every Gemma 4 test used a ≤12-token prompt at
maxContextLength: 512.The fix: a real SWA KV ring
SwaRingSize(ctx, window) = min(ctx, window + SwaRingHeadroom)with
SwaRingHeadroom = max(PrefillBatchChunk, 4096) = 4096. The+chunkheadroom isrequired because batched prefill appends a whole chunk's K/V before any of those
queries attend — a bare-window ring would overwrite the earliest queries' windows.
pos % max_seq_len, where themax_seq_lenarg at every call site equals the allocated cache size. For full-context (dense/global)
caches
pos < max_seq_len, so the modulo is the identity (zero behaviour change);for window-sized SWA rings it wraps. Kernels touched:
llm_kv_append(+bf16),llm_kv_append_batched(+bf16),llm_attention_swa,llm_attention_swa_batched,llm_flash_attn_prefill_tc/tc2/f32.canChunkPast4096 = PrefillFlashAttnEnabled). Flash streams the windowed KV across chunks; the global layersneed flash anyway (the non-flash AttentionBatched caps at 4096).
EstimateMaxContextSWA per-layer byte cap updated from bare window towindow + headroom.This also fixes long-context Gemma 4 correctness on the decode path and in
CudaHybridForwardPass(both indexed the window-sized cache absolutely before): per-tokendecode only needs
ring ≥ window, which the modulo now provides.Validation
Gemma4CudaBatchedPrefillTests:PastWindow_MatchesSequential— 700-token prompt (> 512 window), batched vs per-token.Exercises windowed attention past the window (the latent-bug regime); full cache here.
ChunkedBatchedPrefill_Over4096_MatchesSequential[Theory 5040/6144, 8192/9216] — ctxexceeds
window + headroomso the SWA cache is a true ring (positions wrap). This isthe decisive ring-sizing oracle: the per-token reference attends right after each single
append (needs ring ≥ window) while the chunked path appends a whole chunk first (needs
ring ≥ window + chunk span), so the two overwrite the ring differently — agreement
validates the sizing.
prefill suites green; 26 synthetic bit-wise kernel tests (incl. bf16 append) green.
129.6 t/s vs 47.7 t/s per-token = 2.72×. (More modest than dense Qwen3's 8.3×
because E4B is much heavier per token; the win is real and the chunked path engages.)
window, exercising the ring across chunks) produced a fluent, on-topic continuation
("This is an incredibly detailed and well-structured design document…") — a garbage KV
ring would yield gibberish, so this confirms correctness end-to-end. (The llama.cpp
cross-tool check was skipped: llama-cli hung in interactive mode and the chat-template
token-parity matching is a known rabbit hole; the chunked-vs-per-token oracle + this
coherence run cover correctness.)
Notes / follow-ups
covered by
Qwen3_8B_ChunkedBatchedPrefill_Over4096_MatchesSequential(8192 case).full-context GDN-hybrid path uses the bf16 KV cache); it's added for write/read symmetry
so a future windowed-bf16 model can't silently corrupt. No bf16-wrap test exists.
needed span with 1–2 slots to spare);
SwaRingHeadroomis derived fromPrefillBatchChunkso the two stay coupled.
Closes the SWA sub-item of #162.
🤖 Generated with Claude Code