diff --git a/src/SharpInference.Cuda/CudaBackend.cs b/src/SharpInference.Cuda/CudaBackend.cs index 52b9b58..2c615e2 100644 --- a/src/SharpInference.Cuda/CudaBackend.cs +++ b/src/SharpInference.Cuda/CudaBackend.cs @@ -4225,8 +4225,8 @@ public void FlashAttentionPrefillTc(Tensor qAll, Tensor kCache, Tensor vCache, T /// Issue #147: multi-warp / d-split tensor-core flash-attention prefill — same /// args/semantics as but W=4 warps cooperate /// on each 16-query tile with the head dim split across them, so O is register- - /// resident (no shared-O rescale) and occupancy rises ~10×. Requires - /// % 64 == 0 (W·16) and ≤ 512. Argmax-stable, not bit-exact. + /// resident (no shared-O rescale) and occupancy rises ~10× (RTX 4070 Ti / Ada). + /// Requires % 64 == 0 (W·16) and ≤ 512. Argmax-stable, not bit-exact. /// /// K/V cache element dtype (issue #179): Float32 (default), /// BFloat16, or Q8_0. The matching templated thunk decodes each element to fp32 on diff --git a/src/SharpInference.Cuda/CudaTextKernels.cs b/src/SharpInference.Cuda/CudaTextKernels.cs index f83c2e3..bd44ea4 100644 --- a/src/SharpInference.Cuda/CudaTextKernels.cs +++ b/src/SharpInference.Cuda/CudaTextKernels.cs @@ -6021,11 +6021,12 @@ asm volatile( // ── Multi-warp / d-split tensor-core flash-attention prefill (issue #147) ─── // Fixes the single-warp llm_flash_attn_prefill_tc occupancy limit (1 warp/block + -// 48 KB shared → ~2 warps/SM). Here a block is W warps that cooperate on ONE -// 16-query tile, splitting the head dim: warp w owns output columns [w·dW, …) with -// dW = head_dim/W, so O[16×dW] stays REGISTER-resident (16×128 = 64 regs/lane at -// d=512,W=4) instead of in shared — no per-key-tile shared-O rescale traffic, and -// the freed shared lets occupancy rise to ~16-20 warps/SM. Each warp computes a +// 48 KB shared → ~2 warps/SM, measured on RTX 4070 Ti / Ada). Here a block is W +// warps that cooperate on ONE 16-query tile, splitting the head dim: warp w owns +// output columns [w·dW, …) with dW = head_dim/W, so O[16×dW] stays REGISTER-resident +// (16×128 = 64 regs/lane at d=512,W=4) instead of in shared — no per-key-tile +// shared-O rescale traffic, and the freed shared lets occupancy rise to ~16-20 +// warps/SM (RTX 4070 Ti / Ada). Each warp computes a // PARTIAL QK^T over its d-slice; the partials are summed across warps through a // small shared S buffer ([W×16×16] fp32), after which every warp holds the full // reduced score tile in its C-fragment and proceeds exactly like the single-warp diff --git a/tests/SharpInference.Tests.ForwardPass/CudaFlashAttnTcTests.cs b/tests/SharpInference.Tests.ForwardPass/CudaFlashAttnTcTests.cs index 0be568f..29c24da 100644 --- a/tests/SharpInference.Tests.ForwardPass/CudaFlashAttnTcTests.cs +++ b/tests/SharpInference.Tests.ForwardPass/CudaFlashAttnTcTests.cs @@ -13,6 +13,15 @@ namespace SharpInference.Tests.ForwardPass; /// looser fp16 tolerance. Same config matrix as the half2 flash test: GQA, both /// Gemma 4 head_dims (256 SWA / 512 global), causal, windowing, partial last tile. /// +/// Issue #151 hardens three coverage gaps from the #148 review: +/// 1. startPos > 0 — the continued-prefill / chat-continuation re-prefill path +/// (Q is only the new tokens at absolute positions [startPos, startPos+nTok), +/// while K/V hold the full [0, startPos+nTok) history). +/// 2. single partial tile (nTok < 16, gy == 1) — the degenerate +/// online-softmax / masking case. +/// 3. a genuinely TC1-only head_dim (%16==0 && %64!=0) — the #146 single-warp +/// shared-O sizing path that the #147 multi-warp kernel (needs %64) never reaches. +/// /// Silent no-op on hosts without CUDA, matching the other Cuda* test files. /// public sealed unsafe class CudaFlashAttnTcTests @@ -24,15 +33,40 @@ public sealed unsafe class CudaFlashAttnTcTests catch { return null; } } - // (numHeads, numKvHeads, headDim, window, nTok). window=0 → global (full causal). - // All head_dims are multiples of 64 so the same matrix exercises tc2 (#147, W·16=64). - private static (int nh, int nkv, int hd, int win, int nTok)[] Configs() => new[] + // (numHeads, numKvHeads, headDim, window, nTok, startPos). window=0 → global (full causal). + // All head_dims are multiples of 64 so the same matrix exercises both the #146 single-warp + // (tc) and the #147 multi-warp (tc2, W·16=64) kernels. + private static (int nh, int nkv, int hd, int win, int nTok, int startPos)[] Configs() => new[] + { + (8, 2, 256, 0, 200, 0), // global, SWA head_dim + (8, 2, 512, 0, 173, 0), // global, global head_dim, partial last tile + (8, 2, 256, 64, 200, 0), // sliding window 64 < nTok + (8, 2, 512, 96, 130, 0), // sliding window 96, global head_dim + (4, 4, 128, 0, 64, 0), // MHA (no GQA), small head_dim + + // #151 gap 1 — startPos > 0 (continued prefill): K/V carry prior context. + // These keep maxSeqLen = startPos+nTok (a flat cache), so they validate the + // causal/window mask + key-tile-span interaction with startPos — NOT the SWA + // ring wrap (abs_k % maxSeqLen is the identity here). Ring-wrap is covered at + // the model level by CudaForwardPassKvDtypeTests' long-prompt chunked prefill. + (8, 2, 256, 0, 64, 137), // global, prior context before the new tokens + (8, 2, 512, 96, 80, 211), // SWA, window (96) bounded well inside the prior context (startPos 211) + + // #151 gap 2 — single partial tile (nTok < 16, gy == 1). + (8, 2, 256, 0, 1, 0), // single query, single key + (8, 2, 256, 0, 7, 0), // sub-tile (7 < 16), global + (8, 2, 256, 32, 7, 0), // sub-tile with a sliding window + (8, 2, 256, 0, 7, 40), // sub-tile with prior context (nTok<16 AND startPos>0) + }; + + // #151 gap 3 — head_dim % 16 == 0 but % 64 != 0. Only the #146 single-warp TC1 kernel + // (shared-O sizing) accepts these; FlashAttentionPrefillTc2 requires % 64 and throws, + // so these run in the TC1 path only. + private static (int nh, int nkv, int hd, int win, int nTok, int startPos)[] Tc1OnlyConfigs() => new[] { - (8, 2, 256, 0, 200), // global, SWA head_dim - (8, 2, 512, 0, 173), // global, global head_dim, partial last tile - (8, 2, 256, 64, 200), // sliding window 64 < nTok - (8, 2, 512, 96, 130), // sliding window 96, global head_dim - (4, 4, 128, 0, 64), // MHA (no GQA), small head_dim + (8, 2, 80, 0, 130, 0), // hd=80 (5×16), global — exercises the TC1-only shared-O path + (8, 2, 48, 64, 96, 0), // hd=48 (3×16), sliding window + (4, 4, 112, 0, 33, 17), // hd=112 (7×16), MHA, startPos>0, partial last tile }; [Fact] @@ -40,7 +74,8 @@ public void FlashAttentionPrefillTc_MatchesScalarBatched() // #146 single-warp { using var gpu = TryCreate(); if (gpu is null) return; - RunParity(gpu, tc2: false, label: "TC1"); + RunParity(gpu, Configs(), tc2: false, label: "TC1"); + RunParity(gpu, Tc1OnlyConfigs(), tc2: false, label: "TC1-only-hd"); } [Fact] @@ -48,20 +83,24 @@ public void FlashAttentionPrefillTc2_MatchesScalarBatched() // #147 multi-warp/ { using var gpu = TryCreate(); if (gpu is null) return; - RunParity(gpu, tc2: true, label: "TC2"); + RunParity(gpu, Configs(), tc2: true, label: "TC2"); } - private static void RunParity(CudaBackend gpu, bool tc2, string label) + private static void RunParity( + CudaBackend gpu, + (int nh, int nkv, int hd, int win, int nTok, int startPos)[] configs, + bool tc2, string label) { - foreach (var (nh, nkv, hd, win, nTok) in Configs()) + foreach (var (nh, nkv, hd, win, nTok, startPos) in configs) { - var rng = new Random(20260606 + nh * 7 + hd * 13 + win * 17 + nTok); + var rng = new Random(20260606 + nh * 7 + hd * 13 + win * 17 + nTok + startPos * 101); int qDim = nh * hd, kvDim = nkv * hd; + int kvLen = startPos + nTok; // K/V cache holds the full [0, startPos+nTok) history var q = new float[(long)nTok * qDim]; for (int i = 0; i < q.Length; i++) q[i] = (float)(rng.NextDouble() * 2 - 1); - var k = new float[(long)nTok * kvDim]; - var v = new float[(long)nTok * kvDim]; + var k = new float[(long)kvLen * kvDim]; + var v = new float[(long)kvLen * kvDim]; for (int i = 0; i < k.Length; i++) { k[i] = (float)(rng.NextDouble() * 2 - 1); v[i] = (float)(rng.NextDouble() * 2 - 1); } var gq = gpu.Upload(q, TensorShape.D1(q.Length)); @@ -71,13 +110,13 @@ private static void RunParity(CudaBackend gpu, bool tc2, string label) var gTc = gpu.Allocate(TensorShape.D1(q.Length)); if (win == 0) - gpu.AttentionBatched(gq, gk, gv, gRef, nh, nkv, hd, startPos: 0, maxSeqLen: nTok, nTok: nTok); + gpu.AttentionBatched(gq, gk, gv, gRef, nh, nkv, hd, startPos, maxSeqLen: kvLen, nTok: nTok); else - gpu.AttentionSwaBatched(gq, gk, gv, gRef, nh, nkv, hd, startPos: 0, windowSize: win, maxSeqLen: nTok, nTok: nTok); + gpu.AttentionSwaBatched(gq, gk, gv, gRef, nh, nkv, hd, startPos, windowSize: win, maxSeqLen: kvLen, nTok: nTok); if (tc2) - gpu.FlashAttentionPrefillTc2(gq, gk, gv, gTc, nh, nkv, hd, startPos: 0, windowSize: win, maxSeqLen: nTok, nTok: nTok); + gpu.FlashAttentionPrefillTc2(gq, gk, gv, gTc, nh, nkv, hd, startPos, windowSize: win, maxSeqLen: kvLen, nTok: nTok); else - gpu.FlashAttentionPrefillTc(gq, gk, gv, gTc, nh, nkv, hd, startPos: 0, windowSize: win, maxSeqLen: nTok, nTok: nTok); + gpu.FlashAttentionPrefillTc(gq, gk, gv, gTc, nh, nkv, hd, startPos, windowSize: win, maxSeqLen: kvLen, nTok: nTok); gpu.Synchronize(); var outRef = new float[q.Length]; @@ -103,10 +142,10 @@ private static void RunParity(CudaBackend gpu, bool tc2, string label) if (diff > 2e-2f * rms) mismatches++; } Console.WriteLine( - $"Flash{label} nh={nh} nkv={nkv} hd={hd} win={win} nTok={nTok}: maxAbs={maxAbs:E2} rms={rms:E2} mismatches={mismatches}/{outRef.Length}"); + $"Flash{label} nh={nh} nkv={nkv} hd={hd} win={win} nTok={nTok} startPos={startPos}: maxAbs={maxAbs:E2} rms={rms:E2} mismatches={mismatches}/{outRef.Length}"); // Allow a tiny tail of outliers from fp16 P·V accumulation; the bulk must match. Assert.True(mismatches <= outRef.Length / 200 + 1, - $"{label} flash attention diverged from scalar reference: {mismatches}/{outRef.Length} beyond 2e-2·rms ({rms:E3}), maxAbs={maxAbs:E3} (nh={nh} hd={hd} win={win} nTok={nTok})."); + $"{label} flash attention diverged from scalar reference: {mismatches}/{outRef.Length} beyond 2e-2·rms ({rms:E3}), maxAbs={maxAbs:E3} (nh={nh} hd={hd} win={win} nTok={nTok} startPos={startPos})."); } } }