Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 24 additions & 16 deletions src/SharpInference.Engine/CudaForwardPass.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4090,7 +4090,7 @@ private void RunBatchedTrunk(int[] tokens, int[] positions, CudaSequenceKvCache[
throw new InvalidOperationException(
"Gemma-4 batched decode reached a SnapKV-evicted cache; the constructor forces " +
"the SnapKV budget off for Gemma 4, so this means that invariant was broken.");
RunBatchedTrunkGemma4(tokens, positions, caches);
RunBatchedTrunkGemma4(tokens, positions, caches, allowDecodeMmq);
return;
}

Expand Down Expand Up @@ -4340,12 +4340,16 @@ private void RunBatchedTrunk(int[] tokens, int[] positions, CudaSequenceKvCache[
/// PLE pre-pass → per layer <see cref="GpuLayerBatchedDecodeGemma4"/> → batched final norm +
/// output GEMM + final-logit softcap. Leaves [N×vocab] in <see cref="_decodeLogitsAll"/> on
/// the stream (the caller's tail downloads/argmaxes + advances each cache length), exactly
/// like <see cref="RunBatchedTrunk"/>. The matmuls use the prefill GEMM routing
/// (<see cref="GpuMatMulBatched"/> — proven for Gemma 4 shapes; weight reads amortize across
/// the batch via cuBLAS GEMM), not the dense decode WS/MMQ path (a Gemma-4 follow-up).
/// Argmax-stable with the single-user <see cref="ForwardGemma4"/> loop.
/// like <see cref="RunBatchedTrunk"/>. The trunk matmuls (Q/K/V/O, gate/up/down, lm-head) route
/// through <see cref="BatchDecodeMatMul"/> — the #194 weight-stationary matvec (small-N decode)
/// or the #201/#206 int8 decode-MMQ tile for big Q4_K shapes — exactly like the dense
/// <see cref="RunBatchedTrunk"/> (issue #275). The PLE pre-pass + injection matmuls stay on the
/// cuBLAS GEMM (<see cref="GpuMatMulBatched"/>): their pleWidth shapes aren't a decode-matvec
/// win and the GEMM path is proven argmax-stable for them. Argmax-stable with the single-user
/// <see cref="ForwardGemma4"/> loop (the same contract as the prior all-GEMM routing).
/// </summary>
private void RunBatchedTrunkGemma4(int[] tokens, int[] positions, CudaSequenceKvCache[] caches)
private void RunBatchedTrunkGemma4(int[] tokens, int[] positions, CudaSequenceKvCache[] caches,
bool allowDecodeMmq)
{
int N = tokens.Length;
int embDim = _embDim;
Expand All @@ -4368,13 +4372,13 @@ private void RunBatchedTrunkGemma4(int[] tokens, int[] positions, CudaSequenceKv

// 3. Transformer layers.
for (int layer = 0; layer < _hp.NumLayers; layer++)
GpuLayerBatchedDecodeGemma4(layer, N, positions, caches);
GpuLayerBatchedDecodeGemma4(layer, N, positions, caches, allowDecodeMmq);

// 4. Final norm + output projection + softcap, batched across N (softcap is monotonic, so
// argmax is invariant — but the returned logit values must be softcapped to match the
// single-user ForwardGemma4 finisher and the full-logits BatchForwardMulti path).
_gpu.RmsNormBatched(_bpHidden!, _bpHidden!, _wOutputNorm, N, embDim, _hp.RmsNormEps);
GpuMatMulBatched(_decodeLogitsAll!, _wOutput, _bpHidden!, N);
BatchDecodeMatMul(_decodeLogitsAll!, _wOutput, _bpHidden!, N, allowDecodeMmq);
if (_hp.FinalLogitSoftcap > 0f)
_gpu.SoftcapInPlace(_decodeLogitsAll!, _hp.FinalLogitSoftcap);
}
Expand All @@ -4388,8 +4392,12 @@ private void RunBatchedTrunkGemma4(int[] tokens, int[] positions, CudaSequenceKv
/// single-query attention against caches[n] (shared-KV reads the source layer via effLayer).
/// This is exactly the per-layer attention of the single-token <see cref="RunGemma4DeviceRegion"/>,
/// run once per sequence on a slice of the batched Q/K/V/attnOut scratch.
/// <para>The QKV / O-proj / FFN matmuls route through <see cref="BatchDecodeMatMul"/> (WS /
/// decode-MMQ / GEMM-N, issue #275); the PLE-injection matmuls stay on <see cref="GpuMatMulBatched"/>.
/// <paramref name="allowDecodeMmq"/> mirrors <see cref="RunBatchedTrunk"/>'s parameter.</para>
/// </summary>
private void GpuLayerBatchedDecodeGemma4(int layer, int N, int[] positions, CudaSequenceKvCache[] caches)
private void GpuLayerBatchedDecodeGemma4(int layer, int N, int[] positions, CudaSequenceKvCache[] caches,
bool allowDecodeMmq)
{
int layerHd = _hp.LayerHeadDim is { } lhd ? lhd[layer] : _headDim;
int layerKv = _hp.LayerKvHeads is { } lkv ? lkv[layer] : _numKvHeads;
Expand All @@ -4413,14 +4421,14 @@ private void GpuLayerBatchedDecodeGemma4(int layer, int N, int[] positions, Cuda
_gpu.CopyDevice(_bpResidual!, _bpHidden!);
_gpu.RmsNormBatched(_bpNorm!, _bpHidden!, _wAttnNorm[layer], N, _embDim, _hp.RmsNormEps);

GpuMatMulBatched(qAll, _wq[layer], _bpNorm!, N);
BatchDecodeMatMul(qAll, _wq[layer], _bpNorm!, N, allowDecodeMmq);
if (!kvShared)
{
GpuMatMulBatched(kAll!, _wk[layer], _bpNorm!, N);
BatchDecodeMatMul(kAll!, _wk[layer], _bpNorm!, N, allowDecodeMmq);
if (kEqV)
_gpu.CopyDevice(vAll!, kAll!); // V = raw K projection (pre-norm, pre-RoPE)
else
GpuMatMulBatched(vAll!, _wv[layer]!, _bpNorm!, N);
BatchDecodeMatMul(vAll!, _wv[layer]!, _bpNorm!, N, allowDecodeMmq);
}

// Batched QK-norm (before RoPE, #157) + V-norm — both position-independent across N.
Expand Down Expand Up @@ -4488,21 +4496,21 @@ private void GpuLayerBatchedDecodeGemma4(int layer, int N, int[] positions, Cuda
}

// O-proj + sandwich post-attn norm + residual.
GpuMatMulBatched(_bpHidden!, _wo[layer], attnAll, N);
BatchDecodeMatMul(_bpHidden!, _wo[layer], attnAll, N, allowDecodeMmq);
if (_wPostAttnNorm is not null)
_gpu.RmsNormBatched(_bpHidden!, _bpHidden!, _wPostAttnNorm[layer], N, _embDim, _hp.RmsNormEps);
_gpu.AddInPlace(_bpHidden!, _bpResidual!);

// FFN (GEGLU for Gemma 4) + sandwich post-ffn norm + residual.
_gpu.CopyDevice(_bpResidual!, _bpHidden!);
_gpu.RmsNormBatched(_bpNorm!, _bpHidden!, _wFfnNorm[layer], N, _embDim, _hp.RmsNormEps);
GpuMatMulBatched(_bpFfnGate!, _wGate[layer], _bpNorm!, N);
GpuMatMulBatched(_bpFfnUp!, _wUp[layer], _bpNorm!, N);
BatchDecodeMatMul(_bpFfnGate!, _wGate[layer], _bpNorm!, N, allowDecodeMmq);
BatchDecodeMatMul(_bpFfnUp!, _wUp[layer], _bpNorm!, N, allowDecodeMmq);
if (_hp.FfnActivation == FfnActivation.GeluApprox)
_gpu.GeluTanhMul(_bpFfnGate!, _bpFfnUp!);
else
_gpu.SiLuMul(_bpFfnGate!, _bpFfnUp!);
GpuMatMulBatched(_bpHidden!, _wDown[layer], _bpFfnGate!, N);
BatchDecodeMatMul(_bpHidden!, _wDown[layer], _bpFfnGate!, N, allowDecodeMmq);
if (_wPostFfwNorm is not null)
_gpu.RmsNormBatched(_bpHidden!, _bpHidden!, _wPostFfwNorm[layer], N, _embDim, _hp.RmsNormEps);
_gpu.AddInPlace(_bpHidden!, _bpResidual!);
Expand Down
139 changes: 139 additions & 0 deletions tests/SharpInference.Tests.ForwardPass/Gemma4CudaBatchedDecodeBench.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
using System.Diagnostics;
using SharpInference.Core;
using SharpInference.Cuda;
using SharpInference.Engine;
using Xunit.Abstractions;

namespace SharpInference.Tests.ForwardPass;

/// <summary>
/// Issue #275: aggregate decode-throughput measurement for the Gemma 4 batched decode. #195 shipped
/// Gemma 4 continuous batching with every trunk matmul on the cuBLAS GEMM (<c>GpuMatMulBatched</c>);
/// #275 routes the trunk matmuls (Q/K/V/O, gate/up/down, lm-head) through
/// <see cref="CudaForwardPass.BatchForwardMulti"/>'s decode router (<c>BatchDecodeMatMul</c> — the
/// #194 weight-stationary matvec / #201 decode-MMQ), like the dense path, while the PLE matmuls stay
/// on GEMM. cuBLAS GEMM is compute-bound and known to lose to WS/GEMM-N for small-N decode (#190),
/// so this is the A/B that proves the win.
///
/// Run on <b>Gemma 4 E4B Q8_0</b> (the only GEMM-N-batchable Gemma 4 that fits 12 GB). The routing
/// is selected at construction by the ambient env, so A/B by running twice:
/// # new WS routing (default):
/// $env:SHARPI_BENCH_BATCH=1; dotnet test ... --filter "FullyQualifiedName~Gemma4CudaBatchedDecodeBench"
/// # old all-GEMM routing (#195 baseline):
/// $env:SHARPI_BENCH_BATCH=1; $env:SHARPI_BATCH_DECODE_GEMM=1; dotnet test ... (same filter)
/// Surfaces t/s, asserts no threshold (mirrors <see cref="CudaBatchedDecodeBench"/>). Silent-skips
/// without CUDA / the GGUF.
/// </summary>
public sealed class Gemma4CudaBatchedDecodeBench
{
private const string ModelFile = "gemma-4-E4B-it-Q8_0.gguf";
private readonly ITestOutputHelper _out;
public Gemma4CudaBatchedDecodeBench(ITestOutputHelper outHelper) { _out = outHelper; }

private void Log(string line) { _out.WriteLine(line); Console.Error.WriteLine(line); }

private static bool BenchEnabled => Environment.GetEnvironmentVariable("SHARPI_BENCH_BATCH") == "1";

private static int[] BatchSizesToRun =>
(Environment.GetEnvironmentVariable("SHARPI_BENCH_BATCH_N") ?? "1,2,4,8")
.Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries)
.Select(int.Parse).ToArray();
private static int DecodeSteps =>
int.TryParse(Environment.GetEnvironmentVariable("SHARPI_BENCH_STEPS"), out int s) && s > 0 ? s : 128;

private static readonly int[] Prompt =
{ 2, 651, 6037, 576, 6081, 603, 1234, 4567, 8901, 222, 333, 444, 555, 666, 777, 888 };

private static CudaBackend? TryCreate()
{
if (!CudaBackend.IsAvailable()) return null;
try { return CudaBackend.Create(); }
catch { return null; }
}

private static string? FindModelPath()
{
string[] absolute = { $@"E:\models\{ModelFile}", $@"C:\p\sharpi\models\{ModelFile}" };
foreach (var p in absolute)
if (File.Exists(p)) return p;
var dir = Directory.GetCurrentDirectory();
for (int i = 0; i < 8; i++)
{
var p = Path.Combine(dir, "models", ModelFile);
if (File.Exists(p)) return p;
var parent = Directory.GetParent(dir);
if (parent is null) break;
dir = parent.FullName;
}
return null;
}

private static int Argmax(ReadOnlySpan<float> logits)
{
int best = 0; float bestVal = logits[0];
for (int i = 1; i < logits.Length; i++)
if (logits[i] > bestVal) { bestVal = logits[i]; best = i; }
return best;
}

[Fact]
public void Gemma4_E4B_Decode_Throughput_Batched()
{
if (!BenchEnabled) return;
using var gpu = TryCreate();
if (gpu is null) return;
var path = FindModelPath();
if (path is null) return;

using var model = GgufModel.Open(path);
var hp = ModelHyperparams.FromGgufMetadata(model.Metadata, model);

// Pin SnapKV off (the constructor already forces it off for Gemma 4, but be explicit).
var prevSnap = Environment.GetEnvironmentVariable("SHARPI_SNAPKV_BUDGET");
Environment.SetEnvironmentVariable("SHARPI_SNAPKV_BUDGET", "0");
CudaForwardPass fwdTmp;
try { fwdTmp = new CudaForwardPass(model, gpu, hp, maxContextLength: 2048); }
finally { Environment.SetEnvironmentVariable("SHARPI_SNAPKV_BUDGET", prevSnap); }
using var fwd = fwdTmp;

bool gemm = Environment.GetEnvironmentVariable("SHARPI_BATCH_DECODE_GEMM") == "1";
Log($"[bench-275] routing = {(gemm ? "all-GEMM (#195 baseline)" : "WS/decode-MMQ (#275)")}");

int decodeSteps = DecodeSteps;
const int warmup = 8;
foreach (int n in BatchSizesToRun)
{
var caches = new CudaSequenceKvCache[n];
try
{
var toks = new int[n];
var poss = new int[n];
for (int s = 0; s < n; s++)
{
caches[s] = fwd.CreateCache();
toks[s] = Argmax(fwd.PrefillWithCache(Prompt, caches[s]));
poss[s] = Prompt.Length;
}
for (int i = 0; i < warmup; i++)
{
var lg = fwd.BatchForwardMulti(toks, poss, caches);
for (int s = 0; s < n; s++) { toks[s] = Argmax(lg[s]); poss[s]++; }
}
var sw = Stopwatch.StartNew();
for (int i = 0; i < decodeSteps; i++)
{
var lg = fwd.BatchForwardMulti(toks, poss, caches);
for (int s = 0; s < n; s++) { toks[s] = Argmax(lg[s]); poss[s]++; }
}
sw.Stop();
double aggTps = (double)n * decodeSteps / sw.Elapsed.TotalSeconds;
double perSeq = (double)decodeSteps / sw.Elapsed.TotalSeconds;
Log($"[bench-275] N={n}: aggregate {aggTps:F1} t/s ({perSeq:F1} t/s/seq)");
}
finally
{
foreach (var c in caches) c?.Dispose();
}
}
}
}