diff --git a/src/SharpInference.Engine/CudaForwardPass.cs b/src/SharpInference.Engine/CudaForwardPass.cs index 0bc9d94..977e97d 100644 --- a/src/SharpInference.Engine/CudaForwardPass.cs +++ b/src/SharpInference.Engine/CudaForwardPass.cs @@ -176,6 +176,14 @@ public sealed unsafe class CudaForwardPass : IForwardPass /// public bool BatchedPrefillEnabled { get; set; } /// + /// Issue #162: window size (tokens) for chunked batched prefill of prompts longer + /// than the non-flash 4096 cap. Each window is batched at its own startPos with flash + /// attention streaming the prior KV, so the N-sized trunk scratch stays bounded to + /// this many tokens regardless of prompt length. 4096 matches the well-tested + /// single-shot batch size. + /// + private const int PrefillBatchChunk = 4096; + /// /// Issue #141: route Q8_0 trunk matmuls in the batched prefill through the /// compute-bound cuBLAS GEMM () /// instead of the memory-bound matvec GEMM-N. Default on @@ -1830,11 +1838,48 @@ public ReadOnlySpan Prefill(IReadOnlyList tokens, int startPos = 0) // attention/FFN/PLE launches (whose count grows with N) into batched GEMM-N + // batched-attention launches. Originally Gemma-4-only; #156 opened it to any // dense model the batched kernels cover (e.g. Qwen3-8B Q4_K). Everything else - // (MoE, SnapKV-active, TQ, >4096 context, non-NEOX RoPE, L2 QK-norm, attn bias, - // unbatchable weight dtype) falls back to the per-token loop below. + // (MoE, SnapKV-active, TQ, non-NEOX RoPE, L2 QK-norm, attn bias, unbatchable + // weight dtype) falls back to the per-token loop below. + // + // Issue #162: the 4096 fast-path cap is a limit of the *non-flash* shared-scores + // AttentionBatched kernel (it throws above startPos+nTok=4096). The flash prefill + // kernels stream KV, so when flash is enabled we run the batched path for prompts + // of any length, chunking into PrefillBatchChunk-token windows so the N-sized trunk + // scratch stays bounded. Each chunk is batched at its own startPos; flash attends + // to all prior KV. Without this, a >4096-token prompt drops to the per-token loop + // (memory-bound, ~8× slower: 432 → 50 t/s on Qwen3-8B Q4_K @ 4070 Ti). if (BatchedPrefillEnabled && !snapKvActive && N >= 2 - && startPos + N <= 4096 && IsBatchedPrefillSupported()) - return PrefillBatchedTrunk(tokens, startPos); + && startPos + N <= _maxSeqLen && IsBatchedPrefillSupported()) + { + // Chunking past 4096 requires a streaming attention path (flash) AND simple + // (non-windowed) KV position semantics. SWA layers wrap KV in a window-sized + // ring whose cross-chunk behaviour past the window boundary is not yet + // validated, so any windowed model keeps the proven 4096 cap (follow-up: #162). + // Note: `IsSwaLayer` is only populated for Gemma 4 today, so the explicit + // `SlidingWindowSize <= 0` check fails closed if SWA parsing is later extended + // to a uniform-window arch that sets the window without a per-layer pattern. + bool canChunkPast4096 = PrefillFlashAttnEnabled + && _hp.IsSwaLayer is null && _hp.SlidingWindowSize <= 0; + int cap = canChunkPast4096 ? _maxSeqLen : 4096; + if (startPos + N <= cap) + { + if (N <= PrefillBatchChunk || !canChunkPast4096) + return PrefillBatchedTrunk(tokens, startPos); + + // Chunked: flash streams KV across windows, so process PrefillBatchChunk + // tokens at a time. Only the last chunk's logits are returned (decode + // starts from the final token); the per-chunk final norm + output proj on + // earlier chunks is discarded, a negligible cost vs the batched-trunk win. + int[] all = tokens as int[] ?? System.Linq.Enumerable.ToArray(tokens); + ReadOnlySpan chunkLogits = default; + for (int off = 0; off < N; off += PrefillBatchChunk) + { + int len = Math.Min(PrefillBatchChunk, N - off); + chunkLogits = PrefillBatchedTrunk(new ArraySegment(all, off, len), startPos + off); + } + return chunkLogits; + } + } int W = 0, wStart = 0; if (snapKvActive) diff --git a/tests/SharpInference.Tests.ForwardPass/CudaQ4KPrefillMatmulProbe.cs b/tests/SharpInference.Tests.ForwardPass/CudaQ4KPrefillMatmulProbe.cs new file mode 100644 index 0000000..d1eeac7 --- /dev/null +++ b/tests/SharpInference.Tests.ForwardPass/CudaQ4KPrefillMatmulProbe.cs @@ -0,0 +1,128 @@ +using System.Diagnostics; +using SharpInference.Core; +using SharpInference.Cuda; + +namespace SharpInference.Tests.ForwardPass; + +/// +/// Issue #159/#162 prefill probe (not a correctness test): times the two compute-bound +/// Q4_K trunk-matmul paths — int8 MMQ () and +/// dequant→fp16→cuBLAS GEMM () — at the REAL +/// Qwen3-8B prefill shapes (nTok=1844), and reports achieved TFLOP/s vs the RTX 4070 Ti's +/// ~165 TFLOP/s fp16-TC / ~330 TOP/s int8-TC peak. +/// +/// Motivation: the full batched prefill achieves only ~6 TFLOP/s-equivalent (25.6 TFLOP / +/// ~4.0 s at N=1844), yet the older reported 36–54 int8 +/// TOPS for the isolated MMQ kernel at nTok=1024. This probe times the kernels at the exact +/// production shapes so we can tell whether the matmul kernels themselves are the bottleneck +/// (→ #149/#152 tiling rewrite) or whether per-call overhead / redundant work outside the +/// kernel dominates (→ a cheaper engine-side fix). +/// +/// Run explicitly: --filter FullyQualifiedName~CudaQ4KPrefillMatmulProbe. Silent no-op +/// without CUDA. Always asserts true — it only prints the measurement. +/// +public sealed unsafe class CudaQ4KPrefillMatmulProbe +{ + private static CudaBackend? TryCreate() + { + if (!CudaBackend.IsAvailable()) return null; + try { return CudaBackend.Create(); } + catch { return null; } + } + + private static ushort HalfToUshort(Half h) => BitConverter.ToUInt16(BitConverter.GetBytes(h), 0); + + // block_q4_K = 144 B / 256 elems: d(fp16) dmin(fp16) scales[12] qs[128]. + private static byte[] BuildQ4KMatrix(int rows, int cols, Random rng) + { + int blocksPerRow = cols / 256, bytesPerRow = blocksPerRow * 144; + var bytes = new byte[(long)rows * bytesPerRow]; + for (int r = 0; r < rows; r++) + for (int b = 0; b < blocksPerRow; b++) + { + long off = (long)r * bytesPerRow + (long)b * 144; + ushort d = HalfToUshort((Half)(float)(rng.NextDouble() * 0.04 + 0.01)); + ushort dmin = HalfToUshort((Half)(float)(rng.NextDouble() * 0.02 + 0.005)); + bytes[off] = (byte)(d & 0xFF); bytes[off + 1] = (byte)(d >> 8); + bytes[off + 2] = (byte)(dmin & 0xFF); bytes[off + 3] = (byte)(dmin >> 8); + for (int i = 0; i < 12; i++) bytes[off + 4 + i] = (byte)rng.Next(256); + for (int i = 0; i < 128; i++) bytes[off + 16 + i] = (byte)rng.Next(256); + } + return bytes; + } + + [Fact] + public void Q4K_PrefillMatmul_AchievedThroughput_AtRealShapes() + { + using var gpu = TryCreate(); + if (gpu is null) return; + + // Qwen3-8B trunk matmuls (rows=out, cols=in) at the real prefill batch N=1844. + // hidden=4096, intermediate=12288, qkv=6144 (32+8+8 heads × 128). + const int nTok = 1844; + (int rows, int cols, string what)[] shapes = + { + (6144, 4096, "qkv"), + (4096, 4096, "o-proj"), + (12288, 4096, "ffn-gate"), + (12288, 4096, "ffn-up"), + (4096, 12288, "ffn-down"), + }; + const double fp16Peak = 165.0; // ~RTX 4070 Ti fp16-TC TFLOP/s + const double int8Peak = 330.0; // ~RTX 4070 Ti int8-TC TOP/s + + var rng = new Random(20260607); + double totGflopPerLayer = 0, totMmqMs = 0, totGemmMs = 0; + + foreach (var (rows, cols, what) in shapes) + { + byte[] weightBytes = BuildQ4KMatrix(rows, cols, rng); + var acts = new float[(long)nTok * cols]; + for (int i = 0; i < acts.Length; i++) acts[i] = (float)(rng.NextDouble() * 2 - 1); + + var gpuW = gpu.UploadRaw(weightBytes, TensorShape.D1(weightBytes.Length), DType.Q4_K); + var gpuX = gpu.Upload(acts, TensorShape.D1(acts.Length)); + var gpuY = gpu.Allocate(TensorShape.D1((long)nTok * rows)); + + double macs = (double)rows * cols * nTok; + double gflop = 2.0 * macs / 1e9; + totGflopPerLayer += gflop; + + // --- MMQ (int8 tensor core, default prefill path) --- + for (int i = 0; i < 5; i++) gpu.MatMulBatchedMmq(gpuY, gpuW, gpuX, nTok, DType.Q4_K); + gpu.Synchronize(); + const int iters = 30; + var sw = Stopwatch.StartNew(); + for (int i = 0; i < iters; i++) gpu.MatMulBatchedMmq(gpuY, gpuW, gpuX, nTok, DType.Q4_K); + gpu.Synchronize(); + sw.Stop(); + double mmqMs = sw.Elapsed.TotalMilliseconds / iters; + totMmqMs += mmqMs; + double mmqTops = 2.0 * macs / (mmqMs * 1e-3) / 1e12; + + // --- C1: dequant→fp16→cuBLAS GEMM --- + for (int i = 0; i < 5; i++) gpu.MatMulBatchedGemm(gpuY, gpuW, gpuX, nTok, DType.Q4_K); + gpu.Synchronize(); + sw.Restart(); + for (int i = 0; i < iters; i++) gpu.MatMulBatchedGemm(gpuY, gpuW, gpuX, nTok, DType.Q4_K); + gpu.Synchronize(); + sw.Stop(); + double gemmMs = sw.Elapsed.TotalMilliseconds / iters; + totGemmMs += gemmMs; + double gemmTflop = gflop / (gemmMs * 1e-3) / 1e3; + + Console.WriteLine( + $"{what,-9} [{rows}×{cols}]·[{cols}×{nTok}] " + + $"MMQ {mmqMs,6:F2} ms ({mmqTops,5:F1} TOP/s, {100 * mmqTops / int8Peak,2:F0}% int8) | " + + $"GEMM {gemmMs,6:F2} ms ({gemmTflop,5:F1} TFLOP/s, {100 * gemmTflop / fp16Peak,2:F0}% fp16)"); + + gpu.Free(gpuW); gpu.Free(gpuX); gpu.Free(gpuY); + } + + // Per-token full-trunk projection (×36 layers) to compare with the e2e prefill. + Console.WriteLine( + $"per-layer trunk: {totGflopPerLayer:F1} GFLOP | MMQ {totMmqMs:F2} ms GEMM {totGemmMs:F2} ms " + + $"(×36 layers @ N={nTok}: MMQ {totMmqMs * 36 / 1000:F2} s, GEMM {totGemmMs * 36 / 1000:F2} s)"); + Assert.True(true); + } +} diff --git a/tests/SharpInference.Tests.ForwardPass/Qwen3CudaBatchedPrefillTests.cs b/tests/SharpInference.Tests.ForwardPass/Qwen3CudaBatchedPrefillTests.cs index 0f5ccb5..83096a0 100644 --- a/tests/SharpInference.Tests.ForwardPass/Qwen3CudaBatchedPrefillTests.cs +++ b/tests/SharpInference.Tests.ForwardPass/Qwen3CudaBatchedPrefillTests.cs @@ -125,6 +125,72 @@ public void Qwen3_8B_BatchedPrefill_DefaultMatchesSequential() $"Default batched top-5 overlaps the per-token reference in only {overlap}/5 slots."); } + /// + /// Issue #162: prompts longer than the non-flash 4096 cap must still take the fast + /// batched-trunk path (chunked into PrefillBatchChunk windows, flash streaming + /// the prior KV) and stay argmax-stable vs the bit-exact per-token loop — instead of + /// silently dropping to the ~8× slower memory-bound per-token prefill. + /// + /// N=5040 exercises a full + partial window (4096 + 944); N=8192 exercises the + /// exact-multiple boundary (two full 4096 windows, the final chunk's len == chunk size) + /// to catch the classic off-by-one in the chunk loop. + /// + /// + [Theory] + [InlineData(5040, 6144)] + [InlineData(8192, 9216)] + public void Qwen3_8B_ChunkedBatchedPrefill_Over4096_MatchesSequential(int promptLen, int ctx) + { + using var gpu = TryCreate(); + if (gpu is null) return; + var path = FindModelPath(); + if (path is null) return; + + using var model = GgufModel.Open(path); + var hp = ModelHyperparams.FromGgufMetadata(model.Metadata, model); + Assert.Null(hp.LayerHeadDim); + + // > 4096 tokens → forces the chunked branch. Deterministic spread across the vocab + // via a small LCG; all ids well within Qwen3's 151936 vocab. + var longTokens = new int[promptLen]; + uint s = 0x9E3779B9u; + for (int i = 0; i < longTokens.Length; i++) + { + s = s * 1664525u + 1013904223u; + longTokens[i] = (int)(s % 150000u) + 1; + } + + // Disable SnapKV (a >budget prompt would otherwise route to the per-token SnapKV + // eviction path before the batched gate). Construct under the override, then restore. + var prevSnap = Environment.GetEnvironmentVariable("SHARPI_SNAPKV_BUDGET"); + Environment.SetEnvironmentVariable("SHARPI_SNAPKV_BUDGET", "0"); + CudaForwardPass fwd; + try { fwd = new CudaForwardPass(model, gpu, hp, maxContextLength: ctx); } + finally { Environment.SetEnvironmentVariable("SHARPI_SNAPKV_BUDGET", prevSnap); } + using var _fwd = fwd; + + // Shipped defaults (flash TC on) → chunked batched path. + fwd.BatchedPrefillEnabled = true; + var batched = fwd.Prefill(longTokens).ToArray(); + Assert.True(fwd.LastPrefillWasBatched, + "Chunked batched prefill did not engage for a >4096-token prompt (#162)."); + + fwd.ResetCache(); + fwd.BatchedPrefillEnabled = false; + var sequential = fwd.Prefill(longTokens).ToArray(); + Assert.False(fwd.LastPrefillWasBatched); + + Assert.Equal(sequential.Length, batched.Length); + Assert.Equal(Argmax(sequential), Argmax(batched)); + + var seqTop = TopKSet(sequential, 5); + var batTop = TopKSet(batched, 5); + int overlap = 0; + foreach (var t in batTop) if (seqTop.Contains(t)) overlap++; + Assert.True(overlap >= 4, + $"Chunked batched top-5 overlaps the per-token reference in only {overlap}/5 slots."); + } + /// /// Flash off + GEMM off: Q4_K trunk runs the batched matvec GEMM-N, which is built to /// be bit-identical to N per-token Q4_K dp4a matvecs. Verifies the batched