From ba291a93be535b3522992733983c1b296dc54ab7 Mon Sep 17 00:00:00 2001 From: Pekka Heikura Date: Sun, 14 Jun 2026 13:21:53 +0300 Subject: [PATCH 1/2] =?UTF-8?q?test(cuda):=20harden=20TC=20flash-attention?= =?UTF-8?q?=20parity=20=E2=80=94=20startPos>0,=20sub-tile,=20TC1-only=20he?= =?UTF-8?q?ad=5Fdim=20(#151)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Closes the three coverage gaps from the PR #148 review: 1. startPos > 0 — the continued-prefill / chat-continuation re-prefill path that production runs but no test exercised. RunParity now sizes K/V for the full [0, startPos+nTok) history while Q carries only the new nTok tokens, and passes startPos through to both the reference and the TC kernels. 2. single partial tile (nTok < 16, gy == 1) — added nTok ∈ {1, 7} configs, with/without window and with startPos>0. Both kernels guard their output store (TC1 `q < n_tok`, TC2 `active`), so no OOB. 3. TC1-only head_dim (%16==0 && %64!=0) — new Tc1OnlyConfigs (hd 80/48/ 112) run in the #146 single-warp path only (TC2 requires %64), hitting the shared-O sizing path the model level never reaches. All 14 TC1 + 11 TC2 configs match the scalar batched reference with 0 mismatches. Also anchored the kernel occupancy figures ("~2 warps/SM", "~10×") to the measured RTX 4070 Ti / Ada, per the review's minor note. Co-Authored-By: Claude Opus 4.8 --- src/SharpInference.Cuda/CudaBackend.cs | 2 +- src/SharpInference.Cuda/CudaTextKernels.cs | 11 +-- .../CudaFlashAttnTcTests.cs | 77 ++++++++++++++----- 3 files changed, 63 insertions(+), 27 deletions(-) diff --git a/src/SharpInference.Cuda/CudaBackend.cs b/src/SharpInference.Cuda/CudaBackend.cs index 52b9b58..fadf0eb 100644 --- a/src/SharpInference.Cuda/CudaBackend.cs +++ b/src/SharpInference.Cuda/CudaBackend.cs @@ -4225,7 +4225,7 @@ public void FlashAttentionPrefillTc(Tensor qAll, Tensor kCache, Tensor vCache, T /// Issue #147: multi-warp / d-split tensor-core flash-attention prefill — same /// args/semantics as but W=4 warps cooperate /// on each 16-query tile with the head dim split across them, so O is register- - /// resident (no shared-O rescale) and occupancy rises ~10×. Requires + /// resident (no shared-O rescale) and occupancy rises ~10× (RTX 4070 Ti / Ada). Requires /// % 64 == 0 (W·16) and ≤ 512. Argmax-stable, not bit-exact. /// /// K/V cache element dtype (issue #179): Float32 (default), diff --git a/src/SharpInference.Cuda/CudaTextKernels.cs b/src/SharpInference.Cuda/CudaTextKernels.cs index f83c2e3..bd44ea4 100644 --- a/src/SharpInference.Cuda/CudaTextKernels.cs +++ b/src/SharpInference.Cuda/CudaTextKernels.cs @@ -6021,11 +6021,12 @@ asm volatile( // ── Multi-warp / d-split tensor-core flash-attention prefill (issue #147) ─── // Fixes the single-warp llm_flash_attn_prefill_tc occupancy limit (1 warp/block + -// 48 KB shared → ~2 warps/SM). Here a block is W warps that cooperate on ONE -// 16-query tile, splitting the head dim: warp w owns output columns [w·dW, …) with -// dW = head_dim/W, so O[16×dW] stays REGISTER-resident (16×128 = 64 regs/lane at -// d=512,W=4) instead of in shared — no per-key-tile shared-O rescale traffic, and -// the freed shared lets occupancy rise to ~16-20 warps/SM. Each warp computes a +// 48 KB shared → ~2 warps/SM, measured on RTX 4070 Ti / Ada). Here a block is W +// warps that cooperate on ONE 16-query tile, splitting the head dim: warp w owns +// output columns [w·dW, …) with dW = head_dim/W, so O[16×dW] stays REGISTER-resident +// (16×128 = 64 regs/lane at d=512,W=4) instead of in shared — no per-key-tile +// shared-O rescale traffic, and the freed shared lets occupancy rise to ~16-20 +// warps/SM (RTX 4070 Ti / Ada). Each warp computes a // PARTIAL QK^T over its d-slice; the partials are summed across warps through a // small shared S buffer ([W×16×16] fp32), after which every warp holds the full // reduced score tile in its C-fragment and proceeds exactly like the single-warp diff --git a/tests/SharpInference.Tests.ForwardPass/CudaFlashAttnTcTests.cs b/tests/SharpInference.Tests.ForwardPass/CudaFlashAttnTcTests.cs index 0be568f..ce95da3 100644 --- a/tests/SharpInference.Tests.ForwardPass/CudaFlashAttnTcTests.cs +++ b/tests/SharpInference.Tests.ForwardPass/CudaFlashAttnTcTests.cs @@ -13,6 +13,15 @@ namespace SharpInference.Tests.ForwardPass; /// looser fp16 tolerance. Same config matrix as the half2 flash test: GQA, both /// Gemma 4 head_dims (256 SWA / 512 global), causal, windowing, partial last tile. /// +/// Issue #151 hardens three coverage gaps from the #148 review: +/// 1. startPos > 0 — the continued-prefill / chat-continuation re-prefill path +/// (Q is only the new tokens at absolute positions [startPos, startPos+nTok), +/// while K/V hold the full [0, startPos+nTok) history). +/// 2. single partial tile (nTok < 16, gy == 1) — the degenerate +/// online-softmax / masking case. +/// 3. a genuinely TC1-only head_dim (%16==0 && %64!=0) — the #146 single-warp +/// shared-O sizing path that the #147 multi-warp kernel (needs %64) never reaches. +/// /// Silent no-op on hosts without CUDA, matching the other Cuda* test files. /// public sealed unsafe class CudaFlashAttnTcTests @@ -24,15 +33,36 @@ public sealed unsafe class CudaFlashAttnTcTests catch { return null; } } - // (numHeads, numKvHeads, headDim, window, nTok). window=0 → global (full causal). - // All head_dims are multiples of 64 so the same matrix exercises tc2 (#147, W·16=64). - private static (int nh, int nkv, int hd, int win, int nTok)[] Configs() => new[] + // (numHeads, numKvHeads, headDim, window, nTok, startPos). window=0 → global (full causal). + // All head_dims are multiples of 64 so the same matrix exercises both the #146 single-warp + // (tc) and the #147 multi-warp (tc2, W·16=64) kernels. + private static (int nh, int nkv, int hd, int win, int nTok, int startPos)[] Configs() => new[] + { + (8, 2, 256, 0, 200, 0), // global, SWA head_dim + (8, 2, 512, 0, 173, 0), // global, global head_dim, partial last tile + (8, 2, 256, 64, 200, 0), // sliding window 64 < nTok + (8, 2, 512, 96, 130, 0), // sliding window 96, global head_dim + (4, 4, 128, 0, 64, 0), // MHA (no GQA), small head_dim + + // #151 gap 1 — startPos > 0 (continued prefill): K/V carry prior context. + (8, 2, 256, 0, 64, 137), // global, prior context before the new tokens + (8, 2, 512, 96, 80, 211), // SWA, window (96) bounded well inside the prior context (startPos 211) + + // #151 gap 2 — single partial tile (nTok < 16, gy == 1). + (8, 2, 256, 0, 1, 0), // single query, single key + (8, 2, 256, 0, 7, 0), // sub-tile (7 < 16), global + (8, 2, 256, 32, 7, 0), // sub-tile with a sliding window + (8, 2, 256, 0, 7, 40), // sub-tile with prior context (nTok<16 AND startPos>0) + }; + + // #151 gap 3 — head_dim % 16 == 0 but % 64 != 0. Only the #146 single-warp TC1 kernel + // (shared-O sizing) accepts these; FlashAttentionPrefillTc2 requires % 64 and throws, + // so these run in the TC1 path only. + private static (int nh, int nkv, int hd, int win, int nTok, int startPos)[] Tc1OnlyConfigs() => new[] { - (8, 2, 256, 0, 200), // global, SWA head_dim - (8, 2, 512, 0, 173), // global, global head_dim, partial last tile - (8, 2, 256, 64, 200), // sliding window 64 < nTok - (8, 2, 512, 96, 130), // sliding window 96, global head_dim - (4, 4, 128, 0, 64), // MHA (no GQA), small head_dim + (8, 2, 80, 0, 130, 0), // hd=80 (5×16), global — exercises the TC1-only shared-O path + (8, 2, 48, 64, 96, 0), // hd=48 (3×16), sliding window + (4, 4, 112, 0, 33, 17), // hd=112 (7×16), MHA, startPos>0, partial last tile }; [Fact] @@ -40,7 +70,8 @@ public void FlashAttentionPrefillTc_MatchesScalarBatched() // #146 single-warp { using var gpu = TryCreate(); if (gpu is null) return; - RunParity(gpu, tc2: false, label: "TC1"); + RunParity(gpu, Configs(), tc2: false, label: "TC1"); + RunParity(gpu, Tc1OnlyConfigs(), tc2: false, label: "TC1-only-hd"); } [Fact] @@ -48,20 +79,24 @@ public void FlashAttentionPrefillTc2_MatchesScalarBatched() // #147 multi-warp/ { using var gpu = TryCreate(); if (gpu is null) return; - RunParity(gpu, tc2: true, label: "TC2"); + RunParity(gpu, Configs(), tc2: true, label: "TC2"); } - private static void RunParity(CudaBackend gpu, bool tc2, string label) + private static void RunParity( + CudaBackend gpu, + (int nh, int nkv, int hd, int win, int nTok, int startPos)[] configs, + bool tc2, string label) { - foreach (var (nh, nkv, hd, win, nTok) in Configs()) + foreach (var (nh, nkv, hd, win, nTok, startPos) in configs) { - var rng = new Random(20260606 + nh * 7 + hd * 13 + win * 17 + nTok); + var rng = new Random(20260606 + nh * 7 + hd * 13 + win * 17 + nTok + startPos * 101); int qDim = nh * hd, kvDim = nkv * hd; + int kvLen = startPos + nTok; // K/V cache holds the full [0, startPos+nTok) history var q = new float[(long)nTok * qDim]; for (int i = 0; i < q.Length; i++) q[i] = (float)(rng.NextDouble() * 2 - 1); - var k = new float[(long)nTok * kvDim]; - var v = new float[(long)nTok * kvDim]; + var k = new float[(long)kvLen * kvDim]; + var v = new float[(long)kvLen * kvDim]; for (int i = 0; i < k.Length; i++) { k[i] = (float)(rng.NextDouble() * 2 - 1); v[i] = (float)(rng.NextDouble() * 2 - 1); } var gq = gpu.Upload(q, TensorShape.D1(q.Length)); @@ -71,13 +106,13 @@ private static void RunParity(CudaBackend gpu, bool tc2, string label) var gTc = gpu.Allocate(TensorShape.D1(q.Length)); if (win == 0) - gpu.AttentionBatched(gq, gk, gv, gRef, nh, nkv, hd, startPos: 0, maxSeqLen: nTok, nTok: nTok); + gpu.AttentionBatched(gq, gk, gv, gRef, nh, nkv, hd, startPos, maxSeqLen: kvLen, nTok: nTok); else - gpu.AttentionSwaBatched(gq, gk, gv, gRef, nh, nkv, hd, startPos: 0, windowSize: win, maxSeqLen: nTok, nTok: nTok); + gpu.AttentionSwaBatched(gq, gk, gv, gRef, nh, nkv, hd, startPos, windowSize: win, maxSeqLen: kvLen, nTok: nTok); if (tc2) - gpu.FlashAttentionPrefillTc2(gq, gk, gv, gTc, nh, nkv, hd, startPos: 0, windowSize: win, maxSeqLen: nTok, nTok: nTok); + gpu.FlashAttentionPrefillTc2(gq, gk, gv, gTc, nh, nkv, hd, startPos, windowSize: win, maxSeqLen: kvLen, nTok: nTok); else - gpu.FlashAttentionPrefillTc(gq, gk, gv, gTc, nh, nkv, hd, startPos: 0, windowSize: win, maxSeqLen: nTok, nTok: nTok); + gpu.FlashAttentionPrefillTc(gq, gk, gv, gTc, nh, nkv, hd, startPos, windowSize: win, maxSeqLen: kvLen, nTok: nTok); gpu.Synchronize(); var outRef = new float[q.Length]; @@ -103,10 +138,10 @@ private static void RunParity(CudaBackend gpu, bool tc2, string label) if (diff > 2e-2f * rms) mismatches++; } Console.WriteLine( - $"Flash{label} nh={nh} nkv={nkv} hd={hd} win={win} nTok={nTok}: maxAbs={maxAbs:E2} rms={rms:E2} mismatches={mismatches}/{outRef.Length}"); + $"Flash{label} nh={nh} nkv={nkv} hd={hd} win={win} nTok={nTok} startPos={startPos}: maxAbs={maxAbs:E2} rms={rms:E2} mismatches={mismatches}/{outRef.Length}"); // Allow a tiny tail of outliers from fp16 P·V accumulation; the bulk must match. Assert.True(mismatches <= outRef.Length / 200 + 1, - $"{label} flash attention diverged from scalar reference: {mismatches}/{outRef.Length} beyond 2e-2·rms ({rms:E3}), maxAbs={maxAbs:E3} (nh={nh} hd={hd} win={win} nTok={nTok})."); + $"{label} flash attention diverged from scalar reference: {mismatches}/{outRef.Length} beyond 2e-2·rms ({rms:E3}), maxAbs={maxAbs:E3} (nh={nh} hd={hd} win={win} nTok={nTok} startPos={startPos})."); } } } From fbe0151a035ade3f6991b8632e400a8701819f61 Mon Sep 17 00:00:00 2001 From: Pekka Heikura Date: Sun, 14 Jun 2026 13:27:17 +0300 Subject: [PATCH 2/2] test(cuda): clarify flat-cache SWA+startPos config; rewrap doc comment (review) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Note that the startPos>0 configs keep maxSeqLen=startPos+nTok (flat cache), so they validate the mask/key-span interaction with startPos, not the SWA ring wrap (covered at the model level elsewhere) — per the pr-test-analyzer review. - Rewrap the FlashAttentionPrefillTc2 doc comment so "Requires" doesn't dangle. Co-Authored-By: Claude Opus 4.8 --- src/SharpInference.Cuda/CudaBackend.cs | 4 ++-- .../SharpInference.Tests.ForwardPass/CudaFlashAttnTcTests.cs | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/SharpInference.Cuda/CudaBackend.cs b/src/SharpInference.Cuda/CudaBackend.cs index fadf0eb..2c615e2 100644 --- a/src/SharpInference.Cuda/CudaBackend.cs +++ b/src/SharpInference.Cuda/CudaBackend.cs @@ -4225,8 +4225,8 @@ public void FlashAttentionPrefillTc(Tensor qAll, Tensor kCache, Tensor vCache, T /// Issue #147: multi-warp / d-split tensor-core flash-attention prefill — same /// args/semantics as but W=4 warps cooperate /// on each 16-query tile with the head dim split across them, so O is register- - /// resident (no shared-O rescale) and occupancy rises ~10× (RTX 4070 Ti / Ada). Requires - /// % 64 == 0 (W·16) and ≤ 512. Argmax-stable, not bit-exact. + /// resident (no shared-O rescale) and occupancy rises ~10× (RTX 4070 Ti / Ada). + /// Requires % 64 == 0 (W·16) and ≤ 512. Argmax-stable, not bit-exact. /// /// K/V cache element dtype (issue #179): Float32 (default), /// BFloat16, or Q8_0. The matching templated thunk decodes each element to fp32 on diff --git a/tests/SharpInference.Tests.ForwardPass/CudaFlashAttnTcTests.cs b/tests/SharpInference.Tests.ForwardPass/CudaFlashAttnTcTests.cs index ce95da3..29c24da 100644 --- a/tests/SharpInference.Tests.ForwardPass/CudaFlashAttnTcTests.cs +++ b/tests/SharpInference.Tests.ForwardPass/CudaFlashAttnTcTests.cs @@ -45,6 +45,10 @@ private static (int nh, int nkv, int hd, int win, int nTok, int startPos)[] Conf (4, 4, 128, 0, 64, 0), // MHA (no GQA), small head_dim // #151 gap 1 — startPos > 0 (continued prefill): K/V carry prior context. + // These keep maxSeqLen = startPos+nTok (a flat cache), so they validate the + // causal/window mask + key-tile-span interaction with startPos — NOT the SWA + // ring wrap (abs_k % maxSeqLen is the identity here). Ring-wrap is covered at + // the model level by CudaForwardPassKvDtypeTests' long-prompt chunked prefill. (8, 2, 256, 0, 64, 137), // global, prior context before the new tokens (8, 2, 512, 96, 80, 211), // SWA, window (96) bounded well inside the prior context (startPos 211)