Skip to content

perf(swiglu): architecture-aware column tiling for Blackwell (B200)#1271

Merged
vaibhavjindal merged 5 commits into
linkedin:mainfrom
Celaena24:b200-swiglu-optimization
Jul 2, 2026
Merged

perf(swiglu): architecture-aware column tiling for Blackwell (B200)#1271
vaibhavjindal merged 5 commits into
linkedin:mainfrom
Celaena24:b200-swiglu-optimization

Conversation

@Celaena24

@Celaena24 Celaena24 commented Jun 26, 2026

Copy link
Copy Markdown
Contributor

Summary

The SwiGLU (SiLUMul) Triton kernels launch one program per row with
BLOCK_SIZE = next_pow2(n_cols) — tuned for H100, but it leaves the backward kernel
occupancy-starved on B200. This PR adds a column-tiled 2D-grid path selected only on
Blackwell (SM 10.x) for wide rows
, giving 1.63× backward / 1.41× full step at the
Llama-3 8B shape, bit-for-bit identical numerics (zero tolerance change). Hopper /
Ampere and narrow rows keep the original kernels unchanged → no regression off
Blackwell
(H100 cross-validated at 1.00×).

Details

What & why

Decouple BLOCK_SIZE from n_cols: split each row into fixed 1024-wide column tiles
over a 2D grid (n_rows, cdiv(n_cols, BLOCK_SIZE)), instead of one program covering the
whole row (BLOCK_SIZE = next_pow2(n_cols), e.g. 16384 for Llama-3 8B's 14336).

  • Why it's faster: the one-row block spans the full hidden dim → 70 regs/thread → only
    1 block/SM → the backward kernel is latency-bound (22.7% occupancy, 42% DRAM). Smaller
    tiles drop to 32 regs/thread → 8 blocks/SM, so the memory-bound kernel saturates HBM.
    Access stays coalesced; a column mask handles non-divisible n_cols.
  • Bit-exact: only launch geometry changes — outputs match the original kernel exactly.
  • Tile = 1024: a per-shape sweep over {512…8192} found 1024 best-or-tied for backward
    on every shape (2048 ties; 8192 regresses).
  • Width gate (next_pow2(n_cols) >= 16384): narrow rows aren't register-heavy, so
    tiling only adds overhead (~2–3% regression). They — plus all of H100/A100, which
    already run ~90% of the memory roofline — dispatch to the exact original one-row
    kernels
    , byte-for-byte upstream.
  • Bonus: lifts the original MAX_FUSED_SIZE = 65536 one-row cap on n_cols.

NCU diagnosis (8192 × 14336, bf16, B200 — backward kernel)

metric original (one-row) optimized (tiled)
achieved occupancy 22.7% 83.9%
registers / thread 70 → 1 block/SM 32 → 8 blocks/SM
DRAM throughput 42% 83%
warps issued / scheduler 0.39 (idle ~61%) 0.82
backward SOL efficiency 50.0% 81.5%

Results — Llama-3 8B (isolated SiLUMul, n_cols=14336, bf16, B200)

Headline at T=8192, then across context length (win grows with T as more rows expose the
occupancy gap):

pass baseline this PR latency throughput
forward 0.1141 ms 0.1106 ms −3.1% +3.2%
backward 0.2937 ms 0.1802 ms −38.6% +63.0%
full 0.4020 ms 0.2855 ms −29.0% +40.8%
T backward full backward SOL (orig → opt)
1024 1.52× 1.34× 42.1% → 63.8%
2048 1.56× 1.39× 45.9% → 71.8%
4096 1.61× 1.41× 48.4% → 77.9%
8192 1.63× 1.41× 50.0% → 81.5%

Results — across model shapes (isolated SiLUMul, T = 1024…8192, B200)

Same model set as the existing SwiGLU benchmarks. Only n_cols (= intermediate_size)
and dtype affect the activation kernel; ranges span T=1024→8192.

config dtype n_cols tiled? backward forward full
llama-8B bf16 14336 1.51–1.63× 1.03–1.10× 1.34–1.40×
qwen2.5-7B bf16 18944 1.15–1.21× 1.48–1.49× 1.28–1.34×
llama-7B bf16 11008 1.20–1.24× 0.98–1.00× 1.12–1.15×
llama-7B fp16 11008 1.19–1.24× 0.94–1.00× 1.11–1.15×
llama-7B fp32 11008 1.03–1.05× 0.96–1.01× 1.01–1.04×
small bf16 4096 ➖ fallback 1.00× 1.00× 1.00×
  • Biggest backward wins on the widest bf16 rows (14336). qwen2.5-7B (18944) gets a
    smaller backward win but a large forward win (its original forward is itself
    under-occupied at that width).
  • fp32 is near-flat (already close to roofline) but never regresses on the tiled path.
  • small (4096) is below the width gate → falls back → exactly 1.00× at every T.

H100 cross-validation (no regression off Blackwell)

Re-run on a real H100 (self-contained bench_swiglu_h100.py, n_cols=14336, bf16).
With the arch dispatch, the H100 path executes the identical original kernel:

T backward speedup full speedup
1024 1.000× 1.000×
2048 1.001× 1.000×
4096 1.000× 1.000×
8192 1.000× 1.000×

(An earlier unconditional-tiling version regressed H100 by 1–3% because the one-row
backward there is already at ~90% SOL; dispatching to the original kernels removes that
entirely.)

Testing Done

  • Bit-exact correctness vs the original one-row kernel: forward 0.00e+00,
    backward (da, db) 0.00e+00 max abs diff — on both the tiled path (n_cols=14336) and
    the fallback path (n_cols=4096), across T = 1024–8192.

  • Validated the dispatch: tiled on B200 (SM 10.0) for wide rows, original one-row path on
    H100 (SM 9.0) and for narrow rows on B200; all bit-exact on their respective GPUs.

  • Multi-model sweep (6 configs × T∈{1024,2048,4096,8192}): backward avg 1.21×, max 1.63×,
    min 1.00× (the min is the width-gated small shape, i.e. no regression).

  • Hardware Type: NVIDIA B200 (SM 10.0) + NVIDIA H100 80GB HBM3 (SM 9.0)

  • run make test to ensure correctness

  • run make checkstyle to ensure code style

  • run make test-convergence to ensure convergence

Comment thread src/liger_kernel/ops/swiglu.py Outdated
_SWIGLU_TILE_MIN_BLOCK = 16384


def _is_blackwell():

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be an enumeration over all the devices instead of a boolean?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe infer cuda device capability once in src/liger_kernel/ops/utils.py so all op impls can query?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on infer cuda device arch. Sent a PR for it in #1273.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, will rebase on #1273 once it lands and swap to the shared helper

@yueyiming2009

Copy link
Copy Markdown
Collaborator

Nice results, and the bit-exact framing is great. One gap worth closing before merge: the tiled path has no automated coverage. _should_tile requires _is_blackwell() (SM 10.x) and next_power_of_2(n_cols) >= 16384 (i.e. n_cols >= 8193). CI doesn't run on B200, and every shape in test/transformers/test_swiglu.py has n_cols <= 512, so both _swiglu_forward_kernel_tiled / _swiglu_backward_kernel_tiled and the new column-mask logic are never executed by the suite — correctness rests entirely on the manual B200 runs in the description.

Since the tiled kernels are a launch-geometry change with no Blackwell-only instructions, we can force the gate on and assert bit-exactness against the original one-row kernel on any CUDA GPU. Suggested addition to test/transformers/test_swiglu.py (needs one extra import: import liger_kernel.ops.swiglu as swiglu_ops):

import liger_kernel.ops.swiglu as swiglu_ops


@pytest.mark.skipif(not torch.cuda.is_available(), reason="Blackwell tiled SwiGLU path is CUDA-only")
@pytest.mark.parametrize(
    "n_rows, n_cols",
    [
        (4, 11009),  # wide + NOT a multiple of the 1024 tile -> exercises the column mask
        (3, 14337),  # ragged final tile, odd row count
        (4, 16384),  # exactly tile-aligned (16 full tiles)
    ],
)
@pytest.mark.parametrize("gate_multiplier", [1.0, 1.3])
@pytest.mark.parametrize(
    "dtype",
    [
        torch.float32,
        pytest.param(
            torch.bfloat16,
            marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
        ),
    ],
)
def test_swiglu_blackwell_tiled_matches_original(monkeypatch, n_rows, n_cols, gate_multiplier, dtype):
    """The Blackwell column-tiled path must be bit-for-bit identical to the one-row kernel.

    The dispatch normally only fires on SM 10.x, so CI never reaches it. The tiled kernels
    use no Blackwell-only instructions, so we force the gate on and compare against the
    original kernel on any CUDA GPU. Widths land on a ragged final tile (n_cols % 1024 != 0)
    to cover the column mask.
    """
    torch.manual_seed(0)
    a = torch.randn(n_rows, n_cols, device=device, dtype=dtype)
    b = torch.randn(n_rows, n_cols, device=device, dtype=dtype)
    dc = torch.randn(n_rows, n_cols, device=device, dtype=dtype)

    def run():
        a_ = a.clone().detach().requires_grad_(True)
        b_ = b.clone().detach().requires_grad_(True)
        c = LigerSiLUMulFunction.apply(a_, b_, gate_multiplier)
        c.backward(dc)
        return c.detach(), a_.grad.detach(), b_.grad.detach()

    # Original one-row kernel (gate forced off).
    monkeypatch.setattr(swiglu_ops, "_is_blackwell", lambda: False)
    assert not swiglu_ops._should_tile(n_cols)
    c_ref, da_ref, db_ref = run()

    # Forced Blackwell tiled kernel (gate forced on).
    monkeypatch.setattr(swiglu_ops, "_is_blackwell", lambda: True)
    assert swiglu_ops._should_tile(n_cols)
    c_tiled, da_tiled, db_tiled = run()

    # Launch-geometry change only -> require exact equality (the PR's headline claim).
    torch.testing.assert_close(c_tiled, c_ref, rtol=0, atol=0)
    torch.testing.assert_close(da_tiled, da_ref, rtol=0, atol=0)
    torch.testing.assert_close(db_tiled, db_ref, rtol=0, atol=0)

Notes:

  • assert _should_tile(...) / assert not _should_tile(...) guard against the dispatch silently changing under you, so the test can't pass by accidentally running the same kernel twice.
  • The non-1024-divisible widths (11009, 14337) are the important ones — they're what actually exercises the col_offsets < n_cols mask on the ragged final tile; the benchmark shapes (14336, 18944) mostly miss it.
  • rtol=0, atol=0 directly encodes the "bit-for-bit identical" claim; if the geometry ever perturbs a value this fails loudly.
  • All widths stay under calculate_settings' 65536 cap so the reference (one-row) path is valid.

@Celaena24

Copy link
Copy Markdown
Contributor Author

@yueyiming2009 thanks for catching that, just added the test coverage in 33d64ef.

@Celaena24 Celaena24 force-pushed the b200-swiglu-optimization branch from 86ef172 to 33d64ef Compare June 28, 2026 23:49
@vaibhavjindal

Copy link
Copy Markdown
Collaborator

@Celaena24 let's rebase on #1273, it is merged now.

@Celaena24 Celaena24 force-pushed the b200-swiglu-optimization branch from 9093e5b to 1b2c9c3 Compare July 2, 2026 17:54
@Celaena24

Copy link
Copy Markdown
Contributor Author

@vaibhavjindal rebased

Comment thread src/liger_kernel/ops/swiglu.py Outdated
# occupancy-starved on Blackwell for wide rows. Splitting each row into fixed
# 1024-wide column tiles over a 2D grid raises occupancy so HBM saturates
# (~1.63x backward at n_cols=14336, bit-exact). Tile size 1024 chosen by sweep.
# See optimization/swiglu/report.md for the full diagnosis and numbers.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optimization/swiglu/report.md reference doesn't look valid. Probably an LLM artifact. We should remove this reference.

Comment thread src/liger_kernel/ops/swiglu.py Outdated

def _should_tile(n_cols):
"""Use the Blackwell column-tiled path only for wide-enough rows."""
return _is_blackwell() and triton.next_power_of_2(n_cols) >= _SWIGLU_TILE_MIN_BLOCK

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just replace _is_blackwell() with infer_device_arch().startswith("blackwell"). It's a simple function, no need to define a separate _is_blackwell() function.

Comment thread src/liger_kernel/ops/swiglu.py Outdated
return _is_blackwell() and triton.next_power_of_2(n_cols) >= _SWIGLU_TILE_MIN_BLOCK


def swiglu_tile_settings(n_cols):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: let's rename it by adding underscore in front of the function name as it should be a private function.

swiglu_tile_settings() ---> _swiglu_tile_settings()

@vaibhavjindal vaibhavjindal left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a few minor comments

Comment thread src/liger_kernel/ops/swiglu.py Outdated
def swiglu_tile_settings(n_cols):
"""Pick the column-tile BLOCK_SIZE and num_warps for the Blackwell 2D-grid path."""
BLOCK_SIZE = min(_SWIGLU_TILE, triton.next_power_of_2(n_cols))
num_warps = 8 if BLOCK_SIZE >= 2048 else 4

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since _SWIGLU_TILE is 1024, BLOCK_SIZE will never be greater than 1024. Thus, num_warps is just gonna be 4 at all times. Let's just return num_warps=4?

@Celaena24

Copy link
Copy Markdown
Contributor Author

@vaibhavjindal made the changes

@vaibhavjindal vaibhavjindal left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the changes, LGTM! 🚢

@vaibhavjindal vaibhavjindal enabled auto-merge July 2, 2026 21:41
@vaibhavjindal vaibhavjindal added this pull request to the merge queue Jul 2, 2026
Merged via the queue into linkedin:main with commit c5d3e24 Jul 2, 2026
5 of 7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants