diff --git a/src/SharpInference.Engine/CudaForwardPass.cs b/src/SharpInference.Engine/CudaForwardPass.cs
index 0bc9d94..977e97d 100644
--- a/src/SharpInference.Engine/CudaForwardPass.cs
+++ b/src/SharpInference.Engine/CudaForwardPass.cs
@@ -176,6 +176,14 @@ public sealed unsafe class CudaForwardPass : IForwardPass
///
public bool BatchedPrefillEnabled { get; set; }
///
+ /// Issue #162: window size (tokens) for chunked batched prefill of prompts longer
+ /// than the non-flash 4096 cap. Each window is batched at its own startPos with flash
+ /// attention streaming the prior KV, so the N-sized trunk scratch stays bounded to
+ /// this many tokens regardless of prompt length. 4096 matches the well-tested
+ /// single-shot batch size.
+ ///
+ private const int PrefillBatchChunk = 4096;
+ ///
/// Issue #141: route Q8_0 trunk matmuls in the batched prefill through the
/// compute-bound cuBLAS GEMM ()
/// instead of the memory-bound matvec GEMM-N. Default on
@@ -1830,11 +1838,48 @@ public ReadOnlySpan Prefill(IReadOnlyList tokens, int startPos = 0)
// attention/FFN/PLE launches (whose count grows with N) into batched GEMM-N +
// batched-attention launches. Originally Gemma-4-only; #156 opened it to any
// dense model the batched kernels cover (e.g. Qwen3-8B Q4_K). Everything else
- // (MoE, SnapKV-active, TQ, >4096 context, non-NEOX RoPE, L2 QK-norm, attn bias,
- // unbatchable weight dtype) falls back to the per-token loop below.
+ // (MoE, SnapKV-active, TQ, non-NEOX RoPE, L2 QK-norm, attn bias, unbatchable
+ // weight dtype) falls back to the per-token loop below.
+ //
+ // Issue #162: the 4096 fast-path cap is a limit of the *non-flash* shared-scores
+ // AttentionBatched kernel (it throws above startPos+nTok=4096). The flash prefill
+ // kernels stream KV, so when flash is enabled we run the batched path for prompts
+ // of any length, chunking into PrefillBatchChunk-token windows so the N-sized trunk
+ // scratch stays bounded. Each chunk is batched at its own startPos; flash attends
+ // to all prior KV. Without this, a >4096-token prompt drops to the per-token loop
+ // (memory-bound, ~8× slower: 432 → 50 t/s on Qwen3-8B Q4_K @ 4070 Ti).
if (BatchedPrefillEnabled && !snapKvActive && N >= 2
- && startPos + N <= 4096 && IsBatchedPrefillSupported())
- return PrefillBatchedTrunk(tokens, startPos);
+ && startPos + N <= _maxSeqLen && IsBatchedPrefillSupported())
+ {
+ // Chunking past 4096 requires a streaming attention path (flash) AND simple
+ // (non-windowed) KV position semantics. SWA layers wrap KV in a window-sized
+ // ring whose cross-chunk behaviour past the window boundary is not yet
+ // validated, so any windowed model keeps the proven 4096 cap (follow-up: #162).
+ // Note: `IsSwaLayer` is only populated for Gemma 4 today, so the explicit
+ // `SlidingWindowSize <= 0` check fails closed if SWA parsing is later extended
+ // to a uniform-window arch that sets the window without a per-layer pattern.
+ bool canChunkPast4096 = PrefillFlashAttnEnabled
+ && _hp.IsSwaLayer is null && _hp.SlidingWindowSize <= 0;
+ int cap = canChunkPast4096 ? _maxSeqLen : 4096;
+ if (startPos + N <= cap)
+ {
+ if (N <= PrefillBatchChunk || !canChunkPast4096)
+ return PrefillBatchedTrunk(tokens, startPos);
+
+ // Chunked: flash streams KV across windows, so process PrefillBatchChunk
+ // tokens at a time. Only the last chunk's logits are returned (decode
+ // starts from the final token); the per-chunk final norm + output proj on
+ // earlier chunks is discarded, a negligible cost vs the batched-trunk win.
+ int[] all = tokens as int[] ?? System.Linq.Enumerable.ToArray(tokens);
+ ReadOnlySpan chunkLogits = default;
+ for (int off = 0; off < N; off += PrefillBatchChunk)
+ {
+ int len = Math.Min(PrefillBatchChunk, N - off);
+ chunkLogits = PrefillBatchedTrunk(new ArraySegment(all, off, len), startPos + off);
+ }
+ return chunkLogits;
+ }
+ }
int W = 0, wStart = 0;
if (snapKvActive)
diff --git a/tests/SharpInference.Tests.ForwardPass/CudaQ4KPrefillMatmulProbe.cs b/tests/SharpInference.Tests.ForwardPass/CudaQ4KPrefillMatmulProbe.cs
new file mode 100644
index 0000000..d1eeac7
--- /dev/null
+++ b/tests/SharpInference.Tests.ForwardPass/CudaQ4KPrefillMatmulProbe.cs
@@ -0,0 +1,128 @@
+using System.Diagnostics;
+using SharpInference.Core;
+using SharpInference.Cuda;
+
+namespace SharpInference.Tests.ForwardPass;
+
+///
+/// Issue #159/#162 prefill probe (not a correctness test): times the two compute-bound
+/// Q4_K trunk-matmul paths — int8 MMQ () and
+/// dequant→fp16→cuBLAS GEMM () — at the REAL
+/// Qwen3-8B prefill shapes (nTok=1844), and reports achieved TFLOP/s vs the RTX 4070 Ti's
+/// ~165 TFLOP/s fp16-TC / ~330 TOP/s int8-TC peak.
+///
+/// Motivation: the full batched prefill achieves only ~6 TFLOP/s-equivalent (25.6 TFLOP /
+/// ~4.0 s at N=1844), yet the older reported 36–54 int8
+/// TOPS for the isolated MMQ kernel at nTok=1024. This probe times the kernels at the exact
+/// production shapes so we can tell whether the matmul kernels themselves are the bottleneck
+/// (→ #149/#152 tiling rewrite) or whether per-call overhead / redundant work outside the
+/// kernel dominates (→ a cheaper engine-side fix).
+///
+/// Run explicitly: --filter FullyQualifiedName~CudaQ4KPrefillMatmulProbe. Silent no-op
+/// without CUDA. Always asserts true — it only prints the measurement.
+///
+public sealed unsafe class CudaQ4KPrefillMatmulProbe
+{
+ private static CudaBackend? TryCreate()
+ {
+ if (!CudaBackend.IsAvailable()) return null;
+ try { return CudaBackend.Create(); }
+ catch { return null; }
+ }
+
+ private static ushort HalfToUshort(Half h) => BitConverter.ToUInt16(BitConverter.GetBytes(h), 0);
+
+ // block_q4_K = 144 B / 256 elems: d(fp16) dmin(fp16) scales[12] qs[128].
+ private static byte[] BuildQ4KMatrix(int rows, int cols, Random rng)
+ {
+ int blocksPerRow = cols / 256, bytesPerRow = blocksPerRow * 144;
+ var bytes = new byte[(long)rows * bytesPerRow];
+ for (int r = 0; r < rows; r++)
+ for (int b = 0; b < blocksPerRow; b++)
+ {
+ long off = (long)r * bytesPerRow + (long)b * 144;
+ ushort d = HalfToUshort((Half)(float)(rng.NextDouble() * 0.04 + 0.01));
+ ushort dmin = HalfToUshort((Half)(float)(rng.NextDouble() * 0.02 + 0.005));
+ bytes[off] = (byte)(d & 0xFF); bytes[off + 1] = (byte)(d >> 8);
+ bytes[off + 2] = (byte)(dmin & 0xFF); bytes[off + 3] = (byte)(dmin >> 8);
+ for (int i = 0; i < 12; i++) bytes[off + 4 + i] = (byte)rng.Next(256);
+ for (int i = 0; i < 128; i++) bytes[off + 16 + i] = (byte)rng.Next(256);
+ }
+ return bytes;
+ }
+
+ [Fact]
+ public void Q4K_PrefillMatmul_AchievedThroughput_AtRealShapes()
+ {
+ using var gpu = TryCreate();
+ if (gpu is null) return;
+
+ // Qwen3-8B trunk matmuls (rows=out, cols=in) at the real prefill batch N=1844.
+ // hidden=4096, intermediate=12288, qkv=6144 (32+8+8 heads × 128).
+ const int nTok = 1844;
+ (int rows, int cols, string what)[] shapes =
+ {
+ (6144, 4096, "qkv"),
+ (4096, 4096, "o-proj"),
+ (12288, 4096, "ffn-gate"),
+ (12288, 4096, "ffn-up"),
+ (4096, 12288, "ffn-down"),
+ };
+ const double fp16Peak = 165.0; // ~RTX 4070 Ti fp16-TC TFLOP/s
+ const double int8Peak = 330.0; // ~RTX 4070 Ti int8-TC TOP/s
+
+ var rng = new Random(20260607);
+ double totGflopPerLayer = 0, totMmqMs = 0, totGemmMs = 0;
+
+ foreach (var (rows, cols, what) in shapes)
+ {
+ byte[] weightBytes = BuildQ4KMatrix(rows, cols, rng);
+ var acts = new float[(long)nTok * cols];
+ for (int i = 0; i < acts.Length; i++) acts[i] = (float)(rng.NextDouble() * 2 - 1);
+
+ var gpuW = gpu.UploadRaw(weightBytes, TensorShape.D1(weightBytes.Length), DType.Q4_K);
+ var gpuX = gpu.Upload(acts, TensorShape.D1(acts.Length));
+ var gpuY = gpu.Allocate(TensorShape.D1((long)nTok * rows));
+
+ double macs = (double)rows * cols * nTok;
+ double gflop = 2.0 * macs / 1e9;
+ totGflopPerLayer += gflop;
+
+ // --- MMQ (int8 tensor core, default prefill path) ---
+ for (int i = 0; i < 5; i++) gpu.MatMulBatchedMmq(gpuY, gpuW, gpuX, nTok, DType.Q4_K);
+ gpu.Synchronize();
+ const int iters = 30;
+ var sw = Stopwatch.StartNew();
+ for (int i = 0; i < iters; i++) gpu.MatMulBatchedMmq(gpuY, gpuW, gpuX, nTok, DType.Q4_K);
+ gpu.Synchronize();
+ sw.Stop();
+ double mmqMs = sw.Elapsed.TotalMilliseconds / iters;
+ totMmqMs += mmqMs;
+ double mmqTops = 2.0 * macs / (mmqMs * 1e-3) / 1e12;
+
+ // --- C1: dequant→fp16→cuBLAS GEMM ---
+ for (int i = 0; i < 5; i++) gpu.MatMulBatchedGemm(gpuY, gpuW, gpuX, nTok, DType.Q4_K);
+ gpu.Synchronize();
+ sw.Restart();
+ for (int i = 0; i < iters; i++) gpu.MatMulBatchedGemm(gpuY, gpuW, gpuX, nTok, DType.Q4_K);
+ gpu.Synchronize();
+ sw.Stop();
+ double gemmMs = sw.Elapsed.TotalMilliseconds / iters;
+ totGemmMs += gemmMs;
+ double gemmTflop = gflop / (gemmMs * 1e-3) / 1e3;
+
+ Console.WriteLine(
+ $"{what,-9} [{rows}×{cols}]·[{cols}×{nTok}] " +
+ $"MMQ {mmqMs,6:F2} ms ({mmqTops,5:F1} TOP/s, {100 * mmqTops / int8Peak,2:F0}% int8) | " +
+ $"GEMM {gemmMs,6:F2} ms ({gemmTflop,5:F1} TFLOP/s, {100 * gemmTflop / fp16Peak,2:F0}% fp16)");
+
+ gpu.Free(gpuW); gpu.Free(gpuX); gpu.Free(gpuY);
+ }
+
+ // Per-token full-trunk projection (×36 layers) to compare with the e2e prefill.
+ Console.WriteLine(
+ $"per-layer trunk: {totGflopPerLayer:F1} GFLOP | MMQ {totMmqMs:F2} ms GEMM {totGemmMs:F2} ms " +
+ $"(×36 layers @ N={nTok}: MMQ {totMmqMs * 36 / 1000:F2} s, GEMM {totGemmMs * 36 / 1000:F2} s)");
+ Assert.True(true);
+ }
+}
diff --git a/tests/SharpInference.Tests.ForwardPass/Qwen3CudaBatchedPrefillTests.cs b/tests/SharpInference.Tests.ForwardPass/Qwen3CudaBatchedPrefillTests.cs
index 0f5ccb5..83096a0 100644
--- a/tests/SharpInference.Tests.ForwardPass/Qwen3CudaBatchedPrefillTests.cs
+++ b/tests/SharpInference.Tests.ForwardPass/Qwen3CudaBatchedPrefillTests.cs
@@ -125,6 +125,72 @@ public void Qwen3_8B_BatchedPrefill_DefaultMatchesSequential()
$"Default batched top-5 overlaps the per-token reference in only {overlap}/5 slots.");
}
+ ///
+ /// Issue #162: prompts longer than the non-flash 4096 cap must still take the fast
+ /// batched-trunk path (chunked into PrefillBatchChunk windows, flash streaming
+ /// the prior KV) and stay argmax-stable vs the bit-exact per-token loop — instead of
+ /// silently dropping to the ~8× slower memory-bound per-token prefill.
+ ///
+ /// N=5040 exercises a full + partial window (4096 + 944); N=8192 exercises the
+ /// exact-multiple boundary (two full 4096 windows, the final chunk's len == chunk size)
+ /// to catch the classic off-by-one in the chunk loop.
+ ///
+ ///
+ [Theory]
+ [InlineData(5040, 6144)]
+ [InlineData(8192, 9216)]
+ public void Qwen3_8B_ChunkedBatchedPrefill_Over4096_MatchesSequential(int promptLen, int ctx)
+ {
+ using var gpu = TryCreate();
+ if (gpu is null) return;
+ var path = FindModelPath();
+ if (path is null) return;
+
+ using var model = GgufModel.Open(path);
+ var hp = ModelHyperparams.FromGgufMetadata(model.Metadata, model);
+ Assert.Null(hp.LayerHeadDim);
+
+ // > 4096 tokens → forces the chunked branch. Deterministic spread across the vocab
+ // via a small LCG; all ids well within Qwen3's 151936 vocab.
+ var longTokens = new int[promptLen];
+ uint s = 0x9E3779B9u;
+ for (int i = 0; i < longTokens.Length; i++)
+ {
+ s = s * 1664525u + 1013904223u;
+ longTokens[i] = (int)(s % 150000u) + 1;
+ }
+
+ // Disable SnapKV (a >budget prompt would otherwise route to the per-token SnapKV
+ // eviction path before the batched gate). Construct under the override, then restore.
+ var prevSnap = Environment.GetEnvironmentVariable("SHARPI_SNAPKV_BUDGET");
+ Environment.SetEnvironmentVariable("SHARPI_SNAPKV_BUDGET", "0");
+ CudaForwardPass fwd;
+ try { fwd = new CudaForwardPass(model, gpu, hp, maxContextLength: ctx); }
+ finally { Environment.SetEnvironmentVariable("SHARPI_SNAPKV_BUDGET", prevSnap); }
+ using var _fwd = fwd;
+
+ // Shipped defaults (flash TC on) → chunked batched path.
+ fwd.BatchedPrefillEnabled = true;
+ var batched = fwd.Prefill(longTokens).ToArray();
+ Assert.True(fwd.LastPrefillWasBatched,
+ "Chunked batched prefill did not engage for a >4096-token prompt (#162).");
+
+ fwd.ResetCache();
+ fwd.BatchedPrefillEnabled = false;
+ var sequential = fwd.Prefill(longTokens).ToArray();
+ Assert.False(fwd.LastPrefillWasBatched);
+
+ Assert.Equal(sequential.Length, batched.Length);
+ Assert.Equal(Argmax(sequential), Argmax(batched));
+
+ var seqTop = TopKSet(sequential, 5);
+ var batTop = TopKSet(batched, 5);
+ int overlap = 0;
+ foreach (var t in batTop) if (seqTop.Contains(t)) overlap++;
+ Assert.True(overlap >= 4,
+ $"Chunked batched top-5 overlaps the per-token reference in only {overlap}/5 slots.");
+ }
+
///
/// Flash off + GEMM off: Q4_K trunk runs the batched matvec GEMM-N, which is built to
/// be bit-identical to N per-token Q4_K dp4a matvecs. Verifies the batched