-
Notifications
You must be signed in to change notification settings - Fork 76
Add MLflow support and expose logging configuration in TrainingArgs #680
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
📝 WalkthroughWalkthroughAdds MLflow integration and related logging options: new TrainingArgs fields for run and backend configuration, an MLflowHandler and MLflow wiring in the metric logger, and CLI/entrypoint propagation of the new logging flags. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant CLI as main_ds.py
participant Setup as setup_metric_logger()
participant MLflowHandler
participant MLflow as MLflow Backend
User->>CLI: invoke training with --mlflow_tracking_uri / --mlflow_experiment_name / --run_name
CLI->>Setup: call setup_metric_logger(..., mlflow_tracking_uri, mlflow_experiment_name, run_name, ...)
Setup->>MLflowHandler: instantiate MLflowHandler(run_name, tracking_uri, experiment_name, ...)
MLflowHandler->>MLflowHandler: _setup() (validate mlflow import, configure)
MLflowHandler->>MLflow: set_tracking_uri(tracking_uri)
MLflowHandler->>MLflow: set_experiment(experiment_name)
MLflowHandler->>MLflow: start_run(run_name)
Note over Setup,MLflowHandler: Handler registered with logging configuration
CLI->>MLflowHandler: emit(LogRecord with metrics)
MLflowHandler->>MLflow: log_metrics(flattened_metrics, step)
MLflowHandler->>MLflow: end_run() on close()
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Exposes logging configuration (tensorboard, wandb, mlflow, jsonl) through
flat kwargs in sft(), osft(), and lora_sft() convenience functions.
## New Parameters
- `loggers`: List of loggers to enable (e.g., ["wandb", "mlflow", "jsonl"])
- `run_name`: Run name with placeholder support ({time}, {rank})
- `log_level`: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
- `logging_steps`: How often to log metrics
- `wandb_project`, `wandb_entity`, `wandb_run_name`: W&B configuration
- `tensorboard_log_dir`: TensorBoard output directory
- `mlflow_tracking_uri`, `mlflow_experiment_name`: MLflow configuration
## Backend Support
| Logger | SFT | OSFT | LoRA |
|-------------|-----|------|------|
| wandb | Yes | Yes | Yes |
| tensorboard | Yes | No | Yes |
| mlflow | Yes | No | Yes |
| jsonl | Yes | Yes | No |
OSFT emits warnings for unsupported loggers/params and continues.
Depends on: instructlab/training#680
Co-Authored-By: Claude Opus 4.5 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@src/instructlab/training/main_ds.py`:
- Around line 275-283: The call to setup_metric_logger in main() uses
unnecessary defensive getattr() for mlflow/wandb fields; replace getattr(args,
"mlflow_tracking_uri", None), getattr(args, "mlflow_experiment_name", None),
getattr(args, "wandb_project", None), and getattr(args, "wandb_entity", None)
with direct attribute access args.mlflow_tracking_uri,
args.mlflow_experiment_name, args.wandb_project, and args.wandb_entity
respectively so it matches the pattern used in run_training() and with
train_args.
🧹 Nitpick comments (2)
src/instructlab/training/logger.py (2)
638-665: Unusedlog_dirparameter.The
log_dirparameter is stored asself.log_dirbut never used in_setup()or elsewhere. The docstring mentions it's "used as artifact location" but the implementation doesn't pass it to MLflow. Either use it to set the artifact location or remove it to avoid confusion.♻️ Option 1: Use log_dir as artifact location
def _setup(self): """Initialize the MLflow run with the configured settings.""" if mlflow is None: msg = ( "Could not initialize MLflowHandler because package mlflow could not be imported.\n" "Please ensure it is installed by running 'pip install mlflow'" ) raise RuntimeError(msg) if self.tracking_uri: mlflow.set_tracking_uri(self.tracking_uri) if self.experiment_name: - mlflow.set_experiment(self.experiment_name) + mlflow.set_experiment( + self.experiment_name, + artifact_location=str(self.log_dir), + ) self._mlflow_run = mlflow.start_run( run_name=self.run_name, **self.mlflow_init_kwargs )♻️ Option 2: Remove unused parameter
def __init__( self, level: int = logging.INFO, run_name: str | None = None, - log_dir: str | os.PathLike = "logs", tracking_uri: str | None = None, experiment_name: str | None = None, **mlflow_init_kwargs: Any, ): """Initialize the MLflow logger and check for required dependencies. Args: level: The logging level for this handler run_name: Name of the run, can contain placeholders - log_dir: Directory where MLflow artifacts should be stored (used as artifact location) tracking_uri: MLflow tracking server URI (e.g., "http://localhost:5000") experiment_name: Name of the MLflow experiment **mlflow_init_kwargs: Additional keyword arguments passed to mlflow.start_run() """ super().__init__(level) self.run_name = _substitute_placeholders(run_name) - self.log_dir = Path(log_dir) self.tracking_uri = tracking_uri self.experiment_name = experiment_name self.mlflow_init_kwargs = mlflow_init_kwargs.copy() self._mlflow_run = NoneNote: If removing
log_dir, also updatesetup_metric_loggerto not pass it to the MLflow handler config.
711-721: Consider adding a debug log for skipped non-numeric metrics.Non-numeric values are silently skipped. For consistency with
TensorBoardHandler(which warns on type errors), consider adding a debug-level message to help users understand why certain values aren't appearing in MLflow metrics.♻️ Proposed change
# Filter to only numeric values for metrics metrics_dict = {} for k, v in flat_dict.items(): try: metrics_dict[k] = float(v) except (ValueError, TypeError): # Skip non-numeric values for metrics - pass + warnings.warn( + f"MLflowHandler skipping non-numeric metric '{k}' with value {type(v).__name__}", + stacklevel=2, + )
…logging - Add tensorboard_log_dir field to TrainingArgs in config.py - Update setup_metric_logger to use tensorboard_log_dir when provided - Add CLI argument for tensorboard_log_dir - Wire tensorboard_log_dir through run_training() to subprocess command This allows users to specify a custom directory for TensorBoard logs, defaulting to output_dir if not specified. Co-Authored-By: Claude Opus 4.5 <[email protected]>
- Replace defensive getattr() with direct attribute access in main_ds.py since args are guaranteed to exist from argparse defaults - Remove unused log_dir parameter from MLflowHandler - Add debug logging for non-numeric metrics skipped by MLflowHandler Co-Authored-By: Claude Opus 4.5 <[email protected]>
Summary
TrainingArgsfor programmatic API usagewandb_projectandwandb_entityfields toTrainingArgsfor consistencyChanges
New
TrainingArgsfieldslogger_typestr"async"tensorboard,wandb,mlflow,asyncrun_namestr | NoneNone{time},{rank}, etc.)mlflow_tracking_uristr | NoneNonemlflow_experiment_namestr | NoneNonewandb_projectstr | NoneNonewandb_entitystr | NoneNoneNew
MLflowHandlerclassImplements the same interface as
TensorBoardHandlerandWandbHandler:mlflow.log_metrics()mlflow.log_params()tracking_uriandexperiment_nameconfigurationMLFLOW_TRACKING_URIandMLFLOW_EXPERIMENT_NAMEenv varsUpdated
run_training()APIPreviously,
run_training()hardcoded the logger to"async". Now it reads fromTrainingArgs:Example Usage
Test plan
wandb_project/wandb_entityfieldsasynclogger_typeenables multiple backends simultaneously🤖 Generated with Claude Code
Summary by CodeRabbit
New Features
Documentation
✏️ Tip: You can customize this high-level summary in your review settings.