perf(mtp): 4-input CPU MatVec4In for k>2 batched verify (#209)#287
Conversation
Work item #1 of #209: amortize the dominant CPU mmap dense-FFN weight read across four MTP draft tokens. The 27B-MTP CUDA-hybrid decode cost center is the 46/64 CPU-resident FFN layers (~6 GB/token); the prior MatVec2In only pair-amortized that read, so the verify optimum sat at k=2 (10.1 t/s) and deeper chains lost to the linear-in-k re-stream. - SimdKernels.MatVec4In + register-tiled DotQ5K_4In (AVX-512) / DotQ6K_Q8K_4In (AVX2); Q4_K reuses the existing #114 DotQ4K_4In, F32 via 4x DotF32. Each decodes one weight block once and FMAs four input columns, bit-identical per slot to a single MatVec (per-position bits are k-parity-independent). Q8_0 deliberately routes to the dequant fallback to stay bit-identical to single-token dense-FFN decode (MatVec/MatVecDual), which never specialized it. - CpuDenseFfn4 / DenseFfn4 + a 4-wide lm_head replace the 2-wide loops in both GDN passes' BatchVerify. The three hand-copied duplicated-input-tail idioms collapse into one shared MtpBatchTail.Group4 helper (clamps past-the-end lanes to the last real token, routes their output to a sink). - Default verify batch moves k=2 -> k=4 (ResolveMtpBatchMax 2->4, ResolveDraftN 1->3) now that the 4-input kernel makes k=4 the measured optimum. The GDN ring alloc stops on OOM and clamps MaxBatchVerifyTokens, so tight-VRAM cards degrade gracefully. - Amputate the dead public IForwardPass.LastHiddenT1 accessor (no consumers; MtpDecoder drives BatchVerify/HiddenAt). The CPU pass's backing buffer is removed; the CUDA pass keeps _lastHiddenT1/_gpuLastHiddenT1 as internal SHARPI_CPU_GDN=1 debug-trunk scratch. Bench (27B Q4_K_M CUDA-hybrid, RTX 4070 Ti, -g -1 --no-thinking): decode MTP-off 6.5, k=2 10.1, k=4 12.3, k=6 10.4, k=8 10.2 t/s. New default 12.3 t/s (84% accept) = 1.9x over MTP-off, +22% over the old k=2 default. Tests: MatVec4In_BitwiseMatchesSingleMatVec (Q4_K/Q5_K/Q6_K/Q8_0/F32, serial + Parallel.For) and MtpBatchTail lane-mapping (k=1..9) are bitwise oracles; the CUDA k=4 batched-verify suite stays green. MtpDecoder_GreedyParity_LlamaCpp is untouched and unaffected (both 2In and 4In are per-token bit-identical to single MatVec, so the CPU pass emits identical per-token logits). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request implements a 4-input fused matrix-vector multiplication kernel (MatVec4In) to optimize the CPU mmap dense FFN evaluation during multi-token prediction (MTP) batched verification (Issue #209). It introduces register-tiled quad-batch dot product implementations for Q4_K, Q5_K, and Q6_K quantizations, as well as Float32 and fallbacks. The default MTP batch size is increased from 2 to 4 (SHARPI_MTP_BATCH_MAX defaults to 4, and draft-chain length defaults to 3), yielding a performance improvement on CUDA-hybrid backends. Additionally, the PR centralizes the lane-to-token mapping logic in a new MtpBatchTail helper class, cleans up the unused LastHiddenT1 property, and adds comprehensive unit tests to verify the correctness and bit-identity of the new 4-input SIMD kernels. I have no feedback to provide as there are no review comments.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
There was a problem hiding this comment.
Code Review
This pull request introduces a width-4 batched-verify (MTP) implementation, upgrading the previous pairwise (k=2) batching to quad-batching (k=4) to improve performance. It adds a new MatVec4In fused mat-vec kernel in SimdKernels.cs for various data types (Q4_K, Q5_K, Q6_K, and Float32) and implements corresponding 4-input dot products. The engine's forward passes (CudaHybridGdnForwardPass and HybridGdnForwardPass) are updated to support the new 4-token dense FFN and lm_head operations, utilizing a new centralized MtpBatchTail helper for lane-to-token mapping. The default draft-chain length is increased to 3 (k=4 verify batch), and corresponding unit tests are added. Feedback on the CPU kernels suggests optimizing fallback and sequential paths by checking for duplicate input pointers (which occur in the partial-group tail) to avoid redundant allocations, quantization, and dot-product computations.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| byte* sc0 = stackalloc byte[scratchBytes]; | ||
| byte* sc1 = stackalloc byte[scratchBytes]; | ||
| byte* sc2 = stackalloc byte[scratchBytes]; | ||
| byte* sc3 = stackalloc byte[scratchBytes]; | ||
| QuantizeRowToQ8K(input0, cols, sc0); | ||
| QuantizeRowToQ8K(input1, cols, sc1); | ||
| QuantizeRowToQ8K(input2, cols, sc2); | ||
| QuantizeRowToQ8K(input3, cols, sc3); |
There was a problem hiding this comment.
When the batch size k is not a multiple of 4, the duplicate-input tail clamps empty lanes to the last real token, resulting in duplicate input pointers (e.g., input1 == input2). We can avoid redundant QuantizeRowToQ8K calls and stack allocations by checking for duplicate input pointers and reusing the already-computed scratch buffers.
byte* sc0 = stackalloc byte[scratchBytes];
QuantizeRowToQ8K(input0, cols, sc0);
byte* sc1;
if (input1 == input0)
{
sc1 = sc0;
}
else
{
sc1 = stackalloc byte[scratchBytes];
QuantizeRowToQ8K(input1, cols, sc1);
}
byte* sc2;
if (input2 == input1)
{
sc2 = sc1;
}
else if (input2 == input0)
{
sc2 = sc0;
}
else
{
sc2 = stackalloc byte[scratchBytes];
QuantizeRowToQ8K(input2, cols, sc2);
}
byte* sc3;
if (input3 == input2)
{
sc3 = sc2;
}
else if (input3 == input1)
{
sc3 = sc1;
}
else if (input3 == input0)
{
sc3 = sc0;
}
else
{
sc3 = stackalloc byte[scratchBytes];
QuantizeRowToQ8K(input3, cols, sc3);
}| sum0 = DotQ5K(row, input0, cols); | ||
| sum1 = DotQ5K(row, input1, cols); | ||
| sum2 = DotQ5K(row, input2, cols); | ||
| sum3 = DotQ5K(row, input3, cols); |
There was a problem hiding this comment.
On non-AVX-512 systems, DotQ5K_4In falls back to calling DotQ5K sequentially. We can avoid redundant computations on duplicate inputs by checking if the input pointers are equal and reusing the computed dot products.
sum0 = DotQ5K(row, input0, cols);
sum1 = input1 == input0 ? sum0 : DotQ5K(row, input1, cols);
sum2 = input2 == input1 ? sum1 : (input2 == input0 ? sum0 : DotQ5K(row, input2, cols));
sum3 = input3 == input2 ? sum2 : (input3 == input1 ? sum1 : (input3 == input0 ? sum0 : DotQ5K(row, input3, cols)));| sum0 = DotQ6K_Q8K_Scalar(row, scratch0, numBlocks); | ||
| sum1 = DotQ6K_Q8K_Scalar(row, scratch1, numBlocks); | ||
| sum2 = DotQ6K_Q8K_Scalar(row, scratch2, numBlocks); | ||
| sum3 = DotQ6K_Q8K_Scalar(row, scratch3, numBlocks); |
There was a problem hiding this comment.
On non-AVX2/FMA systems, DotQ6K_Q8K_4In falls back to calling DotQ6K_Q8K_Scalar sequentially. We can avoid redundant computations on duplicate inputs by checking if the scratch pointers are equal and reusing the computed dot products.
sum0 = DotQ6K_Q8K_Scalar(row, scratch0, numBlocks);
sum1 = scratch1 == scratch0 ? sum0 : DotQ6K_Q8K_Scalar(row, scratch1, numBlocks);
sum2 = scratch2 == scratch1 ? sum1 : (scratch2 == scratch0 ? sum0 : DotQ6K_Q8K_Scalar(row, scratch2, numBlocks));
sum3 = scratch3 == scratch2 ? sum2 : (scratch3 == scratch1 ? sum1 : (scratch3 == scratch0 ? sum0 : DotQ6K_Q8K_Scalar(row, scratch3, numBlocks)));| Parallel.For(0, rows, s_parallelOpts, r => | ||
| { | ||
| float* row = m + (long)r * c; | ||
| o0[r] = DotF32(row, i0, c); | ||
| o1[r] = DotF32(row, i1, c); | ||
| o2[r] = DotF32(row, i2, c); | ||
| o3[r] = DotF32(row, i3, c); | ||
| }); |
There was a problem hiding this comment.
For Float32 inputs, we can avoid redundant DotF32 calls on duplicate inputs in the parallel path by checking if the input pointers are equal.
Parallel.For(0, rows, s_parallelOpts, r =>
{
float* row = m + (long)r * c;
o0[r] = DotF32(row, i0, c);
o1[r] = i1 == i0 ? o0[r] : DotF32(row, i1, c);
o2[r] = i2 == i1 ? o1[r] : (i2 == i0 ? o0[r] : DotF32(row, i2, c));
o3[r] = i3 == i2 ? o2[r] : (i3 == i1 ? o1[r] : (i3 == i0 ? o0[r] : DotF32(row, i3, c)));
});| for (int r = 0; r < rows; r++) | ||
| { | ||
| float* row = m + (long)r * cols; | ||
| output0[r] = DotF32(row, input0, cols); | ||
| output1[r] = DotF32(row, input1, cols); | ||
| output2[r] = DotF32(row, input2, cols); | ||
| output3[r] = DotF32(row, input3, cols); | ||
| } |
There was a problem hiding this comment.
For Float32 inputs, we can avoid redundant DotF32 calls on duplicate inputs in the sequential path by checking if the input pointers are equal.
for (int r = 0; r < rows; r++)
{
float* row = m + (long)r * cols;
output0[r] = DotF32(row, input0, cols);
output1[r] = input1 == input0 ? output0[r] : DotF32(row, input1, cols);
output2[r] = input2 == input1 ? output1[r] : (input2 == input0 ? output0[r] : DotF32(row, input2, cols));
output3[r] = input3 == input2 ? output2[r] : (input3 == input1 ? output1[r] : (input3 == input0 ? output0[r] : DotF32(row, input3, cols)));
}
Summary
Implements work item #1 of #209 (the scope chosen for this PR: the highest-leverage lever + the directly-coupled cleanups): a 4-input batched CPU mat-vec that amortizes the dominant CPU mmap dense-FFN weight read across four MTP draft tokens, plus the dead-code amputation and helper dedup the #208 review flagged. Items #2–#6 remain as follow-ups.
The 27B-MTP CUDA-hybrid decode cost center is the CPU-resident FFN layers (~6 GB/token). The prior
MatVec2Inonly pair-amortized that weight read, pinning the verify optimum at k=2; deeper chains lost to the per-step re-stream.MatVec4Inreads each weight row once per four tokens, moving the optimum to k=4.Changes
SimdKernels.MatVec4In+ register-tiledDotQ5K_4In(AVX-512) /DotQ6K_Q8K_4In(AVX2). Q4_K reuses the existing perf(engine,cpu,cuda): remaining GDN-hybrid prefill headroom after #111/#112 (N-input MoE dots + per-position recurrence/SDPA batching) #114DotQ4K_4In; F32 via 4×DotF32. Bit-identical per output slot to a singleMatVec(so per-position bits are k-parity-independent). Q8_0 deliberately uses the dequant fallback to stay bit-identical to single-token dense-FFN decode (MatVec/MatVecDualnever specialized it).CpuDenseFfn4/DenseFfn4+ 4-wide lm_head replace the 2-wide loops in both GDN passes'BatchVerify. The three hand-copied duplicated-input-tail idioms collapse into one sharedMtpBatchTail.Group4helper.ResolveMtpBatchMax2→4,ResolveDraftN1→3). The GDN ring alloc stops on OOM and clampsMaxBatchVerifyTokens, so tight-VRAM cards degrade gracefully.IForwardPass.LastHiddenT1accessor (no consumers —MtpDecoderdrivesBatchVerify/HiddenAt). The CPU pass's backing buffer is removed; the CUDA pass keeps_lastHiddenT1/_gpuLastHiddenT1as internalSHARPI_CPU_GDN=1debug-trunk scratch.Acceptance
Bench — 27B Q4_K_M CUDA-hybrid, RTX 4070 Ti,
-g -1 --no-thinking(decode t/s):MatVec4In≡ singleMatVec, proven byMatVec4In_BitwiseMatchesSingleMatVec)MtpDecoder_GreedyParity_LlamaCppuntouched (both 2In and 4In are per-token bit-identical to singleMatVec, so the CPU pass emits identical per-token logits — unaffected)bench-27b-mtp.ps1+ README row updatedTests
MatVec4In_BitwiseMatchesSingleMatVec— Q4_K/Q5_K/Q6_K/Q8_0/F32, serial +Parallel.For(rows=128) paths.MtpBatchTaillane-mapping oracle — k=1..9: real lanes cover every token once, tail clamps to last real.CudaMtpBatchVerifyTests) stays green on the real 27B model; coherent e2e output confirmed on hardware with dense FFN on CPU.🤖 Generated with Claude Code