perf(cuda): Gemma 4 batched decode routes through the WS/decode-MMQ matvec (#275)#280
Conversation
There was a problem hiding this comment.
Code Review
This pull request updates the Gemma 4 batched decode path in CudaForwardPass.cs to route trunk matmuls (Q/K/V/O, gate/up/down, and lm-head) through BatchDecodeMatMul instead of GpuMatMulBatched, matching the dense decode path. It also introduces a new benchmark test class Gemma4CudaBatchedDecodeBench to measure and compare decode throughput under different routing configurations. There are no review comments, so no additional feedback is provided.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
…atvec (#275) #195 shipped Gemma 4 continuous batching with every trunk matmul on the cuBLAS GEMM (GpuMatMulBatched), correctness-first. cuBLAS GEMM is compute-bound and loses to the #194 weight-stationary matvec / #201 decode-MMQ tile for small-N decode (the dense RunBatchedTrunk already routes through BatchDecodeMatMul for exactly this reason). Route the Gemma 4 trunk matmuls (Q/K/V/O, gate/up/down, lm-head) through BatchDecodeMatMul (WS / decode-MMQ / GEMM-N), threading allowDecodeMmq from the RunBatchedTrunk call site. The PLE pre-pass + injection matmuls stay on the cuBLAS GEMM — their pleWidth shapes aren't a decode-matvec win and the GEMM path is proven argmax-stable for them. The k_eq_v V=rawK copy and shared-KV skips are unchanged. Output views (per-layer head_dim) work unchanged: WS/MMQ derive rows/cols from ElementCount/n exactly like the cuBLAS path. Argmax-stable with the single-user ForwardGemma4 loop (same contract as the prior all-GEMM routing); the 4 Gemma4CudaBatchForwardMultiTests oracles now exercise the WS routing and pass. Bench (Gemma 4 E4B Q8_0, 4070 Ti, BatchForwardMulti aggregate t/s, new WS vs old all-GEMM via SHARPI_BATCH_DECODE_GEMM=1): N=1 60.5 vs 22.4 (2.70x), N=2 105.3 vs 43.0 (2.45x), N=4 144.2 vs 80.7 (1.79x), N=8 157.1 vs 143.4 (1.10x). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
724fbd2 to
6d93e41
Compare
Closes #275. Follow-up to #195 (PR #273) Gemma 4 CUDA continuous batching.
Problem
#195 shipped the Gemma 4 batched decode (
RunBatchedTrunkGemma4/GpuLayerBatchedDecodeGemma4) with every trunk matmul on the prefill cuBLAS GEMM (GpuMatMulBatched) — correctness-first. But cuBLAS GEMM is compute-bound and loses to the #194 weight-stationary matvec / #201 decode-MMQ for small-N decode (the denseRunBatchedTrunkroutes throughBatchDecodeMatMulfor exactly this reason — see #190).Fix
Route the Gemma 4 trunk matmuls (Q/K/V/O, gate/up/down, lm-head) through
BatchDecodeMatMul(..., allowDecodeMmq)— the same WS / decode-MMQ / GEMM-N router the dense path uses — threadingallowDecodeMmqfrom theRunBatchedTrunkcall site.BuildPerLayerProjectionsBatched) and PLE-injection (_gpuInpGate/_gpuPleProj) matmuls stay on the cuBLAS GEMM: theirpleWidthshapes aren't a decode-matvec win and the GEMM path is proven argmax-stable for them.k_eq_vV=rawKCopyDeviceand the shared-KV skips are unchanged.qAll/kAll/vAll) work unchanged —MatMulBatchedWeightStationary/MatMulBatchedDecodeMmqderive rows/cols fromElementCount/nexactly like the cuBLAS path.For E4B Q8_0, decode-MMQ (Q4_K-only) correctly falls back to the bit-exact WS matvec; a future Q4_K/Q5_K/Q6_K 12B would additionally pick up the decode-MMQ tile at N≥5.
Performance
Bench
Gemma4_E4B_Decode_Throughput_Batched(Gemma 4 E4B Q8_0, RTX 4070 Ti,BatchForwardMultiaggregate t/s), new WS routing vs old all-GEMM (SHARPI_BATCH_DECODE_GEMM=1):The expected shape: WS/GEMM-N dominates small-N decode (weight-read amortized), converging toward the compute-bound GEMM as N grows — positive at every batch size.
Correctness
Argmax-stable with the single-user
ForwardGemma4loop (same contract as the prior all-GEMM routing). The 4Gemma4CudaBatchForwardMultiTestsoracles now run through the WS routing and pass (GPU, 4070 Ti). Builds clean (Release,TreatWarningsAsErrors). One code-review pass: no actionable findings.Note: the 12B-global
k_eq_vbatched path still lacks runnable coverage (no GEMM-N-batchable 12B fits 12 GB) — tracked separately by #276.🤖 Generated with Claude Code