Skip to content

Conversation

@aaaandychen
Copy link
Contributor

Description

This PR refactors the top-level API implementation of TrainEngine and its subclasses (FSDPEngine, MegatronEngine).

Previously, the execution logic was fragmented across train_batch, forward_batch, and eval_batch. This PR introduces a unified forward_backward_batch interface to handle the execution flow. This change significantly reduces code duplication across different engines and provides greater flexibility for custom training loops.

Key Changes

  • API Unification: Deprecated separate implementations for train/forward/eval_batch in favor of a single forward_backward_batch entry point.
  • Engine Update: Updated FSDPEngine and MegatronEngine to adapt to the new unified API.

Related Issue

related issue: #601

Fixes #(issue)

Type of Change

  • Bug fix (non-breaking change that fixes an issue)
  • New feature (non-breaking change that adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not
    work as expected)
  • Documentation update
  • Code refactoring (no functional changes)
  • Performance improvement
  • Test coverage improvement

Checklist

  • I have read the Contributing Guide
  • I have run formatting tools (pre-commit or manual)
  • I have run relevant unit tests and they pass
  • I have added tests for new functionality
  • I have updated documentation if needed
  • My branch is up to date with main
  • This PR introduces breaking changes (if yes, fill out details below)
  • If this PR changes documentation, I have built and previewed it locally with
    jb build docs
  • No critical issues raised by AI reviewers (/gemini review)

Breaking Change Details (if applicable):

Additional Context


Need help? Check the Contributing Guide or ask in
GitHub Discussions!

Signed-off-by: chenzhenyang <[email protected]>
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @aaaandychen, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request refactors the core API of the TrainEngine and its specialized distributed training implementations, FSDPEngine and MegatronEngine. The primary goal is to consolidate the previously disparate methods for training, evaluation, and forward passes into a single, comprehensive forward_backward_batch interface. This change not only streamlines the internal architecture by reducing redundant code but also empowers users with more adaptable tools for designing and executing complex training strategies.

Highlights

  • API Unification: The pull request introduces a unified forward_backward_batch interface in TrainEngine and its subclasses, replacing fragmented train_batch, forward_batch, and eval_batch implementations.
  • Code Duplication Reduction: This refactoring significantly reduces code duplication across different engine implementations by centralizing the execution logic.
  • Enhanced Flexibility: The new unified API provides greater flexibility for implementing custom training loops and workflows.
  • Engine Adaption: FSDPEngine and MegatronEngine have been updated to adapt to and utilize the new unified forward_backward_batch API.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request is a great step towards unifying the training engine's API by introducing a forward_backward_batch interface. The refactoring in areal/api/engine_api.py and areal/engine/megatron_engine.py is well-executed and significantly improves code structure and reduces duplication. However, the implementation for FSDPEngine in areal/engine/fsdp_engine.py doesn't fully adhere to the new API contract, which undermines the goal of a unified interface. I've left detailed comments on this, along with a few other potential bugs and areas for improvement.

Comment on lines 1096 to 1107
forward_step_counts = [0] * len(self.model)

def forward_step(batch_iter, model):
nonlocal forward_step_counts
batch = next(batch_iter)
model_vp_stage = getattr(model, "vp_stage", 0)
forward_step_count = forward_step_counts[model_vp_stage]
padding_length = mb_list.padding_lengths[forward_step_count]
orig_input = mb_list.mbs[forward_step_count]
cu_seqlens = batch["cu_seqlens"]
old_cu_seqlens = mb_list.old_cu_seqlens_list[forward_step_count]

forward_step_counts[model_vp_stage] += 1
output = packed_context_parallel_forward(model, batch)

if mpu.is_pipeline_last_stage(
ignore_virtual=False, vp_stage=model_vp_stage
):
output = unpad_logits(
output,
padding_length=padding_length,
cu_seqlens=cu_seqlens,
old_cu_seqlens=old_cu_seqlens,
)

def _post_process_fn(input_, output):
loss = torch.tensor(1.0, device=output.device)
if post_hook is not None:
output = post_hook(output, input_)
return loss, {"output": output}

return output, functools.partial(_post_process_fn, orig_input)
batch_ctx = next(batch_iter)

return self._forward_compute_mb(
mb_input=batch_ctx,
loss_fn=loss_fn,
loss_weight_fn=loss_weight_fn,
model=model,
forward_step_counts=forward_step_counts
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There's an issue in the forward_step function. The forward_step_counts variable is a remnant from a previous implementation and is no longer used, so it and its nonlocal declaration can be removed.

More importantly, output_post_hook and return_outputs are not being passed to _forward_compute_mb. This will cause incorrect behavior for forward_only=True cases, as _forward_compute_mb will not be able to apply the post-processing hook or know that it should return outputs instead of computing a loss.

        def forward_step(batch_iter, model):
            batch_ctx = next(batch_iter)

            return self._forward_compute_mb(
                mb_input=batch_ctx,
                loss_fn=loss_fn,
                loss_weight_fn=loss_weight_fn,
                model=model,
                post_hook=output_post_hook,
                return_output=return_outputs,
            )

Comment on lines 729 to 730
assert total_loss_weight != 0
dist.all_reduce(total_loss_weight, group=self.dp_group)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There seems to be a redundant dist.all_reduce call here. total_loss_weight is already reduced on line 725. This second call is unnecessary and can be removed.

@aaaandychen
Copy link
Contributor Author

Inviting @ChangyiYang and @zhaochenyang20 for review since you are familiar with the TrainEngine refactoring.

@zhaochenyang20
Copy link
Contributor

great job!

Copy link
Collaborator

@nuzant nuzant left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR still needs some refactoring. There are still some problems to be discussed about the API designs, and the code quality has room for improvement.

Additionally, please format the code according to the instructions in CONTRIBUTING.md and double-check the gemini reviews.

We should also make sure all related tests (areal/tests/test_fsdp_*.py and `areal/tests/test_megatron_*.py) can pass.

"""
raise NotImplementedError()

def split_micro_batch(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the way we split the micro batches is not related to engines, we should not expose this API in this class. We can add an API in areal/utils/data.py to assist user-side data handling.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In areal/utils/data, I added create_mb_iterator to convert mb_list to an iterator, returning MB tuples on each iteration. This allows the engine to select specific MB elements for the iterator and wrap metadata for downstream use.

max_seqlen = data_iterator.max_seqlen
num_microbatches = data_iterator.num_microbatches
else:
max_seqlen = self.config.mb_spec.max_tokens_per_mb
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since forward_backward_batch is an API exposed to users, please provide a clear definition of what is the expected input data_iterator.

loss_weight_fn: Callable[[dict[str, Any]], torch.Tensor],
**kwargs,
) -> tuple[torch.Tensor, Callable[[torch.Tensor], tuple[torch.Tensor, dict]]]:
batch_type = kwargs.get("batch_type")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same problem as the implementation in MegatronEngine. _forward_compute_mb should be an atomic operation that takes a micro batch as input and output logits or logprobs. There should not be a bunch of if-else to decide what to output. Just process the output in forward_backward_batch.

@aaaandychen aaaandychen force-pushed the refactor-trainengine-api branch from a197ede to b5eb5b7 Compare December 4, 2025 06:54
@aaaandychen
Copy link
Contributor Author

aaaandychen commented Dec 4, 2025

Hi @nuzant , Thank you so much for your patient review and thoughtful comments! I have modified the code and will complete the conflict handling and testing within today.

Signed-off-by: chenzhenyang <[email protected]>

# Conflicts:
#	areal/engine/fsdp_engine.py
#	areal/engine/megatron_engine.py
@aaaandychen aaaandychen force-pushed the refactor-trainengine-api branch 2 times, most recently from 5f8224f to 8a8b641 Compare December 5, 2025 12:44
@aaaandychen aaaandychen force-pushed the refactor-trainengine-api branch from 8a8b641 to 53dcd95 Compare December 5, 2025 16:01
@aaaandychen
Copy link
Contributor Author

aaaandychen commented Dec 5, 2025

@nuzant I have fixed the code based on your feedback and conflicts and ensured all related tests pass. Please review again at your convenience.To address branch conflicts, the post process logic has been primarily updated, and the post_hook in forward_batch has been removed.

@rchardx rchardx changed the title refactor: refactor train engine hign level APIs refactor: refactor train engine high level APIs Dec 8, 2025
@aaaandychen aaaandychen force-pushed the refactor-trainengine-api branch from b7d8a73 to 42c4f4e Compare December 10, 2025 10:35
@aaaandychen aaaandychen force-pushed the refactor-trainengine-api branch from 5751e6e to 77a1aba Compare December 10, 2025 14:21
@aaaandychen
Copy link
Contributor Author

@rchardx Hi, I have incorporated your feedback. I’ve optimized the return value retrieval by designing hook methods within the API. Additionally, I introduced a BaseTrainEngine to implement the Template Method pattern, which enhances the overall usability of the design.And I have adapted the existing implementation accordingly.

Copy link
Collaborator

@rchardx rchardx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merge now, improve later. Delaying the merge would create more conflicts and complications. Follow-up PRs will address these quality issues.

@rchardx rchardx merged commit e6ab0e8 into inclusionAI:main Dec 11, 2025
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants