Skip to content

perf(mtp): 4-input CPU MatVec4In for k>2 batched verify (#209)#287

Merged
pekkah merged 1 commit into
masterfrom
perf/209-matvec4in-mtp-batched-verify
Jun 17, 2026
Merged

perf(mtp): 4-input CPU MatVec4In for k>2 batched verify (#209)#287
pekkah merged 1 commit into
masterfrom
perf/209-matvec4in-mtp-batched-verify

Conversation

@pekkah

@pekkah pekkah commented Jun 17, 2026

Copy link
Copy Markdown
Owner

Summary

Implements work item #1 of #209 (the scope chosen for this PR: the highest-leverage lever + the directly-coupled cleanups): a 4-input batched CPU mat-vec that amortizes the dominant CPU mmap dense-FFN weight read across four MTP draft tokens, plus the dead-code amputation and helper dedup the #208 review flagged. Items #2#6 remain as follow-ups.

The 27B-MTP CUDA-hybrid decode cost center is the CPU-resident FFN layers (~6 GB/token). The prior MatVec2In only pair-amortized that weight read, pinning the verify optimum at k=2; deeper chains lost to the per-step re-stream. MatVec4In reads each weight row once per four tokens, moving the optimum to k=4.

Changes

  • SimdKernels.MatVec4In + register-tiled DotQ5K_4In (AVX-512) / DotQ6K_Q8K_4In (AVX2). Q4_K reuses the existing perf(engine,cpu,cuda): remaining GDN-hybrid prefill headroom after #111/#112 (N-input MoE dots + per-position recurrence/SDPA batching) #114 DotQ4K_4In; F32 via 4× DotF32. Bit-identical per output slot to a single MatVec (so per-position bits are k-parity-independent). Q8_0 deliberately uses the dequant fallback to stay bit-identical to single-token dense-FFN decode (MatVec/MatVecDual never specialized it).
  • CpuDenseFfn4 / DenseFfn4 + 4-wide lm_head replace the 2-wide loops in both GDN passes' BatchVerify. The three hand-copied duplicated-input-tail idioms collapse into one shared MtpBatchTail.Group4 helper.
  • Default verify batch k=2 → k=4 (ResolveMtpBatchMax 2→4, ResolveDraftN 1→3). The GDN ring alloc stops on OOM and clamps MaxBatchVerifyTokens, so tight-VRAM cards degrade gracefully.
  • Amputate the dead public IForwardPass.LastHiddenT1 accessor (no consumers — MtpDecoder drives BatchVerify/HiddenAt). The CPU pass's backing buffer is removed; the CUDA pass keeps _lastHiddenT1/_gpuLastHiddenT1 as internal SHARPI_CPU_GDN=1 debug-trunk scratch.

Acceptance

Bench — 27B Q4_K_M CUDA-hybrid, RTX 4070 Ti, -g -1 --no-thinking (decode t/s):

config decode t/s accept
MTP-off 6.5
k=2 (old default) 10.1 90%
k=4 (new default) 12.3 84%
k=6 10.4 83%
k=8 10.2 71%
  • ≥ 12 t/s at the new optimum k (12.3 vs 10.4 today, 6.2 MTP-off) — 1.9× over MTP-off, +22% over the old k=2 default
  • Per-position bit-identity across k parities preserved (MatVec4In ≡ single MatVec, proven by MatVec4In_BitwiseMatchesSingleMatVec)
  • MtpDecoder_GreedyParity_LlamaCpp untouched (both 2In and 4In are per-token bit-identical to single MatVec, so the CPU pass emits identical per-token logits — unaffected)
  • bench-27b-mtp.ps1 + README row updated

Tests

  • MatVec4In_BitwiseMatchesSingleMatVec — Q4_K/Q5_K/Q6_K/Q8_0/F32, serial + Parallel.For (rows=128) paths.
  • MtpBatchTail lane-mapping oracle — k=1..9: real lanes cover every token once, tail clamps to last real.
  • CUDA k=4 batched-verify suite (CudaMtpBatchVerifyTests) stays green on the real 27B model; coherent e2e output confirmed on hardware with dense FFN on CPU.

🤖 Generated with Claude Code

Work item #1 of #209: amortize the dominant CPU mmap dense-FFN weight read
across four MTP draft tokens. The 27B-MTP CUDA-hybrid decode cost center is the
46/64 CPU-resident FFN layers (~6 GB/token); the prior MatVec2In only
pair-amortized that read, so the verify optimum sat at k=2 (10.1 t/s) and deeper
chains lost to the linear-in-k re-stream.

- SimdKernels.MatVec4In + register-tiled DotQ5K_4In (AVX-512) / DotQ6K_Q8K_4In
  (AVX2); Q4_K reuses the existing #114 DotQ4K_4In, F32 via 4x DotF32. Each
  decodes one weight block once and FMAs four input columns, bit-identical per
  slot to a single MatVec (per-position bits are k-parity-independent).
  Q8_0 deliberately routes to the dequant fallback to stay bit-identical to
  single-token dense-FFN decode (MatVec/MatVecDual), which never specialized it.
- CpuDenseFfn4 / DenseFfn4 + a 4-wide lm_head replace the 2-wide loops in both
  GDN passes' BatchVerify. The three hand-copied duplicated-input-tail idioms
  collapse into one shared MtpBatchTail.Group4 helper (clamps past-the-end lanes
  to the last real token, routes their output to a sink).
- Default verify batch moves k=2 -> k=4 (ResolveMtpBatchMax 2->4,
  ResolveDraftN 1->3) now that the 4-input kernel makes k=4 the measured optimum.
  The GDN ring alloc stops on OOM and clamps MaxBatchVerifyTokens, so tight-VRAM
  cards degrade gracefully.
- Amputate the dead public IForwardPass.LastHiddenT1 accessor (no consumers;
  MtpDecoder drives BatchVerify/HiddenAt). The CPU pass's backing buffer is
  removed; the CUDA pass keeps _lastHiddenT1/_gpuLastHiddenT1 as internal
  SHARPI_CPU_GDN=1 debug-trunk scratch.

Bench (27B Q4_K_M CUDA-hybrid, RTX 4070 Ti, -g -1 --no-thinking): decode
MTP-off 6.5, k=2 10.1, k=4 12.3, k=6 10.4, k=8 10.2 t/s. New default 12.3 t/s
(84% accept) = 1.9x over MTP-off, +22% over the old k=2 default.

Tests: MatVec4In_BitwiseMatchesSingleMatVec (Q4_K/Q5_K/Q6_K/Q8_0/F32, serial +
Parallel.For) and MtpBatchTail lane-mapping (k=1..9) are bitwise oracles; the
CUDA k=4 batched-verify suite stays green. MtpDecoder_GreedyParity_LlamaCpp is
untouched and unaffected (both 2In and 4In are per-token bit-identical to single
MatVec, so the CPU pass emits identical per-token logits).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@pekkah

pekkah commented Jun 17, 2026

Copy link
Copy Markdown
Owner Author

/gemini review

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements a 4-input fused matrix-vector multiplication kernel (MatVec4In) to optimize the CPU mmap dense FFN evaluation during multi-token prediction (MTP) batched verification (Issue #209). It introduces register-tiled quad-batch dot product implementations for Q4_K, Q5_K, and Q6_K quantizations, as well as Float32 and fallbacks. The default MTP batch size is increased from 2 to 4 (SHARPI_MTP_BATCH_MAX defaults to 4, and draft-chain length defaults to 3), yielding a performance improvement on CUDA-hybrid backends. Additionally, the PR centralizes the lane-to-token mapping logic in a new MtpBatchTail helper class, cleans up the unused LastHiddenT1 property, and adds comprehensive unit tests to verify the correctness and bit-identity of the new 4-input SIMD kernels. I have no feedback to provide as there are no review comments.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a width-4 batched-verify (MTP) implementation, upgrading the previous pairwise (k=2) batching to quad-batching (k=4) to improve performance. It adds a new MatVec4In fused mat-vec kernel in SimdKernels.cs for various data types (Q4_K, Q5_K, Q6_K, and Float32) and implements corresponding 4-input dot products. The engine's forward passes (CudaHybridGdnForwardPass and HybridGdnForwardPass) are updated to support the new 4-token dense FFN and lm_head operations, utilizing a new centralized MtpBatchTail helper for lane-to-token mapping. The default draft-chain length is increased to 3 (k=4 verify batch), and corresponding unit tests are added. Feedback on the CPU kernels suggests optimizing fallback and sequential paths by checking for duplicate input pointers (which occur in the partial-group tail) to avoid redundant allocations, quantization, and dot-product computations.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +562 to +569
byte* sc0 = stackalloc byte[scratchBytes];
byte* sc1 = stackalloc byte[scratchBytes];
byte* sc2 = stackalloc byte[scratchBytes];
byte* sc3 = stackalloc byte[scratchBytes];
QuantizeRowToQ8K(input0, cols, sc0);
QuantizeRowToQ8K(input1, cols, sc1);
QuantizeRowToQ8K(input2, cols, sc2);
QuantizeRowToQ8K(input3, cols, sc3);

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

When the batch size k is not a multiple of 4, the duplicate-input tail clamps empty lanes to the last real token, resulting in duplicate input pointers (e.g., input1 == input2). We can avoid redundant QuantizeRowToQ8K calls and stack allocations by checking for duplicate input pointers and reusing the already-computed scratch buffers.

                byte* sc0 = stackalloc byte[scratchBytes];
                QuantizeRowToQ8K(input0, cols, sc0);

                byte* sc1;
                if (input1 == input0)
                {
                    sc1 = sc0;
                }
                else
                {
                    sc1 = stackalloc byte[scratchBytes];
                    QuantizeRowToQ8K(input1, cols, sc1);
                }

                byte* sc2;
                if (input2 == input1)
                {
                    sc2 = sc1;
                }
                else if (input2 == input0)
                {
                    sc2 = sc0;
                }
                else
                {
                    sc2 = stackalloc byte[scratchBytes];
                    QuantizeRowToQ8K(input2, cols, sc2);
                }

                byte* sc3;
                if (input3 == input2)
                {
                    sc3 = sc2;
                }
                else if (input3 == input1)
                {
                    sc3 = sc1;
                }
                else if (input3 == input0)
                {
                    sc3 = sc0;
                }
                else
                {
                    sc3 = stackalloc byte[scratchBytes];
                    QuantizeRowToQ8K(input3, cols, sc3);
                }

Comment on lines +1034 to +1037
sum0 = DotQ5K(row, input0, cols);
sum1 = DotQ5K(row, input1, cols);
sum2 = DotQ5K(row, input2, cols);
sum3 = DotQ5K(row, input3, cols);

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

On non-AVX-512 systems, DotQ5K_4In falls back to calling DotQ5K sequentially. We can avoid redundant computations on duplicate inputs by checking if the input pointers are equal and reusing the computed dot products.

        sum0 = DotQ5K(row, input0, cols);
        sum1 = input1 == input0 ? sum0 : DotQ5K(row, input1, cols);
        sum2 = input2 == input1 ? sum1 : (input2 == input0 ? sum0 : DotQ5K(row, input2, cols));
        sum3 = input3 == input2 ? sum2 : (input3 == input1 ? sum1 : (input3 == input0 ? sum0 : DotQ5K(row, input3, cols)));

Comment on lines +3533 to +3536
sum0 = DotQ6K_Q8K_Scalar(row, scratch0, numBlocks);
sum1 = DotQ6K_Q8K_Scalar(row, scratch1, numBlocks);
sum2 = DotQ6K_Q8K_Scalar(row, scratch2, numBlocks);
sum3 = DotQ6K_Q8K_Scalar(row, scratch3, numBlocks);

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

On non-AVX2/FMA systems, DotQ6K_Q8K_4In falls back to calling DotQ6K_Q8K_Scalar sequentially. We can avoid redundant computations on duplicate inputs by checking if the scratch pointers are equal and reusing the computed dot products.

        sum0 = DotQ6K_Q8K_Scalar(row, scratch0, numBlocks);
        sum1 = scratch1 == scratch0 ? sum0 : DotQ6K_Q8K_Scalar(row, scratch1, numBlocks);
        sum2 = scratch2 == scratch1 ? sum1 : (scratch2 == scratch0 ? sum0 : DotQ6K_Q8K_Scalar(row, scratch2, numBlocks));
        sum3 = scratch3 == scratch2 ? sum2 : (scratch3 == scratch1 ? sum1 : (scratch3 == scratch0 ? sum0 : DotQ6K_Q8K_Scalar(row, scratch3, numBlocks)));

Comment on lines +600 to +607
Parallel.For(0, rows, s_parallelOpts, r =>
{
float* row = m + (long)r * c;
o0[r] = DotF32(row, i0, c);
o1[r] = DotF32(row, i1, c);
o2[r] = DotF32(row, i2, c);
o3[r] = DotF32(row, i3, c);
});

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For Float32 inputs, we can avoid redundant DotF32 calls on duplicate inputs in the parallel path by checking if the input pointers are equal.

                    Parallel.For(0, rows, s_parallelOpts, r =>
                    {
                        float* row = m + (long)r * c;
                        o0[r] = DotF32(row, i0, c);
                        o1[r] = i1 == i0 ? o0[r] : DotF32(row, i1, c);
                        o2[r] = i2 == i1 ? o1[r] : (i2 == i0 ? o0[r] : DotF32(row, i2, c));
                        o3[r] = i3 == i2 ? o2[r] : (i3 == i1 ? o1[r] : (i3 == i0 ? o0[r] : DotF32(row, i3, c)));
                    });

Comment on lines +611 to +618
for (int r = 0; r < rows; r++)
{
float* row = m + (long)r * cols;
output0[r] = DotF32(row, input0, cols);
output1[r] = DotF32(row, input1, cols);
output2[r] = DotF32(row, input2, cols);
output3[r] = DotF32(row, input3, cols);
}

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For Float32 inputs, we can avoid redundant DotF32 calls on duplicate inputs in the sequential path by checking if the input pointers are equal.

                    for (int r = 0; r < rows; r++)
                    {
                        float* row = m + (long)r * cols;
                        output0[r] = DotF32(row, input0, cols);
                        output1[r] = input1 == input0 ? output0[r] : DotF32(row, input1, cols);
                        output2[r] = input2 == input1 ? output1[r] : (input2 == input0 ? output0[r] : DotF32(row, input2, cols));
                        output3[r] = input3 == input2 ? output2[r] : (input3 == input1 ? output1[r] : (input3 == input0 ? output0[r] : DotF32(row, input3, cols)));
                    }

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant