perf(ce): dtype-aware num_warps (Blackwell-gated)#1267
Conversation
7cdf3f5 to
fdc3c24
Compare
|
@justinhh4 Let's rebase this PR on the latest changes and maybe just report numbers on improvement this specific PR provides on B200. Let's not include the positive results from #1266 |
| from triton.language.math import tanh | ||
|
|
||
| # log2(e): lets us emit the hardware ex2.approx instruction via e^x = 2^(x * LOG2_E) | ||
| LOG2_E = 1.4426950408889634 |
| scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0)) | ||
| m_new = tl.maximum(m, block_max) | ||
| d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new)) | ||
| d = d * tl.exp2((m - m_new) * LOG2_E) + tl.sum(tl.exp2((X_block - m_new) * LOG2_E)) |
| if is_hip(): | ||
| ce_num_warps = 16 | ||
| else: | ||
| is_blackwell = _input.is_cuda and torch.cuda.get_device_capability(_input.device)[0] >= 10 |
There was a problem hiding this comment.
maybe use infer_device() from utils instead of _input.is_cuda?
Also, maybe let's define another function in utils which returns the gpu type: "hopper", "blackwell", "ampere" etc
There was a problem hiding this comment.
@vaibhavjindal thanks for the suggestions. I have rebased and added a gpu type helper function and updated the numbers and figures to reflect the gains from just this PR.
fdc3c24 to
4cfe820
Compare
|
@justinhh4 let's actually wait for #1273 to get merged. You can use the changes made in that PR. |
|
@justinhh4 #1273 is merged now. Let's use the utils added in that PR. |
Replace the PR-local get_gpu_arch() helper in ops/utils.py with the infer_device_arch() utility merged in linkedin#1273 (src/liger_kernel/utils.py). The Blackwell gate now reads infer_device_arch().startswith("blackwell"), which covers the whole sm_100+ generation (blackwell / blackwell_ultra / blackwell_consumer) — matching the original major>=10 intent. bf16/fp16 on Blackwell -> 8 warps; fp32 and Hopper/earlier -> 32; AMD -> 16. num_warps is a scheduling-only parameter, so there is no correctness impact. Full CE suite (177 tests, bf16+fp32) passes on B200 (sm_100). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
|
@vaibhavjindal done. Used the infer_device_arch function from #1273 utils to choose number of warps. |
vaibhavjindal
left a comment
There was a problem hiding this comment.
Thanks for making the changes and addressing the comments. LGTM!
CE
num_warps: dtype- & arch-aware (Blackwell-gated)Rebased onto current
main(post-#1266). All numbers below use #1266 (v1_exp2) as thebaseline, so they isolate this PR's contribution and exclude #1266's exp2 gains. Measured on
B200, which is what this gate targets. Correctness: full CE suite (bf16+fp32, with
weight/softcap/label_smoothing/z_loss/ignore_index combos) passes.
What & why
Replace the hardcoded
num_warps=32in CE forward with a dtype- and arch-aware choice:get_gpu_arch()inops/utils.pythat returns"blackwell"/"hopper"/"ampere"/"". It usesinfer_device()as the CUDA guard andexcludes AMD via
is_hip()(sinceinfer_device()reports AMD as"cuda"too). The gate readsget_gpu_arch() == "blackwell".dominant stall is "Not Selected" from a surplus of warps) — fewer warps cut issue contention.
Hopper/earlier and fp32 are bandwidth-bound and keep 32 to hide memory latency.
element_size() == 2) on sm_100+; fp8 (1 B) and fp32 keep 32warps (unmeasured → safe default).
num_warpsis purely a scheduling parameter → no correctness impact.Headline (BT=8192, V=128256, full fwd+bwd; baseline = #1266)
Only bf16/fp16 fires the gate; fp32 keeps 32 warps and is flat — shown to confirm no regression.
bf16 across context length (V=128256; baseline = #1266)
Gain grows with batch size (−3% → −17%; avg −12% over the sweep).
bf16 across vocab (BT=8192; baseline = #1266)
−17.1% @ V=32k → −8.6% @ V=262k (avg −13.0%) — larger relative win at smaller vocab.
Figures