test(cuda): harden TC flash-attention parity (#151)#245
Conversation
…TC1-only head_dim (#151) Closes the three coverage gaps from the PR #148 review: 1. startPos > 0 — the continued-prefill / chat-continuation re-prefill path that production runs but no test exercised. RunParity now sizes K/V for the full [0, startPos+nTok) history while Q carries only the new nTok tokens, and passes startPos through to both the reference and the TC kernels. 2. single partial tile (nTok < 16, gy == 1) — added nTok ∈ {1, 7} configs, with/without window and with startPos>0. Both kernels guard their output store (TC1 `q < n_tok`, TC2 `active`), so no OOB. 3. TC1-only head_dim (%16==0 && %64!=0) — new Tc1OnlyConfigs (hd 80/48/ 112) run in the #146 single-warp path only (TC2 requires %64), hitting the shared-O sizing path the model level never reaches. All 14 TC1 + 11 TC2 configs match the scalar batched reference with 0 mismatches. Also anchored the kernel occupancy figures ("~2 warps/SM", "~10×") to the measured RTX 4070 Ti / Ada, per the review's minor note. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request addresses Issue #151 by hardening test coverage for three specific gaps in the CUDA Flash Attention Tensor Core implementation: continued-prefill paths with startPos > 0, single partial tiles where nTok < 16, and head dimensions that are multiples of 16 but not 64 (which are exclusive to the TC1 single-warp path). It updates the test suite to include these scenarios and verifies parity against the scalar reference. Additionally, it refines documentation and inline comments in CudaBackend.cs and CudaTextKernels.cs to specify that the occupancy measurements were taken on an RTX 4070 Ti / Ada GPU. There are no review comments, so I have no feedback to provide.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
…t (review) - Note that the startPos>0 configs keep maxSeqLen=startPos+nTok (flat cache), so they validate the mask/key-span interaction with startPos, not the SWA ring wrap (covered at the model level elsewhere) — per the pr-test-analyzer review. - Rewrap the FlashAttentionPrefillTc2 doc comment so "Requires" doesn't dangle. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Closes #151 — the three TC flash-attention parity coverage gaps deferred from the PR #148 review (
pr-test-analyzer). Test-only + comment anchors; no kernel/behavior changes.Gaps closed (in
CudaFlashAttnTcTests)startPos > 0(sev 5 — real production path). Continued-prefill / chat-continuation re-prefill runsFlashAttentionPrefillTc/Tc2withstartPos = priorLength, previously unverified.RunParitynow sizes K/V for the full[0, startPos+nTok)history while Q carries only the newnToktokens, and threadsstartPosthrough both the reference (AttentionBatched/AttentionSwaBatched) and the TC kernels. Added a global config (startPos 137) and a SWA one whose window sits well inside the prior context (startPos 211,win 96).nTok < 16,gy == 1) (sev 5). AddednTok ∈ {1, 7}configs — global, windowed, and withstartPos>0. Verified both kernels guard the output store (TC1: q < n_tok,TC2: active), so the 16-row tile never writes past thenTok-sized output.head_dim(%16==0 && %64!=0) (sev 4). NewTc1OnlyConfigs(hd 80/48/112) run in the perf(cuda): full tensor-core (mma.sync) flash-attention prefill for d=512 — beat the half2 kernel (#141 follow-up) #146 single-warp path only — TC2 requires%64and throws — exercising the shared-O sizing path the model level only reaches viaSHARPI_PREFILL_FLASH_TC1=1.Plus the review's minor note: anchored the kernel occupancy figures ("~2 warps/SM", "~10×") to the measured RTX 4070 Ti / Ada so they don't read as universal invariants.
Result
All 14 TC1 + 11 TC2 configs match the scalar batched reference with 0 mismatches (maxAbs ~1e-4–4e-4 vs the 2e-2·rms threshold). No bug surfaced — these harden the kernels against the production paths and shape edges. The
startPos>0rows show distinct rms, confirming they aren't trivially identical to thestartPos=0cases.🤖 Generated with Claude Code