Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 49 additions & 4 deletions src/SharpInference.Engine/CudaForwardPass.cs
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,14 @@ public sealed unsafe class CudaForwardPass : IForwardPass
/// </summary>
public bool BatchedPrefillEnabled { get; set; }
/// <summary>
/// Issue #162: window size (tokens) for chunked batched prefill of prompts longer
/// than the non-flash 4096 cap. Each window is batched at its own startPos with flash
/// attention streaming the prior KV, so the N-sized trunk scratch stays bounded to
/// this many tokens regardless of prompt length. 4096 matches the well-tested
/// single-shot batch size.
/// </summary>
private const int PrefillBatchChunk = 4096;
/// <summary>
/// Issue #141: route Q8_0 trunk matmuls in the batched prefill through the
/// compute-bound cuBLAS GEMM (<see cref="CudaBackend.MatMulBatchedGemm"/>)
/// instead of the memory-bound matvec GEMM-N. Default on
Expand Down Expand Up @@ -1830,11 +1838,48 @@ public ReadOnlySpan<float> Prefill(IReadOnlyList<int> tokens, int startPos = 0)
// attention/FFN/PLE launches (whose count grows with N) into batched GEMM-N +
// batched-attention launches. Originally Gemma-4-only; #156 opened it to any
// dense model the batched kernels cover (e.g. Qwen3-8B Q4_K). Everything else
// (MoE, SnapKV-active, TQ, >4096 context, non-NEOX RoPE, L2 QK-norm, attn bias,
// unbatchable weight dtype) falls back to the per-token loop below.
// (MoE, SnapKV-active, TQ, non-NEOX RoPE, L2 QK-norm, attn bias, unbatchable
// weight dtype) falls back to the per-token loop below.
//
// Issue #162: the 4096 fast-path cap is a limit of the *non-flash* shared-scores
// AttentionBatched kernel (it throws above startPos+nTok=4096). The flash prefill
// kernels stream KV, so when flash is enabled we run the batched path for prompts
// of any length, chunking into PrefillBatchChunk-token windows so the N-sized trunk
// scratch stays bounded. Each chunk is batched at its own startPos; flash attends
// to all prior KV. Without this, a >4096-token prompt drops to the per-token loop
// (memory-bound, ~8× slower: 432 → 50 t/s on Qwen3-8B Q4_K @ 4070 Ti).
if (BatchedPrefillEnabled && !snapKvActive && N >= 2
&& startPos + N <= 4096 && IsBatchedPrefillSupported())
return PrefillBatchedTrunk(tokens, startPos);
&& startPos + N <= _maxSeqLen && IsBatchedPrefillSupported())
{
// Chunking past 4096 requires a streaming attention path (flash) AND simple
// (non-windowed) KV position semantics. SWA layers wrap KV in a window-sized
// ring whose cross-chunk behaviour past the window boundary is not yet
// validated, so any windowed model keeps the proven 4096 cap (follow-up: #162).
// Note: `IsSwaLayer` is only populated for Gemma 4 today, so the explicit
// `SlidingWindowSize <= 0` check fails closed if SWA parsing is later extended
// to a uniform-window arch that sets the window without a per-layer pattern.
bool canChunkPast4096 = PrefillFlashAttnEnabled
&& _hp.IsSwaLayer is null && _hp.SlidingWindowSize <= 0;
int cap = canChunkPast4096 ? _maxSeqLen : 4096;
if (startPos + N <= cap)
{
if (N <= PrefillBatchChunk || !canChunkPast4096)
return PrefillBatchedTrunk(tokens, startPos);

// Chunked: flash streams KV across windows, so process PrefillBatchChunk
// tokens at a time. Only the last chunk's logits are returned (decode
// starts from the final token); the per-chunk final norm + output proj on
// earlier chunks is discarded, a negligible cost vs the batched-trunk win.
int[] all = tokens as int[] ?? System.Linq.Enumerable.ToArray(tokens);
ReadOnlySpan<float> chunkLogits = default;
for (int off = 0; off < N; off += PrefillBatchChunk)
{
int len = Math.Min(PrefillBatchChunk, N - off);
chunkLogits = PrefillBatchedTrunk(new ArraySegment<int>(all, off, len), startPos + off);
Comment on lines +1873 to +1878

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Performance Bottleneck: Unintentional Heap Allocations During Chunked Prefill

When chunking is active, Prefill passes an ArraySegment<int> to PrefillBatchedTrunk:

chunkLogits = PrefillBatchedTrunk(new ArraySegment<int>(all, off, len), startPos + off);

Inside PrefillBatchedTrunk (around line 2071), the code attempts to cast tokens to int[]:

int[] ids = tokens as int[] ?? System.Linq.Enumerable.ToArray(tokens);

Since ArraySegment<int> is not int[], this cast evaluates to null, forcing a fallback to System.Linq.Enumerable.ToArray(tokens). This allocates a new int[] array of size up to 4096 on the heap for every single chunk during prefill, causing significant GC pressure and defeating the zero-allocation path for int[] inputs.

Recommended Solution

To eliminate these allocations, update PrefillBatchedTrunk to retrieve a ReadOnlySpan<int> from IReadOnlyList<int> using pattern matching, and then use MemoryMarshal.AsBytes on that span:

// Inside PrefillBatchedTrunk:
if (_embIsQuantized && embDType == DType.Q8_0)
{
    ReadOnlySpan<int> ids = tokens switch
    {
        int[] arr => arr,
        ArraySegment<int> seg => seg,
        List<int> list => CollectionsMarshal.AsSpan(list),
        _ => System.Linq.Enumerable.ToArray(tokens)
    };
    var idTensor = _gpu.UploadRaw(
        System.Runtime.InteropServices.MemoryMarshal.AsBytes(ids),
        TensorShape.D1(N), DType.Float32);
    _gpu.EmbedLookupQ8_0Batched(_gpuEmbedding, _bpHidden!, idTensor, N, embDim);
    _gpu.Free(idTensor);
}

This completely avoids heap allocations for int[], ArraySegment<int>, and List<int> inputs.

}
return chunkLogits;
}
}

int W = 0, wStart = 0;
if (snapKvActive)
Expand Down
128 changes: 128 additions & 0 deletions tests/SharpInference.Tests.ForwardPass/CudaQ4KPrefillMatmulProbe.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
using System.Diagnostics;
using SharpInference.Core;
using SharpInference.Cuda;

namespace SharpInference.Tests.ForwardPass;

/// <summary>
/// Issue #159/#162 prefill probe (not a correctness test): times the two compute-bound
/// Q4_K trunk-matmul paths — int8 MMQ (<see cref="CudaBackend.MatMulBatchedMmq"/>) and
/// dequant→fp16→cuBLAS GEMM (<see cref="CudaBackend.MatMulBatchedGemm"/>) — at the REAL
/// Qwen3-8B prefill shapes (nTok=1844), and reports achieved TFLOP/s vs the RTX 4070 Ti's
/// ~165 TFLOP/s fp16-TC / ~330 TOP/s int8-TC peak.
///
/// Motivation: the full batched prefill achieves only ~6 TFLOP/s-equivalent (25.6 TFLOP /
/// ~4.0 s at N=1844), yet the older <see cref="CudaMmqRooflineProbe"/> reported 36–54 int8
/// TOPS for the isolated MMQ kernel at nTok=1024. This probe times the kernels at the exact
/// production shapes so we can tell whether the matmul kernels themselves are the bottleneck
/// (→ #149/#152 tiling rewrite) or whether per-call overhead / redundant work outside the
/// kernel dominates (→ a cheaper engine-side fix).
///
/// Run explicitly: --filter FullyQualifiedName~CudaQ4KPrefillMatmulProbe. Silent no-op
/// without CUDA. Always asserts true — it only prints the measurement.
/// </summary>
public sealed unsafe class CudaQ4KPrefillMatmulProbe
{
private static CudaBackend? TryCreate()
{
if (!CudaBackend.IsAvailable()) return null;
try { return CudaBackend.Create(); }
catch { return null; }
}

private static ushort HalfToUshort(Half h) => BitConverter.ToUInt16(BitConverter.GetBytes(h), 0);

// block_q4_K = 144 B / 256 elems: d(fp16) dmin(fp16) scales[12] qs[128].
private static byte[] BuildQ4KMatrix(int rows, int cols, Random rng)
{
int blocksPerRow = cols / 256, bytesPerRow = blocksPerRow * 144;
var bytes = new byte[(long)rows * bytesPerRow];
for (int r = 0; r < rows; r++)
for (int b = 0; b < blocksPerRow; b++)
{
long off = (long)r * bytesPerRow + (long)b * 144;
ushort d = HalfToUshort((Half)(float)(rng.NextDouble() * 0.04 + 0.01));
ushort dmin = HalfToUshort((Half)(float)(rng.NextDouble() * 0.02 + 0.005));
bytes[off] = (byte)(d & 0xFF); bytes[off + 1] = (byte)(d >> 8);
bytes[off + 2] = (byte)(dmin & 0xFF); bytes[off + 3] = (byte)(dmin >> 8);
for (int i = 0; i < 12; i++) bytes[off + 4 + i] = (byte)rng.Next(256);
for (int i = 0; i < 128; i++) bytes[off + 16 + i] = (byte)rng.Next(256);
}
return bytes;
}

[Fact]
public void Q4K_PrefillMatmul_AchievedThroughput_AtRealShapes()
{
using var gpu = TryCreate();
if (gpu is null) return;

// Qwen3-8B trunk matmuls (rows=out, cols=in) at the real prefill batch N=1844.
// hidden=4096, intermediate=12288, qkv=6144 (32+8+8 heads × 128).
const int nTok = 1844;
(int rows, int cols, string what)[] shapes =
{
(6144, 4096, "qkv"),
(4096, 4096, "o-proj"),
(12288, 4096, "ffn-gate"),
(12288, 4096, "ffn-up"),
(4096, 12288, "ffn-down"),
};
const double fp16Peak = 165.0; // ~RTX 4070 Ti fp16-TC TFLOP/s
const double int8Peak = 330.0; // ~RTX 4070 Ti int8-TC TOP/s

var rng = new Random(20260607);
double totGflopPerLayer = 0, totMmqMs = 0, totGemmMs = 0;

foreach (var (rows, cols, what) in shapes)
{
byte[] weightBytes = BuildQ4KMatrix(rows, cols, rng);
var acts = new float[(long)nTok * cols];
for (int i = 0; i < acts.Length; i++) acts[i] = (float)(rng.NextDouble() * 2 - 1);

var gpuW = gpu.UploadRaw(weightBytes, TensorShape.D1(weightBytes.Length), DType.Q4_K);
var gpuX = gpu.Upload(acts, TensorShape.D1(acts.Length));
var gpuY = gpu.Allocate(TensorShape.D1((long)nTok * rows));

double macs = (double)rows * cols * nTok;
double gflop = 2.0 * macs / 1e9;
totGflopPerLayer += gflop;

// --- MMQ (int8 tensor core, default prefill path) ---
for (int i = 0; i < 5; i++) gpu.MatMulBatchedMmq(gpuY, gpuW, gpuX, nTok, DType.Q4_K);
gpu.Synchronize();
const int iters = 30;
var sw = Stopwatch.StartNew();
for (int i = 0; i < iters; i++) gpu.MatMulBatchedMmq(gpuY, gpuW, gpuX, nTok, DType.Q4_K);
gpu.Synchronize();
sw.Stop();
double mmqMs = sw.Elapsed.TotalMilliseconds / iters;
totMmqMs += mmqMs;
double mmqTops = 2.0 * macs / (mmqMs * 1e-3) / 1e12;

// --- C1: dequant→fp16→cuBLAS GEMM ---
for (int i = 0; i < 5; i++) gpu.MatMulBatchedGemm(gpuY, gpuW, gpuX, nTok, DType.Q4_K);
gpu.Synchronize();
sw.Restart();
for (int i = 0; i < iters; i++) gpu.MatMulBatchedGemm(gpuY, gpuW, gpuX, nTok, DType.Q4_K);
gpu.Synchronize();
sw.Stop();
double gemmMs = sw.Elapsed.TotalMilliseconds / iters;
totGemmMs += gemmMs;
double gemmTflop = gflop / (gemmMs * 1e-3) / 1e3;

Console.WriteLine(
$"{what,-9} [{rows}×{cols}]·[{cols}×{nTok}] " +
$"MMQ {mmqMs,6:F2} ms ({mmqTops,5:F1} TOP/s, {100 * mmqTops / int8Peak,2:F0}% int8) | " +
$"GEMM {gemmMs,6:F2} ms ({gemmTflop,5:F1} TFLOP/s, {100 * gemmTflop / fp16Peak,2:F0}% fp16)");

gpu.Free(gpuW); gpu.Free(gpuX); gpu.Free(gpuY);
}

// Per-token full-trunk projection (×36 layers) to compare with the e2e prefill.
Console.WriteLine(
$"per-layer trunk: {totGflopPerLayer:F1} GFLOP | MMQ {totMmqMs:F2} ms GEMM {totGemmMs:F2} ms " +
$"(×36 layers @ N={nTok}: MMQ {totMmqMs * 36 / 1000:F2} s, GEMM {totGemmMs * 36 / 1000:F2} s)");
Assert.True(true);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,72 @@ public void Qwen3_8B_BatchedPrefill_DefaultMatchesSequential()
$"Default batched top-5 overlaps the per-token reference in only {overlap}/5 slots.");
}

/// <summary>
/// Issue #162: prompts longer than the non-flash 4096 cap must still take the fast
/// batched-trunk path (chunked into <c>PrefillBatchChunk</c> windows, flash streaming
/// the prior KV) and stay argmax-stable vs the bit-exact per-token loop — instead of
/// silently dropping to the ~8× slower memory-bound per-token prefill.
/// <para>
/// N=5040 exercises a full + partial window (4096 + 944); N=8192 exercises the
/// exact-multiple boundary (two full 4096 windows, the final chunk's len == chunk size)
/// to catch the classic off-by-one in the chunk loop.
/// </para>
/// </summary>
[Theory]
[InlineData(5040, 6144)]
[InlineData(8192, 9216)]
public void Qwen3_8B_ChunkedBatchedPrefill_Over4096_MatchesSequential(int promptLen, int ctx)
{
using var gpu = TryCreate();
if (gpu is null) return;
var path = FindModelPath();
if (path is null) return;

using var model = GgufModel.Open(path);
var hp = ModelHyperparams.FromGgufMetadata(model.Metadata, model);
Assert.Null(hp.LayerHeadDim);

// > 4096 tokens → forces the chunked branch. Deterministic spread across the vocab
// via a small LCG; all ids well within Qwen3's 151936 vocab.
var longTokens = new int[promptLen];
uint s = 0x9E3779B9u;
for (int i = 0; i < longTokens.Length; i++)
{
s = s * 1664525u + 1013904223u;
longTokens[i] = (int)(s % 150000u) + 1;
}

// Disable SnapKV (a >budget prompt would otherwise route to the per-token SnapKV
// eviction path before the batched gate). Construct under the override, then restore.
var prevSnap = Environment.GetEnvironmentVariable("SHARPI_SNAPKV_BUDGET");
Environment.SetEnvironmentVariable("SHARPI_SNAPKV_BUDGET", "0");
CudaForwardPass fwd;
try { fwd = new CudaForwardPass(model, gpu, hp, maxContextLength: ctx); }
finally { Environment.SetEnvironmentVariable("SHARPI_SNAPKV_BUDGET", prevSnap); }
using var _fwd = fwd;

// Shipped defaults (flash TC on) → chunked batched path.
fwd.BatchedPrefillEnabled = true;
var batched = fwd.Prefill(longTokens).ToArray();
Assert.True(fwd.LastPrefillWasBatched,
"Chunked batched prefill did not engage for a >4096-token prompt (#162).");

fwd.ResetCache();
fwd.BatchedPrefillEnabled = false;
var sequential = fwd.Prefill(longTokens).ToArray();
Assert.False(fwd.LastPrefillWasBatched);

Assert.Equal(sequential.Length, batched.Length);
Assert.Equal(Argmax(sequential), Argmax(batched));

var seqTop = TopKSet(sequential, 5);
var batTop = TopKSet(batched, 5);
int overlap = 0;
foreach (var t in batTop) if (seqTop.Contains(t)) overlap++;
Assert.True(overlap >= 4,
$"Chunked batched top-5 overlaps the per-token reference in only {overlap}/5 slots.");
}

/// <summary>
/// Flash off + GEMM off: Q4_K trunk runs the batched matvec GEMM-N, which is built to
/// be bit-identical to N per-token Q4_K dp4a matvecs. Verifies the batched
Expand Down