perf(cuda): chunked batched prefill for >4096-token prompts (#162)#163
Conversation
Prompts longer than 4096 tokens silently fell back to the per-token prefill loop (memory-bound, weight re-streamed per token: ~50 t/s on Qwen3-8B Q4_K @ 4070 Ti) because the batched-trunk gate hard-capped at startPos+N <= 4096. That 4096 is a limit of the *non-flash* shared-scores AttentionBatched kernel, which throws above 4096 — but the flash prefill kernels stream KV and have no such cap. When flash attention is enabled (default) and the model has no SWA layers, run the batched-trunk path for prompts of any length up to the context size, chunking into 4096-token windows so the N-sized trunk scratch stays bounded. Each window is batched at its own startPos; flash attends to all prior KV. Greedy output is argmax-identical to the per-token reference. Qwen3-8B Q4_K, 4711-token prompt: 49.9 -> 415.5 t/s prefill (8.3x). SWA models (Gemma 4) keep the proven 4096 cap — their window-ring KV semantics across chunk boundaries are not yet validated (follow-up #162). Investigation context (no decode change — see PR/#162): - decode is BW-bound on already-SoA matvecs (~89% of forward time); the greedy sampler is a trivial argmax, not a Gemma-style full-vocab artifact. - prefill trunk matmuls dominate (95.5%); C1/C2 compute-bound paths already shipped. The residual gap is a cold-L2 ~5x kernel-efficiency penalty (#149/#152), reproduced by the new CudaQ4KPrefillMatmulProbe. Tests: Qwen3_8B_ChunkedBatchedPrefill_Over4096_MatchesSequential (argmax + top-5 parity vs per-token); CudaQ4KPrefillMatmulProbe (real-shape roofline). All 25 batched-prefill tests green. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request introduces chunked batched prefill support for prompts longer than 4096 tokens when flash attention is enabled, preventing fallback to the slower per-token loop. It also adds a performance probe for Q4_K trunk-matmuls and a corresponding test for chunked prefill. Feedback points out a performance bottleneck in CudaForwardPass.cs where passing ArraySegment to PrefillBatchedTrunk triggers unintentional heap allocations due to a failed type cast, and suggests using pattern matching with ReadOnlySpan to maintain a zero-allocation path.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| int[] all = tokens as int[] ?? System.Linq.Enumerable.ToArray(tokens); | ||
| ReadOnlySpan<float> chunkLogits = default; | ||
| for (int off = 0; off < N; off += PrefillBatchChunk) | ||
| { | ||
| int len = Math.Min(PrefillBatchChunk, N - off); | ||
| chunkLogits = PrefillBatchedTrunk(new ArraySegment<int>(all, off, len), startPos + off); |
There was a problem hiding this comment.
Performance Bottleneck: Unintentional Heap Allocations During Chunked Prefill
When chunking is active, Prefill passes an ArraySegment<int> to PrefillBatchedTrunk:
chunkLogits = PrefillBatchedTrunk(new ArraySegment<int>(all, off, len), startPos + off);Inside PrefillBatchedTrunk (around line 2071), the code attempts to cast tokens to int[]:
int[] ids = tokens as int[] ?? System.Linq.Enumerable.ToArray(tokens);Since ArraySegment<int> is not int[], this cast evaluates to null, forcing a fallback to System.Linq.Enumerable.ToArray(tokens). This allocates a new int[] array of size up to 4096 on the heap for every single chunk during prefill, causing significant GC pressure and defeating the zero-allocation path for int[] inputs.
Recommended Solution
To eliminate these allocations, update PrefillBatchedTrunk to retrieve a ReadOnlySpan<int> from IReadOnlyList<int> using pattern matching, and then use MemoryMarshal.AsBytes on that span:
// Inside PrefillBatchedTrunk:
if (_embIsQuantized && embDType == DType.Q8_0)
{
ReadOnlySpan<int> ids = tokens switch
{
int[] arr => arr,
ArraySegment<int> seg => seg,
List<int> list => CollectionsMarshal.AsSpan(list),
_ => System.Linq.Enumerable.ToArray(tokens)
};
var idTensor = _gpu.UploadRaw(
System.Runtime.InteropServices.MemoryMarshal.AsBytes(ids),
TensorShape.D1(N), DType.Float32);
_gpu.EmbedLookupQ8_0Batched(_gpuEmbedding, _bpHidden!, idTensor, N, embDim);
_gpu.Free(idTensor);
}This completely avoids heap allocations for int[], ArraySegment<int>, and List<int> inputs.
Addresses PR #163 review: - silent-failure-hunter / pr-test-analyzer: `_hp.IsSwaLayer is null` means "not Gemma 4", not "no SWA". Add `_hp.SlidingWindowSize <= 0` so the chunked-prefill guard fails closed if SWA parsing is later extended to a uniform-window arch that sets the window without a per-layer pattern array. - pr-test-analyzer: parameterize the chunked-prefill parity test to also cover the exact 4096-multiple boundary (N=8192, two full windows) alongside the partial-window case (N=5040), catching the classic chunk-loop off-by-one. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
|
Addressed the review (commit 4fe26c4):
@gemini-code-assist re the |
|
Thanks for the detailed update and for verifying the chunked parity with the N=8192 case. The logic for the SWA guard and the handling of the |
Summary
This PR started as the #159/#162 profiling decision gates and produced one concrete shippable win plus a clear redirect of the remaining work.
Shipped: chunked batched prefill for long prompts (#162)
Prompts longer than 4096 tokens silently fell back to the per-token prefill loop — memory-bound, weight re-streamed per token. The batched-trunk gate hard-capped at
startPos + N <= 4096, but that 4096 is a limit of the non-flash shared-scoresAttentionBatchedkernel (it throws above 4096). The flash prefill kernels (llm_flash_attn_prefill_tc2/tc/f32) stream KV in 16-position tiles with online softmax and register-resident O — no 4096 cap.When flash is enabled (default) and the model has no SWA layers, we now run the batched-trunk path for prompts of any length up to the context size, chunking into 4096-token windows so the N-sized trunk scratch stays bounded. Each window is batched at its own
startPos; flash attends to all prior KV.Greedy
--temp 0output is argmax-identical to the per-token reference. SWA models (Gemma 4) keep the proven 4096 cap — their window-ring KV semantics across chunk boundaries aren't yet validated (follow-up tracked in #162).Investigation findings (redirect, no code change)
Decode (#162) — Profiled Qwen3-8B: matvecs are ~89% of forward time (ffn 58%, qkv 14%, lm_head 9%, o-proj 8%); non-matvec work (rope/qknorm/kv/attn) is only ~11%. The greedy sampler is a trivial O(vocab) argmax, not a Gemma-#142-style full-vocab artifact. Decode is genuinely BW-bound on already-SoA-optimized matvecs → no deletable hotspot. The only decode lever left is matvec kernel efficiency, which the issue itself rates low-yield. Decode item: no easy win.
Prefill (#159) — Decision gate confirmed: trunk matmuls dominate (95.5% at N=1844; attention 2.9%). But C1 (dequant→cuBLAS GEMM) and C2 (int8 MMQ) are already shipped on master (default-on), giving 2.8× over the memory-bound GEMM-N (153→433 t/s). So #159's stated deliverable is done.
The residual prefill gap to llama.cpp (433 vs 5764 t/s) is a cold-L2 ~5× kernel-efficiency penalty, reproduced precisely by the new
CudaQ4KPrefillMatmulProbe: the isolated MMQ/GEMM kernels do the whole trunk in ~0.5s (looping one warm weight) but the real prefill — cycling the full 5GB model working set — spends ~4s in the same kernels at 100% SM / low DRAM util. That's exactly the MMQ tiling-rewrite territory of #149/#152 (large, risky; split-K and cp.async already ruled out there).Tests
Qwen3_8B_ChunkedBatchedPrefill_Over4096_MatchesSequential— 5040-token prompt forces the two-window chunked branch; asserts argmax equality + top-5 overlap vs the bit-exact per-token loop.CudaQ4KPrefillMatmulProbe— real-shape (nTok=1844) roofline probe for the Q4_K MMQ and cuBLAS-GEMM trunk paths.Closes #159. Refs #162 (decode no-win + cold-L2 prefill lever → #149/#152; Gemma SWA chunking follow-up).
🤖 Generated with Claude Code