Skip to content

Commit a279960

Browse files
committed
fix: DDP deadlock when no valid loss positions on a rank
When a rank's batch has no valid loss positions (e.g., all tokens in Block 0 which is excluded), the loss was a detached zero tensor with no connection to dflash_module parameters. DDP waited forever for gradient sync on those parameters → NCCL ALLREDUCE timeout. Fix: use logits.sum() * 0.0 as zero loss, which maintains the computation graph through dflash_module parameters so DDP can sync zero gradients properly. Also revert to super().forward() for training (matching EAGLE pattern) and add --ddp_find_unused_parameters True, --ddp_timeout 300. Root cause analysis: rank 4 completed ALLREDUCE #272 and proceeded to ALLGATHER #273, while other ranks were stuck at ALLREDUCE #272. This indicated rank 4 had a different backward graph (no gradients for dflash_module on that rank). Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
1 parent 2c42363 commit a279960

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

modelopt/torch/speculative/plugins/hf_dflash.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,9 @@ def forward(
523523
preds = active_logits.argmax(dim=-1)
524524
accuracy = (preds == active_labels).float().mean().item()
525525
else:
526-
loss = torch.tensor(0.0, device=device, dtype=dtype, requires_grad=True)
526+
# No valid positions — compute a zero loss that still flows through
527+
# dflash_module parameters to keep DDP gradient sync happy
528+
loss = logits.sum() * 0.0
527529
accuracy = 0.0
528530

529531
return ModelOutput(

0 commit comments

Comments
 (0)