From 6d93e41a6da7c484daca12dea8cf1cdf05cf2553 Mon Sep 17 00:00:00 2001 From: Pekka Heikura Date: Wed, 17 Jun 2026 18:41:00 +0300 Subject: [PATCH] perf(cuda): Gemma 4 batched decode routes through the WS/decode-MMQ matvec (#275) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit #195 shipped Gemma 4 continuous batching with every trunk matmul on the cuBLAS GEMM (GpuMatMulBatched), correctness-first. cuBLAS GEMM is compute-bound and loses to the #194 weight-stationary matvec / #201 decode-MMQ tile for small-N decode (the dense RunBatchedTrunk already routes through BatchDecodeMatMul for exactly this reason). Route the Gemma 4 trunk matmuls (Q/K/V/O, gate/up/down, lm-head) through BatchDecodeMatMul (WS / decode-MMQ / GEMM-N), threading allowDecodeMmq from the RunBatchedTrunk call site. The PLE pre-pass + injection matmuls stay on the cuBLAS GEMM — their pleWidth shapes aren't a decode-matvec win and the GEMM path is proven argmax-stable for them. The k_eq_v V=rawK copy and shared-KV skips are unchanged. Output views (per-layer head_dim) work unchanged: WS/MMQ derive rows/cols from ElementCount/n exactly like the cuBLAS path. Argmax-stable with the single-user ForwardGemma4 loop (same contract as the prior all-GEMM routing); the 4 Gemma4CudaBatchForwardMultiTests oracles now exercise the WS routing and pass. Bench (Gemma 4 E4B Q8_0, 4070 Ti, BatchForwardMulti aggregate t/s, new WS vs old all-GEMM via SHARPI_BATCH_DECODE_GEMM=1): N=1 60.5 vs 22.4 (2.70x), N=2 105.3 vs 43.0 (2.45x), N=4 144.2 vs 80.7 (1.79x), N=8 157.1 vs 143.4 (1.10x). Co-Authored-By: Claude Opus 4.8 --- src/SharpInference.Engine/CudaForwardPass.cs | 40 +++-- .../Gemma4CudaBatchedDecodeBench.cs | 139 ++++++++++++++++++ 2 files changed, 163 insertions(+), 16 deletions(-) create mode 100644 tests/SharpInference.Tests.ForwardPass/Gemma4CudaBatchedDecodeBench.cs diff --git a/src/SharpInference.Engine/CudaForwardPass.cs b/src/SharpInference.Engine/CudaForwardPass.cs index a5b5834..cec1167 100644 --- a/src/SharpInference.Engine/CudaForwardPass.cs +++ b/src/SharpInference.Engine/CudaForwardPass.cs @@ -4090,7 +4090,7 @@ private void RunBatchedTrunk(int[] tokens, int[] positions, CudaSequenceKvCache[ throw new InvalidOperationException( "Gemma-4 batched decode reached a SnapKV-evicted cache; the constructor forces " + "the SnapKV budget off for Gemma 4, so this means that invariant was broken."); - RunBatchedTrunkGemma4(tokens, positions, caches); + RunBatchedTrunkGemma4(tokens, positions, caches, allowDecodeMmq); return; } @@ -4340,12 +4340,16 @@ private void RunBatchedTrunk(int[] tokens, int[] positions, CudaSequenceKvCache[ /// PLE pre-pass → per layer → batched final norm + /// output GEMM + final-logit softcap. Leaves [N×vocab] in on /// the stream (the caller's tail downloads/argmaxes + advances each cache length), exactly - /// like . The matmuls use the prefill GEMM routing - /// ( — proven for Gemma 4 shapes; weight reads amortize across - /// the batch via cuBLAS GEMM), not the dense decode WS/MMQ path (a Gemma-4 follow-up). - /// Argmax-stable with the single-user loop. + /// like . The trunk matmuls (Q/K/V/O, gate/up/down, lm-head) route + /// through — the #194 weight-stationary matvec (small-N decode) + /// or the #201/#206 int8 decode-MMQ tile for big Q4_K shapes — exactly like the dense + /// (issue #275). The PLE pre-pass + injection matmuls stay on the + /// cuBLAS GEMM (): their pleWidth shapes aren't a decode-matvec + /// win and the GEMM path is proven argmax-stable for them. Argmax-stable with the single-user + /// loop (the same contract as the prior all-GEMM routing). /// - private void RunBatchedTrunkGemma4(int[] tokens, int[] positions, CudaSequenceKvCache[] caches) + private void RunBatchedTrunkGemma4(int[] tokens, int[] positions, CudaSequenceKvCache[] caches, + bool allowDecodeMmq) { int N = tokens.Length; int embDim = _embDim; @@ -4368,13 +4372,13 @@ private void RunBatchedTrunkGemma4(int[] tokens, int[] positions, CudaSequenceKv // 3. Transformer layers. for (int layer = 0; layer < _hp.NumLayers; layer++) - GpuLayerBatchedDecodeGemma4(layer, N, positions, caches); + GpuLayerBatchedDecodeGemma4(layer, N, positions, caches, allowDecodeMmq); // 4. Final norm + output projection + softcap, batched across N (softcap is monotonic, so // argmax is invariant — but the returned logit values must be softcapped to match the // single-user ForwardGemma4 finisher and the full-logits BatchForwardMulti path). _gpu.RmsNormBatched(_bpHidden!, _bpHidden!, _wOutputNorm, N, embDim, _hp.RmsNormEps); - GpuMatMulBatched(_decodeLogitsAll!, _wOutput, _bpHidden!, N); + BatchDecodeMatMul(_decodeLogitsAll!, _wOutput, _bpHidden!, N, allowDecodeMmq); if (_hp.FinalLogitSoftcap > 0f) _gpu.SoftcapInPlace(_decodeLogitsAll!, _hp.FinalLogitSoftcap); } @@ -4388,8 +4392,12 @@ private void RunBatchedTrunkGemma4(int[] tokens, int[] positions, CudaSequenceKv /// single-query attention against caches[n] (shared-KV reads the source layer via effLayer). /// This is exactly the per-layer attention of the single-token , /// run once per sequence on a slice of the batched Q/K/V/attnOut scratch. + /// The QKV / O-proj / FFN matmuls route through (WS / + /// decode-MMQ / GEMM-N, issue #275); the PLE-injection matmuls stay on . + /// mirrors 's parameter. /// - private void GpuLayerBatchedDecodeGemma4(int layer, int N, int[] positions, CudaSequenceKvCache[] caches) + private void GpuLayerBatchedDecodeGemma4(int layer, int N, int[] positions, CudaSequenceKvCache[] caches, + bool allowDecodeMmq) { int layerHd = _hp.LayerHeadDim is { } lhd ? lhd[layer] : _headDim; int layerKv = _hp.LayerKvHeads is { } lkv ? lkv[layer] : _numKvHeads; @@ -4413,14 +4421,14 @@ private void GpuLayerBatchedDecodeGemma4(int layer, int N, int[] positions, Cuda _gpu.CopyDevice(_bpResidual!, _bpHidden!); _gpu.RmsNormBatched(_bpNorm!, _bpHidden!, _wAttnNorm[layer], N, _embDim, _hp.RmsNormEps); - GpuMatMulBatched(qAll, _wq[layer], _bpNorm!, N); + BatchDecodeMatMul(qAll, _wq[layer], _bpNorm!, N, allowDecodeMmq); if (!kvShared) { - GpuMatMulBatched(kAll!, _wk[layer], _bpNorm!, N); + BatchDecodeMatMul(kAll!, _wk[layer], _bpNorm!, N, allowDecodeMmq); if (kEqV) _gpu.CopyDevice(vAll!, kAll!); // V = raw K projection (pre-norm, pre-RoPE) else - GpuMatMulBatched(vAll!, _wv[layer]!, _bpNorm!, N); + BatchDecodeMatMul(vAll!, _wv[layer]!, _bpNorm!, N, allowDecodeMmq); } // Batched QK-norm (before RoPE, #157) + V-norm — both position-independent across N. @@ -4488,7 +4496,7 @@ private void GpuLayerBatchedDecodeGemma4(int layer, int N, int[] positions, Cuda } // O-proj + sandwich post-attn norm + residual. - GpuMatMulBatched(_bpHidden!, _wo[layer], attnAll, N); + BatchDecodeMatMul(_bpHidden!, _wo[layer], attnAll, N, allowDecodeMmq); if (_wPostAttnNorm is not null) _gpu.RmsNormBatched(_bpHidden!, _bpHidden!, _wPostAttnNorm[layer], N, _embDim, _hp.RmsNormEps); _gpu.AddInPlace(_bpHidden!, _bpResidual!); @@ -4496,13 +4504,13 @@ private void GpuLayerBatchedDecodeGemma4(int layer, int N, int[] positions, Cuda // FFN (GEGLU for Gemma 4) + sandwich post-ffn norm + residual. _gpu.CopyDevice(_bpResidual!, _bpHidden!); _gpu.RmsNormBatched(_bpNorm!, _bpHidden!, _wFfnNorm[layer], N, _embDim, _hp.RmsNormEps); - GpuMatMulBatched(_bpFfnGate!, _wGate[layer], _bpNorm!, N); - GpuMatMulBatched(_bpFfnUp!, _wUp[layer], _bpNorm!, N); + BatchDecodeMatMul(_bpFfnGate!, _wGate[layer], _bpNorm!, N, allowDecodeMmq); + BatchDecodeMatMul(_bpFfnUp!, _wUp[layer], _bpNorm!, N, allowDecodeMmq); if (_hp.FfnActivation == FfnActivation.GeluApprox) _gpu.GeluTanhMul(_bpFfnGate!, _bpFfnUp!); else _gpu.SiLuMul(_bpFfnGate!, _bpFfnUp!); - GpuMatMulBatched(_bpHidden!, _wDown[layer], _bpFfnGate!, N); + BatchDecodeMatMul(_bpHidden!, _wDown[layer], _bpFfnGate!, N, allowDecodeMmq); if (_wPostFfwNorm is not null) _gpu.RmsNormBatched(_bpHidden!, _bpHidden!, _wPostFfwNorm[layer], N, _embDim, _hp.RmsNormEps); _gpu.AddInPlace(_bpHidden!, _bpResidual!); diff --git a/tests/SharpInference.Tests.ForwardPass/Gemma4CudaBatchedDecodeBench.cs b/tests/SharpInference.Tests.ForwardPass/Gemma4CudaBatchedDecodeBench.cs new file mode 100644 index 0000000..f3d2a30 --- /dev/null +++ b/tests/SharpInference.Tests.ForwardPass/Gemma4CudaBatchedDecodeBench.cs @@ -0,0 +1,139 @@ +using System.Diagnostics; +using SharpInference.Core; +using SharpInference.Cuda; +using SharpInference.Engine; +using Xunit.Abstractions; + +namespace SharpInference.Tests.ForwardPass; + +/// +/// Issue #275: aggregate decode-throughput measurement for the Gemma 4 batched decode. #195 shipped +/// Gemma 4 continuous batching with every trunk matmul on the cuBLAS GEMM (GpuMatMulBatched); +/// #275 routes the trunk matmuls (Q/K/V/O, gate/up/down, lm-head) through +/// 's decode router (BatchDecodeMatMul — the +/// #194 weight-stationary matvec / #201 decode-MMQ), like the dense path, while the PLE matmuls stay +/// on GEMM. cuBLAS GEMM is compute-bound and known to lose to WS/GEMM-N for small-N decode (#190), +/// so this is the A/B that proves the win. +/// +/// Run on Gemma 4 E4B Q8_0 (the only GEMM-N-batchable Gemma 4 that fits 12 GB). The routing +/// is selected at construction by the ambient env, so A/B by running twice: +/// # new WS routing (default): +/// $env:SHARPI_BENCH_BATCH=1; dotnet test ... --filter "FullyQualifiedName~Gemma4CudaBatchedDecodeBench" +/// # old all-GEMM routing (#195 baseline): +/// $env:SHARPI_BENCH_BATCH=1; $env:SHARPI_BATCH_DECODE_GEMM=1; dotnet test ... (same filter) +/// Surfaces t/s, asserts no threshold (mirrors ). Silent-skips +/// without CUDA / the GGUF. +/// +public sealed class Gemma4CudaBatchedDecodeBench +{ + private const string ModelFile = "gemma-4-E4B-it-Q8_0.gguf"; + private readonly ITestOutputHelper _out; + public Gemma4CudaBatchedDecodeBench(ITestOutputHelper outHelper) { _out = outHelper; } + + private void Log(string line) { _out.WriteLine(line); Console.Error.WriteLine(line); } + + private static bool BenchEnabled => Environment.GetEnvironmentVariable("SHARPI_BENCH_BATCH") == "1"; + + private static int[] BatchSizesToRun => + (Environment.GetEnvironmentVariable("SHARPI_BENCH_BATCH_N") ?? "1,2,4,8") + .Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries) + .Select(int.Parse).ToArray(); + private static int DecodeSteps => + int.TryParse(Environment.GetEnvironmentVariable("SHARPI_BENCH_STEPS"), out int s) && s > 0 ? s : 128; + + private static readonly int[] Prompt = + { 2, 651, 6037, 576, 6081, 603, 1234, 4567, 8901, 222, 333, 444, 555, 666, 777, 888 }; + + private static CudaBackend? TryCreate() + { + if (!CudaBackend.IsAvailable()) return null; + try { return CudaBackend.Create(); } + catch { return null; } + } + + private static string? FindModelPath() + { + string[] absolute = { $@"E:\models\{ModelFile}", $@"C:\p\sharpi\models\{ModelFile}" }; + foreach (var p in absolute) + if (File.Exists(p)) return p; + var dir = Directory.GetCurrentDirectory(); + for (int i = 0; i < 8; i++) + { + var p = Path.Combine(dir, "models", ModelFile); + if (File.Exists(p)) return p; + var parent = Directory.GetParent(dir); + if (parent is null) break; + dir = parent.FullName; + } + return null; + } + + private static int Argmax(ReadOnlySpan logits) + { + int best = 0; float bestVal = logits[0]; + for (int i = 1; i < logits.Length; i++) + if (logits[i] > bestVal) { bestVal = logits[i]; best = i; } + return best; + } + + [Fact] + public void Gemma4_E4B_Decode_Throughput_Batched() + { + if (!BenchEnabled) return; + using var gpu = TryCreate(); + if (gpu is null) return; + var path = FindModelPath(); + if (path is null) return; + + using var model = GgufModel.Open(path); + var hp = ModelHyperparams.FromGgufMetadata(model.Metadata, model); + + // Pin SnapKV off (the constructor already forces it off for Gemma 4, but be explicit). + var prevSnap = Environment.GetEnvironmentVariable("SHARPI_SNAPKV_BUDGET"); + Environment.SetEnvironmentVariable("SHARPI_SNAPKV_BUDGET", "0"); + CudaForwardPass fwdTmp; + try { fwdTmp = new CudaForwardPass(model, gpu, hp, maxContextLength: 2048); } + finally { Environment.SetEnvironmentVariable("SHARPI_SNAPKV_BUDGET", prevSnap); } + using var fwd = fwdTmp; + + bool gemm = Environment.GetEnvironmentVariable("SHARPI_BATCH_DECODE_GEMM") == "1"; + Log($"[bench-275] routing = {(gemm ? "all-GEMM (#195 baseline)" : "WS/decode-MMQ (#275)")}"); + + int decodeSteps = DecodeSteps; + const int warmup = 8; + foreach (int n in BatchSizesToRun) + { + var caches = new CudaSequenceKvCache[n]; + try + { + var toks = new int[n]; + var poss = new int[n]; + for (int s = 0; s < n; s++) + { + caches[s] = fwd.CreateCache(); + toks[s] = Argmax(fwd.PrefillWithCache(Prompt, caches[s])); + poss[s] = Prompt.Length; + } + for (int i = 0; i < warmup; i++) + { + var lg = fwd.BatchForwardMulti(toks, poss, caches); + for (int s = 0; s < n; s++) { toks[s] = Argmax(lg[s]); poss[s]++; } + } + var sw = Stopwatch.StartNew(); + for (int i = 0; i < decodeSteps; i++) + { + var lg = fwd.BatchForwardMulti(toks, poss, caches); + for (int s = 0; s < n; s++) { toks[s] = Argmax(lg[s]); poss[s]++; } + } + sw.Stop(); + double aggTps = (double)n * decodeSteps / sw.Elapsed.TotalSeconds; + double perSeq = (double)decodeSteps / sw.Elapsed.TotalSeconds; + Log($"[bench-275] N={n}: aggregate {aggTps:F1} t/s ({perSeq:F1} t/s/seq)"); + } + finally + { + foreach (var c in caches) c?.Dispose(); + } + } + } +}