diff --git a/src/SharpInference.Cuda/CudaBackend.cs b/src/SharpInference.Cuda/CudaBackend.cs
index 52b9b58..2c615e2 100644
--- a/src/SharpInference.Cuda/CudaBackend.cs
+++ b/src/SharpInference.Cuda/CudaBackend.cs
@@ -4225,8 +4225,8 @@ public void FlashAttentionPrefillTc(Tensor qAll, Tensor kCache, Tensor vCache, T
/// Issue #147: multi-warp / d-split tensor-core flash-attention prefill — same
/// args/semantics as but W=4 warps cooperate
/// on each 16-query tile with the head dim split across them, so O is register-
- /// resident (no shared-O rescale) and occupancy rises ~10×. Requires
- /// % 64 == 0 (W·16) and ≤ 512. Argmax-stable, not bit-exact.
+ /// resident (no shared-O rescale) and occupancy rises ~10× (RTX 4070 Ti / Ada).
+ /// Requires % 64 == 0 (W·16) and ≤ 512. Argmax-stable, not bit-exact.
///
/// K/V cache element dtype (issue #179): Float32 (default),
/// BFloat16, or Q8_0. The matching templated thunk decodes each element to fp32 on
diff --git a/src/SharpInference.Cuda/CudaTextKernels.cs b/src/SharpInference.Cuda/CudaTextKernels.cs
index f83c2e3..bd44ea4 100644
--- a/src/SharpInference.Cuda/CudaTextKernels.cs
+++ b/src/SharpInference.Cuda/CudaTextKernels.cs
@@ -6021,11 +6021,12 @@ asm volatile(
// ── Multi-warp / d-split tensor-core flash-attention prefill (issue #147) ───
// Fixes the single-warp llm_flash_attn_prefill_tc occupancy limit (1 warp/block +
-// 48 KB shared → ~2 warps/SM). Here a block is W warps that cooperate on ONE
-// 16-query tile, splitting the head dim: warp w owns output columns [w·dW, …) with
-// dW = head_dim/W, so O[16×dW] stays REGISTER-resident (16×128 = 64 regs/lane at
-// d=512,W=4) instead of in shared — no per-key-tile shared-O rescale traffic, and
-// the freed shared lets occupancy rise to ~16-20 warps/SM. Each warp computes a
+// 48 KB shared → ~2 warps/SM, measured on RTX 4070 Ti / Ada). Here a block is W
+// warps that cooperate on ONE 16-query tile, splitting the head dim: warp w owns
+// output columns [w·dW, …) with dW = head_dim/W, so O[16×dW] stays REGISTER-resident
+// (16×128 = 64 regs/lane at d=512,W=4) instead of in shared — no per-key-tile
+// shared-O rescale traffic, and the freed shared lets occupancy rise to ~16-20
+// warps/SM (RTX 4070 Ti / Ada). Each warp computes a
// PARTIAL QK^T over its d-slice; the partials are summed across warps through a
// small shared S buffer ([W×16×16] fp32), after which every warp holds the full
// reduced score tile in its C-fragment and proceeds exactly like the single-warp
diff --git a/tests/SharpInference.Tests.ForwardPass/CudaFlashAttnTcTests.cs b/tests/SharpInference.Tests.ForwardPass/CudaFlashAttnTcTests.cs
index 0be568f..29c24da 100644
--- a/tests/SharpInference.Tests.ForwardPass/CudaFlashAttnTcTests.cs
+++ b/tests/SharpInference.Tests.ForwardPass/CudaFlashAttnTcTests.cs
@@ -13,6 +13,15 @@ namespace SharpInference.Tests.ForwardPass;
/// looser fp16 tolerance. Same config matrix as the half2 flash test: GQA, both
/// Gemma 4 head_dims (256 SWA / 512 global), causal, windowing, partial last tile.
///
+/// Issue #151 hardens three coverage gaps from the #148 review:
+/// 1. startPos > 0 — the continued-prefill / chat-continuation re-prefill path
+/// (Q is only the new tokens at absolute positions [startPos, startPos+nTok),
+/// while K/V hold the full [0, startPos+nTok) history).
+/// 2. single partial tile (nTok < 16, gy == 1) — the degenerate
+/// online-softmax / masking case.
+/// 3. a genuinely TC1-only head_dim (%16==0 && %64!=0) — the #146 single-warp
+/// shared-O sizing path that the #147 multi-warp kernel (needs %64) never reaches.
+///
/// Silent no-op on hosts without CUDA, matching the other Cuda* test files.
///
public sealed unsafe class CudaFlashAttnTcTests
@@ -24,15 +33,40 @@ public sealed unsafe class CudaFlashAttnTcTests
catch { return null; }
}
- // (numHeads, numKvHeads, headDim, window, nTok). window=0 → global (full causal).
- // All head_dims are multiples of 64 so the same matrix exercises tc2 (#147, W·16=64).
- private static (int nh, int nkv, int hd, int win, int nTok)[] Configs() => new[]
+ // (numHeads, numKvHeads, headDim, window, nTok, startPos). window=0 → global (full causal).
+ // All head_dims are multiples of 64 so the same matrix exercises both the #146 single-warp
+ // (tc) and the #147 multi-warp (tc2, W·16=64) kernels.
+ private static (int nh, int nkv, int hd, int win, int nTok, int startPos)[] Configs() => new[]
+ {
+ (8, 2, 256, 0, 200, 0), // global, SWA head_dim
+ (8, 2, 512, 0, 173, 0), // global, global head_dim, partial last tile
+ (8, 2, 256, 64, 200, 0), // sliding window 64 < nTok
+ (8, 2, 512, 96, 130, 0), // sliding window 96, global head_dim
+ (4, 4, 128, 0, 64, 0), // MHA (no GQA), small head_dim
+
+ // #151 gap 1 — startPos > 0 (continued prefill): K/V carry prior context.
+ // These keep maxSeqLen = startPos+nTok (a flat cache), so they validate the
+ // causal/window mask + key-tile-span interaction with startPos — NOT the SWA
+ // ring wrap (abs_k % maxSeqLen is the identity here). Ring-wrap is covered at
+ // the model level by CudaForwardPassKvDtypeTests' long-prompt chunked prefill.
+ (8, 2, 256, 0, 64, 137), // global, prior context before the new tokens
+ (8, 2, 512, 96, 80, 211), // SWA, window (96) bounded well inside the prior context (startPos 211)
+
+ // #151 gap 2 — single partial tile (nTok < 16, gy == 1).
+ (8, 2, 256, 0, 1, 0), // single query, single key
+ (8, 2, 256, 0, 7, 0), // sub-tile (7 < 16), global
+ (8, 2, 256, 32, 7, 0), // sub-tile with a sliding window
+ (8, 2, 256, 0, 7, 40), // sub-tile with prior context (nTok<16 AND startPos>0)
+ };
+
+ // #151 gap 3 — head_dim % 16 == 0 but % 64 != 0. Only the #146 single-warp TC1 kernel
+ // (shared-O sizing) accepts these; FlashAttentionPrefillTc2 requires % 64 and throws,
+ // so these run in the TC1 path only.
+ private static (int nh, int nkv, int hd, int win, int nTok, int startPos)[] Tc1OnlyConfigs() => new[]
{
- (8, 2, 256, 0, 200), // global, SWA head_dim
- (8, 2, 512, 0, 173), // global, global head_dim, partial last tile
- (8, 2, 256, 64, 200), // sliding window 64 < nTok
- (8, 2, 512, 96, 130), // sliding window 96, global head_dim
- (4, 4, 128, 0, 64), // MHA (no GQA), small head_dim
+ (8, 2, 80, 0, 130, 0), // hd=80 (5×16), global — exercises the TC1-only shared-O path
+ (8, 2, 48, 64, 96, 0), // hd=48 (3×16), sliding window
+ (4, 4, 112, 0, 33, 17), // hd=112 (7×16), MHA, startPos>0, partial last tile
};
[Fact]
@@ -40,7 +74,8 @@ public void FlashAttentionPrefillTc_MatchesScalarBatched() // #146 single-warp
{
using var gpu = TryCreate();
if (gpu is null) return;
- RunParity(gpu, tc2: false, label: "TC1");
+ RunParity(gpu, Configs(), tc2: false, label: "TC1");
+ RunParity(gpu, Tc1OnlyConfigs(), tc2: false, label: "TC1-only-hd");
}
[Fact]
@@ -48,20 +83,24 @@ public void FlashAttentionPrefillTc2_MatchesScalarBatched() // #147 multi-warp/
{
using var gpu = TryCreate();
if (gpu is null) return;
- RunParity(gpu, tc2: true, label: "TC2");
+ RunParity(gpu, Configs(), tc2: true, label: "TC2");
}
- private static void RunParity(CudaBackend gpu, bool tc2, string label)
+ private static void RunParity(
+ CudaBackend gpu,
+ (int nh, int nkv, int hd, int win, int nTok, int startPos)[] configs,
+ bool tc2, string label)
{
- foreach (var (nh, nkv, hd, win, nTok) in Configs())
+ foreach (var (nh, nkv, hd, win, nTok, startPos) in configs)
{
- var rng = new Random(20260606 + nh * 7 + hd * 13 + win * 17 + nTok);
+ var rng = new Random(20260606 + nh * 7 + hd * 13 + win * 17 + nTok + startPos * 101);
int qDim = nh * hd, kvDim = nkv * hd;
+ int kvLen = startPos + nTok; // K/V cache holds the full [0, startPos+nTok) history
var q = new float[(long)nTok * qDim];
for (int i = 0; i < q.Length; i++) q[i] = (float)(rng.NextDouble() * 2 - 1);
- var k = new float[(long)nTok * kvDim];
- var v = new float[(long)nTok * kvDim];
+ var k = new float[(long)kvLen * kvDim];
+ var v = new float[(long)kvLen * kvDim];
for (int i = 0; i < k.Length; i++) { k[i] = (float)(rng.NextDouble() * 2 - 1); v[i] = (float)(rng.NextDouble() * 2 - 1); }
var gq = gpu.Upload(q, TensorShape.D1(q.Length));
@@ -71,13 +110,13 @@ private static void RunParity(CudaBackend gpu, bool tc2, string label)
var gTc = gpu.Allocate(TensorShape.D1(q.Length));
if (win == 0)
- gpu.AttentionBatched(gq, gk, gv, gRef, nh, nkv, hd, startPos: 0, maxSeqLen: nTok, nTok: nTok);
+ gpu.AttentionBatched(gq, gk, gv, gRef, nh, nkv, hd, startPos, maxSeqLen: kvLen, nTok: nTok);
else
- gpu.AttentionSwaBatched(gq, gk, gv, gRef, nh, nkv, hd, startPos: 0, windowSize: win, maxSeqLen: nTok, nTok: nTok);
+ gpu.AttentionSwaBatched(gq, gk, gv, gRef, nh, nkv, hd, startPos, windowSize: win, maxSeqLen: kvLen, nTok: nTok);
if (tc2)
- gpu.FlashAttentionPrefillTc2(gq, gk, gv, gTc, nh, nkv, hd, startPos: 0, windowSize: win, maxSeqLen: nTok, nTok: nTok);
+ gpu.FlashAttentionPrefillTc2(gq, gk, gv, gTc, nh, nkv, hd, startPos, windowSize: win, maxSeqLen: kvLen, nTok: nTok);
else
- gpu.FlashAttentionPrefillTc(gq, gk, gv, gTc, nh, nkv, hd, startPos: 0, windowSize: win, maxSeqLen: nTok, nTok: nTok);
+ gpu.FlashAttentionPrefillTc(gq, gk, gv, gTc, nh, nkv, hd, startPos, windowSize: win, maxSeqLen: kvLen, nTok: nTok);
gpu.Synchronize();
var outRef = new float[q.Length];
@@ -103,10 +142,10 @@ private static void RunParity(CudaBackend gpu, bool tc2, string label)
if (diff > 2e-2f * rms) mismatches++;
}
Console.WriteLine(
- $"Flash{label} nh={nh} nkv={nkv} hd={hd} win={win} nTok={nTok}: maxAbs={maxAbs:E2} rms={rms:E2} mismatches={mismatches}/{outRef.Length}");
+ $"Flash{label} nh={nh} nkv={nkv} hd={hd} win={win} nTok={nTok} startPos={startPos}: maxAbs={maxAbs:E2} rms={rms:E2} mismatches={mismatches}/{outRef.Length}");
// Allow a tiny tail of outliers from fp16 P·V accumulation; the bulk must match.
Assert.True(mismatches <= outRef.Length / 200 + 1,
- $"{label} flash attention diverged from scalar reference: {mismatches}/{outRef.Length} beyond 2e-2·rms ({rms:E3}), maxAbs={maxAbs:E3} (nh={nh} hd={hd} win={win} nTok={nTok}).");
+ $"{label} flash attention diverged from scalar reference: {mismatches}/{outRef.Length} beyond 2e-2·rms ({rms:E3}), maxAbs={maxAbs:E3} (nh={nh} hd={hd} win={win} nTok={nTok} startPos={startPos}).");
}
}
}