perf(gemma4-cuda): collapse per-token launches + all-GPU batched-trunk prefill (#136)#137
Conversation
…ices (#136) The PLE projection slices have static per-layer offsets (layer*PleWidth), but both BuildPerLayerProjectionsGpu and ApplyPerLayerEmbeddingGpu round-tripped each slice through CopyDeviceRegion every token because the comment claimed "the Tensor abstraction can't encode a device-pointer offset". That's stale — CudaBackend.View (added for #111) does exactly this. All-GPU path (CudaForwardPass): - Precompute static per-layer views into _gpuProjPerLayer / _gpuPleRow once. - Build loop now does RmsNorm+Add+Scale in place on views (3 copies/slice gone). - ApplyPerLayerEmbeddingGpu reads the proj slice via its view (1 copy/layer gone). Hybrid path (CudaHybridForwardPass): - Upload the full [nGpu*pleWidth] CPU projection once per token instead of L tiny per-layer H2D transfers; GPU layer i reads slice i via a static view. Removes L-1 uploads and their barriers per token. Numerically inert. Gemma4 CUDA + hybrid + kernel tests green (12/12). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…136) BuildPerLayerProjectionsGpu ran RmsNorm+Add+Scale per layer slice — 3×L tiny launches per token. The add and scale are pure elementwise and don't couple across slice boundaries; only the RmsNorm is per-slice. So the loop collapses to: one RmsNormBatched (one block per row, byte-identical to llm_rmsnorm) over all L rows, then a single full-buffer AddInPlace and ScaleInPlace. 3×L → 3 launches, reusing existing parity-proven kernels (no new CUDA). Removes the now-unused _gpuPleRowSliceViews. Gemma4 CUDA argmax-vs-CPU parity and long-decode coherence tests green (12/12). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…ill (#136) Batched NEOX-with-factors RoPE over N tokens (Gemma 4 global layers), mirroring llm_rope_neox_partial_batched's grid.y=token / position=base+token layout. Per row bit-identical to the per-token llm_rope_neox_with_factors. Oracle test RoPEWithFactorsBatched_BitwiseMatchesSingle green across nTok={3,17,64}. First of two new batched kernels needed for all-GPU Gemma 4 batched-trunk prefill. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…#136) Batched sliding-window attention over N query tokens (Gemma 4 SWA layers): llm_full_seq_attention's grid=(num_heads,n_tok) structure with the per-token llm_attention_swa windowing. Window bounds eff_seq ≤ window_size ≤ 4096, so the shared-scores path always suffices — no global scratch. Bit-identical per (head,token) to the per-token kernel. Oracle AttentionSwaBatched_BitwiseMatchesSingle green across startPos/nTok/window combinations incl. clamped and >0 windows. Second of two new batched kernels for all-GPU Gemma 4 batched-trunk prefill. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…prefill (#136) Two more batched primitives the all-GPU Gemma 4 trunk needs: - llm_matvec_q8_0_gemm_n: batches the Q8_0 matvec over N tokens (the on-disk gemma-4-E4B model is Q8_0, which MatMulBatched previously rejected). Same (rows+7)/8 × nTok geometry as the F32/Q5/Q6 GEMM-N path; bit-identical per (token,row) to llm_matvec_q8_0. Oracle MatMulBatched_Q8_0_BitwiseMatchesSequential green. - llm_gelu_tanh_mul_strided: GeluTanhMul whose up operand is read with a per-token stride/offset, so batched PLE can inject the per-layer slice of a [N×(L*pleWidth)] projection buffer without a transpose/gather. Per element bit-identical to llm_gelu_tanh_mul. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Lifts Gemma 4 out of the per-token prefill loop in CudaForwardPass. Embeds all N prompt tokens, builds the per-layer PLE projections batched (one proj GEMM-N + per-(token,layer) RmsNormBatched + full-buffer add/scale), runs every transformer layer batched across N, then the final norm + output projection on the last token. Batched per layer: RmsNormBatched, Q/K/V/O/gate/up/down GEMM-N (MatMulBatched, now incl. Q8_0), HeadNormBatched QK-norm, dual-RoPE (RoPEWithFactorsBatched global / RoPEPartialBatched SWA), KvAppendBatched, SWA/full batched attention, GeluTanhMul FFN, and fully-batched PLE injection via the strided GeluTanhMul (per-layer slice of the [N×L*pleWidth] projection read by stride — no transpose/gather). Per-layer head_dim handled via max-sized scratch + exact-N dense views. Gated by IsGemma4BatchedPrefillSupported (dense, NEOX, non-L2 QK-norm, no attn-bias, TQ off, batchable weight dtypes) and startPos+N ≤ 4096 (shared-scores attention); SnapKV- active / unsupported configs fall back to the bit-exact per-token loop. SHARPI_BATCHED_PREFILL=0 forces sequential. End-to-end parity: Gemma4_E4B_BatchedPrefill_MatchesSequential confirms batched and per-token post-prompt logits agree (argmax-exact, maxAbs<1e-2) on the real Q8_0 E4B GGUF. Full Gemma4 CUDA + batched-kernel suite green (26/26). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…ram (#136) The Gemma 4 attention scale is 1.0, but the CUDA attention kernels baked in rsqrtf(head_dim) unconditionally, so every layer prescaled Q by sqrt(head_dim) to cancel it — one ScaleInPlace launch per layer per token (decode + prefill). Added a trailing `float attn_scale` to llm_attention, llm_attention_swa, llm_full_seq_attention, and llm_attention_swa_batched: `scale = attn_scale>0 ? attn_scale : rsqrtf(head_dim)`. The dispatches default attnScale=-1f (sentinel), so every non-Gemma caller is bit-identical to before. Gemma passes attnScale=1f and drops the prescale at all 5 call sites (ForwardGemma4, ForwardProfiledGemma4, batched trunk, hybrid GpuLayerGemma4) — which also matches the CPU reference's scale=1.0 exactly (no prescale round-trip rounding). Gemma4 CUDA argmax-vs-CPU + long-decode coherence + batched-prefill parity green (14/14); AttentionSwaBatched sentinel oracle confirms non-Gemma bit-identity. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Gemma 4 applies per-head RmsNorm to Q (numHeads) and K (numKvHeads) with separate
weights — two HeadNorm launches per layer. New llm_head_norm_qk / _batched run both
in one launch over numHeads+numKvHeads blocks (first numHeads do Q, rest K). Per
block bit-identical to llm_head_norm. Used in ForwardGemma4 / ForwardProfiledGemma4
/ batched trunk / hybrid GpuLayerGemma4 when !kvShared (kv-shared layers keep the
single Q-only HeadNorm). Halves the QK-norm launch pair per layer (decode + prefill).
Oracle HeadNormQk_BitwiseMatchesSeparate (single + batched nTok={3,17,40}) green;
Gemma4 CUDA suite + batched-prefill parity 10/10.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Batched-trunk prefill (#136) lifts the all-GPU Gemma 4 prefill from a per-token loop to batched GEMM-N + batched attention. Re-measured at the table's ~1K warm methodology (bench-gemma4-136.ps1, same prompt as bench-allrows-1k.ps1): prefill 47.7 -> 109.4 t/s (SHARPI_BATCHED_PREFILL=0 A/B: 53.4 -> 109.4, ~2.05x) decode 44.1 -> 49.1 t/s (near-zero-ctx; PLE-projection fusion + dropped Q-prescale + fused Q/K HeadNorm on a launch-bound model) Hybrid -g 22 row unchanged: prefill still per-token there, decode is CPU-dense-FFN bandwidth-bound, so the launch fusions don't move it (measured 6.7/6.5 ~ 6.6/6.8 noise). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request implements an all-GPU batched-trunk prefill optimization for Gemma 4 (Issue #136), which collapses per-position attention, FFN, and PLE launches into batched GEMM-N and batched-attention launches to significantly improve prefill performance. It introduces several new batched CUDA kernels (such as dual Q+K HeadNorm, batched SWA, and strided GELU-tanh multiplication) and integrates them into the forward pass logic, backed by comprehensive unit tests to ensure bit-identical outputs. The review feedback suggests improving the benchmark script by parameterizing the hardcoded model directory for better portability and wrapping the environment variable modifications in a try-finally block to ensure robust cleanup.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| $E = "E:\models" | ||
| $model = "$E\gemma-4-E4B-it-Q8_0.gguf" |
There was a problem hiding this comment.
The model path is hardcoded, which makes the script less portable. Consider making the model directory a script parameter with a default value. This would allow running the benchmark against models in different locations without modifying the script. It's a good practice to define parameters at the beginning of the script.
param(
[string]$ModelDir = "E:\models"
)
$model = Join-Path $ModelDir "gemma-4-E4B-it-Q8_0.gguf"
| function Run($tag, $a, $envOff) { | ||
| if ($envOff) { $env:SHARPI_BATCHED_PREFILL = "0" } else { Remove-Item env:SHARPI_BATCHED_PREFILL -ErrorAction SilentlyContinue } | ||
| # Run twice, keep the second (warm) — matches the discard-first guidance. | ||
| $null = .\scripts\bench-textgen.ps1 -Model $model -Tag "$tag-w" -NTokens 60 -Prompt $prompt -TimeoutSec 900 -ExtraArgs $a | ||
| $r = .\scripts\bench-textgen.ps1 -Model $model -Tag $tag -NTokens 60 -Prompt $prompt -TimeoutSec 900 -ExtraArgs $a | ||
| Remove-Item env:SHARPI_BATCHED_PREFILL -ErrorAction SilentlyContinue | ||
| $script:rows += [PSCustomObject]@{ Tag=$tag; PrefTok=$r.PrefillTok; PrefillTps=$r.PrefillTps; DecodeTps=$r.DecodeTps; Wall=$r.WallSec; TO=$r.TimedOut } | ||
| Write-Host (" {0,-22} pref={1,7} t/s dec={2,6} t/s ({3} tok)" -f $tag,$r.PrefillTps,$r.DecodeTps,$r.PrefillTok) -ForegroundColor Green | ||
| } |
There was a problem hiding this comment.
The SHARPI_BATCHED_PREFILL environment variable is modified but not robustly cleaned up. If an error occurs during the benchmark run, the variable might persist in the shell session. Using a try...finally block ensures that the environment variable is always cleaned up, regardless of whether the script runs successfully or encounters an error.
function Run($tag, $a, $envOff) {
try {
if ($envOff) { $env:SHARPI_BATCHED_PREFILL = "0" } else { Remove-Item env:SHARPI_BATCHED_PREFILL -ErrorAction SilentlyContinue }
# Run twice, keep the second (warm) — matches the discard-first guidance.
$null = .\scripts\bench-textgen.ps1 -Model $model -Tag "$tag-w" -NTokens 60 -Prompt $prompt -TimeoutSec 900 -ExtraArgs $a
$r = .\scripts\bench-textgen.ps1 -Model $model -Tag $tag -NTokens 60 -Prompt $prompt -TimeoutSec 900 -ExtraArgs $a
$script:rows += [PSCustomObject]@{ Tag=$tag; PrefTok=$r.PrefillTok; PrefillTps=$r.PrefillTps; DecodeTps=$r.DecodeTps; Wall=$r.WallSec; TO=$r.TimedOut }
Write-Host (" {0,-22} pref={1,7} t/s dec={2,6} t/s ({3} tok)" -f $tag,$r.PrefillTps,$r.DecodeTps,$r.PrefillTok) -ForegroundColor Green
}
finally {
Remove-Item env:SHARPI_BATCHED_PREFILL -ErrorAction SilentlyContinue
}
}
- Fail-safe prefill gate: new BatchableWeight() treats an unregistered weight handle as NOT batchable (was WDType→Q4_K default, which biased the gate toward "supported"); also gate on _gpuPerLayerModelProj's dtype. Forces the per-token fallback for any untracked weight instead of misdispatching into GEMM-N. - Add the missing high-value oracle: AttentionExplicitScale_BitwiseMatchesSingle_ AndDiffersFromSentinel exercises the attn_scale=1f path (the PR's central change) bit-exact for batched SWA + full attention, and asserts 1f != the rsqrt sentinel so the param is provably plumbed (previously only end-to-end coverage). - Tighten the end-to-end parity test: assert >95% bit-exact fraction, not just the loose maxAbs<1e-2 envelope (the exact count was computed but unused). - Add <param name="attnScale"> docs to the 4 attention dispatches (sentinel contract). - Comment fixes: PLE collapse note (≈6 ops×L, not 3×L); batched-scratch note now states Q/K/V/attn buffers are max-head_dim-sized + viewed per layer. Gemma4 CUDA suite + all batched-kernel oracles 28/28 green. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…ni, PR #137) Addresses the two gemini-code-assist comments on scripts/bench-gemma4-136.ps1: - $ModelDir param (default E:\models) instead of a hardcoded path. - Wrap the SHARPI_BATCHED_PREFILL toggle in try/finally so the env var is always cleaned up even if a bench run throws. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Review addressed (toolkit + gemini)Ran the pr-review-toolkit (code-reviewer, silent-failure-hunter, pr-test-analyzer, comment-analyzer) plus the gemini-code-assist comments. Fixes pushed in Code / safety
Tests (the top finding — the central change was under-covered)
Comments / docs
Bench script (gemini)
Deferred to a separate issue (#138), not a blocker
|
Gemma 4 CUDA performance: collapse per-token launches + batched-trunk prefill
Closes the work tracked in #136. Gemma 4 decode on CUDA is launch-bound (~1,000+ kernel launches/token), and prefill ran a per-token loop. This PR collapses launches on the decode path and adds an all-GPU batched-trunk prefill, plus two launch fusions.
Measured impact (E4B-it Q8, RTX 4070 Ti,
-g -1 -c 2048, ~1K warm)SHARPI_BATCHED_PREFILL=0A/B: 53.4 → 109.4)What's in it (8 perf commits + 1 docs)
P0 — decode launch collapse
Viewslices (the "Tensor can't encode a device offset" comment was stale;Viewfrom perf(engine,cuda): batch the GDN-hybrid prefill trunk (GEMM-N attn/GDN projections) — trunk is now ~62% of prefill (#110 follow-up) #111 does exactly that).RmsNormBatched+ 2 full-buffer ops (~204 → 3 launches/token).P0 — all-GPU batched-trunk prefill (
CudaForwardPass)RoPEWithFactorsBatched,AttentionSwaBatched,llm_matvec_q8_0_gemm_n(unblocks the Q8_0 model),llm_gelu_tanh_mul_strided(gather-free batched PLE).IsGemma4BatchedPrefillSupported(dense, NEOX, non-L2 QK-norm, no attn-bias, TQ off, batchable dtypes, ctx ≤ 4096); SnapKV-active / unsupported configs fall back to the bit-exact per-token loop.SHARPI_BATCHED_PREFILL=0forces sequential.P1/P2 — launch fusions
attn_scalekernel param (sentinel-1fkeeps every non-Gemma caller bit-identical; Gemma passes1.0and now matches the CPUscale=1.0reference exactly).llm_head_norm_qk+ batched).Verification
*_BitwiseMatchesSingle/Sequential/Separateoracle.Gemma4_E4B_BatchedPrefill_MatchesSequentialconfirms batched ≡ per-token post-prompt logits on the real Q8_0 GGUF.attn_scalesentinel + shared kernel-array additions are inert). 0 warnings (TreatWarningsAsErrors).Deferred (tracked as open boxes on #136)
BeginRecord/RecordBarrierscaffolding already present).layer_output_scalefold /EmbeddingScalebake (~1 launch/layer each; marginal vs. residual-path risk).🤖 Generated with Claude Code