Skip to content

perf(gemma4-cuda): collapse per-token launches + all-GPU batched-trunk prefill (#136)#137

Merged
pekkah merged 11 commits into
masterfrom
perf/gemma4-cuda-launch-reduction-136
Jun 5, 2026
Merged

perf(gemma4-cuda): collapse per-token launches + all-GPU batched-trunk prefill (#136)#137
pekkah merged 11 commits into
masterfrom
perf/gemma4-cuda-launch-reduction-136

Conversation

@pekkah

@pekkah pekkah commented Jun 5, 2026

Copy link
Copy Markdown
Owner

Gemma 4 CUDA performance: collapse per-token launches + batched-trunk prefill

Closes the work tracked in #136. Gemma 4 decode on CUDA is launch-bound (~1,000+ kernel launches/token), and prefill ran a per-token loop. This PR collapses launches on the decode path and adds an all-GPU batched-trunk prefill, plus two launch fusions.

Measured impact (E4B-it Q8, RTX 4070 Ti, -g -1 -c 2048, ~1K warm)

Before After Δ
Prefill t/s 47.7 109.4 ~2.05× (SHARPI_BATCHED_PREFILL=0 A/B: 53.4 → 109.4)
Decode t/s (near-zero-ctx) 44.1 49.1 +11%

What's in it (8 perf commits + 1 docs)

P0 — decode launch collapse

P0 — all-GPU batched-trunk prefill (CudaForwardPass)

  • Embeds N tokens, builds PLE projections batched, runs every layer batched across N, final norm/output on the last token. Per-layer head_dim via max-sized scratch + exact-N views.
  • Four new batched kernels, each with a bit-exact oracle: RoPEWithFactorsBatched, AttentionSwaBatched, llm_matvec_q8_0_gemm_n (unblocks the Q8_0 model), llm_gelu_tanh_mul_strided (gather-free batched PLE).
  • Gated by IsGemma4BatchedPrefillSupported (dense, NEOX, non-L2 QK-norm, no attn-bias, TQ off, batchable dtypes, ctx ≤ 4096); SnapKV-active / unsupported configs fall back to the bit-exact per-token loop. SHARPI_BATCHED_PREFILL=0 forces sequential.

P1/P2 — launch fusions

  • Drop the per-layer Q-prescale via an attn_scale kernel param (sentinel -1f keeps every non-Gemma caller bit-identical; Gemma passes 1.0 and now matches the CPU scale=1.0 reference exactly).
  • Fuse Q+K HeadNorm into one launch (llm_head_norm_qk + batched).

Verification

  • Every new kernel has a *_BitwiseMatchesSingle/Sequential/Separate oracle.
  • End-to-end Gemma4_E4B_BatchedPrefill_MatchesSequential confirms batched ≡ per-token post-prompt logits on the real Q8_0 GGUF.
  • Full CUDA regression green: 111/111 (all non-Gemma model paths — Qwen/GDN/MoE/Coder — confirm the attn_scale sentinel + shared kernel-array additions are inert). 0 warnings (TreatWarningsAsErrors).

Deferred (tracked as open boxes on #136)

  • CUDA graphs (the structural fix for the launch-bound ceiling; large, its own effort — no-op BeginRecord/RecordBarrier scaffolding already present).
  • layer_output_scale fold / EmbeddingScale bake (~1 launch/layer each; marginal vs. residual-path risk).

🤖 Generated with Claude Code

pekkah and others added 9 commits June 4, 2026 22:29
…ices (#136)

The PLE projection slices have static per-layer offsets (layer*PleWidth), but
both BuildPerLayerProjectionsGpu and ApplyPerLayerEmbeddingGpu round-tripped
each slice through CopyDeviceRegion every token because the comment claimed
"the Tensor abstraction can't encode a device-pointer offset". That's stale —
CudaBackend.View (added for #111) does exactly this.

All-GPU path (CudaForwardPass):
- Precompute static per-layer views into _gpuProjPerLayer / _gpuPleRow once.
- Build loop now does RmsNorm+Add+Scale in place on views (3 copies/slice gone).
- ApplyPerLayerEmbeddingGpu reads the proj slice via its view (1 copy/layer gone).

Hybrid path (CudaHybridForwardPass):
- Upload the full [nGpu*pleWidth] CPU projection once per token instead of L
  tiny per-layer H2D transfers; GPU layer i reads slice i via a static view.
  Removes L-1 uploads and their barriers per token.

Numerically inert. Gemma4 CUDA + hybrid + kernel tests green (12/12).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…136)

BuildPerLayerProjectionsGpu ran RmsNorm+Add+Scale per layer slice — 3×L tiny
launches per token. The add and scale are pure elementwise and don't couple
across slice boundaries; only the RmsNorm is per-slice. So the loop collapses
to: one RmsNormBatched (one block per row, byte-identical to llm_rmsnorm) over
all L rows, then a single full-buffer AddInPlace and ScaleInPlace. 3×L → 3
launches, reusing existing parity-proven kernels (no new CUDA).

Removes the now-unused _gpuPleRowSliceViews. Gemma4 CUDA argmax-vs-CPU parity
and long-decode coherence tests green (12/12).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…ill (#136)

Batched NEOX-with-factors RoPE over N tokens (Gemma 4 global layers), mirroring
llm_rope_neox_partial_batched's grid.y=token / position=base+token layout.
Per row bit-identical to the per-token llm_rope_neox_with_factors. Oracle test
RoPEWithFactorsBatched_BitwiseMatchesSingle green across nTok={3,17,64}.

First of two new batched kernels needed for all-GPU Gemma 4 batched-trunk prefill.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…#136)

Batched sliding-window attention over N query tokens (Gemma 4 SWA layers):
llm_full_seq_attention's grid=(num_heads,n_tok) structure with the per-token
llm_attention_swa windowing. Window bounds eff_seq ≤ window_size ≤ 4096, so the
shared-scores path always suffices — no global scratch. Bit-identical per
(head,token) to the per-token kernel. Oracle AttentionSwaBatched_BitwiseMatchesSingle
green across startPos/nTok/window combinations incl. clamped and >0 windows.

Second of two new batched kernels for all-GPU Gemma 4 batched-trunk prefill.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…prefill (#136)

Two more batched primitives the all-GPU Gemma 4 trunk needs:

- llm_matvec_q8_0_gemm_n: batches the Q8_0 matvec over N tokens (the on-disk
  gemma-4-E4B model is Q8_0, which MatMulBatched previously rejected). Same
  (rows+7)/8 × nTok geometry as the F32/Q5/Q6 GEMM-N path; bit-identical per
  (token,row) to llm_matvec_q8_0. Oracle MatMulBatched_Q8_0_BitwiseMatchesSequential green.
- llm_gelu_tanh_mul_strided: GeluTanhMul whose up operand is read with a per-token
  stride/offset, so batched PLE can inject the per-layer slice of a [N×(L*pleWidth)]
  projection buffer without a transpose/gather. Per element bit-identical to llm_gelu_tanh_mul.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Lifts Gemma 4 out of the per-token prefill loop in CudaForwardPass. Embeds all N
prompt tokens, builds the per-layer PLE projections batched (one proj GEMM-N +
per-(token,layer) RmsNormBatched + full-buffer add/scale), runs every transformer
layer batched across N, then the final norm + output projection on the last token.

Batched per layer: RmsNormBatched, Q/K/V/O/gate/up/down GEMM-N (MatMulBatched,
now incl. Q8_0), HeadNormBatched QK-norm, dual-RoPE (RoPEWithFactorsBatched global
/ RoPEPartialBatched SWA), KvAppendBatched, SWA/full batched attention, GeluTanhMul
FFN, and fully-batched PLE injection via the strided GeluTanhMul (per-layer slice of
the [N×L*pleWidth] projection read by stride — no transpose/gather).

Per-layer head_dim handled via max-sized scratch + exact-N dense views. Gated by
IsGemma4BatchedPrefillSupported (dense, NEOX, non-L2 QK-norm, no attn-bias, TQ off,
batchable weight dtypes) and startPos+N ≤ 4096 (shared-scores attention); SnapKV-
active / unsupported configs fall back to the bit-exact per-token loop.
SHARPI_BATCHED_PREFILL=0 forces sequential.

End-to-end parity: Gemma4_E4B_BatchedPrefill_MatchesSequential confirms batched and
per-token post-prompt logits agree (argmax-exact, maxAbs<1e-2) on the real Q8_0 E4B
GGUF. Full Gemma4 CUDA + batched-kernel suite green (26/26).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…ram (#136)

The Gemma 4 attention scale is 1.0, but the CUDA attention kernels baked in
rsqrtf(head_dim) unconditionally, so every layer prescaled Q by sqrt(head_dim)
to cancel it — one ScaleInPlace launch per layer per token (decode + prefill).

Added a trailing `float attn_scale` to llm_attention, llm_attention_swa,
llm_full_seq_attention, and llm_attention_swa_batched: `scale = attn_scale>0 ?
attn_scale : rsqrtf(head_dim)`. The dispatches default attnScale=-1f (sentinel),
so every non-Gemma caller is bit-identical to before. Gemma passes attnScale=1f
and drops the prescale at all 5 call sites (ForwardGemma4, ForwardProfiledGemma4,
batched trunk, hybrid GpuLayerGemma4) — which also matches the CPU reference's
scale=1.0 exactly (no prescale round-trip rounding).

Gemma4 CUDA argmax-vs-CPU + long-decode coherence + batched-prefill parity green
(14/14); AttentionSwaBatched sentinel oracle confirms non-Gemma bit-identity.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Gemma 4 applies per-head RmsNorm to Q (numHeads) and K (numKvHeads) with separate
weights — two HeadNorm launches per layer. New llm_head_norm_qk / _batched run both
in one launch over numHeads+numKvHeads blocks (first numHeads do Q, rest K). Per
block bit-identical to llm_head_norm. Used in ForwardGemma4 / ForwardProfiledGemma4
/ batched trunk / hybrid GpuLayerGemma4 when !kvShared (kv-shared layers keep the
single Q-only HeadNorm). Halves the QK-norm launch pair per layer (decode + prefill).

Oracle HeadNormQk_BitwiseMatchesSeparate (single + batched nTok={3,17,40}) green;
Gemma4 CUDA suite + batched-prefill parity 10/10.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Batched-trunk prefill (#136) lifts the all-GPU Gemma 4 prefill from a per-token
loop to batched GEMM-N + batched attention. Re-measured at the table's ~1K warm
methodology (bench-gemma4-136.ps1, same prompt as bench-allrows-1k.ps1):

  prefill  47.7 -> 109.4 t/s  (SHARPI_BATCHED_PREFILL=0 A/B: 53.4 -> 109.4, ~2.05x)
  decode   44.1 ->  49.1 t/s  (near-zero-ctx; PLE-projection fusion + dropped
                               Q-prescale + fused Q/K HeadNorm on a launch-bound model)

Hybrid -g 22 row unchanged: prefill still per-token there, decode is CPU-dense-FFN
bandwidth-bound, so the launch fusions don't move it (measured 6.7/6.5 ~ 6.6/6.8 noise).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements an all-GPU batched-trunk prefill optimization for Gemma 4 (Issue #136), which collapses per-position attention, FFN, and PLE launches into batched GEMM-N and batched-attention launches to significantly improve prefill performance. It introduces several new batched CUDA kernels (such as dual Q+K HeadNorm, batched SWA, and strided GELU-tanh multiplication) and integrates them into the forward pass logic, backed by comprehensive unit tests to ensure bit-identical outputs. The review feedback suggests improving the benchmark script by parameterizing the hardcoded model directory for better portability and wrapping the environment variable modifications in a try-finally block to ensure robust cleanup.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread scripts/bench-gemma4-136.ps1 Outdated
Comment on lines +6 to +7
$E = "E:\models"
$model = "$E\gemma-4-E4B-it-Q8_0.gguf"

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The model path is hardcoded, which makes the script less portable. Consider making the model directory a script parameter with a default value. This would allow running the benchmark against models in different locations without modifying the script. It's a good practice to define parameters at the beginning of the script.

param(
    [string]$ModelDir = "E:\models"
)
$model = Join-Path $ModelDir "gemma-4-E4B-it-Q8_0.gguf"

Comment on lines +31 to +39
function Run($tag, $a, $envOff) {
if ($envOff) { $env:SHARPI_BATCHED_PREFILL = "0" } else { Remove-Item env:SHARPI_BATCHED_PREFILL -ErrorAction SilentlyContinue }
# Run twice, keep the second (warm) — matches the discard-first guidance.
$null = .\scripts\bench-textgen.ps1 -Model $model -Tag "$tag-w" -NTokens 60 -Prompt $prompt -TimeoutSec 900 -ExtraArgs $a
$r = .\scripts\bench-textgen.ps1 -Model $model -Tag $tag -NTokens 60 -Prompt $prompt -TimeoutSec 900 -ExtraArgs $a
Remove-Item env:SHARPI_BATCHED_PREFILL -ErrorAction SilentlyContinue
$script:rows += [PSCustomObject]@{ Tag=$tag; PrefTok=$r.PrefillTok; PrefillTps=$r.PrefillTps; DecodeTps=$r.DecodeTps; Wall=$r.WallSec; TO=$r.TimedOut }
Write-Host (" {0,-22} pref={1,7} t/s dec={2,6} t/s ({3} tok)" -f $tag,$r.PrefillTps,$r.DecodeTps,$r.PrefillTok) -ForegroundColor Green
}

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The SHARPI_BATCHED_PREFILL environment variable is modified but not robustly cleaned up. If an error occurs during the benchmark run, the variable might persist in the shell session. Using a try...finally block ensures that the environment variable is always cleaned up, regardless of whether the script runs successfully or encounters an error.

function Run($tag, $a, $envOff) {
    try {
        if ($envOff) { $env:SHARPI_BATCHED_PREFILL = "0" } else { Remove-Item env:SHARPI_BATCHED_PREFILL -ErrorAction SilentlyContinue }
        # Run twice, keep the second (warm) — matches the discard-first guidance.
        $null = .\scripts\bench-textgen.ps1 -Model $model -Tag "$tag-w" -NTokens 60 -Prompt $prompt -TimeoutSec 900 -ExtraArgs $a
        $r = .\scripts\bench-textgen.ps1 -Model $model -Tag $tag -NTokens 60 -Prompt $prompt -TimeoutSec 900 -ExtraArgs $a
        $script:rows += [PSCustomObject]@{ Tag=$tag; PrefTok=$r.PrefillTok; PrefillTps=$r.PrefillTps; DecodeTps=$r.DecodeTps; Wall=$r.WallSec; TO=$r.TimedOut }
        Write-Host ("  {0,-22} pref={1,7} t/s  dec={2,6} t/s  ({3} tok)" -f $tag,$r.PrefillTps,$r.DecodeTps,$r.PrefillTok) -ForegroundColor Green
    }
    finally {
        Remove-Item env:SHARPI_BATCHED_PREFILL -ErrorAction SilentlyContinue
    }
}

- Fail-safe prefill gate: new BatchableWeight() treats an unregistered weight
  handle as NOT batchable (was WDType→Q4_K default, which biased the gate toward
  "supported"); also gate on _gpuPerLayerModelProj's dtype. Forces the per-token
  fallback for any untracked weight instead of misdispatching into GEMM-N.
- Add the missing high-value oracle: AttentionExplicitScale_BitwiseMatchesSingle_
  AndDiffersFromSentinel exercises the attn_scale=1f path (the PR's central change)
  bit-exact for batched SWA + full attention, and asserts 1f != the rsqrt sentinel
  so the param is provably plumbed (previously only end-to-end coverage).
- Tighten the end-to-end parity test: assert >95% bit-exact fraction, not just the
  loose maxAbs<1e-2 envelope (the exact count was computed but unused).
- Add <param name="attnScale"> docs to the 4 attention dispatches (sentinel contract).
- Comment fixes: PLE collapse note (≈6 ops×L, not 3×L); batched-scratch note now
  states Q/K/V/attn buffers are max-head_dim-sized + viewed per layer.

Gemma4 CUDA suite + all batched-kernel oracles 28/28 green.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…ni, PR #137)

Addresses the two gemini-code-assist comments on scripts/bench-gemma4-136.ps1:
- $ModelDir param (default E:\models) instead of a hardcoded path.
- Wrap the SHARPI_BATCHED_PREFILL toggle in try/finally so the env var is always
  cleaned up even if a bench run throws.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@pekkah

pekkah commented Jun 5, 2026

Copy link
Copy Markdown
Owner Author

Review addressed (toolkit + gemini)

Ran the pr-review-toolkit (code-reviewer, silent-failure-hunter, pr-test-analyzer, comment-analyzer) plus the gemini-code-assist comments. Fixes pushed in 5b1a859 (toolkit) and 9ae0329 (gemini). 28/28 Gemma 4 CUDA + batched-kernel oracles green after the changes.

Code / safety

  • Fail-safe prefill gate — new BatchableWeight() treats an unregistered weight handle as not-batchable (was WDType → Q4_K default, which biased the gate toward "supported"); also gate on _gpuPerLayerModelProj's dtype. An untracked weight now forces the per-token fallback instead of being misdispatched into GEMM-N.

Tests (the top finding — the central change was under-covered)

  • Added AttentionExplicitScale_BitwiseMatchesSingle_AndDiffersFromSentinel: exercises the attn_scale=1f path (Gemma's production path) bit-exact for both batched SWA and full attention, and asserts 1f ≠ the rsqrt sentinel so the param is provably plumbed. Previously only end-to-end coverage hit it.
  • Tightened the end-to-end parity test to assert a >95% bit-exact fraction (the exact count was computed but unused), not just the loose maxAbs<1e-2 envelope.

Comments / docs

  • <param name="attnScale"> added to all 4 attention dispatches (the ≤0 → 1/sqrt(headDim) sentinel contract).
  • Fixed the PLE-collapse note (≈6 ops×L, not 3×L) and the batched-scratch note (Q/K/V/attn buffers are max-head_dim-sized + viewed per layer).

Bench script (gemini)

  • $ModelDir param instead of a hardcoded path; try/finally around the SHARPI_BATCHED_PREFILL toggle.

Deferred to a separate issue (#138), not a blocker

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant