[trainer,hparams,docs] feat: add CRD (Centered Reward Distillation) algorithm#121
Open
yuanzhi-zhu wants to merge 1 commit intoX-GenGroup:mainfrom
Open
[trainer,hparams,docs] feat: add CRD (Centered Reward Distillation) algorithm#121yuanzhi-zhu wants to merge 1 commit intoX-GenGroup:mainfrom
yuanzhi-zhu wants to merge 1 commit intoX-GenGroup:mainfrom
Conversation
…lgorithm Implements Centered Reward Distillation (arXiv:2603.14128) as a new decoupled RL trainer for flow-matching models. Key changes: - `trainers/crd.py`: Full CRDTrainer implementation with old/sampling model snapshots, dual-direction centering loss, adaptive KL, and per-step velocity-space implicit reward estimation - `hparams/training_args.py`: CRDTrainingArguments with paper-aligned defaults (decay schedules, kl_beta=0.1, kl_cfg=4.5) - `trainers/registry.py`: Register 'crd' key - `hparams/__init__.py`: Export CRDTrainingArguments - `examples/crd/lora/sd3_5.yaml`: SD3.5 + OCR example config matching paper Table 3 hyperparameters (K=24, 2 grad steps, timestep_range=0.99) - `guidance/algorithms.md`: CRD section with hyperparameter reference and centering modes table - `.agents/knowledge/architecture.md`: Add CRD to trainer registry table Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Pull request overview
Adds a new decoupled RL trainer implementing Centered Reward Distillation (CRD) for flow-matching models, along with corresponding hyperparameters, registry wiring, and user/internal documentation.
Changes:
- Introduces
CRDTrainerwith old/sampling parameter snapshots, centered reward-matching loss, and reference-model KL regularization. - Adds
CRDTrainingArgumentsand registers the new'crd'trainer + hparams key. - Documents CRD usage/hyperparameters and provides an SD3.5 + OCR example config.
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| src/flow_factory/trainers/registry.py | Registers 'crd' → CRDTrainer for dynamic trainer loading. |
| src/flow_factory/trainers/crd.py | New CRD trainer implementation (sampling/optimization/loss/KL + snapshot decay). |
| src/flow_factory/hparams/training_args.py | Adds CRDTrainingArguments and registers it under 'crd'. |
| src/flow_factory/hparams/init.py | Exposes CRDTrainingArguments from the hparams package. |
| guidance/algorithms.md | Adds CRD algorithm documentation and hyperparameter reference. |
| examples/crd/lora/sd3_5.yaml | Adds a paper-aligned CRD LoRA config example for SD3.5 + OCR reward. |
| .agents/knowledge/architecture.md | Updates internal architecture docs to include CRD in the trainer registry table. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
+768
to
+793
| # Gather r_theta across all GPUs for centering | ||
| r_theta_gathered = self.accelerator.gather(r_theta_local.detach()).to( | ||
| self.accelerator.device | ||
| ) | ||
|
|
||
| # 5. Compute advantages for CRD centering | ||
| adv = batch['advantage'] | ||
| adv_clip_range = self.training_args.adv_clip_range | ||
| adv_clipped = torch.clamp(adv, adv_clip_range[0], adv_clip_range[1]) | ||
|
|
||
| # Normalize to [0, 1] | ||
| normalized_adv = (adv_clipped / max(adv_clip_range)) / 2.0 + 0.5 | ||
| adv_cur_rank = torch.clamp(normalized_adv, 0, 1) | ||
|
|
||
| # Gather advantages across all GPUs | ||
| adv_cur = self.accelerator.gather(adv_cur_rank.detach()).to( | ||
| self.accelerator.device | ||
| ) | ||
|
|
||
| # 6. Centered Reward Distillation loss (supports dual-direction centering) | ||
| ori_policy_loss = self._compute_crd_loss( | ||
| adv_cur=adv_cur, | ||
| adv_cur_rank=adv_cur_rank, | ||
| r_theta_gathered=r_theta_gathered, | ||
| r_theta_local=r_theta_local, | ||
| ) |
Comment on lines
+380
to
+405
| def _blend_named_params(self, name: str, decay: float): | ||
| """ | ||
| Blend a named parameter snapshot towards the current trainable parameters. | ||
|
|
||
| Formula: ``snapshot = decay * snapshot + (1 - decay) * current`` | ||
|
|
||
| Args: | ||
| name: Name of the parameter snapshot. | ||
| decay: Blending coefficient. 0.0 = full copy, 1.0 = no change. | ||
| """ | ||
| if decay <= 0.0: | ||
| # Full copy from current params (no blending) | ||
| self.adapter.update_named_parameters(name) | ||
| elif decay >= 1.0: | ||
| # Keep snapshot unchanged (fully offline) | ||
| pass | ||
| else: | ||
| # Exponential blending: snapshot = decay * snapshot + (1 - decay) * current | ||
| info = self.adapter._named_parameters[name] | ||
| current_params = self.adapter._get_component_parameters(info.target_components) | ||
| with torch.no_grad(): | ||
| for ema_param, param in zip(info.ema_wrapper.ema_parameters, current_params, strict=True): | ||
| ema_param.data.mul_(decay).add_( | ||
| param.detach().to(ema_param.device), alpha=(1.0 - decay) | ||
| ) | ||
|
|
Comment on lines
+804
to
+810
| if self.reward_adaptive_kl: | ||
| # Linearly scale KL based on reward value | ||
| raw_reward = adv_cur_rank # Already in [0, 1] | ||
| base_beta = 1e-4 | ||
| min_coef = base_beta / max(self.kl_beta, 1e-8) | ||
| kl_loss = self.kl_beta * torch.mean((min_coef + raw_reward * (1 - min_coef)) * kl_div) | ||
| else: |
Comment on lines
+306
to
+313
| ### Centering Modes (`weight_temp`) | ||
|
|
||
| | `weight_temp` | Mode | Description | | ||
| |---|---|---| | ||
| | `< 0` | Uniform (τ→∞) | Simple mean centering; recommended default | | ||
| | `== 0` | Hard selection | Positive pool (adv > 0) vs negative pool (adv < 0) | | ||
| | `> 0` | Softmax temperature | Dual-direction: `softmax(adv/τ)` and `softmax(-adv/τ)` | | ||
|
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Implements Centered Reward Distillation (CRD) (arXiv:2603.14128) as a new decoupled RL trainer for flow-matching models.
trainers/crd.py: FullCRDTrainer— maintains_crd_oldand_crd_samplingnamed parameter snapshots; dual-direction centering loss (weight_tempmodes); adaptive KL regularization against a CFG-guided pretrained reference; per-step velocity-space implicit reward estimationhparams/training_args.py:CRDTrainingArgumentswith paper-aligned defaults (linear decay schedules,kl_beta=0.1,kl_cfg=4.5,timestep_range=0.99)trainers/registry.py: Register'crd'keyhparams/__init__.py: ExportCRDTrainingArgumentsexamples/crd/lora/sd3_5.yaml: SD3.5 + OCR config matching paper Table 3 (K=24, 2 gradient steps,old_model_decay="0-0.25-0.005-0.999",sampling_model_decay="75-0.0-0.0075-0.999")guidance/algorithms.md: CRD section with hyperparameter reference and centering modes table.agents/knowledge/architecture.md: Add CRD to trainer registry tableTest plan
get_training_args_class('crd')returnsCRDTrainingArgumentsget_trainer_class('crd')loadsCRDTrainer🤖 Generated with Claude Code