Skip to content

perf(mtp): weight-stationary batched-decode kernels for the hybrid GPU trunk (#209 item 2) #288

@pekkah

Description

@pekkah

Follow-up to #209 (work item #2). #209's #1 (4-input CPU MatVec4In) shipped in #287 and moved the 27B-MTP CUDA-hybrid verify optimum from k=2 (10.1 t/s) to k=4 (12.3 t/s). The bench shows k>4 now regresses (k=6 10.4, k=8 10.2 t/s): beyond k=4 the GPU trunk + lm_head run the linear-in-k matvec re-stream (GpuMatMulBatched → GEMM-N for small N, CudaHybridGdnForwardPass.cs), so per-step cost grows ~linearly while CPU FFN is now flat.

Goal

Wire the proven weight-stationary batched-decode kernels (CudaWsKernels.cs, the #194 pattern — dense path measured 2.42× at N=8) into the GDN-hybrid GPU trunk so the QKV/gate/SSM projections + lm_head amortize the weight HBM read N× instead of re-streaming per token. This should push the optimum past k=4.

Notes

  • GpuMatMulBatched (the GDN trunk dispatch) currently routes small-N to MatMulBatched (GEMM-N re-stream); MMQ/dequant-GEMM only engage at nTok > MatMulComputeBatchMinN (8) because their fp16 temps WDDM-page at decode k. WS keeps GEMM-N's near-zero fixed cost + adds the weight-read amortization GEMM-N lacks.
  • Must stay bit-identical per slot (the verify path's k-parity-independence contract); the dense path's WS kernels are argmax-stable.
  • After this lands, re-evaluate raising the default SHARPI_MTP_BATCH_MAX/SHARPI_MTP_DRAFT_N past 4.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions