perf(swiglu): architecture-aware column tiling for Blackwell (B200)#1271
Conversation
| _SWIGLU_TILE_MIN_BLOCK = 16384 | ||
|
|
||
|
|
||
| def _is_blackwell(): |
There was a problem hiding this comment.
should this be an enumeration over all the devices instead of a boolean?
There was a problem hiding this comment.
maybe infer cuda device capability once in src/liger_kernel/ops/utils.py so all op impls can query?
There was a problem hiding this comment.
+1 on infer cuda device arch. Sent a PR for it in #1273.
There was a problem hiding this comment.
Sounds good, will rebase on #1273 once it lands and swap to the shared helper
|
Nice results, and the bit-exact framing is great. One gap worth closing before merge: the tiled path has no automated coverage. Since the tiled kernels are a launch-geometry change with no Blackwell-only instructions, we can force the gate on and assert bit-exactness against the original one-row kernel on any CUDA GPU. Suggested addition to import liger_kernel.ops.swiglu as swiglu_ops
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Blackwell tiled SwiGLU path is CUDA-only")
@pytest.mark.parametrize(
"n_rows, n_cols",
[
(4, 11009), # wide + NOT a multiple of the 1024 tile -> exercises the column mask
(3, 14337), # ragged final tile, odd row count
(4, 16384), # exactly tile-aligned (16 full tiles)
],
)
@pytest.mark.parametrize("gate_multiplier", [1.0, 1.3])
@pytest.mark.parametrize(
"dtype",
[
torch.float32,
pytest.param(
torch.bfloat16,
marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
),
],
)
def test_swiglu_blackwell_tiled_matches_original(monkeypatch, n_rows, n_cols, gate_multiplier, dtype):
"""The Blackwell column-tiled path must be bit-for-bit identical to the one-row kernel.
The dispatch normally only fires on SM 10.x, so CI never reaches it. The tiled kernels
use no Blackwell-only instructions, so we force the gate on and compare against the
original kernel on any CUDA GPU. Widths land on a ragged final tile (n_cols % 1024 != 0)
to cover the column mask.
"""
torch.manual_seed(0)
a = torch.randn(n_rows, n_cols, device=device, dtype=dtype)
b = torch.randn(n_rows, n_cols, device=device, dtype=dtype)
dc = torch.randn(n_rows, n_cols, device=device, dtype=dtype)
def run():
a_ = a.clone().detach().requires_grad_(True)
b_ = b.clone().detach().requires_grad_(True)
c = LigerSiLUMulFunction.apply(a_, b_, gate_multiplier)
c.backward(dc)
return c.detach(), a_.grad.detach(), b_.grad.detach()
# Original one-row kernel (gate forced off).
monkeypatch.setattr(swiglu_ops, "_is_blackwell", lambda: False)
assert not swiglu_ops._should_tile(n_cols)
c_ref, da_ref, db_ref = run()
# Forced Blackwell tiled kernel (gate forced on).
monkeypatch.setattr(swiglu_ops, "_is_blackwell", lambda: True)
assert swiglu_ops._should_tile(n_cols)
c_tiled, da_tiled, db_tiled = run()
# Launch-geometry change only -> require exact equality (the PR's headline claim).
torch.testing.assert_close(c_tiled, c_ref, rtol=0, atol=0)
torch.testing.assert_close(da_tiled, da_ref, rtol=0, atol=0)
torch.testing.assert_close(db_tiled, db_ref, rtol=0, atol=0)Notes:
|
|
@yueyiming2009 thanks for catching that, just added the test coverage in 33d64ef. |
86ef172 to
33d64ef
Compare
|
@Celaena24 let's rebase on #1273, it is merged now. |
9093e5b to
1b2c9c3
Compare
|
@vaibhavjindal rebased |
| # occupancy-starved on Blackwell for wide rows. Splitting each row into fixed | ||
| # 1024-wide column tiles over a 2D grid raises occupancy so HBM saturates | ||
| # (~1.63x backward at n_cols=14336, bit-exact). Tile size 1024 chosen by sweep. | ||
| # See optimization/swiglu/report.md for the full diagnosis and numbers. |
There was a problem hiding this comment.
optimization/swiglu/report.md reference doesn't look valid. Probably an LLM artifact. We should remove this reference.
|
|
||
| def _should_tile(n_cols): | ||
| """Use the Blackwell column-tiled path only for wide-enough rows.""" | ||
| return _is_blackwell() and triton.next_power_of_2(n_cols) >= _SWIGLU_TILE_MIN_BLOCK |
There was a problem hiding this comment.
Let's just replace _is_blackwell() with infer_device_arch().startswith("blackwell"). It's a simple function, no need to define a separate _is_blackwell() function.
| return _is_blackwell() and triton.next_power_of_2(n_cols) >= _SWIGLU_TILE_MIN_BLOCK | ||
|
|
||
|
|
||
| def swiglu_tile_settings(n_cols): |
There was a problem hiding this comment.
nit: let's rename it by adding underscore in front of the function name as it should be a private function.
swiglu_tile_settings() ---> _swiglu_tile_settings()
vaibhavjindal
left a comment
There was a problem hiding this comment.
Left a few minor comments
| def swiglu_tile_settings(n_cols): | ||
| """Pick the column-tile BLOCK_SIZE and num_warps for the Blackwell 2D-grid path.""" | ||
| BLOCK_SIZE = min(_SWIGLU_TILE, triton.next_power_of_2(n_cols)) | ||
| num_warps = 8 if BLOCK_SIZE >= 2048 else 4 |
There was a problem hiding this comment.
Since _SWIGLU_TILE is 1024, BLOCK_SIZE will never be greater than 1024. Thus, num_warps is just gonna be 4 at all times. Let's just return num_warps=4?
|
@vaibhavjindal made the changes |
vaibhavjindal
left a comment
There was a problem hiding this comment.
Thanks for the changes, LGTM! 🚢
Summary
The SwiGLU (
SiLUMul) Triton kernels launch one program per row withBLOCK_SIZE = next_pow2(n_cols)— tuned for H100, but it leaves the backward kerneloccupancy-starved on B200. This PR adds a column-tiled 2D-grid path selected only on
Blackwell (SM 10.x) for wide rows, giving 1.63× backward / 1.41× full step at the
Llama-3 8B shape, bit-for-bit identical numerics (zero tolerance change). Hopper /
Ampere and narrow rows keep the original kernels unchanged → no regression off
Blackwell (H100 cross-validated at 1.00×).
Details
What & why
Decouple
BLOCK_SIZEfromn_cols: split each row into fixed 1024-wide column tilesover a 2D grid
(n_rows, cdiv(n_cols, BLOCK_SIZE)), instead of one program covering thewhole row (
BLOCK_SIZE = next_pow2(n_cols), e.g. 16384 for Llama-3 8B's 14336).1 block/SM → the backward kernel is latency-bound (22.7% occupancy, 42% DRAM). Smaller
tiles drop to 32 regs/thread → 8 blocks/SM, so the memory-bound kernel saturates HBM.
Access stays coalesced; a column mask handles non-divisible
n_cols.on every shape (2048 ties; 8192 regresses).
next_pow2(n_cols) >= 16384): narrow rows aren't register-heavy, sotiling only adds overhead (~2–3% regression). They — plus all of H100/A100, which
already run ~90% of the memory roofline — dispatch to the exact original one-row
kernels, byte-for-byte upstream.
MAX_FUSED_SIZE = 65536one-row cap onn_cols.NCU diagnosis (8192 × 14336, bf16, B200 — backward kernel)
Results — Llama-3 8B (isolated SiLUMul, n_cols=14336, bf16, B200)
Headline at T=8192, then across context length (win grows with T as more rows expose the
occupancy gap):
Results — across model shapes (isolated SiLUMul, T = 1024…8192, B200)
Same model set as the existing SwiGLU benchmarks. Only
n_cols(=intermediate_size)and dtype affect the activation kernel; ranges span T=1024→8192.
smaller backward win but a large forward win (its original forward is itself
under-occupied at that width).
H100 cross-validation (no regression off Blackwell)
Re-run on a real H100 (self-contained
bench_swiglu_h100.py, n_cols=14336, bf16).With the arch dispatch, the H100 path executes the identical original kernel:
(An earlier unconditional-tiling version regressed H100 by 1–3% because the one-row
backward there is already at ~90% SOL; dispatching to the original kernels removes that
entirely.)
Testing Done
Bit-exact correctness vs the original one-row kernel: forward
0.00e+00,backward (da, db)
0.00e+00max abs diff — on both the tiled path (n_cols=14336) andthe fallback path (n_cols=4096), across T = 1024–8192.
Validated the dispatch: tiled on B200 (SM 10.0) for wide rows, original one-row path on
H100 (SM 9.0) and for narrow rows on B200; all bit-exact on their respective GPUs.
Multi-model sweep (6 configs × T∈{1024,2048,4096,8192}): backward avg 1.21×, max 1.63×,
min 1.00× (the min is the width-gated small shape, i.e. no regression).
Hardware Type: NVIDIA B200 (SM 10.0) + NVIDIA H100 80GB HBM3 (SM 9.0)
run
make testto ensure correctnessrun
make checkstyleto ensure code stylerun
make test-convergenceto ensure convergence