Skip to content

Commit 7699860

Browse files
ChenhanYuclaude
andcommitted
address PR review feedback
Code simplification: - _get_attn_fn: use ALL_ATTENTION_FUNCTIONS directly, remove _eager_attention - Remove dead code: _original_forward_cls, _base_forward, mlp_bias, _psg_debug - Remove redundant DFlashExporter.__init__ - Extract helpers: _build_noise_embedding, _build_position_ids, _build_draft_attention_mask, _compute_loss from forward() - Merge duplicate estimate_ar if-branches in eagle_utils.py - Rename make_eagle_supervised_data_module -> make_speculative_data_module - Reuse shared path constants from modeling_fakebase in _find_base_model_parts - Fix import_plugin: nest hf_dflash under transformers block Config and defaults: - Move static defaults to dflash/default_config.py (matching EAGLE pattern) - Align block_size config default to 8 (matching recipe) - Expose answer_only_loss as configurable training arg (recipe default: true for DFlash, false for EAGLE3) - Add shift_labels option for correct label/mask alignment (shifted for autoregressive EAGLE3, unshifted for diffusion DFlash) - Pin transformers>=4.58 for Qwen3.5 support Bug fixes: - _apply meta buffer fix: always check (no one-shot flag), skip .to(device) in modify() when base model is on meta device (from_pretrained context) - Attention masks: create directly in target dtype (matching EAGLE convention) - Combine labels and attention_mask for loss mask (LabelSmoother.ignore_index) - validate_online: don't count rejection correction token as accepted - trust_remote_code: pass through from config instead of hardcoding True - Restore --trust_remote_code in export_hf_checkpoint.py - Add quantization-aware export to DFlashExporter (matching EAGLE pattern) - Remove ShareGPT format: only accept OpenAI messages format - Validate HEAD_NODE_IP in multi-node training script Script reorganization: - Move scripts to common/specdec/ (mode-agnostic for EAGLE/DFlash) - Add generalized vllm_smoke_test.sh (configurable via SPEC_METHOD) - Auto-export last checkpoint after training (rank 0 only) - Remove standalone export.sh (merged into training script) - Add Qwen3.5-4B DFlash launcher example Documentation and tests: - Add answer_only_loss / chat template limitation note in dflash.md - Add TODO for epoch-seeded anchor sampling and co-training - Add tests: validate_online, DFlashExporter, _ensure_generation_tags - Remove misleading "Legacy" comments on utility functions Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
1 parent 64dd332 commit 7699860

23 files changed

Lines changed: 823 additions & 657 deletions

File tree

examples/speculative_decoding/doc/dflash.md

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ Reference: [arXiv:2602.06036](https://arxiv.org/abs/2602.06036) |
99

1010
## Architecture
1111

12-
```
12+
```text
1313
Target Model (frozen)
1414
1515
├─ hidden_states[layer 1, 9, 17, 25, 33] ──► concat ──► FC + RMSNorm ──► target_hidden
@@ -43,7 +43,7 @@ Target Model (frozen)
4343

4444
Given context `"The answer is"` and block_size=4 with anchor `"is"`:
4545

46-
```
46+
```text
4747
Target model hidden states (from frozen base model):
4848
h["The"] h["answer"] h["is"] ← target_hidden (ctx_len=3)
4949
│ │ │
@@ -87,7 +87,7 @@ In each DFlash decoder layer:
8787

8888
**Training vs Inference:**
8989

90-
```
90+
```text
9191
TRAINING (2 anchors, block_size=4):
9292
9393
Context tokens: "The" "answer" "is" "5" "."
@@ -166,13 +166,25 @@ See [`modelopt_recipes/general/speculative_decoding/dflash.yaml`](../../../model
166166
| `dflash.dflash_architecture_config.mask_token_id` | auto | Token ID for masked positions |
167167
| `training.answer_only_loss` | false | Mask loss on non-assistant tokens |
168168

169+
> **Note on `answer_only_loss` and chat templates:** When `answer_only_loss=true`, the
170+
> dataset loader replaces the tokenizer's chat template with a simplified version that has
171+
> `{% generation %}` tags to identify assistant turns. This simplified template may not
172+
> support all features of the original (e.g., tool use formatting, multi-turn system
173+
> prompts). During serving, the draft model reuses the target model's original tokenizer
174+
> and template, so there is no train/inference mismatch in the tokenization itself — only
175+
> the loss masking during training uses the simplified template. However, if training data
176+
> contains tool-use conversations with model-family-specific formatting, the simplified
177+
> template may tokenize them differently, affecting which tokens get masked. For best
178+
> results with tool-use data, set `answer_only_loss=false` or provide a custom
179+
> `chat_template` that supports both generation tags and tool-use formatting.
180+
169181
### Random Anchor Sampling (`num_anchors`)
170182

171183
During training, anchor positions are sampled randomly from valid (assistant response)
172184
tokens in each batch, rather than dividing the sequence into fixed blocks. Each anchor
173185
starts a block of `block_size` tokens where the draft model predicts positions 1..B-1.
174186

175-
```
187+
```text
176188
Sequence: [SYS] You helpful [USR] What 2+3? [AST] The answer is 5
177189
Position: 0 1 2 3 4 5 6 7 8 9 10
178190
loss_mask: 0 0 0 0 0 0 0 1 1 1 1
@@ -208,7 +220,7 @@ The exponential decay factor (gamma) weights early block positions higher than l
208220
If position 1 in a block is wrong, all subsequent positions are rejected in speculative
209221
decoding. Decay aligns the training loss with what matters for acceptance rate.
210222

211-
```
223+
```text
212224
weight[k] = exp(-(k-1).clamp(min=0) / gamma) for k = 0..B-1
213225
```
214226

@@ -324,8 +336,8 @@ ModelOpt wins acceptance length on 7/8 categories and TPS on 8/8 categories.
324336
- **FP8 / NVFP4 quantization**: Export pipeline supports quantized checkpoints via
325337
`hf_ptq.py` (PTQ succeeded in testing). AR impact of quantization not yet measured.
326338
The flow: train (bf16) → `mtq.quantize(model, quant_cfg)``export_hf_checkpoint.py`.
327-
- **Checkpoint resume**: `DFlashModule._apply()` handles meta-tensor rotary buffers.
328-
Validated in training runs but not covered by integration tests.
339+
- **Checkpoint resume**: `DFlashModule._apply()` handles meta-tensor rotary buffers
340+
(one-shot check on first `.to(device)` call). Validated in train+resume E2E tests.
329341

330342
### Validated
331343

@@ -334,10 +346,12 @@ ModelOpt wins acceptance length on 7/8 categories and TPS on 8/8 categories.
334346
- **AR evaluation**: `ar_validate.py` with online GT, per-category MT-Bench.
335347
- **vLLM deployment**: Speculative decoding with `vllm/vllm-openai:nightly` (v0.19.1+).
336348
3.1x speedup over baseline. Per-category benchmarks on MT-Bench.
349+
337350
```bash
338351
vllm serve Qwen/Qwen3-8B \
339352
--speculative-config '{"method": "dflash", "model": "path/to/checkpoint", "num_speculative_tokens": 7}' \
340353
--max-num-batched-tokens 32768
341354
```
355+
342356
- **Export**: z-lab compatible HF format, loadable by vLLM and z-lab benchmark.
343357
- **Loss decay**: Validated +0.12 AR improvement with gamma=7 (bs16).

examples/speculative_decoding/eagle_utils.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -137,12 +137,19 @@ def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
137137
return batch
138138

139139

140-
def make_eagle_supervised_data_module(
140+
def make_speculative_data_module(
141141
tokenizer: transformers.PreTrainedTokenizer,
142142
data_args,
143143
train_len=None,
144144
answer_only_loss=False,
145+
shift_labels=True,
145146
) -> dict:
147+
"""Create data module for speculative decoding training.
148+
149+
Args:
150+
shift_labels: If True, labels are shifted by 1 for autoregressive training (EAGLE3).
151+
If False, labels are unshifted for diffusion-style training (DFlash).
152+
"""
146153
if data_args.offline_data_path is None:
147154
train_dataset = ShardedDataset("json", data_files=data_args.data_path)
148155

@@ -152,6 +159,7 @@ def make_eagle_supervised_data_module(
152159
train_len=train_len,
153160
return_labels=True,
154161
answer_only_loss=answer_only_loss,
162+
shift_labels=shift_labels,
155163
)
156164
else:
157165
data_collator = VisionLanguageDataCollator(
@@ -213,6 +221,11 @@ def on_log(self, args, state, control, **kwargs):
213221
print_rank_0(f"Step {state.global_step} Training Acc: [{acc_str}]")
214222
except Exception:
215223
print_rank_0(f"Step {state.global_step} Training Acc: {average_acc}")
224+
# Log accuracy to HF Trainer's logs dict (picked up by TensorBoard)
225+
logs = kwargs.get("logs") or {}
226+
for i, draft_acc in enumerate(average_acc):
227+
for j, step_acc in enumerate(draft_acc):
228+
logs[f"train_acc/parallel_{i}_step_{j}"] = float(step_acc)
216229
if self.estimate_ar:
217230
# Calculate mean training AR since last log
218231
# NOTE: This is only an estimate of the real AR.
@@ -226,13 +239,6 @@ def on_log(self, args, state, control, **kwargs):
226239
acc_cumprod *= draft_acc[-1]
227240
est_ar += acc_cumprod
228241
print_rank_0(f"Step {state.global_step} Estimated Training AR: {est_ar:.4f}")
229-
230-
# Log accuracy to HF Trainer's logs dict (picked up by TensorBoard)
231-
logs = kwargs.get("logs") or {}
232-
for i, draft_acc in enumerate(average_acc):
233-
for j, step_acc in enumerate(draft_acc):
234-
logs[f"train_acc/parallel_{i}_step_{j}"] = float(step_acc)
235-
if self.estimate_ar:
236242
logs["estimated_training_ar"] = est_ar
237243

238244
# log to wandb

examples/speculative_decoding/main.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from eagle_utils import (
4141
EagleTrainerWithAccLog,
4242
EagleTrainingPlot,
43-
make_eagle_supervised_data_module,
43+
make_speculative_data_module,
4444
patch_ring_attention_for_ttt,
4545
)
4646
from omegaconf import OmegaConf
@@ -108,6 +108,12 @@ class TrainingArguments(transformers.TrainingArguments):
108108
default=False, metadata={"help": "Whether to estimate AR using training accuracy to log."}
109109
)
110110
ar_validate_steps: int = field(default=1000, metadata={"help": "AR validation interval."})
111+
answer_only_loss: bool = field(
112+
default=False,
113+
metadata={
114+
"help": "Mask loss on non-assistant tokens. Default: True for dflash, False for eagle3."
115+
},
116+
)
111117
cp_size: int = field(default=1, metadata={"help": "Context parallelism size."})
112118
dp_shard_size: int | None = field(
113119
default=None,
@@ -262,12 +268,14 @@ def train():
262268
raise Exception(f"{training_args.mode} is not supported!")
263269

264270
print_rank_0("Loading dataset...")
271+
is_dflash = training_args.mode == "dflash"
265272
if training_args.mode in ("eagle3", "dflash"):
266-
data_module = make_eagle_supervised_data_module(
273+
data_module = make_speculative_data_module(
267274
tokenizer,
268275
data_args,
269276
train_len=training_args.training_seq_len,
270-
answer_only_loss=(training_args.mode == "dflash"),
277+
answer_only_loss=training_args.answer_only_loss,
278+
shift_labels=not is_dflash,
271279
)
272280

273281
trainer = EagleTrainerWithAccLog(
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
accelerate==1.12.0
2-
transformers<5.4
2+
transformers>=4.58,<5.4

examples/speculative_decoding/scripts/export_hf_checkpoint.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def parse_args():
2929
description="Export a HF checkpoint (with ModelOpt state) for deployment."
3030
)
3131
parser.add_argument("--model_path", type=str, default="Path of the trained checkpoint.")
32+
parser.add_argument("--trust_remote_code", action="store_true", help="Trust remote code")
3233
parser.add_argument(
3334
"--export_path", type=str, default="Destination directory for exported files."
3435
)
@@ -38,7 +39,9 @@ def parse_args():
3839
mto.enable_huggingface_checkpointing()
3940

4041
args = parse_args()
41-
model = load_vlm_or_llm(args.model_path, torch_dtype="auto")
42+
model = load_vlm_or_llm(
43+
args.model_path, torch_dtype="auto", trust_remote_code=args.trust_remote_code
44+
)
4245
model.eval()
4346
with torch.inference_mode():
4447
export_speculative_decoding(

modelopt/torch/export/plugins/hf_spec_export.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -253,10 +253,6 @@ class DFlashExporter(SpeculativeDecodingExporter):
253253
- config.json: Qwen3-style config with dflash_config field
254254
"""
255255

256-
def __init__(self, model: nn.Module):
257-
"""Initialize the DFlashExporter."""
258-
super().__init__(model)
259-
260256
def _extract_state_dict(self, full_state_dict: dict):
261257
"""Extract DFlash module weights, stripping the dflash_module prefix."""
262258
export_sd = {}
@@ -316,7 +312,9 @@ def _export_config(self):
316312
),
317313
"rope_scaling": getattr(base_config, "rope_scaling", None),
318314
"tie_word_embeddings": False,
319-
"torch_dtype": str(getattr(base_config, "torch_dtype", torch.bfloat16)).replace("torch.", ""),
315+
"torch_dtype": str(getattr(base_config, "torch_dtype", torch.bfloat16)).replace(
316+
"torch.", ""
317+
),
320318
"num_target_layers": getattr(base_config, "num_hidden_layers", 36),
321319
}
322320

@@ -333,18 +331,31 @@ def export(self, export_dir: Path | str, dtype: torch.dtype | None = None):
333331
export_dir = Path(export_dir)
334332
export_dir.mkdir(parents=True, exist_ok=True)
335333

334+
# Export quantized modules if applicable
335+
if has_quant_opt(self.model):
336+
from ..unified_export_hf import _export_transformers_checkpoint
337+
338+
full_sd, hf_quant_config = _export_transformers_checkpoint(self.model, dtype)
339+
else:
340+
full_sd, hf_quant_config = self.model.state_dict(), None
341+
336342
# Export state dict
337-
full_sd = self.model.state_dict()
338343
drafter_sd = self._extract_state_dict(full_sd)
339-
if dtype is not None:
344+
if dtype is not None and hf_quant_config is None:
340345
drafter_sd = {k: v.to(dtype) for k, v in drafter_sd.items()}
341346
save_file(drafter_sd, f"{export_dir}/model.safetensors")
342347

343348
# Export config
344349
drafter_config = self._export_config()
350+
if hf_quant_config is not None:
351+
drafter_config["quantization_config"] = hf_quant_config
345352
with open(f"{export_dir}/config.json", "w") as f:
346353
json.dump(drafter_config, f, indent=2)
347354

355+
if hf_quant_config is not None:
356+
with open(f"{export_dir}/hf_quant_config.json", "w") as f:
357+
json.dump(hf_quant_config, f, indent=2)
358+
348359
print(
349360
f"Exported DFlash draft model: {len(drafter_sd)} tensors, "
350361
f"config keys: {list(drafter_config.keys())[:5]}..."

modelopt/torch/speculative/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ class DFlashConfig(ModeloptBaseConfig):
6464
"""DFlash config for block-wise parallel speculative decoding."""
6565

6666
dflash_block_size: int = ModeloptField(
67-
default=16,
67+
default=8,
6868
description="Block size for parallel prediction. Draft predicts this many tokens per block.",
6969
)
7070

modelopt/torch/speculative/dflash/default_config.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,20 @@
1616
"""Default DFlash architecture config.
1717
1818
Model-specific settings (hidden_size, num_attention_heads, rope_*, etc.)
19-
are inherited from the base model in HFDFlashModel.modify(). Only
20-
DFlash-specific defaults are set here.
19+
are inherited from the base model in HFDFlashModel.modify(). Static
20+
defaults that don't depend on the base model are set here, similar to
21+
``eagle/default_config.py``.
2122
"""
2223

2324
default_dflash_config = {
25+
# DFlash-specific
2426
"num_hidden_layers": 5,
27+
# Architecture defaults (overridable by user config)
28+
"hidden_act": "silu",
2529
"rms_norm_eps": 1e-06,
30+
"initializer_range": 0.02,
2631
"attention_bias": False,
2732
"attention_dropout": 0.0,
33+
"tie_word_embeddings": False,
34+
"_attn_implementation": "sdpa",
2835
}

modelopt/torch/speculative/plugins/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,5 @@
3030
from .megatron_medusa import *
3131

3232
with import_plugin("transformers"):
33-
from .transformers import *
34-
35-
with import_plugin("hf_dflash"):
3633
from .hf_dflash import *
34+
from .transformers import *

0 commit comments

Comments
 (0)