Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/SharpInference.Cuda/CudaBackend.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4225,8 +4225,8 @@ public void FlashAttentionPrefillTc(Tensor qAll, Tensor kCache, Tensor vCache, T
/// Issue #147: multi-warp / d-split tensor-core flash-attention prefill — same
/// args/semantics as <see cref="FlashAttentionPrefillTc"/> but W=4 warps cooperate
/// on each 16-query tile with the head dim split across them, so O is register-
/// resident (no shared-O rescale) and occupancy rises ~10×. Requires
/// <paramref name="headDim"/> % 64 == 0 (W·16) and ≤ 512. Argmax-stable, not bit-exact.
/// resident (no shared-O rescale) and occupancy rises ~10× (RTX 4070 Ti / Ada).
/// Requires <paramref name="headDim"/> % 64 == 0 (W·16) and ≤ 512. Argmax-stable, not bit-exact.
/// </summary>
/// <param name="kvCacheType">K/V cache element dtype (issue #179): Float32 (default),
/// BFloat16, or Q8_0. The matching templated thunk decodes each element to fp32 on
Expand Down
11 changes: 6 additions & 5 deletions src/SharpInference.Cuda/CudaTextKernels.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6021,11 +6021,12 @@ asm volatile(

// ── Multi-warp / d-split tensor-core flash-attention prefill (issue #147) ───
// Fixes the single-warp llm_flash_attn_prefill_tc occupancy limit (1 warp/block +
// 48 KB shared → ~2 warps/SM). Here a block is W warps that cooperate on ONE
// 16-query tile, splitting the head dim: warp w owns output columns [w·dW, …) with
// dW = head_dim/W, so O[16×dW] stays REGISTER-resident (16×128 = 64 regs/lane at
// d=512,W=4) instead of in shared — no per-key-tile shared-O rescale traffic, and
// the freed shared lets occupancy rise to ~16-20 warps/SM. Each warp computes a
// 48 KB shared → ~2 warps/SM, measured on RTX 4070 Ti / Ada). Here a block is W
// warps that cooperate on ONE 16-query tile, splitting the head dim: warp w owns
// output columns [w·dW, …) with dW = head_dim/W, so O[16×dW] stays REGISTER-resident
// (16×128 = 64 regs/lane at d=512,W=4) instead of in shared — no per-key-tile
// shared-O rescale traffic, and the freed shared lets occupancy rise to ~16-20
// warps/SM (RTX 4070 Ti / Ada). Each warp computes a
// PARTIAL QK^T over its d-slice; the partials are summed across warps through a
// small shared S buffer ([W×16×16] fp32), after which every warp holds the full
// reduced score tile in its C-fragment and proceeds exactly like the single-warp
Expand Down
81 changes: 60 additions & 21 deletions tests/SharpInference.Tests.ForwardPass/CudaFlashAttnTcTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@ namespace SharpInference.Tests.ForwardPass;
/// looser fp16 tolerance. Same config matrix as the half2 flash test: GQA, both
/// Gemma 4 head_dims (256 SWA / 512 global), causal, windowing, partial last tile.
///
/// Issue #151 hardens three coverage gaps from the #148 review:
/// 1. <c>startPos &gt; 0</c> — the continued-prefill / chat-continuation re-prefill path
/// (Q is only the new tokens at absolute positions <c>[startPos, startPos+nTok)</c>,
/// while K/V hold the full <c>[0, startPos+nTok)</c> history).
/// 2. single partial tile (<c>nTok &lt; 16</c>, <c>gy == 1</c>) — the degenerate
/// online-softmax / masking case.
/// 3. a genuinely TC1-only head_dim (<c>%16==0 &amp;&amp; %64!=0</c>) — the #146 single-warp
/// shared-O sizing path that the #147 multi-warp kernel (needs <c>%64</c>) never reaches.
///
/// Silent no-op on hosts without CUDA, matching the other Cuda* test files.
/// </summary>
public sealed unsafe class CudaFlashAttnTcTests
Expand All @@ -24,44 +33,74 @@ public sealed unsafe class CudaFlashAttnTcTests
catch { return null; }
}

// (numHeads, numKvHeads, headDim, window, nTok). window=0 → global (full causal).
// All head_dims are multiples of 64 so the same matrix exercises tc2 (#147, W·16=64).
private static (int nh, int nkv, int hd, int win, int nTok)[] Configs() => new[]
// (numHeads, numKvHeads, headDim, window, nTok, startPos). window=0 → global (full causal).
// All head_dims are multiples of 64 so the same matrix exercises both the #146 single-warp
// (tc) and the #147 multi-warp (tc2, W·16=64) kernels.
private static (int nh, int nkv, int hd, int win, int nTok, int startPos)[] Configs() => new[]
{
(8, 2, 256, 0, 200, 0), // global, SWA head_dim
(8, 2, 512, 0, 173, 0), // global, global head_dim, partial last tile
(8, 2, 256, 64, 200, 0), // sliding window 64 < nTok
(8, 2, 512, 96, 130, 0), // sliding window 96, global head_dim
(4, 4, 128, 0, 64, 0), // MHA (no GQA), small head_dim

// #151 gap 1 — startPos > 0 (continued prefill): K/V carry prior context.
// These keep maxSeqLen = startPos+nTok (a flat cache), so they validate the
// causal/window mask + key-tile-span interaction with startPos — NOT the SWA
// ring wrap (abs_k % maxSeqLen is the identity here). Ring-wrap is covered at
// the model level by CudaForwardPassKvDtypeTests' long-prompt chunked prefill.
(8, 2, 256, 0, 64, 137), // global, prior context before the new tokens
(8, 2, 512, 96, 80, 211), // SWA, window (96) bounded well inside the prior context (startPos 211)

// #151 gap 2 — single partial tile (nTok < 16, gy == 1).
(8, 2, 256, 0, 1, 0), // single query, single key
(8, 2, 256, 0, 7, 0), // sub-tile (7 < 16), global
(8, 2, 256, 32, 7, 0), // sub-tile with a sliding window
(8, 2, 256, 0, 7, 40), // sub-tile with prior context (nTok<16 AND startPos>0)
};

// #151 gap 3 — head_dim % 16 == 0 but % 64 != 0. Only the #146 single-warp TC1 kernel
// (shared-O sizing) accepts these; FlashAttentionPrefillTc2 requires % 64 and throws,
// so these run in the TC1 path only.
private static (int nh, int nkv, int hd, int win, int nTok, int startPos)[] Tc1OnlyConfigs() => new[]
{
(8, 2, 256, 0, 200), // global, SWA head_dim
(8, 2, 512, 0, 173), // global, global head_dim, partial last tile
(8, 2, 256, 64, 200), // sliding window 64 < nTok
(8, 2, 512, 96, 130), // sliding window 96, global head_dim
(4, 4, 128, 0, 64), // MHA (no GQA), small head_dim
(8, 2, 80, 0, 130, 0), // hd=80 (5×16), global — exercises the TC1-only shared-O path
(8, 2, 48, 64, 96, 0), // hd=48 (3×16), sliding window
(4, 4, 112, 0, 33, 17), // hd=112 (7×16), MHA, startPos>0, partial last tile
};

[Fact]
public void FlashAttentionPrefillTc_MatchesScalarBatched() // #146 single-warp
{
using var gpu = TryCreate();
if (gpu is null) return;
RunParity(gpu, tc2: false, label: "TC1");
RunParity(gpu, Configs(), tc2: false, label: "TC1");
RunParity(gpu, Tc1OnlyConfigs(), tc2: false, label: "TC1-only-hd");
}

[Fact]
public void FlashAttentionPrefillTc2_MatchesScalarBatched() // #147 multi-warp/d-split
{
using var gpu = TryCreate();
if (gpu is null) return;
RunParity(gpu, tc2: true, label: "TC2");
RunParity(gpu, Configs(), tc2: true, label: "TC2");
}

private static void RunParity(CudaBackend gpu, bool tc2, string label)
private static void RunParity(
CudaBackend gpu,
(int nh, int nkv, int hd, int win, int nTok, int startPos)[] configs,
bool tc2, string label)
{
foreach (var (nh, nkv, hd, win, nTok) in Configs())
foreach (var (nh, nkv, hd, win, nTok, startPos) in configs)
{
var rng = new Random(20260606 + nh * 7 + hd * 13 + win * 17 + nTok);
var rng = new Random(20260606 + nh * 7 + hd * 13 + win * 17 + nTok + startPos * 101);
int qDim = nh * hd, kvDim = nkv * hd;
int kvLen = startPos + nTok; // K/V cache holds the full [0, startPos+nTok) history

var q = new float[(long)nTok * qDim];
for (int i = 0; i < q.Length; i++) q[i] = (float)(rng.NextDouble() * 2 - 1);
var k = new float[(long)nTok * kvDim];
var v = new float[(long)nTok * kvDim];
var k = new float[(long)kvLen * kvDim];
var v = new float[(long)kvLen * kvDim];
for (int i = 0; i < k.Length; i++) { k[i] = (float)(rng.NextDouble() * 2 - 1); v[i] = (float)(rng.NextDouble() * 2 - 1); }

var gq = gpu.Upload(q, TensorShape.D1(q.Length));
Expand All @@ -71,13 +110,13 @@ private static void RunParity(CudaBackend gpu, bool tc2, string label)
var gTc = gpu.Allocate(TensorShape.D1(q.Length));

if (win == 0)
gpu.AttentionBatched(gq, gk, gv, gRef, nh, nkv, hd, startPos: 0, maxSeqLen: nTok, nTok: nTok);
gpu.AttentionBatched(gq, gk, gv, gRef, nh, nkv, hd, startPos, maxSeqLen: kvLen, nTok: nTok);
else
gpu.AttentionSwaBatched(gq, gk, gv, gRef, nh, nkv, hd, startPos: 0, windowSize: win, maxSeqLen: nTok, nTok: nTok);
gpu.AttentionSwaBatched(gq, gk, gv, gRef, nh, nkv, hd, startPos, windowSize: win, maxSeqLen: kvLen, nTok: nTok);
if (tc2)
gpu.FlashAttentionPrefillTc2(gq, gk, gv, gTc, nh, nkv, hd, startPos: 0, windowSize: win, maxSeqLen: nTok, nTok: nTok);
gpu.FlashAttentionPrefillTc2(gq, gk, gv, gTc, nh, nkv, hd, startPos, windowSize: win, maxSeqLen: kvLen, nTok: nTok);
else
gpu.FlashAttentionPrefillTc(gq, gk, gv, gTc, nh, nkv, hd, startPos: 0, windowSize: win, maxSeqLen: nTok, nTok: nTok);
gpu.FlashAttentionPrefillTc(gq, gk, gv, gTc, nh, nkv, hd, startPos, windowSize: win, maxSeqLen: kvLen, nTok: nTok);
gpu.Synchronize();

var outRef = new float[q.Length];
Expand All @@ -103,10 +142,10 @@ private static void RunParity(CudaBackend gpu, bool tc2, string label)
if (diff > 2e-2f * rms) mismatches++;
}
Console.WriteLine(
$"Flash{label} nh={nh} nkv={nkv} hd={hd} win={win} nTok={nTok}: maxAbs={maxAbs:E2} rms={rms:E2} mismatches={mismatches}/{outRef.Length}");
$"Flash{label} nh={nh} nkv={nkv} hd={hd} win={win} nTok={nTok} startPos={startPos}: maxAbs={maxAbs:E2} rms={rms:E2} mismatches={mismatches}/{outRef.Length}");
// Allow a tiny tail of outliers from fp16 P·V accumulation; the bulk must match.
Assert.True(mismatches <= outRef.Length / 200 + 1,
$"{label} flash attention diverged from scalar reference: {mismatches}/{outRef.Length} beyond 2e-2·rms ({rms:E3}), maxAbs={maxAbs:E3} (nh={nh} hd={hd} win={win} nTok={nTok}).");
$"{label} flash attention diverged from scalar reference: {mismatches}/{outRef.Length} beyond 2e-2·rms ({rms:E3}), maxAbs={maxAbs:E3} (nh={nh} hd={hd} win={win} nTok={nTok} startPos={startPos}).");
}
}
}