Skip to content

[trainer,hparams,docs] feat: add CRD (Centered Reward Distillation) algorithm#121

Open
yuanzhi-zhu wants to merge 1 commit intoX-GenGroup:mainfrom
yuanzhi-zhu:feat/crd-trainer
Open

[trainer,hparams,docs] feat: add CRD (Centered Reward Distillation) algorithm#121
yuanzhi-zhu wants to merge 1 commit intoX-GenGroup:mainfrom
yuanzhi-zhu:feat/crd-trainer

Conversation

@yuanzhi-zhu
Copy link
Copy Markdown

Summary

Implements Centered Reward Distillation (CRD) (arXiv:2603.14128) as a new decoupled RL trainer for flow-matching models.

  • trainers/crd.py: Full CRDTrainer — maintains _crd_old and _crd_sampling named parameter snapshots; dual-direction centering loss (weight_temp modes); adaptive KL regularization against a CFG-guided pretrained reference; per-step velocity-space implicit reward estimation
  • hparams/training_args.py: CRDTrainingArguments with paper-aligned defaults (linear decay schedules, kl_beta=0.1, kl_cfg=4.5, timestep_range=0.99)
  • trainers/registry.py: Register 'crd' key
  • hparams/__init__.py: Export CRDTrainingArguments
  • examples/crd/lora/sd3_5.yaml: SD3.5 + OCR config matching paper Table 3 (K=24, 2 gradient steps, old_model_decay="0-0.25-0.005-0.999", sampling_model_decay="75-0.0-0.0075-0.999")
  • guidance/algorithms.md: CRD section with hyperparameter reference and centering modes table
  • .agents/knowledge/architecture.md: Add CRD to trainer registry table

Test plan

  • get_training_args_class('crd') returns CRDTrainingArguments
  • get_trainer_class('crd') loads CRDTrainer
  • Training runs end-to-end on SD3.5 + OCR reward with the provided config
  • Loss values are non-NaN and policy loss decreases over epochs
  • Old/sampling model decay values logged correctly to wandb

🤖 Generated with Claude Code

…lgorithm

Implements Centered Reward Distillation (arXiv:2603.14128) as a new
decoupled RL trainer for flow-matching models.

Key changes:
- `trainers/crd.py`: Full CRDTrainer implementation with old/sampling
  model snapshots, dual-direction centering loss, adaptive KL, and
  per-step velocity-space implicit reward estimation
- `hparams/training_args.py`: CRDTrainingArguments with paper-aligned
  defaults (decay schedules, kl_beta=0.1, kl_cfg=4.5)
- `trainers/registry.py`: Register 'crd' key
- `hparams/__init__.py`: Export CRDTrainingArguments
- `examples/crd/lora/sd3_5.yaml`: SD3.5 + OCR example config matching
  paper Table 3 hyperparameters (K=24, 2 grad steps, timestep_range=0.99)
- `guidance/algorithms.md`: CRD section with hyperparameter reference
  and centering modes table
- `.agents/knowledge/architecture.md`: Add CRD to trainer registry table

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Copilot AI review requested due to automatic review settings April 10, 2026 10:13
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a new decoupled RL trainer implementing Centered Reward Distillation (CRD) for flow-matching models, along with corresponding hyperparameters, registry wiring, and user/internal documentation.

Changes:

  • Introduces CRDTrainer with old/sampling parameter snapshots, centered reward-matching loss, and reference-model KL regularization.
  • Adds CRDTrainingArguments and registers the new 'crd' trainer + hparams key.
  • Documents CRD usage/hyperparameters and provides an SD3.5 + OCR example config.

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
src/flow_factory/trainers/registry.py Registers 'crd'CRDTrainer for dynamic trainer loading.
src/flow_factory/trainers/crd.py New CRD trainer implementation (sampling/optimization/loss/KL + snapshot decay).
src/flow_factory/hparams/training_args.py Adds CRDTrainingArguments and registers it under 'crd'.
src/flow_factory/hparams/init.py Exposes CRDTrainingArguments from the hparams package.
guidance/algorithms.md Adds CRD algorithm documentation and hyperparameter reference.
examples/crd/lora/sd3_5.yaml Adds a paper-aligned CRD LoRA config example for SD3.5 + OCR reward.
.agents/knowledge/architecture.md Updates internal architecture docs to include CRD in the trainer registry table.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +768 to +793
# Gather r_theta across all GPUs for centering
r_theta_gathered = self.accelerator.gather(r_theta_local.detach()).to(
self.accelerator.device
)

# 5. Compute advantages for CRD centering
adv = batch['advantage']
adv_clip_range = self.training_args.adv_clip_range
adv_clipped = torch.clamp(adv, adv_clip_range[0], adv_clip_range[1])

# Normalize to [0, 1]
normalized_adv = (adv_clipped / max(adv_clip_range)) / 2.0 + 0.5
adv_cur_rank = torch.clamp(normalized_adv, 0, 1)

# Gather advantages across all GPUs
adv_cur = self.accelerator.gather(adv_cur_rank.detach()).to(
self.accelerator.device
)

# 6. Centered Reward Distillation loss (supports dual-direction centering)
ori_policy_loss = self._compute_crd_loss(
adv_cur=adv_cur,
adv_cur_rank=adv_cur_rank,
r_theta_gathered=r_theta_gathered,
r_theta_local=r_theta_local,
)
Comment on lines +380 to +405
def _blend_named_params(self, name: str, decay: float):
"""
Blend a named parameter snapshot towards the current trainable parameters.

Formula: ``snapshot = decay * snapshot + (1 - decay) * current``

Args:
name: Name of the parameter snapshot.
decay: Blending coefficient. 0.0 = full copy, 1.0 = no change.
"""
if decay <= 0.0:
# Full copy from current params (no blending)
self.adapter.update_named_parameters(name)
elif decay >= 1.0:
# Keep snapshot unchanged (fully offline)
pass
else:
# Exponential blending: snapshot = decay * snapshot + (1 - decay) * current
info = self.adapter._named_parameters[name]
current_params = self.adapter._get_component_parameters(info.target_components)
with torch.no_grad():
for ema_param, param in zip(info.ema_wrapper.ema_parameters, current_params, strict=True):
ema_param.data.mul_(decay).add_(
param.detach().to(ema_param.device), alpha=(1.0 - decay)
)

Comment on lines +804 to +810
if self.reward_adaptive_kl:
# Linearly scale KL based on reward value
raw_reward = adv_cur_rank # Already in [0, 1]
base_beta = 1e-4
min_coef = base_beta / max(self.kl_beta, 1e-8)
kl_loss = self.kl_beta * torch.mean((min_coef + raw_reward * (1 - min_coef)) * kl_div)
else:
Comment on lines +306 to +313
### Centering Modes (`weight_temp`)

| `weight_temp` | Mode | Description |
|---|---|---|
| `< 0` | Uniform (τ→∞) | Simple mean centering; recommended default |
| `== 0` | Hard selection | Positive pool (adv > 0) vs negative pool (adv < 0) |
| `> 0` | Softmax temperature | Dual-direction: `softmax(adv/τ)` and `softmax(-adv/τ)` |

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants