Skip to content

perf(ce): dtype-aware num_warps (Blackwell-gated)#1267

Merged
vaibhavjindal merged 3 commits into
linkedin:mainfrom
justinhh4:ce-opt-2-numwarps
Jul 2, 2026
Merged

perf(ce): dtype-aware num_warps (Blackwell-gated)#1267
vaibhavjindal merged 3 commits into
linkedin:mainfrom
justinhh4:ce-opt-2-numwarps

Conversation

@justinhh4

@justinhh4 justinhh4 commented Jun 25, 2026

Copy link
Copy Markdown
Contributor

CE num_warps: dtype- & arch-aware (Blackwell-gated)

Rebased onto current main (post-#1266). All numbers below use #1266 (v1_exp2) as the
baseline
, so they isolate this PR's contribution and exclude #1266's exp2 gains. Measured on
B200, which is what this gate targets. Correctness: full CE suite (bf16+fp32, with
weight/softcap/label_smoothing/z_loss/ignore_index combos) passes.

What & why

Replace the hardcoded num_warps=32 in CE forward with a dtype- and arch-aware choice:

  • Blackwell (sm_100+) bf16/fp16 → 8; everything else (Hopper & earlier, all fp32) → 32; AMD → 16.
  • Adds a small reusable helper get_gpu_arch() in ops/utils.py that returns
    "blackwell" / "hopper" / "ampere" / "". It uses infer_device() as the CUDA guard and
    excludes AMD via is_hip() (since infer_device() reports AMD as "cuda" too). The gate reads
    get_gpu_arch() == "blackwell".
  • Why 8 on Blackwell bf16: CE there is instruction-issue-bound (ncu: ALU is the top pipe; the
    dominant stall is "Not Selected" from a surplus of warps) — fewer warps cut issue contention.
    Hopper/earlier and fp32 are bandwidth-bound and keep 32 to hide memory latency.
  • Scoped to exactly bf16/fp16 (element_size() == 2) on sm_100+; fp8 (1 B) and fp32 keep 32
    warps (unmeasured → safe default). num_warps is purely a scheduling parameter → no correctness impact.

Headline (BT=8192, V=128256, full fwd+bwd; baseline = #1266)

dtype baseline (#1266) this PR latency throughput
bf16 1.686 ms 1.438 ms −14.7% +17.2%
fp32 1.892 ms 1.892 ms +0.0% (unchanged) +0.0%

Only bf16/fp16 fires the gate; fp32 keeps 32 warps and is flat — shown to confirm no regression.

bf16 across context length (V=128256; baseline = #1266)

BT baseline this PR latency
1024 0.405 ms 0.391 ms −3.4%
2048 0.583 ms 0.542 ms −7.0%
4096 0.950 ms 0.839 ms −11.6%
8192 1.686 ms 1.438 ms −14.7%
16384 3.184 ms 2.694 ms −15.4%
32768 6.092 ms 5.101 ms −16.3%
65536 11.958 ms 9.924 ms −17.0%

Gain grows with batch size (−3% → −17%; avg −12% over the sweep).

bf16 across vocab (BT=8192; baseline = #1266)

−17.1% @ V=32k → −8.6% @ V=262k (avg −13.0%) — larger relative win at smaller vocab.

Figures

pr2_bf16_bt pr2_bf16_vocab pr2_fp32_bt pr2_fp32_vocab

@justinhh4 justinhh4 force-pushed the ce-opt-2-numwarps branch 2 times, most recently from 7cdf3f5 to fdc3c24 Compare June 26, 2026 17:48
@vaibhavjindal

Copy link
Copy Markdown
Collaborator

@justinhh4 Let's rebase this PR on the latest changes and maybe just report numbers on improvement this specific PR provides on B200. Let's not include the positive results from #1266

Comment thread src/liger_kernel/ops/cross_entropy.py Outdated
from triton.language.math import tanh

# log2(e): lets us emit the hardware ex2.approx instruction via e^x = 2^(x * LOG2_E)
LOG2_E = 1.4426950408889634

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's rebase

scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0))
m_new = tl.maximum(m, block_max)
d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new))
d = d * tl.exp2((m - m_new) * LOG2_E) + tl.sum(tl.exp2((X_block - m_new) * LOG2_E))

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Comment thread src/liger_kernel/ops/cross_entropy.py Outdated
if is_hip():
ce_num_warps = 16
else:
is_blackwell = _input.is_cuda and torch.cuda.get_device_capability(_input.device)[0] >= 10

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe use infer_device() from utils instead of _input.is_cuda?

Also, maybe let's define another function in utils which returns the gpu type: "hopper", "blackwell", "ampere" etc

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vaibhavjindal thanks for the suggestions. I have rebased and added a gpu type helper function and updated the numbers and figures to reflect the gains from just this PR.

@justinhh4 justinhh4 force-pushed the ce-opt-2-numwarps branch from fdc3c24 to 4cfe820 Compare June 30, 2026 18:12
@vaibhavjindal

Copy link
Copy Markdown
Collaborator

@justinhh4 let's actually wait for #1273 to get merged. You can use the changes made in that PR.
cc @yueyiming2009

@vaibhavjindal

Copy link
Copy Markdown
Collaborator

@justinhh4 #1273 is merged now. Let's use the utils added in that PR.

Replace the PR-local get_gpu_arch() helper in ops/utils.py with the
infer_device_arch() utility merged in linkedin#1273 (src/liger_kernel/utils.py).

The Blackwell gate now reads infer_device_arch().startswith("blackwell"),
which covers the whole sm_100+ generation (blackwell / blackwell_ultra /
blackwell_consumer) — matching the original major>=10 intent. bf16/fp16 on
Blackwell -> 8 warps; fp32 and Hopper/earlier -> 32; AMD -> 16. num_warps is
a scheduling-only parameter, so there is no correctness impact.

Full CE suite (177 tests, bf16+fp32) passes on B200 (sm_100).

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@justinhh4

Copy link
Copy Markdown
Contributor Author

@vaibhavjindal done. Used the infer_device_arch function from #1273 utils to choose number of warps.

@vaibhavjindal vaibhavjindal left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making the changes and addressing the comments. LGTM!

@vaibhavjindal vaibhavjindal added this pull request to the merge queue Jul 2, 2026
Merged via the queue into linkedin:main with commit 0b93bc9 Jul 2, 2026
5 of 7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants