Skip to content

perf(cuda): chunked batched prefill for >4096-token prompts (#162)#163

Merged
pekkah merged 2 commits into
masterfrom
perf/cuda-q4k-decode-prefill-162-159
Jun 7, 2026
Merged

perf(cuda): chunked batched prefill for >4096-token prompts (#162)#163
pekkah merged 2 commits into
masterfrom
perf/cuda-q4k-decode-prefill-162-159

Conversation

@pekkah

@pekkah pekkah commented Jun 7, 2026

Copy link
Copy Markdown
Owner

Summary

This PR started as the #159/#162 profiling decision gates and produced one concrete shippable win plus a clear redirect of the remaining work.

Shipped: chunked batched prefill for long prompts (#162)

Prompts longer than 4096 tokens silently fell back to the per-token prefill loop — memory-bound, weight re-streamed per token. The batched-trunk gate hard-capped at startPos + N <= 4096, but that 4096 is a limit of the non-flash shared-scores AttentionBatched kernel (it throws above 4096). The flash prefill kernels (llm_flash_attn_prefill_tc2/tc/f32) stream KV in 16-position tiles with online softmax and register-resident O — no 4096 cap.

When flash is enabled (default) and the model has no SWA layers, we now run the batched-trunk path for prompts of any length up to the context size, chunking into 4096-token windows so the N-sized trunk scratch stays bounded. Each window is batched at its own startPos; flash attends to all prior KV.

Qwen3-8B Q4_K, 4711-token prompt prefill t/s
before (per-token fallback) 49.9
after (chunked batched) 415.5 (8.3×)

Greedy --temp 0 output is argmax-identical to the per-token reference. SWA models (Gemma 4) keep the proven 4096 cap — their window-ring KV semantics across chunk boundaries aren't yet validated (follow-up tracked in #162).

Investigation findings (redirect, no code change)

Decode (#162) — Profiled Qwen3-8B: matvecs are ~89% of forward time (ffn 58%, qkv 14%, lm_head 9%, o-proj 8%); non-matvec work (rope/qknorm/kv/attn) is only ~11%. The greedy sampler is a trivial O(vocab) argmax, not a Gemma-#142-style full-vocab artifact. Decode is genuinely BW-bound on already-SoA-optimized matvecs → no deletable hotspot. The only decode lever left is matvec kernel efficiency, which the issue itself rates low-yield. Decode item: no easy win.

Prefill (#159) — Decision gate confirmed: trunk matmuls dominate (95.5% at N=1844; attention 2.9%). But C1 (dequant→cuBLAS GEMM) and C2 (int8 MMQ) are already shipped on master (default-on), giving 2.8× over the memory-bound GEMM-N (153→433 t/s). So #159's stated deliverable is done.

The residual prefill gap to llama.cpp (433 vs 5764 t/s) is a cold-L2 ~5× kernel-efficiency penalty, reproduced precisely by the new CudaQ4KPrefillMatmulProbe: the isolated MMQ/GEMM kernels do the whole trunk in ~0.5s (looping one warm weight) but the real prefill — cycling the full 5GB model working set — spends ~4s in the same kernels at 100% SM / low DRAM util. That's exactly the MMQ tiling-rewrite territory of #149/#152 (large, risky; split-K and cp.async already ruled out there).

Tests

  • Qwen3_8B_ChunkedBatchedPrefill_Over4096_MatchesSequential — 5040-token prompt forces the two-window chunked branch; asserts argmax equality + top-5 overlap vs the bit-exact per-token loop.
  • CudaQ4KPrefillMatmulProbe — real-shape (nTok=1844) roofline probe for the Q4_K MMQ and cuBLAS-GEMM trunk paths.
  • All 25 batched-prefill tests green (Qwen3 ×5, Gemma4, both Hybrid suites) — Gemma4's ≤4096 path confirmed intact by the SWA guard.

Closes #159. Refs #162 (decode no-win + cold-L2 prefill lever → #149/#152; Gemma SWA chunking follow-up).

🤖 Generated with Claude Code

Prompts longer than 4096 tokens silently fell back to the per-token
prefill loop (memory-bound, weight re-streamed per token: ~50 t/s on
Qwen3-8B Q4_K @ 4070 Ti) because the batched-trunk gate hard-capped at
startPos+N <= 4096. That 4096 is a limit of the *non-flash* shared-scores
AttentionBatched kernel, which throws above 4096 — but the flash prefill
kernels stream KV and have no such cap.

When flash attention is enabled (default) and the model has no SWA layers,
run the batched-trunk path for prompts of any length up to the context
size, chunking into 4096-token windows so the N-sized trunk scratch stays
bounded. Each window is batched at its own startPos; flash attends to all
prior KV. Greedy output is argmax-identical to the per-token reference.

  Qwen3-8B Q4_K, 4711-token prompt: 49.9 -> 415.5 t/s prefill (8.3x).

SWA models (Gemma 4) keep the proven 4096 cap — their window-ring KV
semantics across chunk boundaries are not yet validated (follow-up #162).

Investigation context (no decode change — see PR/#162):
- decode is BW-bound on already-SoA matvecs (~89% of forward time); the
  greedy sampler is a trivial argmax, not a Gemma-style full-vocab artifact.
- prefill trunk matmuls dominate (95.5%); C1/C2 compute-bound paths already
  shipped. The residual gap is a cold-L2 ~5x kernel-efficiency penalty
  (#149/#152), reproduced by the new CudaQ4KPrefillMatmulProbe.

Tests: Qwen3_8B_ChunkedBatchedPrefill_Over4096_MatchesSequential (argmax +
top-5 parity vs per-token); CudaQ4KPrefillMatmulProbe (real-shape roofline).
All 25 batched-prefill tests green.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces chunked batched prefill support for prompts longer than 4096 tokens when flash attention is enabled, preventing fallback to the slower per-token loop. It also adds a performance probe for Q4_K trunk-matmuls and a corresponding test for chunked prefill. Feedback points out a performance bottleneck in CudaForwardPass.cs where passing ArraySegment to PrefillBatchedTrunk triggers unintentional heap allocations due to a failed type cast, and suggests using pattern matching with ReadOnlySpan to maintain a zero-allocation path.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +1869 to +1874
int[] all = tokens as int[] ?? System.Linq.Enumerable.ToArray(tokens);
ReadOnlySpan<float> chunkLogits = default;
for (int off = 0; off < N; off += PrefillBatchChunk)
{
int len = Math.Min(PrefillBatchChunk, N - off);
chunkLogits = PrefillBatchedTrunk(new ArraySegment<int>(all, off, len), startPos + off);

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Performance Bottleneck: Unintentional Heap Allocations During Chunked Prefill

When chunking is active, Prefill passes an ArraySegment<int> to PrefillBatchedTrunk:

chunkLogits = PrefillBatchedTrunk(new ArraySegment<int>(all, off, len), startPos + off);

Inside PrefillBatchedTrunk (around line 2071), the code attempts to cast tokens to int[]:

int[] ids = tokens as int[] ?? System.Linq.Enumerable.ToArray(tokens);

Since ArraySegment<int> is not int[], this cast evaluates to null, forcing a fallback to System.Linq.Enumerable.ToArray(tokens). This allocates a new int[] array of size up to 4096 on the heap for every single chunk during prefill, causing significant GC pressure and defeating the zero-allocation path for int[] inputs.

Recommended Solution

To eliminate these allocations, update PrefillBatchedTrunk to retrieve a ReadOnlySpan<int> from IReadOnlyList<int> using pattern matching, and then use MemoryMarshal.AsBytes on that span:

// Inside PrefillBatchedTrunk:
if (_embIsQuantized && embDType == DType.Q8_0)
{
    ReadOnlySpan<int> ids = tokens switch
    {
        int[] arr => arr,
        ArraySegment<int> seg => seg,
        List<int> list => CollectionsMarshal.AsSpan(list),
        _ => System.Linq.Enumerable.ToArray(tokens)
    };
    var idTensor = _gpu.UploadRaw(
        System.Runtime.InteropServices.MemoryMarshal.AsBytes(ids),
        TensorShape.D1(N), DType.Float32);
    _gpu.EmbedLookupQ8_0Batched(_gpuEmbedding, _bpHidden!, idTensor, N, embDim);
    _gpu.Free(idTensor);
}

This completely avoids heap allocations for int[], ArraySegment<int>, and List<int> inputs.

Addresses PR #163 review:
- silent-failure-hunter / pr-test-analyzer: `_hp.IsSwaLayer is null` means
  "not Gemma 4", not "no SWA". Add `_hp.SlidingWindowSize <= 0` so the
  chunked-prefill guard fails closed if SWA parsing is later extended to a
  uniform-window arch that sets the window without a per-layer pattern array.
- pr-test-analyzer: parameterize the chunked-prefill parity test to also cover
  the exact 4096-multiple boundary (N=8192, two full windows) alongside the
  partial-window case (N=5040), catching the classic chunk-loop off-by-one.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@pekkah

pekkah commented Jun 7, 2026

Copy link
Copy Markdown
Owner Author

Addressed the review (commit 4fe26c4):

  • SWA guard fail-closed (silent-failure-hunter + pr-test-analyzer): hardened to _hp.IsSwaLayer is null && _hp.SlidingWindowSize <= 0. Today IsSwaLayer is only populated for Gemma 4, so the explicit window check fails closed if SWA parsing is later extended to a uniform-window arch that sets the window without a per-layer pattern array.
  • Exact-boundary test (pr-test-analyzer): the chunked parity test is now a [Theory] covering both N=5040 (full+partial window) and N=8192 (exact 4096-multiple, two full windows) to catch chunk-loop off-by-one. Both pass.

@gemini-code-assist re the ArraySegmentToArray allocation: that tokens as int[] ?? ToArray path is inside the Q8_0-embed branch of PrefillBatchedTrunk, which is Gemma-4-only — and Gemma 4 (SWA) is excluded from the chunked path by the guard above. Dense non-SWA models (the only ones that chunk) use the non-Q8_0 embed loop that indexes the ArraySegment directly, so no per-chunk heap allocation occurs on the chunked path. Prefill setup is also not the per-token hot path. Left as-is.

@gemini-code-assist

Copy link
Copy Markdown

Thanks for the detailed update and for verifying the chunked parity with the N=8192 case. The logic for the SWA guard and the handling of the ArraySegment allocation in the PrefillBatchedTrunk path seems solid given the constraints you've outlined. The approach of keeping the SWA models at the 4096 cap while enabling chunking for dense models is a prudent way to maintain correctness until the windowed KV semantics are fully validated. The added tests look comprehensive for catching potential off-by-one errors in the chunking loop.

@pekkah pekkah merged commit bc3cd2c into master Jun 7, 2026
1 check passed
@pekkah pekkah deleted the perf/cuda-q4k-decode-prefill-162-159 branch June 7, 2026 11:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

perf(cuda): compute-bound Q4_K prefill matmul — dequant->cuBLAS or Q4_K MMQ (#156 Item C)

1 participant