feat(cuda): Gemma 4 continuous batching — per-layer head_dim / SWA / shared-KV / PLE (#195)#273
feat(cuda): Gemma 4 continuous batching — per-layer head_dim / SWA / shared-KV / PLE (#195)#273pekkah wants to merge 2 commits into
Conversation
…shared-KV / PLE / softcap (#195) PR #192 (#190) shipped CUDA continuous batching for dense transformers only; CudaForwardPass excluded Gemma-4 models (SupportsContinuousBatching / ThrowIfBatchingUnsupported threw on per-layer head_dim). Gemma 4 is a primary target (#82/#124), so this adds batched serving for it. - FillKvCacheArrays: extract the owned-cache allocation into the single source of truth for per-layer KV geometry (per-layer head_dim + KV-head count, SWA ring sizing, shared-KV tail aliasing). Used by BOTH the constructor (fills the owned arrays in place — identity preserved so _ownedKCache stays valid) and CreateCache, so a per-sequence cache the batched decode / PrefillWithCache binds is byte-identical to the owned one. Adds a forward-reference guard + frees partial (non-aliased) allocations on OOM. - CreateCache: now allocates Gemma-4 per-layer geometry + SWA rings + shared-KV aliasing (was dense-only). - RunBatchedTrunkGemma4 + GpuLayerBatchedDecodeGemma4: the N-sequence decode analogue of the single-token RunGemma4DeviceRegion — batched embed (+scale) + PLE pre-pass, per layer {batched RmsNorm/QKV/QK-norm/V-norm, per-sequence RoPE/KV-append/SWA-or-global attention against caches[n] at positions[n], batched O-proj + sandwich norms + FFN + PLE injection + layer_output_scale}, batched final norm + output GEMM + final-logit softcap into _decodeLogitsAll. Matmuls use the proven GpuMatMulBatched (cuBLAS GEMM, weight-amortized) — the dense WS/MMQ decode routing for Gemma 4 is a follow-up. - Gates: SupportsContinuousBatching = dense OR Gemma4BatchedDecodeSupported (the latter does NOT exclude softcap, which the decode applies); ThrowIfBatching Unsupported admits Gemma 4. Spec-verify (SupportsBatchVerify), the 2-slot prefix cache (SupportsMultiSlotPrefix), and the #193 dense packed prefill (IsDensePackablePrefill) stay dense-only — Gemma 4 prefill uses the sequential PrefillWithCache loop (the Gemma-4-capable single-seq batched trunk). - Loader: dense CUDA full-offload already routes via SupportsContinuousBatching, so Gemma 4 full-offload now batches automatically (image input + batching stays rejected; hybrid/Q4_0 fall back to single-user). Tests (Gemma 4 E4B Q8_0, 4070 Ti — exercises per-layer head_dim 256, SWA rings, the 18-layer shared-KV tail, PLE): SupportsContinuousBatching gate, PrefillWithCache vs single-user prefill (CreateCache geometry/aliasing), BatchForwardMulti N=2 and two decode steps vs single-user ForwardGemma4 — argmax-stable within the fp16-GEMM tolerance. Dense (15) + Gemma 4 single-user + prefill (14) regressions all pass (the constructor cache refactor is byte-correct for both families). Two review passes (code-reviewer + silent-failure-hunter), no critical issues. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request implements CUDA continuous batching and batched decode support for Gemma-4 models (issue #195). Key changes include refactoring KV-cache allocation into a shared helper to maintain layout consistency, updating batching support gates, and introducing a dedicated Gemma-4 batched decode path (RunBatchedTrunkGemma4 and GpuLayerBatchedDecodeGemma4) that handles per-layer head dimensions, SWA rings, shared-KV aliasing, and PLE injection. New integration tests are also added to verify these paths. The review feedback provides excellent, actionable optimization suggestions to avoid allocating and freeing key/value tensor views on the hot path when dealing with shared-KV layers.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| var qAll = _gpu.View(_bpQ!, 0, (long)N * qDimL); | ||
| var kAll = _gpu.View(_bpK!, 0, (long)N * kvDimL); | ||
| var vAll = _gpu.View(_bpV!, 0, (long)N * kvDimL); | ||
| var attnAll = _gpu.View(_bpAttnOut!, 0, (long)N * qDimL); |
There was a problem hiding this comment.
For shared-KV layers (kvShared is true), the key and value projections are not computed. We can avoid creating the kAll and vAll views entirely in this case, reducing the overhead of creating and freeing views on the hot path.
var qAll = _gpu.View(_bpQ!, 0, (long)N * qDimL);
Tensor? kAll = null;
Tensor? vAll = null;
if (!kvShared)
{
kAll = _gpu.View(_bpK!, 0, (long)N * kvDimL);
vAll = _gpu.View(_bpV!, 0, (long)N * kvDimL);
}
var attnAll = _gpu.View(_bpAttnOut!, 0, (long)N * qDimL);| var qv = _gpu.View(_bpQ!, (long)n * qDimL, qDimL); | ||
| var kv = _gpu.View(_bpK!, (long)n * kvDimL, kvDimL); | ||
| var vv = _gpu.View(_bpV!, (long)n * kvDimL, kvDimL); | ||
| var av = _gpu.View(_bpAttnOut!, (long)n * qDimL, qDimL); |
There was a problem hiding this comment.
Similarly, inside the sequence loop, we can avoid creating the kv and vv views when kvShared is true, as they are not used or modified for shared-KV layers.
var qv = _gpu.View(_bpQ!, (long)n * qDimL, qDimL);
Tensor? kv = null;
Tensor? vv = null;
if (!kvShared)
{
kv = _gpu.View(_bpK!, (long)n * kvDimL, kvDimL);
vv = _gpu.View(_bpV!, (long)n * kvDimL, kvDimL);
}
var av = _gpu.View(_bpAttnOut!, (long)n * qDimL, qDimL);| finally | ||
| { | ||
| _gpu.Free(qv); _gpu.Free(kv); _gpu.Free(vv); _gpu.Free(av); | ||
| } |
| if (_layerOutputScale is not null) | ||
| _gpu.ScaleInPlace(_bpHidden!, _layerOutputScale[layer]); | ||
|
|
||
| _gpu.Free(qAll); _gpu.Free(kAll); _gpu.Free(vAll); _gpu.Free(attnAll); |
…195 review) Address Gemini Code Assist feedback on PR #273: GpuLayerBatchedDecodeGemma4 created the kAll/vAll batched views and per-sequence kv/vv views on every layer, but a shared-KV tail layer (KvSourceLayer >= 0) reuses the source layer's cache and computes no K/V — those views are never written or read. Create them only on KV-owning layers (nullable + conditional), eliding two batched + 2·N per-sequence view alloc/frees per shared layer per decode step. Bit-identical (the elided views were dead); Gemma 4 E4B batched-decode oracles still pass. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
|
Thanks @gemini-code-assist — applied the suggestion: |
|
Understood. The reduction in view allocations for shared-KV layers should effectively minimize overhead in the batched decode path. Since the oracles pass and the logic is verified, this optimization appears complete. |
…195 review) Address Gemini Code Assist feedback on PR #273: GpuLayerBatchedDecodeGemma4 created the kAll/vAll batched views and per-sequence kv/vv views on every layer, but a shared-KV tail layer (KvSourceLayer >= 0) reuses the source layer's cache and computes no K/V — those views are never written or read. Create them only on KV-owning layers (nullable + conditional), eliding two batched + 2·N per-sequence view alloc/frees per shared layer per decode step. Bit-identical (the elided views were dead); Gemma 4 E4B batched-decode oracles still pass. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…shared-KV / PLE (#195) (#278) * feat(cuda): Gemma 4 continuous batching — per-layer head_dim / SWA / shared-KV / PLE / softcap (#195) PR #192 (#190) shipped CUDA continuous batching for dense transformers only; CudaForwardPass excluded Gemma-4 models (SupportsContinuousBatching / ThrowIfBatchingUnsupported threw on per-layer head_dim). Gemma 4 is a primary target (#82/#124), so this adds batched serving for it. - FillKvCacheArrays: extract the owned-cache allocation into the single source of truth for per-layer KV geometry (per-layer head_dim + KV-head count, SWA ring sizing, shared-KV tail aliasing). Used by BOTH the constructor (fills the owned arrays in place — identity preserved so _ownedKCache stays valid) and CreateCache, so a per-sequence cache the batched decode / PrefillWithCache binds is byte-identical to the owned one. Adds a forward-reference guard + frees partial (non-aliased) allocations on OOM. - CreateCache: now allocates Gemma-4 per-layer geometry + SWA rings + shared-KV aliasing (was dense-only). - RunBatchedTrunkGemma4 + GpuLayerBatchedDecodeGemma4: the N-sequence decode analogue of the single-token RunGemma4DeviceRegion — batched embed (+scale) + PLE pre-pass, per layer {batched RmsNorm/QKV/QK-norm/V-norm, per-sequence RoPE/KV-append/SWA-or-global attention against caches[n] at positions[n], batched O-proj + sandwich norms + FFN + PLE injection + layer_output_scale}, batched final norm + output GEMM + final-logit softcap into _decodeLogitsAll. Matmuls use the proven GpuMatMulBatched (cuBLAS GEMM, weight-amortized) — the dense WS/MMQ decode routing for Gemma 4 is a follow-up. - Gates: SupportsContinuousBatching = dense OR Gemma4BatchedDecodeSupported (the latter does NOT exclude softcap, which the decode applies); ThrowIfBatching Unsupported admits Gemma 4. Spec-verify (SupportsBatchVerify), the 2-slot prefix cache (SupportsMultiSlotPrefix), and the #193 dense packed prefill (IsDensePackablePrefill) stay dense-only — Gemma 4 prefill uses the sequential PrefillWithCache loop (the Gemma-4-capable single-seq batched trunk). - Loader: dense CUDA full-offload already routes via SupportsContinuousBatching, so Gemma 4 full-offload now batches automatically (image input + batching stays rejected; hybrid/Q4_0 fall back to single-user). Tests (Gemma 4 E4B Q8_0, 4070 Ti — exercises per-layer head_dim 256, SWA rings, the 18-layer shared-KV tail, PLE): SupportsContinuousBatching gate, PrefillWithCache vs single-user prefill (CreateCache geometry/aliasing), BatchForwardMulti N=2 and two decode steps vs single-user ForwardGemma4 — argmax-stable within the fp16-GEMM tolerance. Dense (15) + Gemma 4 single-user + prefill (14) regressions all pass (the constructor cache refactor is byte-correct for both families). Two review passes (code-reviewer + silent-failure-hunter), no critical issues. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> * perf(cuda): skip unused K/V views on Gemma 4 shared-KV decode layers (#195 review) Address Gemini Code Assist feedback on PR #273: GpuLayerBatchedDecodeGemma4 created the kAll/vAll batched views and per-sequence kv/vv views on every layer, but a shared-KV tail layer (KvSourceLayer >= 0) reuses the source layer's cache and computes no K/V — those views are never written or read. Create them only on KV-owning layers (nullable + conditional), eliding two batched + 2·N per-sequence view alloc/frees per shared layer per decode step. Bit-identical (the elided views were dead); Gemma 4 E4B batched-decode oracles still pass. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
Summary
Closes #195. Stacked on #193 (PR #272) — base is
feat/193-cuda-packed-prefillso this diff shows only the Gemma-4 work; rebase tomasteronce #272 merges.PR #192 (#190) shipped CUDA continuous batching for dense transformers only;
CudaForwardPassexcluded Gemma-4 models (SupportsContinuousBatching/ThrowIfBatchingUnsupportedthrew on per-layer head_dim). Gemma 4 is a primary target (#82/#124), so this adds batched serving for it.What changed (
src/SharpInference.Engine/CudaForwardPass.cs)FillKvCacheArrays— extracts the owned-cache allocation into the single source of truth for per-layer KV geometry (per-layerhead_dim+ KV-head count, SWA ring sizing viaSwaRingSize, shared-KV tail aliasing viaKvSourceLayer). Used by both the constructor (fills the owned arrays in place — identity preserved so_ownedKCachestays valid) andCreateCache, so a per-sequence cache the batched decode /PrefillWithCachebinds is byte-identical to the owned one. Forward-reference guard + frees partial (non-aliased) allocations on OOM.CreateCache— now allocates Gemma-4 per-layer geometry + SWA rings + shared-KV aliasing (was dense-only).RunBatchedTrunkGemma4+GpuLayerBatchedDecodeGemma4— the N-sequence decode analogue of the single-tokenRunGemma4DeviceRegion: batched embed (+scale) + PLE pre-pass, per layer {batched RmsNorm/QKV/QK-norm/V-norm, per-sequence RoPE / KV-append / SWA-or-global attention againstcaches[n]atpositions[n], batched O-proj + sandwich norms + FFN + PLE injection +layer_output_scale}, batched final norm + output GEMM + final-logit softcap into_decodeLogitsAll. Matmuls use the provenGpuMatMulBatched(cuBLAS GEMM, weight-amortized across the batch); the dense WS/MMQ decode routing for Gemma 4 is a follow-up.SupportsContinuousBatching = dense OR Gemma4BatchedDecodeSupported(the Gemma-4 gate does not exclude softcap, which the decode applies);ThrowIfBatchingUnsupportedadmits Gemma 4. Spec-verify (SupportsBatchVerify), the 2-slot prefix cache (SupportsMultiSlotPrefix), and the feat(cuda): packed multi-prompt prefill for continuous batching — PrefillPackedMulti is sequential (#190 follow-up) #193 dense packed prefill (IsDensePackablePrefill) stay dense-only — Gemma 4 prefill uses the sequentialPrefillWithCacheloop (the Gemma-4-capable single-seq batched trunk).SupportsContinuousBatching, so Gemma 4 full-offload now batches automatically. Image input + batching stays rejected; hybrid / Q4_0 fall back to single-user.Testing (Gemma 4 E4B Q8_0, RTX 4070 Ti)
E4B exercises per-layer head_dim (256), SWA rings, the 18-layer shared-KV tail, and PLE. New
Gemma4CudaBatchForwardMultiTests:SupportsContinuousBatchinggate is true for E4B.PrefillWithCachevs single-userPrefill(validatesCreateCachegeometry + aliasing).BatchForwardMultiN=2 and two decode steps vs single-userForwardGemma4— argmax-stable within the fp16-GEMM tolerance.Regressions: dense batched + #193 packed (15) and Gemma 4 single-user + prefill trunk (14) all pass — the constructor cache refactor is byte-correct for both families. The only local 12B is Q4_0 (not GEMM-N-batchable → single-user fallback), so k_eq_v's batched path isn't covered by a runnable test; it mirrors
RunGemma4DeviceRegionexactly.Review
Code-reviewer + silent-failure-hunter passes: no critical/high issues. Partial-state (cache
.Lengthadvanced only by the caller's tail after the trunk completes; caller discards all caches on throw), double-free (aliased-layer skip), and Q4_0-slip-through (dtype gate) all verified safe. Applied the suggested forward-reference guard and fixed two stale comments.🤖 Generated with Claude Code