Follow-up to #209 (work item #2). #209's #1 (4-input CPU MatVec4In) shipped in #287 and moved the 27B-MTP CUDA-hybrid verify optimum from k=2 (10.1 t/s) to k=4 (12.3 t/s). The bench shows k>4 now regresses (k=6 10.4, k=8 10.2 t/s): beyond k=4 the GPU trunk + lm_head run the linear-in-k matvec re-stream (GpuMatMulBatched → GEMM-N for small N, CudaHybridGdnForwardPass.cs), so per-step cost grows ~linearly while CPU FFN is now flat.
Goal
Wire the proven weight-stationary batched-decode kernels (CudaWsKernels.cs, the #194 pattern — dense path measured 2.42× at N=8) into the GDN-hybrid GPU trunk so the QKV/gate/SSM projections + lm_head amortize the weight HBM read N× instead of re-streaming per token. This should push the optimum past k=4.
Notes
GpuMatMulBatched (the GDN trunk dispatch) currently routes small-N to MatMulBatched (GEMM-N re-stream); MMQ/dequant-GEMM only engage at nTok > MatMulComputeBatchMinN (8) because their fp16 temps WDDM-page at decode k. WS keeps GEMM-N's near-zero fixed cost + adds the weight-read amortization GEMM-N lacks.
- Must stay bit-identical per slot (the verify path's k-parity-independence contract); the dense path's WS kernels are argmax-stable.
- After this lands, re-evaluate raising the default
SHARPI_MTP_BATCH_MAX/SHARPI_MTP_DRAFT_N past 4.
Follow-up to #209 (work item #2). #209's #1 (4-input CPU
MatVec4In) shipped in #287 and moved the 27B-MTP CUDA-hybrid verify optimum from k=2 (10.1 t/s) to k=4 (12.3 t/s). The bench shows k>4 now regresses (k=6 10.4, k=8 10.2 t/s): beyond k=4 the GPU trunk + lm_head run the linear-in-k matvec re-stream (GpuMatMulBatched→ GEMM-N for small N,CudaHybridGdnForwardPass.cs), so per-step cost grows ~linearly while CPU FFN is now flat.Goal
Wire the proven weight-stationary batched-decode kernels (
CudaWsKernels.cs, the #194 pattern — dense path measured 2.42× at N=8) into the GDN-hybrid GPU trunk so the QKV/gate/SSM projections + lm_head amortize the weight HBM read N× instead of re-streaming per token. This should push the optimum past k=4.Notes
GpuMatMulBatched(the GDN trunk dispatch) currently routes small-N toMatMulBatched(GEMM-N re-stream); MMQ/dequant-GEMM only engage atnTok > MatMulComputeBatchMinN(8) because their fp16 temps WDDM-page at decode k. WS keeps GEMM-N's near-zero fixed cost + adds the weight-read amortization GEMM-N lacks.SHARPI_MTP_BATCH_MAX/SHARPI_MTP_DRAFT_Npast 4.