diff --git a/README.md b/README.md
index b5b37d2..d212791 100644
--- a/README.md
+++ b/README.md
@@ -51,7 +51,7 @@ coherent (`scripts/bench-all.ps1`); top-1 parity vs llama.cpp b8585 verified on
| Qwen3.6-35B-A3B-MTP (GDN+MoE) | (same) | 22 GB | **CUDA** `-g -1 --no-thinking` (hybrid) | **65.0** | **22.9** | Requires `SHARPI_CPU_MOE=1`: 30 GDN + 10 attn + shared expert on GPU, routed experts CPU mmap. 100% acceptance. Fused GDN scan + batched SDPA (#114-B/#118), bit-identical, grows with ctx |
| Carnice (Qwen3.6-35B-A3B-MTP finetune) | [mudler](https://huggingface.co/mudler/Carnice-Qwen3.6-MoE-35B-A3B-APEX-MTP-GGUF) | 17 GB | **CUDA** `-g -1 --no-thinking` (hybrid) | **43.6** | **25.0** | agentic finetune of 35B-A3B-MTP; 77% acceptance (`bench-carnice.ps1` — the default prompt 1-token-EOSes on this terser tune). APEX mixed-precision (Q3_K + Q8_0 experts); Q8_KS per-32 int dots auto-enable at load (#99/#101/#107), +4.6% decode at ~4× tighter parity vs plain Q8_K (`SHARPI_Q3K_Q8K=0`/`SHARPI_Q8_0_Q8K=0` to disable). Fused GDN scan + wave SDPA (#114-B/#118) bit-identical past 4096 |
| Gemma 4 E4B-it Q8 | [unsloth](https://huggingface.co/unsloth/gemma-4-E4B-it-GGUF) | 8 GB | CPU | 4.9 | 5.0 | dense 42-layer gemma4: per-layer head_dim (256 SWA / 512 global), dual-RoPE, KV-share tail (18 layers), 5:1 SWA:global, logit softcap 30, PLE-256 injection (~4.2 GB mmap-resident) |
-| Gemma 4 E4B-it Q8 | (same) | 8 GB | **CUDA** `-g -1 -c 2048` | **3698** | **59** | all 42 layers fit at `-c 2048`. KV-share alias + SWA/global split per layer; PLE projections (~215 MB) upload at construction. **Prefill (#141):** int8 **tensor-core MMQ** matmul (`mma.m16n8k32.s8`, each Q8_0 weight read once as int8 — beats the dequant→fp16→cuBLAS GEMM, drops its fp16 HBM temp; `SHARPI_PREFILL_MMQ=0` reverts) + a **tensor-core flash-attention** prefill (#146/#147): both QK^T and P·V on the mma cores (`mma.m16n8k16.f16`), multi-warp **d-split** so the O tile stays register-resident — replaces the scalar O(n²) per-query attention (which re-streamed each query's K/V window up to ~512×) and is **+27% at ~1K / +40% at 1.8K** over the earlier half2 flash kernel (`SHARPI_PREFILL_FLASH_TC=0` reverts to half2, `=…_FLASH=0` to scalar) + a **SoA Q8_0 weight repack** (#149): all Q8_0 readers (MMQ, dp4a, fp32 matvec, GEMM-N, dequant) read the quants 16-byte-aligned with the fp16 scales split out, killing the `qs` 2-byte-misalignment funnelshift tax — **+10-12% prefill, bit-identical**; `SHARPI_MMQ_SOA=0` reverts) + a batched Q8_0 embedding lookup. **109→3698 at ~1K ctx, →4240 at 1.8K** — profiling showed *attention*, then the matmul inner-loop efficiency, were the dominant prefill costs at realistic prompt lengths. **Decode (#142):** dp4a/Q8_1 int8 matvec (`SHARPI_Q80_DP4A=0` to bisect) + CUDA-graph capture/replay default-on (`SHARPI_CUDA_GRAPH=0` to bisect). All prefill/decode fast paths are argmax-stable vs the fp32 path, not bit-exact (the SoA repack is bit-identical). Remaining gap to llama.cpp (~8475 prefill / ~78 decode): cp.async-pipelined MMQ on the SoA layout + decode matvec work |
+| Gemma 4 E4B-it Q8 | (same) | 8 GB | **CUDA** `-g -1 -c 2048` | **3698** | **59** | all 42 layers fit at `-c 2048`; KV-share alias + per-layer SWA/global split; PLE projections (~215 MB) upload at construction. **Prefill (#141/#146/#147/#149):** int8 tensor-core MMQ (`mma.m16n8k32.s8`, weight read once as int8) + tensor-core flash attention (QK^T and P·V on mma cores, register-resident O via d-split) + a SoA Q8_0 weight repack (16-byte-aligned quants) + batched embedding lookup — **109 → 3698 t/s @1K, 4240 @1.8K**. **>4096 prompts (#162/#164):** a real SWA KV ring (cache sized window + one chunk span, indexed `pos % size`) lets long prompts take the chunked batched-flash path instead of the ~8× slower per-token fallback (**2.72× — 47.7 → 129.6 t/s on a 5.3K-token prompt**, measured at a context that admits it, not the `-c 2048` of this row), and fixes long-context correctness past the 512 window (the cache was previously window-sized but indexed absolutely → out of bounds). **Decode (#142):** dp4a/Q8_1 int8 matvec + CUDA-graph replay. Fast paths argmax-stable vs fp32 (SoA repack + ring bit-identical below the window). Bisect env: `SHARPI_PREFILL_MMQ` / `_FLASH_TC` / `SHARPI_MMQ_SOA` / `SHARPI_Q80_DP4A` / `SHARPI_CUDA_GRAPH` / `SHARPI_BATCHED_PREFILL` (`=0`). Remaining gap to llama.cpp (~8475 prefill / ~78 decode): cp.async-pipelined MMQ + decode matvec |
| Gemma 4 E4B-it Q8 | (same) | 8 GB | **CUDA** `-g 22 -c 2048` (hybrid) | 6.6 | 6.8 | 22 GPU + 20 CPU layers. `-g ≤ 22` required so the CPU shared-KV tail can read its own-KV source layers; CPU dense-FFN dominates decode (bandwidth-bound). `SHARPI_CUDA_PROFILE=1` for per-phase breakdown |
_Numbers re-measured across every on-disk row at ~1K ctx so the prefill column is comparable; per-issue
diff --git a/src/SharpInference.Cuda/CudaTextKernels.cs b/src/SharpInference.Cuda/CudaTextKernels.cs
index 4421642..abd884b 100644
--- a/src/SharpInference.Cuda/CudaTextKernels.cs
+++ b/src/SharpInference.Cuda/CudaTextKernels.cs
@@ -599,7 +599,10 @@ __device__ __forceinline__ unsigned int sharpi_uint_at(const unsigned int* __res
{
int i = (int)(blockIdx.x * blockDim.x + threadIdx.x);
if (i >= kv_dim) return;
- long offset = (long)position * (long)kv_dim + (long)i;
+ // Ring slot: `position % max_seq_len`. `max_seq_len` is the allocated cache size,
+ // so for a full-context (dense / global) cache `position < max_seq_len` makes this
+ // the identity; for a window-sized SWA ring it wraps the write into the ring.
+ long offset = (long)(position % max_seq_len) * (long)kv_dim + (long)i;
k_cache[offset] = k_in[i];
v_cache[offset] = v_in[i];
}
@@ -618,7 +621,11 @@ __device__ __forceinline__ unsigned int sharpi_uint_at(const unsigned int* __res
{
int i = (int)(blockIdx.x * blockDim.x + threadIdx.x);
if (i >= kv_dim) return;
- long offset = (long)position * (long)kv_dim + (long)i;
+ // Ring slot `position % max_seq_len` (identity for a full-context cache; wraps a
+ // window-sized ring). Matches the f32 llm_kv_append so the write/read indexing stays
+ // uniform if a windowed model ever uses the bf16 KV cache (today only the full-context
+ // GDN-hybrid path does, where position < max_seq_len makes this the identity).
+ long offset = (long)(position % max_seq_len) * (long)kv_dim + (long)i;
k_cache[offset] = (unsigned short)sharpi_fp32_to_bf16(k_in[i]);
v_cache[offset] = (unsigned short)sharpi_fp32_to_bf16(v_in[i]);
}
@@ -4066,7 +4073,9 @@ __device__ __forceinline__ void sharpi_q4k_scale_min(
for (int t = (int)tid; t < eff_seq; t += 256) {
int abs_t = t + window_start;
float dot = 0.f;
- long k_off = (long)abs_t * (long)kv_dim + (long)kv_head * (long)head_dim;
+ // Ring slot `abs_t % max_seq_len` (max_seq_len = allocated cache size): identity
+ // for a full cache, wraps a window-sized SWA ring. abs_t itself stays logical.
+ long k_off = (long)(abs_t % max_seq_len) * (long)kv_dim + (long)kv_head * (long)head_dim;
for (int d = 0; d < head_dim; d++)
dot += q[q_off + d] * k_cache[k_off + d];
float score = dot * scale;
@@ -4127,7 +4136,7 @@ __device__ __forceinline__ void sharpi_q4k_scale_min(
for (int t = 0; t < eff_seq; t++) {
int abs_t = t + window_start;
float weight = use_shared ? shared_scores[t] : head_scratch[t];
- long v_off = (long)abs_t * (long)kv_dim + (long)kv_head * (long)head_dim;
+ long v_off = (long)(abs_t % max_seq_len) * (long)kv_dim + (long)kv_head * (long)head_dim;
acc += weight * v_cache[v_off + d];
}
out[out_off + d] = acc;
@@ -4246,8 +4255,10 @@ __device__ __forceinline__ float fatc_mask(float s, int qpos, int abs_k, int win
for (int idx = lane; idx < FATC_KT * head_dim; idx += 32) {
int kk = idx / head_dim, d = idx - kk * head_dim;
int abs_k = kt0 + kk;
+ // Cache read at ring slot `abs_k % max_seq_len` (identity for a full cache,
+ // wraps a window-sized SWA ring); abs_k stays logical for the causal bound.
float kv = (abs_k < key_end)
- ? k_cache[(long)abs_k * kv_dim + (long)kv_head * head_dim + d] : 0.f;
+ ? k_cache[(long)(abs_k % max_seq_len) * kv_dim + (long)kv_head * head_dim + d] : 0.f;
sKV[idx] = (unsigned short)sharpi_fp32_to_fp16(kv);
}
__syncthreads();
@@ -4345,7 +4356,7 @@ asm volatile(
int kk = idx / head_dim, d = idx - kk * head_dim;
int abs_k = kt0 + kk;
float vv = (abs_k < key_end)
- ? v_cache[(long)abs_k * kv_dim + (long)kv_head * head_dim + d] : 0.f;
+ ? v_cache[(long)(abs_k % max_seq_len) * kv_dim + (long)kv_head * head_dim + d] : 0.f;
sKV[idx] = (unsigned short)sharpi_fp32_to_fp16(vv);
}
__syncthreads();
@@ -4454,8 +4465,10 @@ asm volatile(
for (int idx = tid; idx < FATC2_KT * head_dim; idx += FATC2_W * 32) {
int kk = idx / head_dim, d = idx - kk * head_dim;
int abs_k = kt0 + kk;
+ // Cache read at ring slot `abs_k % max_seq_len` (identity for a full cache,
+ // wraps a window-sized SWA ring); abs_k stays logical for the causal bound.
float kv = (abs_k < key_end)
- ? k_cache[(long)abs_k * kv_dim + (long)kv_head * head_dim + d] : 0.f;
+ ? k_cache[(long)(abs_k % max_seq_len) * kv_dim + (long)kv_head * head_dim + d] : 0.f;
sKV[idx] = (unsigned short)sharpi_fp32_to_fp16(kv);
}
__syncthreads();
@@ -4566,7 +4579,7 @@ asm volatile(
int kk = idx / head_dim, d = idx - kk * head_dim;
int abs_k = kt0 + kk;
float vv = (abs_k < key_end)
- ? v_cache[(long)abs_k * kv_dim + (long)kv_head * head_dim + d] : 0.f;
+ ? v_cache[(long)(abs_k % max_seq_len) * kv_dim + (long)kv_head * head_dim + d] : 0.f;
sKV[idx] = (unsigned short)sharpi_fp32_to_fp16(vv);
}
__syncthreads();
@@ -4697,7 +4710,9 @@ asm volatile(
int kk = idx / hd2, pr = idx - kk * hd2;
unsigned int kh = 0u;
if (kk < tile_keys) {
- long off = (long)(kt0 + kk) * kv_dim + (long)kv_head * head_dim + 2 * pr;
+ // Ring slot `(kt0+kk) % max_seq_len`: identity for a full cache, wraps a
+ // window-sized SWA ring. The kt0+kk index stays logical for tile bounds.
+ long off = (long)((kt0 + kk) % max_seq_len) * kv_dim + (long)kv_head * head_dim + 2 * pr;
kh = sharpi_f32x2_to_f16x2(k_cache[off], k_cache[off + 1]);
}
sKh[idx] = kh;
@@ -4706,7 +4721,7 @@ asm volatile(
for (int idx = tid; idx < kt_tile * head_dim; idx += (int)blockDim.x) {
int kk = idx / head_dim, d = idx - kk * head_dim;
sV[idx] = (kk < tile_keys)
- ? v_cache[(long)(kt0 + kk) * kv_dim + (long)kv_head * head_dim + d]
+ ? v_cache[(long)((kt0 + kk) % max_seq_len) * kv_dim + (long)kv_head * head_dim + d]
: 0.f;
}
__syncthreads();
@@ -4799,7 +4814,9 @@ asm volatile(
for (int t = (int)tid; t < eff_seq; t += 256) {
int abs_t = t + window_start;
float dot = 0.f;
- long k_off = (long)abs_t * (long)kv_dim + (long)kv_head * (long)head_dim;
+ // Ring slot `abs_t % max_seq_len` (max_seq_len = allocated cache size): identity
+ // for a full cache, wraps a window-sized SWA ring. abs_t itself stays logical.
+ long k_off = (long)(abs_t % max_seq_len) * (long)kv_dim + (long)kv_head * (long)head_dim;
for (int dd = 0; dd < head_dim; dd++)
dot += q[q_off + dd] * k_cache[k_off + dd];
shared_scores[t] = dot * scale;
@@ -4841,7 +4858,7 @@ asm volatile(
float acc = 0.f;
for (int t = 0; t < eff_seq; t++) {
int abs_t = t + window_start;
- long v_off = (long)abs_t * (long)kv_dim + (long)kv_head * (long)head_dim;
+ long v_off = (long)(abs_t % max_seq_len) * (long)kv_dim + (long)kv_head * (long)head_dim;
acc += shared_scores[t] * v_cache[v_off + dd];
}
out[out_off + dd] = acc;
@@ -5616,7 +5633,9 @@ asm volatile(
int e = (int)(blockIdx.x * blockDim.x + threadIdx.x);
int i = (int)blockIdx.y;
if (e >= kv_dim || i >= n_tok) return;
- long off = (long)(start_pos + i) * (long)kv_dim + (long)e;
+ // Ring slot `(start_pos+i) % max_seq_len`: identity for a full-context cache
+ // (position < max_seq_len), wraps into a window-sized SWA ring otherwise.
+ long off = (long)((start_pos + i) % max_seq_len) * (long)kv_dim + (long)e;
k_cache[off] = k_all[(long)i * kv_dim + e];
v_cache[off] = v_all[(long)i * kv_dim + e];
}
@@ -5630,7 +5649,9 @@ asm volatile(
int e = (int)(blockIdx.x * blockDim.x + threadIdx.x);
int i = (int)blockIdx.y;
if (e >= kv_dim || i >= n_tok) return;
- long off = (long)(start_pos + i) * (long)kv_dim + (long)e;
+ // Ring slot `(start_pos+i) % max_seq_len` (identity for a full-context cache; wraps a
+ // window-sized ring) — kept in lockstep with the f32 llm_kv_append_batched.
+ long off = (long)((start_pos + i) % max_seq_len) * (long)kv_dim + (long)e;
k_cache[off] = (unsigned short)sharpi_fp32_to_bf16(k_all[(long)i * kv_dim + e]);
v_cache[off] = (unsigned short)sharpi_fp32_to_bf16(v_all[(long)i * kv_dim + e]);
}
diff --git a/src/SharpInference.Engine/CudaForwardPass.cs b/src/SharpInference.Engine/CudaForwardPass.cs
index 977e97d..f6e41ad 100644
--- a/src/SharpInference.Engine/CudaForwardPass.cs
+++ b/src/SharpInference.Engine/CudaForwardPass.cs
@@ -183,6 +183,30 @@ public sealed unsafe class CudaForwardPass : IForwardPass
/// single-shot batch size.
///
private const int PrefillBatchChunk = 4096;
+
+ ///
+ /// Headroom (in positions) added to a Gemma-4 SWA layer's window when sizing its KV
+ /// ring (issue #162). A batched prefill appends a whole batch of K/V before any of
+ /// those queries attend, so the ring must hold the window PLUS one batched-append span
+ /// or the earliest queries' window would be overwritten by the latest appends. The
+ /// widest single append span is the larger of the chunked-prefill window
+ /// () and the 4096 non-flash batched-attention cap, so a
+ /// ring of window + SwaRingHeadroom is always large enough. Capped at the model
+ /// context () — a full-context cache needs no ring at all.
+ ///
+ private const int SwaRingHeadroom = PrefillBatchChunk > 4096 ? PrefillBatchChunk : 4096;
+
+ ///
+ /// Allocated KV-cache size, in positions, for a Gemma-4 sliding-window layer: the
+ /// window plus , capped at the full context. When this
+ /// equals the context the cache is full (the ring modulo in the kernels degenerates to
+ /// the identity); when the context exceeds it the kernels wrap writes/reads modulo this
+ /// size. The value passed as each SWA append/attention call's maxSeqLen argument
+ /// MUST equal this so the kernel's pos % maxSeqLen lands in the right ring slot.
+ ///
+ private static int SwaRingSize(int maxSeqLen, int window) =>
+ (int)Math.Min(maxSeqLen, (long)window + SwaRingHeadroom);
+
///
/// Issue #141: route Q8_0 trunk matmuls in the batched prefill through the
/// compute-bound cuBLAS GEMM ()
@@ -382,6 +406,14 @@ public CudaForwardPass(GgufModel model, CudaBackend gpu, ModelHyperparams hp,
_maxSeqLen = EstimateMaxContextTq(model, gpu, hp, tqFp32Window, tqBits);
else
_maxSeqLen = EstimateMaxContext(model, gpu, hp);
+ // The KV-append/attention kernels index the cache at `pos % _maxSeqLen` (the ring
+ // modulo, identity for full caches), so a zero context — e.g. a malformed GGUF with
+ // context_length=0 reached via an explicit ctx-size — would be an in-kernel
+ // divide-by-zero (GPU trap). Fail loud at construction instead.
+ if (_maxSeqLen < 1)
+ throw new ArgumentException(
+ $"Resolved max context length is {_maxSeqLen}; the model's context_length " +
+ "metadata is missing or zero.", nameof(maxContextLength));
if (_tqEnabled)
{
@@ -548,8 +580,12 @@ void TraceVram(string label)
int layerHd = perLayerKv ? hp.LayerHeadDim![i] : _headDim;
int layerKvDim = _numKvHeads * layerHd;
+ // SWA layers use a window-sized ring (window + headroom for one batched
+ // append span, issue #162); everything else is full-context. The same
+ // SwaRingSize value is passed to the kernels as maxSeqLen so their
+ // pos % maxSeqLen wraps into this exact ring.
int layerCtx = (perLayerKv && hp.IsSwaLayer is { } swa && swa[i])
- ? Math.Min(_maxSeqLen, swaWindow)
+ ? SwaRingSize(_maxSeqLen, swaWindow)
: _maxSeqLen;
_gpuKCache[i] = gpu.Allocate(TensorShape.D1((long)layerCtx * layerKvDim));
_gpuVCache[i] = gpu.Allocate(TensorShape.D1((long)layerCtx * layerKvDim));
@@ -1338,7 +1374,7 @@ private void RunGemma4DeviceRegion(int position)
if (!kvShared)
{
int layerCtx = isSwa && _hp.SlidingWindowSize > 0
- ? Math.Min(_maxSeqLen, _hp.SlidingWindowSize)
+ ? SwaRingSize(_maxSeqLen, _hp.SlidingWindowSize)
: _maxSeqLen;
_gpu.KvAppend(kView, vView, _gpuKCache[layer], _gpuVCache[layer],
kvDimL, position, layerCtx);
@@ -1346,7 +1382,7 @@ private void RunGemma4DeviceRegion(int position)
int effLayerCtx = (_hp.IsSwaLayer is { } swaEff && swaEff[effLayer]
&& _hp.SlidingWindowSize > 0)
- ? Math.Min(_maxSeqLen, _hp.SlidingWindowSize)
+ ? SwaRingSize(_maxSeqLen, _hp.SlidingWindowSize)
: _maxSeqLen;
// Gemma 4 uses attention_scale = 1.0 (no 1/sqrt(head_dim) prefactor). Pass
@@ -1731,13 +1767,13 @@ private ReadOnlySpan ForwardProfiledGemma4(int token, int position)
if (!kvShared)
{
int layerCtx = isSwa && _hp.SlidingWindowSize > 0
- ? Math.Min(_maxSeqLen, _hp.SlidingWindowSize)
+ ? SwaRingSize(_maxSeqLen, _hp.SlidingWindowSize)
: _maxSeqLen;
_gpu.KvAppend(kView, vView, _gpuKCache[layer], _gpuVCache[layer], kvDimL, position, layerCtx);
}
int effLayerCtx = (_hp.IsSwaLayer is { } swaEff && swaEff[effLayer]
&& _hp.SlidingWindowSize > 0)
- ? Math.Min(_maxSeqLen, _hp.SlidingWindowSize)
+ ? SwaRingSize(_maxSeqLen, _hp.SlidingWindowSize)
: _maxSeqLen;
if (isSwa)
_gpu.AttentionSwa(qView, _gpuKCache[effLayer], _gpuVCache[effLayer], attnOutView,
@@ -1851,15 +1887,22 @@ public ReadOnlySpan Prefill(IReadOnlyList tokens, int startPos = 0)
if (BatchedPrefillEnabled && !snapKvActive && N >= 2
&& startPos + N <= _maxSeqLen && IsBatchedPrefillSupported())
{
- // Chunking past 4096 requires a streaming attention path (flash) AND simple
- // (non-windowed) KV position semantics. SWA layers wrap KV in a window-sized
- // ring whose cross-chunk behaviour past the window boundary is not yet
- // validated, so any windowed model keeps the proven 4096 cap (follow-up: #162).
- // Note: `IsSwaLayer` is only populated for Gemma 4 today, so the explicit
- // `SlidingWindowSize <= 0` check fails closed if SWA parsing is later extended
- // to a uniform-window arch that sets the window without a per-layer pattern.
+ // Chunking past 4096 requires a streaming attention path (flash) for the
+ // global (full-causal) layers — the non-flash AttentionBatched caps at 4096.
+ // SWA layers are now correct across chunk boundaries (issue #162): their KV
+ // ring is sized window + SwaRingHeadroom (≥ one chunk span), so appending a
+ // whole chunk before attending never overwrites a still-needed window, and the
+ // flash/append kernels wrap reads/writes modulo the ring (SwaRingSize).
+ //
+ // Fail-closed guard (kept from the pre-#162 gate): the per-layer SWA dispatch
+ // keys off `IsSwaLayer`, which today only Gemma 4 populates. A future arch that
+ // sets a model-wide `SlidingWindowSize` WITHOUT a per-layer `IsSwaLayer` pattern
+ // would run every layer as full-causal (window silently ignored) — harmless
+ // within the proven 4096 cap, but extending that past 4096 would silently drop
+ // the window. So only chunk past 4096 when either there's a real per-layer SWA
+ // pattern or the model has no window at all.
bool canChunkPast4096 = PrefillFlashAttnEnabled
- && _hp.IsSwaLayer is null && _hp.SlidingWindowSize <= 0;
+ && (_hp.IsSwaLayer is not null || _hp.SlidingWindowSize <= 0);
int cap = canChunkPast4096 ? _maxSeqLen : 4096;
if (startPos + N <= cap)
{
@@ -2254,12 +2297,12 @@ void ApplyRopeBatched()
if (!kvShared)
{
- int layerCtx = isSwa && window > 0 ? Math.Min(_maxSeqLen, window) : _maxSeqLen;
+ int layerCtx = isSwa && window > 0 ? SwaRingSize(_maxSeqLen, window) : _maxSeqLen;
_gpu.KvAppendBatched(kAll, vAll, _gpuKCache[layer], _gpuVCache[layer], kvDimL, startPos, layerCtx, N);
}
int effLayerCtx = (_hp.IsSwaLayer is { } swaEff && swaEff[effLayer] && window > 0)
- ? Math.Min(_maxSeqLen, window) : _maxSeqLen;
+ ? SwaRingSize(_maxSeqLen, window) : _maxSeqLen;
if (s_prefillProfile) { _gpu.Synchronize(); _profSw.Restart(); }
// Gemma 4: attention_scale = 1.0, passed explicitly (kernel skips its rsqrtf).
@@ -2815,6 +2858,10 @@ public static int EstimateMaxContext(GgufModel model, CudaBackend gpu, ModelHype
if (hp.LayerHeadDim is { } lhd && hp.IsSwaLayer is { } swa)
{
int swaWindow = hp.SlidingWindowSize > 0 ? hp.SlidingWindowSize : int.MaxValue;
+ // SWA layers are sized as a ring of window + SwaRingHeadroom positions (issue
+ // #162), so the cap for the per-token byte formula is the ring size, not the
+ // bare window. Guard against overflow when swaWindow is "unbounded".
+ long swaCap = swaWindow == int.MaxValue ? long.MaxValue : (long)swaWindow + SwaRingHeadroom;
long globalKvDimPerToken = 0;
long swaKvDimPerToken = 0;
for (int i = 0; i < hp.NumLayers; i++)
@@ -2827,24 +2874,26 @@ public static int EstimateMaxContext(GgufModel model, CudaBackend gpu, ModelHype
else globalKvDimPerToken += layerKvDim;
}
// For a given maxCtx C: bytes = globalKvDimPerToken * C
- // + swaKvDimPerToken * min(C, swaWindow)
+ // + swaKvDimPerToken * min(C, swaCap)
// Solve for the largest C ≤ hp.ContextLength that fits in `available`.
- // Branch on whether C ≤ swaWindow:
- // if C ≤ swaWindow: bytes = (global+swa) * C
- // else: bytes = global * C + swa * swaWindow
+ // Branch on whether C ≤ swaCap:
+ // if C ≤ swaCap: bytes = (global+swa) * C
+ // else: bytes = global * C + swa * swaCap
long globalPlusSwa = globalKvDimPerToken + swaKvDimPerToken;
int candA = globalPlusSwa > 0 ? (int)(available / globalPlusSwa) : int.MaxValue;
int maxCtxL;
- if (candA <= swaWindow)
+ if (candA <= swaCap)
{
maxCtxL = candA;
}
else
{
- long remain = available - swaKvDimPerToken * swaWindow;
+ // swaCap is finite here (candA ≤ long.MaxValue always takes the branch
+ // above when swaCap is unbounded), so swaKvDimPerToken * swaCap is safe.
+ long remain = available - swaKvDimPerToken * swaCap;
int candB = globalKvDimPerToken > 0 && remain > 0
? (int)(remain / globalKvDimPerToken) : 0;
- maxCtxL = Math.Max(swaWindow, candB);
+ maxCtxL = (int)Math.Max(swaCap, candB);
}
return Math.Clamp(maxCtxL, 512, hp.ContextLength);
}
diff --git a/src/SharpInference.Engine/CudaHybridForwardPass.cs b/src/SharpInference.Engine/CudaHybridForwardPass.cs
index a624f46..524b178 100644
--- a/src/SharpInference.Engine/CudaHybridForwardPass.cs
+++ b/src/SharpInference.Engine/CudaHybridForwardPass.cs
@@ -446,6 +446,16 @@ void TraceVram(string label)
// Gemma 4: each GPU layer sizes its KV cache by per-layer head_dim and
// (for SWA layers) caps at SlidingWindowSize. Non-gemma4 stays at the
// model-wide head_dim × full context.
+ //
+ // This path is DECODE-ONLY for Gemma 4 (IsBatchedPrefillSupported requires
+ // !_isGemma4Like), so the SWA ring only needs to hold `window` positions:
+ // each token appends one K/V then attends its window before the next append,
+ // so the bare-window ring is never overwritten early. The shared kernels'
+ // `pos % max_seq_len` (max_seq_len == this layerCtx) wraps writes/reads into
+ // it — which also fixes the pre-#162 latent OOB here (positions ≥ window used
+ // to index a window-sized cache absolutely). It deliberately does NOT use
+ // CudaForwardPass.SwaRingSize (the window+chunk headroom) because that's only
+ // needed by the batched-append (chunked-prefill) path, which never runs here.
int layerHd = _isGemma4Like ? hp.LayerHeadDim![i] : _headDim;
int layerKvDim = _numKvHeads * layerHd;
int layerCtx = (_isGemma4Like && hp.IsSwaLayer is { } swa && swa[i] && hp.SlidingWindowSize > 0)
diff --git a/tests/SharpInference.Tests.ForwardPass/Gemma4CudaBatchedPrefillTests.cs b/tests/SharpInference.Tests.ForwardPass/Gemma4CudaBatchedPrefillTests.cs
index d9259fb..94e61fb 100644
--- a/tests/SharpInference.Tests.ForwardPass/Gemma4CudaBatchedPrefillTests.cs
+++ b/tests/SharpInference.Tests.ForwardPass/Gemma4CudaBatchedPrefillTests.cs
@@ -325,6 +325,160 @@ public void Gemma4_E4B_BatchedPrefill_GemmOff_MatchesSequentialBitExact()
$"(expected >95%); maxAbs={maxAbs}.");
}
+ ///
+ /// Issue #162 (SWA sub-item): a Gemma-4 prompt LONGER than the 512-token sliding window
+ /// but still under the 4096 single-batch cap must produce correct output. This is the
+ /// scenario the old window-sized-cache-with-absolute-indexing got wrong (positions ≥
+ /// window read/wrote out of bounds). With the SWA ring (cache sized window +
+ /// SwaRingHeadroom, capped at ctx) the per-token loop is itself correct again, so the
+ /// batched path is checked against it. ctx=1024 makes the SWA cache full (1024 <
+ /// window+headroom), so this isolates the windowed-attention-past-the-window fix from
+ /// the ring-wrap fix (covered by the chunked test below).
+ ///
+ [Fact]
+ public void Gemma4_E4B_BatchedPrefill_PastWindow_MatchesSequential()
+ {
+ using var gpu = TryCreate();
+ if (gpu is null) return;
+ var path = FindModelPath();
+ if (path is null) return;
+
+ using var model = GgufModel.Open(path);
+ var hp = ModelHyperparams.FromGgufMetadata(model.Metadata, model);
+ Assert.NotNull(hp.LayerHeadDim);
+ Assert.True(hp.SlidingWindowSize > 0 && hp.SlidingWindowSize < 700,
+ $"Test assumes a sliding window < the 700-token prompt; got {hp.SlidingWindowSize}.");
+
+ // 700 > window (512): the trailing queries' windows exclude the earliest tokens,
+ // exactly the regime that the absolute-into-window-sized-cache bug corrupted.
+ var tokens = MakeTokens(700);
+
+ using var fwd = new CudaForwardPass(model, gpu, hp, maxContextLength: 1024);
+
+ fwd.BatchedPrefillEnabled = true; // shipped defaults (flash TC on)
+ var batched = fwd.Prefill(tokens).ToArray();
+ Assert.True(fwd.LastPrefillWasBatched);
+
+ fwd.ResetCache();
+ fwd.BatchedPrefillEnabled = false;
+ var sequential = fwd.Prefill(tokens).ToArray();
+ Assert.False(fwd.LastPrefillWasBatched);
+
+ Assert.Equal(sequential.Length, batched.Length);
+ Assert.Equal(Argmax(sequential), Argmax(batched));
+
+ // Logit-envelope check on top of argmax: the flash TC path rounds to fp16 over 42
+ // layers + softcap, so a few tenths of a logit is expected; a whole-number gap would
+ // signal a wiring bug (e.g. a ring slot reading the wrong position) that argmax alone
+ // could miss when it doesn't quite flip the top token.
+ float maxAbs = 0f;
+ for (int i = 0; i < sequential.Length; i++)
+ maxAbs = MathF.Max(maxAbs, MathF.Abs(sequential[i] - batched[i]));
+ Assert.True(maxAbs < 1.5f,
+ $"Past-window batched vs per-token logits diverged beyond fp16 tolerance: maxAbs={maxAbs}.");
+
+ var seqTop = TopKSet(sequential, 5);
+ var batTop = TopKSet(batched, 5);
+ int overlap = 0;
+ foreach (var t in batTop) if (seqTop.Contains(t)) overlap++;
+ Assert.True(overlap >= 4,
+ $"Past-window batched top-5 overlaps the per-token reference in only {overlap}/5 slots.");
+ }
+
+ ///
+ /// Issue #162 (SWA sub-item): a Gemma-4 prompt longer than the 4096 cap must take the
+ /// chunked batched path (flash streaming the prior KV) and stay argmax-stable vs the
+ /// per-token loop, with both running through the SWA KV ring.
+ ///
+ /// This is the decisive ring oracle: the per-token loop attends right after each single
+ /// append (so it only ever needs ring ≥ window) while the chunked path appends a whole
+ /// 4096-token chunk before any of those queries attend (so it needs ring ≥ window +
+ /// chunk span). If the ring were undersized, the chunked path would overwrite an
+ /// earlier query's window and diverge from the per-token reference — so agreement
+ /// validates the window + SwaRingHeadroom sizing.
+ ///
+ ///
+ /// The prompt lengths make the LAST chunk SHORTER than the 512 window (4296 = 4096+200,
+ /// 8292 = 2·4096+100). Prefill returns only the final token's logits, and with a short
+ /// last chunk the final token's window reaches BACK across the preceding chunk boundary —
+ /// so the single observable logit genuinely depends on cross-chunk KV reads (and, for
+ /// 8292 at ctx 9216 > window+headroom=4608, ring-WRAPPED reads). A whole-chunk prompt
+ /// (e.g. 8192) would leave the final window entirely inside the last chunk, hiding any
+ /// cross-chunk ring overwrite from the assertion. 8292 also drives a 3-chunk loop.
+ ///
+ ///
+ [Theory]
+ [InlineData(4296, 6144)]
+ [InlineData(8292, 9216)]
+ public void Gemma4_E4B_ChunkedBatchedPrefill_Over4096_MatchesSequential(int promptLen, int ctx)
+ {
+ using var gpu = TryCreate();
+ if (gpu is null) return;
+ var path = FindModelPath();
+ if (path is null) return;
+
+ using var model = GgufModel.Open(path);
+ var hp = ModelHyperparams.FromGgufMetadata(model.Metadata, model);
+ Assert.NotNull(hp.LayerHeadDim);
+ Assert.True(hp.SlidingWindowSize > 0);
+
+ var longTokens = MakeTokens(promptLen);
+
+ // Disable SnapKV (a >budget prompt would otherwise route to the per-token SnapKV
+ // eviction path before the batched gate). Construct under the override, then restore.
+ var prevSnap = Environment.GetEnvironmentVariable("SHARPI_SNAPKV_BUDGET");
+ Environment.SetEnvironmentVariable("SHARPI_SNAPKV_BUDGET", "0");
+ CudaForwardPass fwd;
+ try { fwd = new CudaForwardPass(model, gpu, hp, maxContextLength: ctx); }
+ finally { Environment.SetEnvironmentVariable("SHARPI_SNAPKV_BUDGET", prevSnap); }
+ using var _fwd = fwd;
+
+ fwd.BatchedPrefillEnabled = true; // shipped defaults (flash TC on) → chunked path
+ var batched = fwd.Prefill(longTokens).ToArray();
+ Assert.True(fwd.LastPrefillWasBatched,
+ "Chunked batched prefill did not engage for a >4096-token Gemma 4 prompt (#162).");
+
+ fwd.ResetCache();
+ fwd.BatchedPrefillEnabled = false;
+ var sequential = fwd.Prefill(longTokens).ToArray();
+ Assert.False(fwd.LastPrefillWasBatched);
+
+ Assert.Equal(sequential.Length, batched.Length);
+ Assert.Equal(Argmax(sequential), Argmax(batched));
+
+ // Envelope check: the chunked flash path reassociates the softmax over thousands of
+ // streamed keys, so it's looser than the ≤4096 case, but a several-unit gap would
+ // still flag a ring-overwrite bug (an early query reading a clobbered window slot)
+ // that argmax-stability alone might not surface.
+ float maxAbs = 0f;
+ for (int i = 0; i < sequential.Length; i++)
+ maxAbs = MathF.Max(maxAbs, MathF.Abs(sequential[i] - batched[i]));
+ Assert.True(maxAbs < 2.0f,
+ $"Chunked batched vs per-token logits diverged beyond fp tolerance: maxAbs={maxAbs}.");
+
+ var seqTop = TopKSet(sequential, 5);
+ var batTop = TopKSet(batched, 5);
+ int overlap = 0;
+ foreach (var t in batTop) if (seqTop.Contains(t)) overlap++;
+ Assert.True(overlap >= 4,
+ $"Chunked batched top-5 overlaps the per-token reference in only {overlap}/5 slots.");
+ }
+
+ // Deterministic spread across the vocab via a small LCG; all ids well within Gemma 4's
+ // vocab. Token 2 (BOS) leads so the prompt starts in-distribution.
+ private static int[] MakeTokens(int n)
+ {
+ var t = new int[n];
+ t[0] = 2;
+ uint s = 0x9E3779B9u;
+ for (int i = 1; i < n; i++)
+ {
+ s = s * 1664525u + 1013904223u;
+ t[i] = (int)(s % 100000u) + 5;
+ }
+ return t;
+ }
+
private static HashSet TopKSet(ReadOnlySpan logits, int k)
{
var idx = new int[logits.Length];