[bug][train] Fix max_seq_len calculation#1303
Conversation
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
| if cfg.trainer.algorithm.loss_reduction == "seq_mean_token_sum_norm": | ||
| if cfg.trainer.algorithm.max_seq_len is None: | ||
| raise ValueError( | ||
| "`trainer.algorithm.max_seq_len` must be set explicitly when " | ||
| "`trainer.algorithm.loss_reduction='seq_mean_token_sum_norm'`. " | ||
| "Choose the total sequence-length normalization constant for your setup; " | ||
| "this often matches the model context window / vLLM `max_model_len` when appropriate." |
There was a problem hiding this comment.
🔴 Breaking change: Dr. GRPO example script fails because auto-calculated max_seq_len fallback was removed
The PR removes the max_seq_len auto-calculation from SkyRLTrainConfig.__post_init__ (skyrl/train/config/config.py:713-722 on LEFT) and adds a hard assertion requiring it to be set explicitly when loss_reduction='seq_mean_token_sum_norm'. However, the official Dr. GRPO example script at examples/train/algorithms/drgrpo/run_drgrpo_gsm8k.sh:15,23 uses LOSS_REDUCTION="seq_mean_token_sum_norm" but never passes trainer.algorithm.max_seq_len. This script previously worked because __post_init__ auto-computed max_seq_len = max_input_length + max_generate_length. Now it will crash with an AssertionError at validation time.
Same issue in skyrl-agent example
skyrl-agent/examples/run_skyrl/run_skyrl_swe.sh:67 also sets trainer.algorithm.loss_reduction="seq_mean_token_sum_norm" without setting max_seq_len, so it will also fail.
Prompt for agents
Two example scripts need to be updated to explicitly pass trainer.algorithm.max_seq_len now that the auto-calculation fallback has been removed:
1. examples/train/algorithms/drgrpo/run_drgrpo_gsm8k.sh: Add a line like trainer.algorithm.max_seq_len=1536 (512 + 1024, matching max_prompt_length + max_generate_length from the script) to the uv run command.
2. skyrl-agent/examples/run_skyrl/run_skyrl_swe.sh: Add a line like trainer.algorithm.max_seq_len=40768 (8000 + 32768, matching max_prompt_length + max_generate_length from the script) to the uv run command.
Both scripts use loss_reduction=seq_mean_token_sum_norm and will now fail the new assertion at skyrl/train/utils/utils.py:279-285 without this fix.
Was this helpful? React with 👍 or 👎 to provide feedback.
There was a problem hiding this comment.
@tamoghnokandar this is important - can you grep for all usages of seq_mean_token_sum_norm in our example scripts and ensure that max_seq_len is explicitly passed in now (calculate it based on the generation and input lengths in the script).
There was a problem hiding this comment.
SumanthRH
left a comment
There was a problem hiding this comment.
@tamoghnokandar can you merge the latest changes from main? We've had some important updates, especially bf243b8
|
Done! |
There was a problem hiding this comment.
can you also add a note in the loss_type docstring that max_seq_len is required for seq_man_token_sum_norm now?
|
@SumanthRH Ready to be merged |
SumanthRH
left a comment
There was a problem hiding this comment.
One small fix: Can we ensure that we update this script as well now with an explicit max_seq_len argument?
CharlieFRuan
left a comment
There was a problem hiding this comment.
This is great, thank you!
| generator.inference_engine.http_endpoint_port=8000 \ | ||
| generator.sampling_params.max_generate_length=4096 \ | ||
| trainer.algorithm.max_seq_len=$MAX_MODEL_LEN \ | ||
| generator.inference_engine.engine_init_kwargs.max_model_len=$MAX_MODEL_LEN \ |
There was a problem hiding this comment.
hmm why constrain inference engine max model len?
There was a problem hiding this comment.
This is just a generation entrypoint, so it doesn't really matter. For training this is needed to limit the total context window size for the model, which should match $MAX_MODEL_LEN. For token-in-token-out, that is essentially the maximum sequence length (padded or not)
Fixes #1154
Summary
This PR removes the implicit
max_seq_lenheuristic calculation and requires users to set it explicitly when usingtrainer.algorithm.loss_reduction=seq_mean_token_sum_norm.Changes
trainer.algorithm.max_seq_lendefault fromSkyRLTrainConfig.__post_init__trainer.algorithm.max_seq_lento be explicitly set whenloss_reduction == "seq_mean_token_sum_norm"max_seq_lenmust be chosen based on the user’s intended sequence-length normalization budgetmax_seq_lenremainingNoneby defaultmax_seq_lenvalues being preservedvalidate_cfg()failing whenseq_mean_token_sum_normis used withoutmax_seq_lenvalidate_cfg()continuing to allowtoken_meanandsequence_meanwithoutmax_seq_lenvalidate_cfg()passing whenseq_mean_token_sum_normis used with an explicitmax_seq_lenTesting
tests/train/test_config.pyfor the new behavior