Skip to content

perf(cuda): chunked batched prefill for Gemma 4 SWA via a real KV ring (#162)#164

Merged
pekkah merged 3 commits into
masterfrom
perf/cuda-gemma4-swa-chunked-prefill-162
Jun 7, 2026
Merged

perf(cuda): chunked batched prefill for Gemma 4 SWA via a real KV ring (#162)#164
pekkah merged 3 commits into
masterfrom
perf/cuda-gemma4-swa-chunked-prefill-162

Conversation

@pekkah

@pekkah pekkah commented Jun 7, 2026

Copy link
Copy Markdown
Owner

perf(cuda): chunked batched prefill for Gemma 4 SWA via a real KV ring (#162)

What & why

Generalizes the >4096-token chunked batched prefill (#162/#163, dense models) to Gemma 4
sliding-window-attention (SWA) layers. Getting there required first fixing a pre-existing
latent correctness bug
: SWA KV caches were allocated window-sized (min(ctx, 512) for
E4B) but every CUDA kernel indexed them by absolute position with no modulo. So for any
context beyond the 512-token window, the SWA layers read/wrote out of bounds
producing silently wrong output (OOB on pooled VRAM doesn't crash). It was never caught
because every Gemma 4 test used a ≤12-token prompt at maxContextLength: 512.

The fix: a real SWA KV ring

  • SWA caches are now sized SwaRingSize(ctx, window) = min(ctx, window + SwaRingHeadroom)
    with SwaRingHeadroom = max(PrefillBatchChunk, 4096) = 4096. The +chunk headroom is
    required because batched prefill appends a whole chunk's K/V before any of those
    queries attend — a bare-window ring would overwrite the earliest queries' windows.
  • All KV-cache-indexing kernels now wrap at pos % max_seq_len, where the max_seq_len
    arg at every call site equals the allocated cache size. For full-context (dense/global)
    caches pos < max_seq_len, so the modulo is the identity (zero behaviour change);
    for window-sized SWA rings it wraps. Kernels touched: llm_kv_append(+bf16),
    llm_kv_append_batched(+bf16), llm_attention_swa, llm_attention_swa_batched,
    llm_flash_attn_prefill_tc / tc2 / f32.
  • Chunked-prefill guard relaxed from "no SWA" to "flash enabled" (canChunkPast4096 = PrefillFlashAttnEnabled). Flash streams the windowed KV across chunks; the global layers
    need flash anyway (the non-flash AttentionBatched caps at 4096).
  • EstimateMaxContext SWA per-layer byte cap updated from bare window to window + headroom.

This also fixes long-context Gemma 4 correctness on the decode path and in
CudaHybridForwardPass (both indexed the window-sized cache absolutely before): per-token
decode only needs ring ≥ window, which the modulo now provides.

Validation

  • New tests in Gemma4CudaBatchedPrefillTests:
    • PastWindow_MatchesSequential — 700-token prompt (> 512 window), batched vs per-token.
      Exercises windowed attention past the window (the latent-bug regime); full cache here.
    • ChunkedBatchedPrefill_Over4096_MatchesSequential [Theory 5040/6144, 8192/9216] — ctx
      exceeds window + headroom so the SWA cache is a true ring (positions wrap). This is
      the decisive ring-sizing oracle: the per-token reference attends right after each single
      append (needs ring ≥ window) while the chunked path appends a whole chunk first (needs
      ring ≥ window + chunk span), so the two overwrite the ring differently — agreement
      validates the sizing.
  • Gemma4 batched-prefill tests green (5 existing ≤512 + 3 new); Gemma4 decode + Qwen3
    prefill suites green; 26 synthetic bit-wise kernel tests (incl. bf16 append) green.
  • Speedup A/B (Gemma 4 E4B, 5313-token prompt, RTX 4070 Ti): chunked SWA prefill
    129.6 t/s vs 47.7 t/s per-token = 2.72×. (More modest than dense Qwen3's 8.3×
    because E4B is much heavier per token; the win is real and the chunked path engages.)
  • End-to-end coherence: greedy decode after a >5000-token prompt (well past the 512
    window, exercising the ring across chunks) produced a fluent, on-topic continuation
    ("This is an incredibly detailed and well-structured design document…") — a garbage KV
    ring would yield gibberish, so this confirms correctness end-to-end. (The llama.cpp
    cross-tool check was skipped: llama-cli hung in interactive mode and the chat-template
    token-parity matching is a known rabbit hole; the chunked-vs-per-token oracle + this
    coherence run cover correctness.)

Notes / follow-ups

  • The chunk-loop boundary off-by-one (exact-4096-multiple) is shared dense code already
    covered by Qwen3_8B_ChunkedBatchedPrefill_Over4096_MatchesSequential (8192 case).
  • The bf16 append-kernel ring modulo is identity-only in production today (only the
    full-context GDN-hybrid path uses the bf16 KV cache); it's added for write/read symmetry
    so a future windowed-bf16 model can't silently corrupt. No bf16-wrap test exists.
  • The ring's safety margin over a 4096-token chunk is tight (≈window + chunk holds the
    needed span with 1–2 slots to spare); SwaRingHeadroom is derived from PrefillBatchChunk
    so the two stay coupled.

Closes the SWA sub-item of #162.

🤖 Generated with Claude Code

#162)

Generalizes >4096-token chunked batched prefill to Gemma 4 sliding-window
(SWA) layers. This required first fixing a pre-existing latent bug: SWA KV
caches were allocated window-sized but every CUDA kernel indexed them by
absolute position, so any context beyond the 512-token window read/wrote
out of bounds (silently wrong output; never caught because all Gemma 4
tests used <=12-token prompts at maxContextLength 512).

Fix: a real SWA KV ring.
- Cache sized SwaRingSize(ctx, window) = min(ctx, window + 4096). The +chunk
  headroom is required because batched prefill appends a whole chunk before
  any of those queries attend; a bare-window ring would clobber the earliest
  queries' windows.
- All KV-indexing kernels wrap at `pos % max_seq_len` (= allocated cache
  size). Identity for full/dense/global caches (pos < size); wraps for
  windowed SWA rings. Kernels: llm_kv_append (+bf16), llm_kv_append_batched
  (+bf16), llm_attention_swa, llm_attention_swa_batched,
  llm_flash_attn_prefill_tc / tc2 / f32.
- Chunked-prefill guard relaxed from "no SWA" to "flash enabled".
- EstimateMaxContext SWA per-layer cap updated to window + headroom.

Also fixes long-context Gemma 4 correctness on the decode path and in
CudaHybridForwardPass (both indexed the window-sized cache absolutely).

Validation:
- New tests: PastWindow (700-token prompt > window) and chunked-ring Theory
  (5040/6144, 8192/9216 — ctx > window+headroom so the cache is a true ring;
  chunked vs per-token overwrite the ring differently, so agreement validates
  sizing). Gemma4 + Qwen3 prefill/decode suites green; 26 synthetic bit-wise
  kernel tests (incl. bf16 append) green.
- Speedup A/B (Gemma 4 E4B, 5313-token prompt, RTX 4070 Ti): chunked 129.6
  t/s vs 47.7 per-token = 2.72x. Long-context greedy decode stays coherent.

Closes the SWA sub-item of #162.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements a sliding window attention (SWA) ring buffer for Gemma-4 models to support prompts longer than the sliding window and enable chunked batched prefill past 4096 tokens. It introduces modulo indexing in CUDA kernels and adds headroom to the SWA ring size calculation to prevent overwriting needed cache slots. Feedback highlights a potential thread-safety issue in the newly added tests due to process-wide environment variable modifications, suggesting sequential execution for these tests.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +425 to +426
var prevSnap = Environment.GetEnvironmentVariable("SHARPI_SNAPKV_BUDGET");
Environment.SetEnvironmentVariable("SHARPI_SNAPKV_BUDGET", "0");

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Modifying environment variables via Environment.SetEnvironmentVariable is a process-wide operation and is not thread-safe. If tests are executed in parallel, this can lead to race conditions and flaky test results if other tests concurrently instantiate CudaForwardPass and read SHARPI_SNAPKV_BUDGET.

To prevent this, consider marking this test class with the [Collection("Sequential")] attribute (or a shared collection) to ensure it runs sequentially with respect to other tests that access or modify environment variables.

…162)

The Gemma 4 `-g -1` row's Notes cell had grown into a ~400-word historical
dump of #141/#146/#147/#149/#142. Condense to the essentials (the prefill
optimizations, env toggles, decode path, remaining llama.cpp gap) and fold in
this PR's SWA KV ring: >4096-token prompts now take the chunked batched-flash
path (2.72×, 47.7 → 129.6 t/s on a 5.3K prompt) instead of the per-token
fallback, and long-context correctness past the 512 window is fixed.

The ~1K-ctx prefill/decode columns are unchanged — re-benchmarked the affected
CUDA rows (qwen3 439 vs 432, gemma4 3594 vs 3698, gemma4-hybrid 6.5 vs 6.6 t/s,
all within run-to-run noise); the chunked path only triggers >4096 tokens, so
the change is modulo-identity for the ~1K column by construction.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@pekkah

pekkah commented Jun 7, 2026

Copy link
Copy Markdown
Owner Author

Thanks. Re: the SHARPI_SNAPKV_BUDGET env-var set/restore — this project disables xUnit parallelization (tests/SharpInference.Tests.ForwardPass/xunit.runner.json: parallelizeAssembly: false, parallelizeTestCollections: false), so no two tests run concurrently and the temporary value can't be observed by another test. The pattern also matches the existing Qwen3CudaBatchedPrefillTests chunked test. Leaving it as-is rather than adding a redundant [Collection].

Second review-toolkit cycle did surface three real items, now fixed in follow-up commits:

  • Restored a fail-closed guard on the chunked gate: chunking past 4096 now requires a real per-layer IsSwaLayer pattern (or no window), so a future model-wide-window-without-pattern arch can't silently run full-causal past 4096.
  • Strengthened the chunked test: the prompt lengths now make the last chunk shorter than the window (4296, 8292), so the returned final-token logit's window spans the chunk boundary — the assertion now actually observes cross-chunk (and wrapped) ring reads. Whole-chunk prompts (8192) hid that.
  • Added a fail-loud guard for _maxSeqLen < 1 (the kernels now do pos % maxSeqLen, which would GPU-trap on a malformed zero-context GGUF).

…g test (#162)

Second review-toolkit cycle findings:

- Restore the fail-closed clause on the chunked-prefill gate. canChunkPast4096
  is again gated on a real per-layer IsSwaLayer pattern (or no window), so a
  future arch that sets a model-wide SlidingWindowSize without a per-layer SWA
  pattern can't silently run full-causal attention past 4096 (window dropped).

- Guard _maxSeqLen < 1 at construction. The KV kernels now index `pos %
  max_seq_len`, so a malformed zero-context GGUF reached via an explicit
  ctx-size would GPU-trap on divide-by-zero; fail loud instead.

- Strengthen the chunked-prefill test. Prefill returns only the final token's
  logits; with whole-chunk prompts (8192) the final window sits entirely inside
  the last chunk, so a cross-chunk ring overwrite would be invisible to the
  assertion. New lengths 4296 (=4096+200) and 8292 (=2*4096+100) make the last
  chunk shorter than the 512 window, so the final token's window reaches back
  across the chunk boundary — the observed logit now genuinely depends on
  cross-chunk (and, for 8292, ring-wrapped) KV reads. 8/8 green.

README: clarify the >4096 2.72x number is measured at a context that admits the
5.3K prompt, not the row's -c 2048.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant