Skip to content

perf(cuda): packed multi-prompt prefill for continuous batching (#193)#272

Merged
pekkah merged 1 commit into
masterfrom
feat/193-cuda-packed-prefill
Jun 17, 2026
Merged

perf(cuda): packed multi-prompt prefill for continuous batching (#193)#272
pekkah merged 1 commit into
masterfrom
feat/193-cuda-packed-prefill

Conversation

@pekkah

@pekkah pekkah commented Jun 17, 2026

Copy link
Copy Markdown
Owner

Summary

Closes #193 (the last remaining scope item — item 3 — of #190 CUDA dense continuous batching).

CudaForwardPass.PrefillPackedMulti prefilled the S pending prompts' chunks sequentially — a loop over PrefillWithCache, each chunk a separate batched-trunk forward pass that re-read every weight. With several prompts admitted concurrently their prefill GEMMs were not amortized across prompts (unlike the CPU ForwardPass.PrefillPackedMulti, #183 Gap 2).

This packs the S chunks token-major into the batched-trunk scratch and runs one forward pass over N = Σ chunk_len, so every trunk + output GEMM amortizes its weight read across all prompts. Per-sequence RoPE / QK-norm / KV-append / varlen attention run on each chunk's slice at its absolute startPos against its own per-sequence cache (cu_seqlens-style — no cross-sequence attention, no padding), mirroring the CPU packed path and the single-sequence PrefillBatchedTrunk.

What changed

  • Extracted GpuPrefillAppendAttention from GpuLayerBatchedTrunk so the packed path drives the identical KV-append + attention dispatch per sub-sequence — guaranteeing argmax-stability with the single-sequence trunk. The extraction is byte-for-byte the prior inline dispatch.
  • New GpuLayerPackedTrunk (dense-only packed layer) + PrefillPackedTrunkMulti (driver) + AllChunksPackable gate (mirrors the single-seq perf(cuda): close remaining Qwen3-8B Q4_K DECODE gap to llama.cpp — non-matvec cost (prefill handled by #167; kernel-efficiency in #149/#152) #162 attention cap).
  • PrefillPackedMulti packs when S>=2 && IsBatchedPrefillSupported() && AllChunksPackable, else falls back to the (always-correct) sequential loop — so attn-bias / L2-QK-norm / non-NEOX models and over-cap sub-sequences degrade cleanly.
  • Dense-only assertion in PrefillPackedTrunkMulti self-enforces the contract (the real Gemma-4/softcap gate is ThrowIfBatchingUnsupported via DenseBatchedDecodeSupported, not IsBatchedPrefillSupported, which Gemma 4 satisfies).

Testing (Qwen3-8B Q4_K, RTX 4070 Ti)

New oracles in CudaBatchForwardMultiTests, all argmax-stable within the cross-path tolerance the batched-trunk oracles hold:

  • PrefillPackedMulti_N2_MatchesSequential — packed final-token logits vs per-sequence PrefillWithCache.
  • PrefillPackedMulti_Chunked_MatchesWhole — two prompts chunked together vs whole-prompt prefill (cross-chunk KV reads + the S<2 fallback when the shorter prompt finishes).
  • PrefillPackedMulti_ThenBatchedDecode_MatchesSingleUser — end-to-end packed admission → batched decode vs single-user prefill+decode.

Regression-checked the extraction: full CudaBatchForwardMultiTests + Qwen3CudaBatchedPrefillTests (18) and Gemma4CudaBatchedPrefillTests (8, covers the SWA / shared-KV / PLE / per-layer-head_dim branches of the shared helper) all pass. Full solution builds clean (TreatWarningsAsErrors).

Review

Two review passes (code-reviewer + silent-failure-hunter) found no critical issues; the consensus hardening item — making the dense-only contract self-enforcing rather than depending on a guard two call-frames up, and fixing doc comments that credited the wrong gate — is applied.

🤖 Generated with Claude Code

CudaForwardPass.PrefillPackedMulti prefilled the S pending prompts'
chunks sequentially — a loop over PrefillWithCache, each chunk a separate
batched-trunk forward pass that re-read every weight. With several prompts
admitted concurrently their prefill GEMMs were not amortized across prompts
(unlike the CPU ForwardPass.PrefillPackedMulti, #183 Gap 2).

Pack the S chunks token-major into the batched-trunk scratch and run ONE
forward pass over N = Σ chunk_len so every trunk + output GEMM amortizes its
weight read across all prompts. Per-sequence RoPE / QK-norm / KV-append /
varlen attention run on each chunk's slice at its absolute startPos against
its own per-sequence cache (cu_seqlens-style — no cross-sequence attention,
no padding), mirroring the CPU packed path and the single-sequence
PrefillBatchedTrunk.

- Extract GpuPrefillAppendAttention from GpuLayerBatchedTrunk so the packed
  path drives the IDENTICAL KV-append + attention dispatch per sub-sequence
  (argmax-stable with the single-sequence trunk; verified the extraction is
  byte-for-byte the prior inline dispatch — Qwen3 dense + Gemma4 SWA/shared-KV
  prefill oracles all pass).
- New GpuLayerPackedTrunk (dense-only packed layer) + PrefillPackedTrunkMulti
  (driver) + AllChunksPackable gate (mirrors the single-seq #162 attention cap).
- PrefillPackedMulti packs when S>=2 && IsBatchedPrefillSupported() &&
  AllChunksPackable, else falls back to the (always-correct) sequential loop.
- Dense-only assertion in PrefillPackedTrunkMulti self-enforces the contract
  (ThrowIfBatchingUnsupported is the real Gemma-4/softcap gate, not
  IsBatchedPrefillSupported which Gemma 4 satisfies).

Tests (Qwen3-8B Q4_K, 4070 Ti): packed-vs-sequential final-token logits,
chunked packed prefill vs whole-prompt (cross-chunk KV + S<2 fallback), and
packed prefill -> batched decode vs single-user prefill+decode. All
argmax-stable within the cross-path tolerance the batched-trunk oracles hold.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@gemini-code-assist

Copy link
Copy Markdown

Warning

Gemini encountered an error creating the review. You can try again by commenting /gemini review.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

feat(cuda): packed multi-prompt prefill for continuous batching — PrefillPackedMulti is sequential (#190 follow-up)

1 participant