Skip to content

perf(cuda): Gemma 4 batched decode routes through the WS/decode-MMQ matvec (#275)#280

Merged
pekkah merged 1 commit into
masterfrom
perf/275-gemma4-decode-routing
Jun 17, 2026
Merged

perf(cuda): Gemma 4 batched decode routes through the WS/decode-MMQ matvec (#275)#280
pekkah merged 1 commit into
masterfrom
perf/275-gemma4-decode-routing

Conversation

@pekkah

@pekkah pekkah commented Jun 17, 2026

Copy link
Copy Markdown
Owner

Closes #275. Follow-up to #195 (PR #273) Gemma 4 CUDA continuous batching.

Problem

#195 shipped the Gemma 4 batched decode (RunBatchedTrunkGemma4 / GpuLayerBatchedDecodeGemma4) with every trunk matmul on the prefill cuBLAS GEMM (GpuMatMulBatched) — correctness-first. But cuBLAS GEMM is compute-bound and loses to the #194 weight-stationary matvec / #201 decode-MMQ for small-N decode (the dense RunBatchedTrunk routes through BatchDecodeMatMul for exactly this reason — see #190).

Fix

Route the Gemma 4 trunk matmuls (Q/K/V/O, gate/up/down, lm-head) through BatchDecodeMatMul(..., allowDecodeMmq) — the same WS / decode-MMQ / GEMM-N router the dense path uses — threading allowDecodeMmq from the RunBatchedTrunk call site.

  • The PLE pre-pass (BuildPerLayerProjectionsBatched) and PLE-injection (_gpuInpGate/_gpuPleProj) matmuls stay on the cuBLAS GEMM: their pleWidth shapes aren't a decode-matvec win and the GEMM path is proven argmax-stable for them.
  • The k_eq_v V=rawK CopyDevice and the shared-KV skips are unchanged.
  • Per-layer-head_dim output views (qAll/kAll/vAll) work unchanged — MatMulBatchedWeightStationary/MatMulBatchedDecodeMmq derive rows/cols from ElementCount/n exactly like the cuBLAS path.

For E4B Q8_0, decode-MMQ (Q4_K-only) correctly falls back to the bit-exact WS matvec; a future Q4_K/Q5_K/Q6_K 12B would additionally pick up the decode-MMQ tile at N≥5.

Performance

Bench Gemma4_E4B_Decode_Throughput_Batched (Gemma 4 E4B Q8_0, RTX 4070 Ti, BatchForwardMulti aggregate t/s), new WS routing vs old all-GEMM (SHARPI_BATCH_DECODE_GEMM=1):

N new (WS/decode-MMQ) old (all-GEMM) speedup
1 60.5 22.4 2.70×
2 105.3 43.0 2.45×
4 144.2 80.7 1.79×
8 157.1 143.4 1.10×

The expected shape: WS/GEMM-N dominates small-N decode (weight-read amortized), converging toward the compute-bound GEMM as N grows — positive at every batch size.

Correctness

Argmax-stable with the single-user ForwardGemma4 loop (same contract as the prior all-GEMM routing). The 4 Gemma4CudaBatchForwardMultiTests oracles now run through the WS routing and pass (GPU, 4070 Ti). Builds clean (Release, TreatWarningsAsErrors). One code-review pass: no actionable findings.

Note: the 12B-global k_eq_v batched path still lacks runnable coverage (no GEMM-N-batchable 12B fits 12 GB) — tracked separately by #276.

🤖 Generated with Claude Code

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the Gemma 4 batched decode path in CudaForwardPass.cs to route trunk matmuls (Q/K/V/O, gate/up/down, and lm-head) through BatchDecodeMatMul instead of GpuMatMulBatched, matching the dense decode path. It also introduces a new benchmark test class Gemma4CudaBatchedDecodeBench to measure and compare decode throughput under different routing configurations. There are no review comments, so no additional feedback is provided.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

…atvec (#275)

#195 shipped Gemma 4 continuous batching with every trunk matmul on the cuBLAS
GEMM (GpuMatMulBatched), correctness-first. cuBLAS GEMM is compute-bound and
loses to the #194 weight-stationary matvec / #201 decode-MMQ tile for small-N
decode (the dense RunBatchedTrunk already routes through BatchDecodeMatMul for
exactly this reason).

Route the Gemma 4 trunk matmuls (Q/K/V/O, gate/up/down, lm-head) through
BatchDecodeMatMul (WS / decode-MMQ / GEMM-N), threading allowDecodeMmq from the
RunBatchedTrunk call site. The PLE pre-pass + injection matmuls stay on the
cuBLAS GEMM — their pleWidth shapes aren't a decode-matvec win and the GEMM path
is proven argmax-stable for them. The k_eq_v V=rawK copy and shared-KV skips are
unchanged. Output views (per-layer head_dim) work unchanged: WS/MMQ derive
rows/cols from ElementCount/n exactly like the cuBLAS path.

Argmax-stable with the single-user ForwardGemma4 loop (same contract as the prior
all-GEMM routing); the 4 Gemma4CudaBatchForwardMultiTests oracles now exercise
the WS routing and pass.

Bench (Gemma 4 E4B Q8_0, 4070 Ti, BatchForwardMulti aggregate t/s, new WS vs old
all-GEMM via SHARPI_BATCH_DECODE_GEMM=1):
  N=1 60.5 vs 22.4 (2.70x), N=2 105.3 vs 43.0 (2.45x),
  N=4 144.2 vs 80.7 (1.79x), N=8 157.1 vs 143.4 (1.10x).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@pekkah pekkah force-pushed the perf/275-gemma4-decode-routing branch from 724fbd2 to 6d93e41 Compare June 17, 2026 16:04
@pekkah pekkah merged commit c9a3ec1 into master Jun 17, 2026
1 check passed
@pekkah pekkah deleted the perf/275-gemma4-decode-routing branch June 17, 2026 16:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

perf(cuda): Gemma 4 batched decode uses cuBLAS GEMM, not the dense WS/MMQ decode routing (#195 follow-up)

1 participant