-
Notifications
You must be signed in to change notification settings - Fork 248
refactor: refactor train engine high level APIs #658
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refactor: refactor train engine high level APIs #658
Conversation
Signed-off-by: chenzhenyang <[email protected]>
Summary of ChangesHello @aaaandychen, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request refactors the core API of the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request is a great step towards unifying the training engine's API by introducing a forward_backward_batch interface. The refactoring in areal/api/engine_api.py and areal/engine/megatron_engine.py is well-executed and significantly improves code structure and reduces duplication. However, the implementation for FSDPEngine in areal/engine/fsdp_engine.py doesn't fully adhere to the new API contract, which undermines the goal of a unified interface. I've left detailed comments on this, along with a few other potential bugs and areas for improvement.
| forward_step_counts = [0] * len(self.model) | ||
|
|
||
| def forward_step(batch_iter, model): | ||
| nonlocal forward_step_counts | ||
| batch = next(batch_iter) | ||
| model_vp_stage = getattr(model, "vp_stage", 0) | ||
| forward_step_count = forward_step_counts[model_vp_stage] | ||
| padding_length = mb_list.padding_lengths[forward_step_count] | ||
| orig_input = mb_list.mbs[forward_step_count] | ||
| cu_seqlens = batch["cu_seqlens"] | ||
| old_cu_seqlens = mb_list.old_cu_seqlens_list[forward_step_count] | ||
|
|
||
| forward_step_counts[model_vp_stage] += 1 | ||
| output = packed_context_parallel_forward(model, batch) | ||
|
|
||
| if mpu.is_pipeline_last_stage( | ||
| ignore_virtual=False, vp_stage=model_vp_stage | ||
| ): | ||
| output = unpad_logits( | ||
| output, | ||
| padding_length=padding_length, | ||
| cu_seqlens=cu_seqlens, | ||
| old_cu_seqlens=old_cu_seqlens, | ||
| ) | ||
|
|
||
| def _post_process_fn(input_, output): | ||
| loss = torch.tensor(1.0, device=output.device) | ||
| if post_hook is not None: | ||
| output = post_hook(output, input_) | ||
| return loss, {"output": output} | ||
|
|
||
| return output, functools.partial(_post_process_fn, orig_input) | ||
| batch_ctx = next(batch_iter) | ||
|
|
||
| return self._forward_compute_mb( | ||
| mb_input=batch_ctx, | ||
| loss_fn=loss_fn, | ||
| loss_weight_fn=loss_weight_fn, | ||
| model=model, | ||
| forward_step_counts=forward_step_counts | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's an issue in the forward_step function. The forward_step_counts variable is a remnant from a previous implementation and is no longer used, so it and its nonlocal declaration can be removed.
More importantly, output_post_hook and return_outputs are not being passed to _forward_compute_mb. This will cause incorrect behavior for forward_only=True cases, as _forward_compute_mb will not be able to apply the post-processing hook or know that it should return outputs instead of computing a loss.
def forward_step(batch_iter, model):
batch_ctx = next(batch_iter)
return self._forward_compute_mb(
mb_input=batch_ctx,
loss_fn=loss_fn,
loss_weight_fn=loss_weight_fn,
model=model,
post_hook=output_post_hook,
return_output=return_outputs,
)
areal/engine/fsdp_engine.py
Outdated
| assert total_loss_weight != 0 | ||
| dist.all_reduce(total_loss_weight, group=self.dp_group) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
Inviting @ChangyiYang and @zhaochenyang20 for review since you are familiar with the TrainEngine refactoring. |
|
great job! |
nuzant
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR still needs some refactoring. There are still some problems to be discussed about the API designs, and the code quality has room for improvement.
Additionally, please format the code according to the instructions in CONTRIBUTING.md and double-check the gemini reviews.
We should also make sure all related tests (areal/tests/test_fsdp_*.py and `areal/tests/test_megatron_*.py) can pass.
areal/api/engine_api.py
Outdated
| """ | ||
| raise NotImplementedError() | ||
|
|
||
| def split_micro_batch( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the way we split the micro batches is not related to engines, we should not expose this API in this class. We can add an API in areal/utils/data.py to assist user-side data handling.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In areal/utils/data, I added create_mb_iterator to convert mb_list to an iterator, returning MB tuples on each iteration. This allows the engine to select specific MB elements for the iterator and wrap metadata for downstream use.
areal/engine/megatron_engine.py
Outdated
| max_seqlen = data_iterator.max_seqlen | ||
| num_microbatches = data_iterator.num_microbatches | ||
| else: | ||
| max_seqlen = self.config.mb_spec.max_tokens_per_mb |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since forward_backward_batch is an API exposed to users, please provide a clear definition of what is the expected input data_iterator.
areal/engine/fsdp_engine.py
Outdated
| loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor], | ||
| **kwargs, | ||
| ) -> tuple[torch.Tensor, Callable[[torch.Tensor], tuple[torch.Tensor, dict]]]: | ||
| batch_type = kwargs.get("batch_type") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same problem as the implementation in MegatronEngine. _forward_compute_mb should be an atomic operation that takes a micro batch as input and output logits or logprobs. There should not be a bunch of if-else to decide what to output. Just process the output in forward_backward_batch.
Signed-off-by: chenzhenyang <[email protected]>
a197ede to
b5eb5b7
Compare
|
Hi @nuzant , Thank you so much for your patient review and thoughtful comments! I have modified the code and will complete the conflict handling and testing within today. |
Signed-off-by: chenzhenyang <[email protected]> # Conflicts: # areal/engine/fsdp_engine.py # areal/engine/megatron_engine.py
5f8224f to
8a8b641
Compare
Signed-off-by: chenzhenyang <[email protected]>
8a8b641 to
53dcd95
Compare
|
@nuzant I have fixed the code based on your feedback and conflicts and ensured all related tests pass. Please review again at your convenience.To address branch conflicts, the post process logic has been primarily updated, and the post_hook in forward_batch has been removed. |
# Conflicts: # areal/engine/fsdp_engine.py
Signed-off-by: chenzhenyang <[email protected]>
Signed-off-by: chenzhenyang <[email protected]>
b7d8a73 to
42c4f4e
Compare
5751e6e to
77a1aba
Compare
|
@rchardx Hi, I have incorporated your feedback. I’ve optimized the return value retrieval by designing hook methods within the API. Additionally, I introduced a BaseTrainEngine to implement the Template Method pattern, which enhances the overall usability of the design.And I have adapted the existing implementation accordingly. |
rchardx
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Merge now, improve later. Delaying the merge would create more conflicts and complications. Follow-up PRs will address these quality issues.
Description
This PR refactors the top-level API implementation of TrainEngine and its subclasses (FSDPEngine, MegatronEngine).
Previously, the execution logic was fragmented across train_batch, forward_batch, and eval_batch. This PR introduces a unified forward_backward_batch interface to handle the execution flow. This change significantly reduces code duplication across different engines and provides greater flexibility for custom training loops.
Key Changes
Related Issue
related issue: #601
Fixes #(issue)
Type of Change
work as expected)
Checklist
jb build docs/gemini review)Breaking Change Details (if applicable):
Additional Context
Need help? Check the Contributing Guide or ask in
GitHub Discussions!