Skip to content

test(cuda): harden TC flash-attention parity — startPos>0, single-partial-tile (nTok<16), TC1-only head_dim (#148 review) #151

Description

@pekkah

Deferred from the PR #148 review cycle (pr-review-toolkit pr-test-analyzer). The tensor-core flash-attention parity tests (CudaFlashAttnTcTests, both #146 single-warp and #147 multi-warp via RunParity) are solid for the Gemma 4 shapes but have three coverage gaps. None indicates a known bug; these harden the kernels against real production paths and shape edges.

1. startPos > 0 is never tested (sev 5 — real production path)

Every CudaFlashAttnTcTests config and the model-level oracle use startPos: 0. But continued-prefill (chat-continuation cache) re-prefills only the new tokens with startPos = priorLength, so FlashAttentionPrefillTc/Tc2 do run with startPos > 0 in production. The kernels handle it (qpos = start_pos + qi, windowing + causal mask interact with it) but it's unverified. Add a parity config that fills K/V for [0, startPos+nTok) and runs both the scalar reference and the TC kernel with startPos > 0.

2. No single-partial-tile config (nTok < 16, gy == 1) (sev 5)

Smallest current config is nTok=64. The gy = ceil(nTok/16) grid and the online-softmax masking are most fragile in the degenerate sub-tile case. Add a nTok ∈ {1, 7} config.

3. No genuinely TC1-only head_dim (%16==0 && %64!=0) (sev 4)

The #147 multi-warp kernel needs head_dim % 64 == 0; the #146 single-warp kernel is the %16 fallback. Gemma 4's head dims (256/512) are both %64, so at the model level TC1 is only ever reached via SHARPI_PREFILL_FLASH_TC1=1 (no test sets it). CudaFlashAttnTcTests covers hd=128 (still TC2-eligible) — there is no config with hd ∈ {48, 80, 112, ...} that exercises the TC1-only shared-O sizing path. Add one.

Files: tests/SharpInference.Tests.ForwardPass/CudaFlashAttnTcTests.cs (extend Configs() / RunParity).

Minor related polish (optional): the occupancy/perf figures in the TC kernel comments ("~2 warps/SM", "23-34% of int8 TC peak") are RTX 4070 Ti / Ada measurements stated without a device anchor — add "measured on RTX 4070 Ti" so they don't read as universal invariants.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions