Skip to content

perf(cuda): evaluate routing Gemma 4 (E-series) PLE matmuls through BatchDecodeMatMul — profile-gated, likely not a win #286

@pekkah

Description

@pekkah

Deferred secondary from #283 (the PLE part of the #275 follow-up).

Context

The Gemma 4 batched decode (RunBatchedTrunkGemma4 / GpuLayerBatchedDecodeGemma4) routes the trunk matmuls (Q/K/V/O, gate/up/down, lm-head) through BatchDecodeMatMul (#275), but the PLE pre-pass + injection matmuls (_gpuInpGate / _gpuPleProj / per_layer_model_proj) deliberately stay on cuBLAS GpuMatMulBatched. #283 asked whether routing them through BatchDecodeMatMul too is a decode win.

Assessment (from #283) — likely NOT worth it

  1. The PLE matmul weights are F32 (_gpuInpGate is [PleWidth, embDim] F32, _gpuPleProj is [embDim, PleWidth] F32; see CudaForwardPass.cs ~lines 151-152). The perf(cuda): batched-decode WS matvecs are LSU/latency-bound at N=8 — 27.7 of 32.5 ms/step, 3.3x over weight floor (#197 follow-up) #201/perf(cuda): evaluate SHARPI_BATCH_DECODE_MMQ default-on for batched serving (#203 item 5) #206 int8 decode-MMQ tile is Q4_K/Q6_K-only, so it can never engage for these — routing them through BatchDecodeMatMul would only swap cuBLAS-GEMM for the WS matvec, not unlock the decode-MMQ lever.
  2. The shapes are narrow: inp_gate output rows = pleWidth (~256, below the 2048 decode-MMQ floor anyway), proj is [embDim × ~256]. These are a small fraction of per-layer decode cost, which is dominated by the trunk GEMMs.
  3. Scope is E-series only: the 12B (Gemma 4 batched decode: validate on a GEMM-N-batchable 12B (real k_eq_v coverage + Q4_K decode-MMQ) — #275/#276 follow-up #283's model) has no PLE (embedding_length_per_layer_input = 0), so this concerns only E2B/E4B.

What would change the call

Only ship this if a real Gemma 4 (E-series) decode profile shows the PLE pre-pass + injection matmuls are a non-trivial fraction of decode time. Bench before shipping (same discipline as #275; Gemma4CudaBatchedDecodeBench can A/B it once the routing exists). Until then, leaving PLE on cuBLAS GEMM is the right default.

Refs: #283, #275 (PR #280), #195, #190 (umbrella).

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions