@@ -9,7 +9,7 @@ Reference: [arXiv:2602.06036](https://arxiv.org/abs/2602.06036) |
99
1010## Architecture
1111
12- ```
12+ ``` text
1313Target Model (frozen)
1414 │
1515 ├─ hidden_states[layer 1, 9, 17, 25, 33] ──► concat ──► FC + RMSNorm ──► target_hidden
@@ -43,7 +43,7 @@ Target Model (frozen)
4343
4444Given context ` "The answer is" ` and block_size=4 with anchor ` "is" ` :
4545
46- ```
46+ ``` text
4747Target model hidden states (from frozen base model):
4848 h["The"] h["answer"] h["is"] ← target_hidden (ctx_len=3)
4949 │ │ │
@@ -87,7 +87,7 @@ In each DFlash decoder layer:
8787
8888** Training vs Inference:**
8989
90- ```
90+ ``` text
9191TRAINING (2 anchors, block_size=4):
9292
9393 Context tokens: "The" "answer" "is" "5" "."
@@ -166,13 +166,25 @@ See [`modelopt_recipes/general/speculative_decoding/dflash.yaml`](../../../model
166166| ` dflash.dflash_architecture_config.mask_token_id ` | auto | Token ID for masked positions |
167167| ` training.answer_only_loss ` | false | Mask loss on non-assistant tokens |
168168
169+ > ** Note on ` answer_only_loss ` and chat templates:** When ` answer_only_loss=true ` , the
170+ > dataset loader replaces the tokenizer's chat template with a simplified version that has
171+ > ` {% generation %} ` tags to identify assistant turns. This simplified template may not
172+ > support all features of the original (e.g., tool use formatting, multi-turn system
173+ > prompts). During serving, the draft model reuses the target model's original tokenizer
174+ > and template, so there is no train/inference mismatch in the tokenization itself — only
175+ > the loss masking during training uses the simplified template. However, if training data
176+ > contains tool-use conversations with model-family-specific formatting, the simplified
177+ > template may tokenize them differently, affecting which tokens get masked. For best
178+ > results with tool-use data, set ` answer_only_loss=false ` or provide a custom
179+ > ` chat_template ` that supports both generation tags and tool-use formatting.
180+
169181### Random Anchor Sampling (` num_anchors ` )
170182
171183During training, anchor positions are sampled randomly from valid (assistant response)
172184tokens in each batch, rather than dividing the sequence into fixed blocks. Each anchor
173185starts a block of ` block_size ` tokens where the draft model predicts positions 1..B-1.
174186
175- ```
187+ ``` text
176188Sequence: [SYS] You helpful [USR] What 2+3? [AST] The answer is 5
177189Position: 0 1 2 3 4 5 6 7 8 9 10
178190loss_mask: 0 0 0 0 0 0 0 1 1 1 1
@@ -208,7 +220,7 @@ The exponential decay factor (gamma) weights early block positions higher than l
208220If position 1 in a block is wrong, all subsequent positions are rejected in speculative
209221decoding. Decay aligns the training loss with what matters for acceptance rate.
210222
211- ```
223+ ``` text
212224weight[k] = exp(-(k-1).clamp(min=0) / gamma) for k = 0..B-1
213225```
214226
@@ -324,8 +336,8 @@ ModelOpt wins acceptance length on 7/8 categories and TPS on 8/8 categories.
324336- ** FP8 / NVFP4 quantization** : Export pipeline supports quantized checkpoints via
325337 ` hf_ptq.py ` (PTQ succeeded in testing). AR impact of quantization not yet measured.
326338 The flow: train (bf16) → ` mtq.quantize(model, quant_cfg) ` → ` export_hf_checkpoint.py ` .
327- - ** Checkpoint resume** : ` DFlashModule._apply() ` handles meta-tensor rotary buffers.
328- Validated in training runs but not covered by integration tests.
339+ - ** Checkpoint resume** : ` DFlashModule._apply() ` handles meta-tensor rotary buffers
340+ (one-shot check on first ` .to(device) ` call). Validated in train+resume E2E tests.
329341
330342### Validated
331343
@@ -334,10 +346,12 @@ ModelOpt wins acceptance length on 7/8 categories and TPS on 8/8 categories.
334346- ** AR evaluation** : ` ar_validate.py ` with online GT, per-category MT-Bench.
335347- ** vLLM deployment** : Speculative decoding with ` vllm/vllm-openai:nightly ` (v0.19.1+).
336348 3.1x speedup over baseline. Per-category benchmarks on MT-Bench.
349+
337350 ``` bash
338351 vllm serve Qwen/Qwen3-8B \
339352 --speculative-config ' {"method": "dflash", "model": "path/to/checkpoint", "num_speculative_tokens": 7}' \
340353 --max-num-batched-tokens 32768
341354 ```
355+
342356- ** Export** : z-lab compatible HF format, loadable by vLLM and z-lab benchmark.
343357- ** Loss decay** : Validated +0.12 AR improvement with gamma=7 (bs16).
0 commit comments