diff --git a/fast_llm/config.py b/fast_llm/config.py index 96beadcf8..815b6b00a 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -5,11 +5,11 @@ import traceback import types import typing +import warnings import yaml -from fast_llm.engine.config_utils.logging import log -from fast_llm.utils import Assert, Tag, get_type_name, header +from fast_llm.utils import Assert, Tag, get_type_name, header, log logger = logging.getLogger(__name__) @@ -270,7 +270,7 @@ class Config: __class_validated__: typing.ClassVar[bool] = True _abstract: typing.ClassVar[bool] = False _validated: bool = Field(init=False, repr=False) - _unknown_fields: tuple = Field(init=False, repr=False) + _unknown_fields: dict[str] = Field(init=False, repr=False) def __post_init__(self): """ @@ -335,7 +335,7 @@ def _validate(self): value = getattr(self, name) new_value = self._validate_nested(value, field.type, field.name, field.valid, errors, False) setattr(self, name, new_value) - for name in getattr(self, "_unknown_fields", ()): + for name in getattr(self, "_unknown_fields", {}): errors.append(f"Unknown field `{name}` in class {self._get_class_name()}") if errors: # TODO: Option to show traceback for errors. @@ -621,7 +621,7 @@ def _get_class_name(cls): @classmethod def from_dict( cls, - default: typing.Union["Config", dict], + default: typing.Union["Config", dict[str]], *updates: typing.Union["Config", dict[str | tuple[str, ...], typing.Any]], strict: bool = True, ): @@ -646,7 +646,7 @@ def from_dict( @classmethod def from_flat_dict( cls, - default: dict, + default: dict[str], strict: bool = True, ): # TODO v0.2: Remove flat format @@ -655,7 +655,7 @@ def from_flat_dict( @classmethod def _from_dict( cls, - default: dict, + default: dict[str], strict: bool = True, flat: bool = False, ): @@ -691,9 +691,9 @@ def _from_dict( f"Invalid field type `{get_type_name(field.type)}` in class {cls._get_class_name()}: " + ", ".join(e.args) ) - out = cls(**out_arg_dict) # noqa - if strict and default: - out._unknown_fields = tuple(default) + out = cls(**out_arg_dict) # noqa + if strict and default: + out._unknown_fields = default.copy() if _AUTO_VALIDATE: out.validate() return out @@ -767,6 +767,12 @@ def _from_dict_dict(cls, value, type_, strict: bool): # Keys can't include configs so we only recurse on values. return {key: cls._from_dict_nested(value_, args[1], strict) for key, value_ in value.items()} + @classmethod + def _handle_renamed_field(cls, default: dict[str], old_name: str, new_name: str): + if old_name in default: + warnings.warn(f"Field `{old_name}` is deprecated in class {get_type_name(cls)}, use `{new_name}` instead.") + default[new_name] = default.pop(old_name) + def compare(self, other: "Config", log_fn: typing.Union[BaseException, typing.Callable] = ValueError): # TODO: Check classes? self_dict = self._to_dict(format_=_ConfigDictFormat.tuple, serializable=True) diff --git a/fast_llm/data/config.py b/fast_llm/data/config.py index 8c16c3146..f105c5054 100644 --- a/fast_llm/data/config.py +++ b/fast_llm/data/config.py @@ -11,14 +11,6 @@ from fast_llm.engine.distributed.distributed import Distributed -class DatasetType(str, enum.Enum): - """ - Placeholder for future generalization to other data types. - """ - - gpt = "gpt" - - class DatasetSource(str, enum.Enum): """ An enum for the different ways to load datasets. @@ -61,41 +53,41 @@ class FimConfig(Config): Configuration for FIM. """ - fim_rate: float = Field( + rate: float = Field( default=0.0, desc="FIM rate for each sample.", hint=FieldHint.core, valid=check_field(Assert.in_range_incl, 0, 1), ) - fim_max_middle_len: int | None = Field( + max_middle_len: int | None = Field( default=None, desc="Maximum length of the middle segment in FIM.", hint=FieldHint.feature, valid=skip_valid_if_none(check_field(Assert.gt, 0)), ) - fim_split_sample: str | None = Field( + split_sample: str | None = Field( default=None, desc="Split samples on this token and permute each fragment separately.", hint=FieldHint.feature, ) - fim_fragment_rate: float = Field( + fragment_rate: float = Field( default=0.0, desc="FIM rate for each fragment when using fim_split_sample.", hint=FieldHint.feature, valid=check_field(Assert.in_range_incl, 0, 1), ) - fim_ignore_prefix: str | None = Field( + ignore_prefix: str | None = Field( default=None, desc="Do not apply FIM to fragments that start with this prefix.", hint=FieldHint.feature, ) - fim_spm_rate: float = Field( + spm_rate: float = Field( default=0.5, desc="TODO.", hint=FieldHint.feature, valid=check_field(Assert.in_range_incl, 0, 1), ) - fim_truncate_or_pad: bool = Field( + truncate_or_pad: bool = Field( default=False, desc="TODO.", hint=FieldHint.feature, @@ -103,7 +95,7 @@ class FimConfig(Config): def _validate(self): super()._validate() - Assert.in_range_incl(self.fim_rate, 0, 1) + Assert.in_range_incl(self.rate, 0, 1) EOD = "<|endoftext|>" @@ -117,13 +109,13 @@ class TokenizerConfig(Config): Currently, the tokenizer is only needed for FIM. """ - tokenizer_type: str = Field( + format: str = Field( default="TokenizerFromFile", desc="Unused.", hint=FieldHint.deprecated, valid=check_field(Assert.eq, TokenizerFromFile), ) - tokenizer_file: str | None = Field( + path: str | None = Field( default=None, desc="Path to the tokenizer file.", hint=FieldHint.core, @@ -181,17 +173,12 @@ class DataConfig(AbstractDataConfig): hint=FieldHint.core, valid=_validate_split, ) - dataset_type: DatasetType = Field( - default=DatasetType.gpt, - desc="Unused.", - hint=FieldHint.wip, - ) - dataset_source: DatasetSource = Field( + format: DatasetSource = Field( default=DatasetSource.list, desc="Format for the dataset definition.", hint=FieldHint.core, ) - data_path: list[str] = Field( + path: list[str] = Field( default_factory=list, desc="Path or list of paths and weights.", hint=FieldHint.core, diff --git a/fast_llm/data/data.py b/fast_llm/data/data.py index 62367507a..e58b62c4a 100644 --- a/fast_llm/data/data.py +++ b/fast_llm/data/data.py @@ -9,7 +9,7 @@ import torch import torch.utils.data -from fast_llm.data.config import AbstractData, DataConfig, DatasetSource, DatasetType +from fast_llm.data.config import AbstractData, DataConfig, DatasetSource from fast_llm.data.dataset import BlendedDataset, SampledDataset, Sampler from fast_llm.data.gpt import DummyGPTDataset, GPTDataset, GPTSampledDataset from fast_llm.data.mmap import MMapIndexedDataset @@ -67,35 +67,34 @@ def __init__( } data_base_path = None - if self._config.dataset_source == DatasetSource.file: - Assert.eq(len(self._config.data_path), 1) - data_path = pathlib.Path(self._config.data_path[0]) + if self._config.format == DatasetSource.file: + Assert.eq(len(self._config.path), 1) + data_path = pathlib.Path(self._config.path[0]) dataset_defs = json.load(data_path.open("r")) data_base_path = data_path.parent dataset_prefixes = [dataset_def["prefix"] for dataset_def in dataset_defs["datasets"]] dataset_weights = normalize_probs([dataset_def["weight"] for dataset_def in dataset_defs["datasets"]]) self._build_and_sample_dataset = self._build_and_sample_gpt_dataset - elif self._config.dataset_source == DatasetSource.list: - Assert.geq(len(self._config.data_path), 1) - if len(self._config.data_path) == 1: - dataset_prefixes, dataset_weights = [self._config.data_path[0].strip()], [1.0] + elif self._config.format == DatasetSource.list: + Assert.geq(len(self._config.path), 1) + if len(self._config.path) == 1: + dataset_prefixes, dataset_weights = [self._config.path[0].strip()], [1.0] else: - Assert.eq(self._config.dataset_type, DatasetType.gpt) - Assert.custom(lambda x: x % 2 == 0, len(self._config.data_path)) - dataset_prefixes = [x.strip() for x in self._config.data_path[1::2]] + Assert.custom(lambda x: x % 2 == 0, len(self._config.path)) + dataset_prefixes = [x.strip() for x in self._config.path[1::2]] assert len(dataset_prefixes) == len(set(dataset_prefixes)) - dataset_weights = normalize_probs([float(x) for x in self._config.data_path[::2]]) + dataset_weights = normalize_probs([float(x) for x in self._config.path[::2]]) self._build_and_sample_dataset = self._build_and_sample_gpt_dataset - elif self._config.dataset_source == DatasetSource.sample: - Assert.eq(len(self._config.data_path), 1) - dataset_prefixes, dataset_weights = [self._config.data_path[0].strip()], [1.0] + elif self._config.format == DatasetSource.sample: + Assert.eq(len(self._config.path), 1) + dataset_prefixes, dataset_weights = [self._config.path[0].strip()], [1.0] self._build_and_sample_dataset = self._build_and_sample_dummy_dataset - elif self._config.dataset_source == DatasetSource.random: - Assert.eq(len(self._config.data_path), 0) + elif self._config.format == DatasetSource.random: + Assert.eq(len(self._config.path), 0) dataset_prefixes, dataset_weights = [None], [1.0] self._build_and_sample_dataset = self._build_and_sample_dummy_dataset else: - raise NotImplementedError(self._config.dataset_source) + raise NotImplementedError(self._config.format) dataset_names = [ f"dataset_{i}_{'dummy' if prefix is None else prefix.replace('/','__')}" @@ -124,7 +123,7 @@ def setup(self, distributed: Distributed, samples_per_phase: dict[PhaseType, int run = get_run() Assert.leq(set(samples_per_phase), set(self._phase_split)) log_main_rank(f"Preparing {self._num_datasets} datasets. This may take several minutes.") - self._tokenizer = Tokenizer(self._config.tokenizer) if self._config.fim.fim_rate > 0 else None + self._tokenizer = Tokenizer(self._config.tokenizer) if self._config.fim.rate > 0 else None self._distributed = distributed self._cache_dir = run.dataset_cache_dir self._samples_per_phase = samples_per_phase diff --git a/fast_llm/data/fim.py b/fast_llm/data/fim.py index 84bf5ceee..1dec0c80f 100644 --- a/fast_llm/data/fim.py +++ b/fast_llm/data/fim.py @@ -17,7 +17,7 @@ def __init__(self, config: FimConfig, tokenizer: Tokenizer): self._tokenizer.vocab[tok] for tok in [FIM_SUFFIX, FIM_PREFIX, FIM_MIDDLE, FIM_PAD] ) self.fim_split_sample = ( - self._tokenizer.vocab[self._config.fim_split_sample] if self._config.fim_split_sample is not None else None + self._tokenizer.vocab[self._config.split_sample] if self._config.split_sample is not None else None ) def __call__(self, sample, np_rng): @@ -61,15 +61,15 @@ def _fim_split_and_permute_sequence(self, sequence, np_rng): fragment_fim_rate: if set, apply fim with this rate to each fragment. """ if self.fim_split_sample is None: - return self._fim_permute_sequence(sequence, np_rng, self._config.fim_rate) + return self._fim_permute_sequence(sequence, np_rng, self._config.rate) # fim_split_sample is set: split the sample on this token and permute each fragment separately. # Typically, if each sample is a repository, then we split again on the file level. # Each fragment is a file, and we permute the files. fragment_breaks = np.argwhere(sequence == self.fim_split_sample) if fragment_breaks.shape == (0, 1): # no split token in this sample - return self._fim_permute_sequence(sequence, np_rng, self._config.fim_rate) - if not np_rng.binomial(1, self._config.fim_rate): + return self._fim_permute_sequence(sequence, np_rng, self._config.rate) + if not np_rng.binomial(1, self._config.rate): # don't do FIM preproc return sequence # Do FIM on each fragment @@ -79,12 +79,12 @@ def _fim_split_and_permute_sequence(self, sequence, np_rng): for loc in np.nditer(fragment_breaks): if loc - curr_start_position > 0: permuted = self._fim_permute_sequence( - sequence[curr_start_position:loc], np_rng, self._config.fim_fragment_rate + sequence[curr_start_position:loc], np_rng, self._config.fragment_rate ) new_samples += [permuted, [self.fim_split_sample]] curr_start_position = loc + 1 # Jump over the split token # Permute the segment after the last split token - permuted = self._fim_permute_sequence(sequence[curr_start_position:], np_rng, self._config.fim_fragment_rate) + permuted = self._fim_permute_sequence(sequence[curr_start_position:], np_rng, self._config.fragment_rate) new_samples.append(permuted) return np.concatenate(new_samples) @@ -106,10 +106,10 @@ def _fim_permute_sequence( contents = self._tokenizer.detokenize(sequence) # Do not apply FIM if the sample starts with no_fim_prefix - if self._config.fim_ignore_prefix is not None and contents.startswith(self._config.fim_ignore_prefix): + if self._config.ignore_prefix is not None and contents.startswith(self._config.ignore_prefix): return sequence - if self._config.fim_max_middle_len is None: + if self._config.max_middle_len is None: # Sample the two boundaries uniformly at random # A boundary can be =0 (prefix will be empty) # a boundary can be =len(contents) (suffix will be empty) @@ -118,7 +118,7 @@ def _fim_permute_sequence( boundaries.sort() else: # Sample a window-length - middle_length = np_rng.randint(low=0, high=min(self._config.fim_max_middle_len, len(contents)) + 1) + middle_length = np_rng.randint(low=0, high=min(self._config.max_middle_len, len(contents)) + 1) first_boundary = np_rng.randint(low=0, high=len(contents) - middle_length + 1) # middle_length <= Second-boundary <= len(contents) boundaries = [first_boundary, first_boundary + middle_length] @@ -134,7 +134,7 @@ def _fim_permute_sequence( # here we truncate each given segment to fit the same length as it was before # A consequence is that we never reach the end of a file? # we should rather truncate at the context-level - if self._config.fim_truncate_or_pad: + if self._config.truncate_or_pad: # need to make same length as the input. Take the 3 sentinel tokens into account new_length = suffix.shape[0] + prefix.shape[0] + middle.shape[0] + 3 diff = new_length - sequence.shape[0] @@ -145,7 +145,7 @@ def _fim_permute_sequence( elif diff < 0: # too short suffix = np.concatenate([suffix, np.full((-1 * diff), self._pad_tok_id)]) - if np_rng.binomial(1, self._config.fim_spm_rate): + if np_rng.binomial(1, self._config.spm_rate): # SPM (variant 2 from FIM paper) new_sample = np.concatenate( [[self._prefix_tok_id, self._suffix_tok_id], suffix, [self._middle_tok_id], prefix, middle] # noqa diff --git a/fast_llm/data/gpt.py b/fast_llm/data/gpt.py index 038110a9e..e70629c67 100644 --- a/fast_llm/data/gpt.py +++ b/fast_llm/data/gpt.py @@ -154,7 +154,7 @@ def __init__( ): self._dataset = dataset - if config.fim.fim_rate > 0: + if config.fim.rate > 0: assert tokenizer is not None self._fim = Fim(config.fim, tokenizer) else: diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index b6654f30d..d75aab7f1 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -10,9 +10,9 @@ class Tokenizer: """ def __init__(self, config: TokenizerConfig): - log_main_rank(f"> loading tokenizer from {config.tokenizer_file} ...") + log_main_rank(f"> loading tokenizer from {config.path} ...") special_tokens = [EOD] - self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=config.tokenizer_file, errors="replace", max_len=None) + self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=config.path, errors="replace", max_len=None) self.tokenizer.add_special_tokens({"additional_special_tokens": special_tokens}) self.eod_id = self.tokenizer.vocab[EOD] # Token->id mapping for additional special-tokens diff --git a/fast_llm/engine/config_utils/data_type.py b/fast_llm/engine/config_utils/data_type.py index bf04c2d97..25aa1ea40 100644 --- a/fast_llm/engine/config_utils/data_type.py +++ b/fast_llm/engine/config_utils/data_type.py @@ -6,6 +6,7 @@ if typing.TYPE_CHECKING: import numpy as np import torch + from triton import language as tl diff --git a/fast_llm/engine/config_utils/logging.py b/fast_llm/engine/config_utils/logging.py index 1e0503cd4..06d7298b3 100644 --- a/fast_llm/engine/config_utils/logging.py +++ b/fast_llm/engine/config_utils/logging.py @@ -2,7 +2,9 @@ import logging.config import math import pathlib -import typing + +from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.utils import Assert logger = logging.getLogger(__name__) @@ -53,27 +55,35 @@ def configure_logging( logging.config.dictConfig(logging_config) -def log(*message, log_fn: typing.Union[BaseException, typing.Callable] = logger.info, join: str = ", "): - message = join.join([str(m() if callable(m) else m) for m in message]) - if isinstance(log_fn, BaseException): - raise log_fn(message) - else: - return log_fn(message) +@config_class() +class TensorLogsConfig(Config): + save: bool = Field( + default=False, + desc="Save tensor logs to an artifact file.", + hint=FieldHint.logging, + ) + show: bool = Field( + default=True, + desc="Post all tensor logs to stdout. May lead to extremely large log", + hint=FieldHint.logging, + ) + max_elements: int = Field( + default=8, + desc="Maximum number of tensor values to print for each tensor when posting tensor logs to stdout.", + hint=FieldHint.logging, + valid=skip_valid_if_none(check_field(Assert.gt, 0)), + ) class TensorLogs: # A global buffer for holding logged tensor stats. _tensor_log_stats: list | None = None - max_logged_elements = 8 - verbose: bool = True - - @classmethod - def reset(cls, enabled=True): - cls._tensor_log_stats = [] if enabled else None + config: TensorLogsConfig | None = None @classmethod - def enabled(cls): - return cls._tensor_log_stats is not None + def reset(cls, config: TensorLogsConfig): + cls.config = config + cls._tensor_log_stats = [] if config.save else None @classmethod def append(cls, stats): diff --git a/fast_llm/engine/config_utils/run.py b/fast_llm/engine/config_utils/run.py index e3d58948d..740b83171 100644 --- a/fast_llm/engine/config_utils/run.py +++ b/fast_llm/engine/config_utils/run.py @@ -7,11 +7,11 @@ import yaml -from fast_llm.config import Config, Field, FieldHint, FieldVerboseLevel, check_field, config_class, skip_valid_if_none -from fast_llm.engine.config_utils.logging import TensorLogs, configure_logging, log +from fast_llm.config import Config, Field, FieldHint, FieldVerboseLevel, config_class +from fast_llm.engine.config_utils.logging import TensorLogs, TensorLogsConfig, configure_logging from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.utils import Assert +from fast_llm.utils import Assert, log if typing.TYPE_CHECKING: from fast_llm.engine.distributed.distributed import Distributed @@ -21,17 +21,8 @@ @config_class() class RunConfig(Config): - log_interval: int = Field( - default=100, - desc="Number of iteration between each progress and metric logging.", - hint=FieldHint.logging, - valid=check_field(Assert.gt, 0), - ) - log_offset: int = Field( - default=1, - desc="Determine the first logging iteration, for example to log after the first iteration.", - hint=FieldHint.logging, - valid=check_field(Assert.geq, 0), + tensor_logs: TensorLogsConfig = Field( + default_factory=TensorLogsConfig, desc="Configuration for debug tensor logs.", hint=FieldHint.logging ) # TODO v0.2: Adjust (now only affects logging to file). structured_logs: bool = Field( @@ -46,92 +37,20 @@ class RunConfig(Config): hint=FieldHint.logging, ) log_timestamps: bool = Field( - default=False, desc="Add a timestamp to every Fast-LLM (structured) log.", hint=FieldHint.logging - ) - checkpoint_interval: int | None = Field( - default=None, - desc="The number of training iterations between each checkpoint.", - doc="Checkpoints are temporary saves of the model kept to enable resuming in case of a shutdown.", - hint=FieldHint.feature, - valid=skip_valid_if_none(check_field(Assert.gt, 0)), - ) - checkpoint_offset: int = Field( - default=0, - desc="Determine the first checkpoint iteration, if applicable.", - hint=FieldHint.feature, - valid=check_field(Assert.geq, 0), - ) - # Drop checkpoints if there are more than this amount. - # TODO: Set default to 5? - max_checkpoints: int | None = Field( - default=None, - desc="The maximum number of checkpoints to keep. When exceeding this value, checkpoints are deleted starting from the older ones.", - hint=FieldHint.feature, - valid=skip_valid_if_none(check_field(Assert.gt, 0)), - ) - # Exclude these checkpoints from the `max_checkpoints` - # (counted in training steps, must be a multiple of `checkpoint_interval`) - export_interval: int | None = Field( - default=None, - desc="The number of training iterations between each export. Must be a multiple of the checkpoint interval.", - doc="Export are permanent saves of the model, which may for example be kept for downstream usage such as benchmarking, for future reference, or as additional backup.", - hint=FieldHint.feature, - valid=skip_valid_if_none(check_field(Assert.gt, 0)), - ) - stop_interval: int | None = Field( - default=None, - desc="Perform automated shutdowns at predefined intervals.", - hint=FieldHint.feature, - valid=skip_valid_if_none(check_field(Assert.gt, 0)), - ) - stop_offset: int = Field( - default=0, - desc="Determine the iteration for the first automated shutdown, if applicable.", - hint=FieldHint.feature, - valid=check_field(Assert.geq, 0), + default=True, desc="Add a timestamp to every Fast-LLM (structured) log.", hint=FieldHint.logging ) + # TODO: Only needed for wandb? experiment_name: str | None = Field( default=None, desc="A custom name for the experiment. Default: the experiment directory name or 'default'", hint=FieldHint.feature, ) - wandb_group_name: str = Field(default="default", desc="A group name for Wandb", hint=FieldHint.feature) - wandb_project_name: str = Field(default="fast_llm", desc="A project name for Wandb", hint=FieldHint.feature) - wandb_entity_name: str | None = Field(default=None, desc="An entity (user) name for Wandb", hint=FieldHint.feature) - wandb_status_interval: int | None = Field( - default=None, - desc="The number of training iterations between each Wandb log. Must be a multiple of the logging interval.", - hint=FieldHint.feature, - valid=skip_valid_if_none(check_field(Assert.gt, 0)), - ) - wandb_post_alerts: bool = Field( - default=None, - desc="Post wandb status updates on status changes (run begin/end) and optionally every `wandb_status_interval` iterations. " - "The update may be posted by email and/or slack depending on the Wandb account configuration.", - hint=FieldHint.feature, - ) # Enable torch compile. torch_dynamo_enable: bool = Field( default=True, desc="Set to False to disable torch compile entirely. Not recommended unless there is a good reason to do so.", hint=FieldHint.expert, ) - save_tensor_logs: bool = Field( - default=False, - desc="Save tensor logs to an artifact file.", - hint=FieldHint.logging, - ) - show_tensor_logs: bool = Field( - default=True, - desc="Post all tensor logs to stdout. May lead to extremely large log", - hint=FieldHint.logging, - ) - tensor_logs_show_elements: int = Field( - default=8, - desc="Maximum number of tensor values to print for each tensor when posting tensor logs to stdout.", - hint=FieldHint.logging, - valid=skip_valid_if_none(check_field(Assert.gt, 0)), - ) enable_triton_kernels: bool = Field( default=True, desc="Global switch to allow disabling triton kernels. This parameter may be ignored when no alternative is available.", @@ -145,18 +64,9 @@ class RunConfig(Config): ) def _validate(self): - if self.wandb_post_alerts is None: - self.wandb_post_alerts = bool(self.wandb_status_interval) - super()._validate() - if self.wandb_status_interval: - assert self.wandb_post_alerts - assert self.wandb_status_interval % self.log_interval == 0 if self.experiment_dir is None: - assert not self.checkpoint_interval - if not self.checkpoint_interval: - assert not self.export_interval - elif self.export_interval: - assert self.checkpoint_interval and self.export_interval % self.checkpoint_interval == 0 + assert not self.tensor_logs.save + super()._validate() @config_class() @@ -234,69 +144,55 @@ def __init__( self._distributed = distributed # TODO: Main rank should contain the last pipeline stage so it calculates loss - self.is_main_rank = self._distributed_config.rank == _MAIN_RANK - self.is_model_parallel_main_rank = self._distributed_config.data_rank == 0 - self.is_pipeline_parallel_main_rank = ( + self._is_main_rank = self._distributed_config.rank == _MAIN_RANK + self._is_model_parallel_main_rank = self._distributed_config.data_rank == 0 + self._is_pipeline_parallel_main_rank = ( self._distributed_config.data_rank == 0 and self._distributed_config.tensor_rank == 0 ) config_dict = config.to_serialized() if self._config.experiment_dir is not None: - experiment_dir = self._config.experiment_dir.resolve() - self.dataset_cache_dir = experiment_dir / "dataset_cache" - self._checkpoint_dir = experiment_dir / "checkpoints" - self._export_dir = experiment_dir / "export" - if self.is_main_rank: + self._experiment_directory = self._config.experiment_dir.resolve() + self.dataset_cache_dir = self._experiment_directory / "dataset_cache" + self._checkpoint_dir = self._experiment_directory / "checkpoints" + self._export_dir = self._experiment_directory / "export" + if self._is_main_rank: self._checkpoint_dir.mkdir(exist_ok=True, parents=True) - (experiment_dir / "runs").mkdir(exist_ok=True, parents=True) - run = len(list((experiment_dir / "runs").iterdir())) - (experiment_dir / "runs" / str(run)).mkdir() - yaml.safe_dump(config_dict, (experiment_dir / "config.yaml").open("w")) + (self._experiment_directory / "runs").mkdir(exist_ok=True, parents=True) + run = len(list((self._experiment_directory / "runs").iterdir())) + (self._experiment_directory / "runs" / str(run)).mkdir() + yaml.safe_dump(config_dict, (self._experiment_directory / "config.yaml").open("w")) self.dataset_cache_dir.mkdir(exist_ok=True) else: run = 0 # Make sure all the workers agree on the run. This also acts as a barrier. self.index = self._broadcast_int(run) - run_dir = experiment_dir / "runs" / str(self.index) + run_dir = self._experiment_directory / "runs" / str(self.index) self._artifact_dir = run_dir / "artifacts" / str(self._distributed_config.rank) log_dir = run_dir / "logs" - self._save_tensor_logs = self._config.save_tensor_logs else: - experiment_dir, self._checkpoint_dir, self._artifact_dir, log_dir = None, None, None, None + _experiment_directory, self._checkpoint_dir, self._artifact_dir, log_dir = None, None, None, None self.dataset_cache_dir = None self.index = None - self._save_tensor_logs = False if self._config.structured_logs: config.configure_logging(log_dir) - self.use_wandb = self._config.wandb_entity_name is not None and self.is_main_rank - self.experiment_name = self._config.experiment_name or ( - "default" if experiment_dir is None else experiment_dir.name + self._experiment_name = self._config.experiment_name or ( + "default" if self._experiment_directory is None else self._experiment_directory.name ) - if self.use_wandb: - import wandb - - # Wandb login from file - api_key_path = os.environ.get("WANDB_API_KEY_PATH") - if api_key_path: - os.environ["WANDB_API_KEY"] = pathlib.Path(api_key_path).open("r").read().strip() - wandb_path = None if experiment_dir is None else experiment_dir / "wandb_config.yaml" - if wandb_path is not None and wandb_path.is_file(): - wandb_config = yaml.safe_load(wandb_path.open("r")) - else: - wandb_config = { - "id": wandb.sdk.lib.runid.generate_id(16), - "project": self._config.wandb_project_name, - "name": self.experiment_name, - "entity": self._config.wandb_entity_name, - "group": self._config.wandb_group_name, - "save_code": False, - "resume": "allow", - } - if wandb_path is not None: - yaml.safe_dump(wandb_config, wandb_path.open("w")) - wandb.init(config=config_dict, **wandb_config) + + @property + def is_main_rank(self): + return self._is_main_rank + + @property + def experiment_directory(self): + return self._experiment_directory + + @property + def experiment_name(self): + return self._experiment_name @property def _is_running(self): @@ -306,14 +202,13 @@ def save_logged_tensors(self, iteration: int | str): import torch assert self._is_running - if self._save_tensor_logs: - tensor_stats = TensorLogs.get() - if tensor_stats: - torch.save(tensor_stats, self.open_artifact(f"tensor_logs_{iteration}.pt", mode="wb")) - TensorLogs.reset() + tensor_stats = TensorLogs.get() + if tensor_stats: + torch.save(tensor_stats, self.open_artifact(f"tensor_logs_{iteration}.pt", mode="wb")) + TensorLogs.reset(self._config.tensor_logs) - def get_save_checkpoint_context(self, iteration: int, export: bool = False): - return self._SaveCheckpointContext(self, iteration, export) + def get_save_checkpoint_context(self, iteration: int, export: bool = False, keep: int | None = None): + return self._SaveCheckpointContext(self, iteration, export, keep) def get_load_checkpoint_context(self, iteration: int): return self._LoadCheckpointContext(self, iteration) @@ -342,16 +237,17 @@ def directory(self): return self._directory class _SaveCheckpointContext(_CheckpointContext): - def __init__(self, run: "Run", iteration: int, export: bool = False): + def __init__(self, run: "Run", iteration: int, export: bool = False, keep: int | None = None): super().__init__(run, iteration) self._export = export + self._keep = keep if self._export: self._link_directory = self._directory self._directory = self._run._export_dir / str(self._iteration) def __enter__(self): assert self._run._is_running - if self._run.is_main_rank: + if self._run._is_main_rank: logger.info(f"Saving checkpoint at iteration {self._iteration}") self._directory.mkdir(parents=True) if self._export: @@ -363,11 +259,11 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): if not exc_type: self._run.barrier(f"save {self._iteration} exit") - if self._run.is_main_rank: + if self._run._is_main_rank: # Prevent corrupted checkpoint. (self._directory / "ok").open("w") logger.info(f"Checkpoint saved to {self._directory}") - self._run._delete_old_checkpoints() + self._run._delete_old_checkpoints(self._keep) class _LoadCheckpointContext(_CheckpointContext): def __enter__(self): @@ -379,12 +275,12 @@ def __exit__(self, exc_type, exc_val, exc_tb): if not exc_type: self._run.barrier(f"load {self._iteration} exit") - def _delete_old_checkpoints(self): + def _delete_old_checkpoints(self, keep: int | None): assert self._is_running - if self._config.max_checkpoints is None: + if keep is None: return checkpoints = sorted(int(path.name) for path in self._checkpoint_dir.iterdir()) - for checkpoint in checkpoints[: -self._config.max_checkpoints]: + for checkpoint in checkpoints[:-keep]: path = self._checkpoint_dir / str(checkpoint) logger.info(f"Deleting checkpoint at {path}") try: @@ -396,7 +292,7 @@ def get_last_checkpoint(self): assert self._is_running if self._checkpoint_dir is None: return None - if self.is_main_rank: + if self._is_main_rank: checkpoints = [int(path.name) for path in self._checkpoint_dir.iterdir()] iteration = max(checkpoints) if checkpoints else -1 else: @@ -404,27 +300,6 @@ def get_last_checkpoint(self): iteration = self._broadcast_int(iteration) return iteration if iteration >= 0 else None - def log_wandb_metrics(self, completed_steps: int, metrics: dict[str, dict[str, float | int]]): - assert self._is_running - # Note: metrics modified in-place - if self.use_wandb: - import wandb - - wandb.log(metrics, step=completed_steps) # noqa - - def post_wandb_alert(self, title, text, level="INFO", wait=0.001): - assert self._is_running - if self.use_wandb and self._config.wandb_post_alerts: - import wandb - - wandb.alert( - title=title() if callable(title) else title, - text=f"[{self._config.wandb_project_name}/{self.experiment_name}, run {self.index}]" - f" {text() if callable(text) else text}", - level=level, - wait_duration=wait, - ) - def open_artifact(self, name: str, mode: str | None = "w", verbose=True): assert self._is_running if self._artifact_dir is None: @@ -441,19 +316,11 @@ def __enter__(self): assert not self._is_running global _run _run = self - self.post_wandb_alert(f"Run started!", "", "ERROR") - if self._save_tensor_logs: - TensorLogs.reset() - TensorLogs.verbose = self._config.show_tensor_logs - TensorLogs.max_logged_elements = self._config.tensor_logs_show_elements + TensorLogs.reset(self._config.tensor_logs) def __exit__(self, exc_type, exc_val: OSError, exc_tb): assert self._is_running global _run - if exc_val: - self.post_wandb_alert(f"Run crashed!", (lambda: ", ".join(exc_val.args)), "ERROR") - else: - self.post_wandb_alert(f"Run ended!", "", "INFO") self.save_logged_tensors("none") _run = None @@ -476,7 +343,7 @@ def log_main_rank(*message, log_fn: typing.Union[BaseException, typing.Callable] def is_model_parallel_main_rank(): - return is_main_rank() if _run is None else _run.is_model_parallel_main_rank + return is_main_rank() if _run is None else _run._is_model_parallel_main_rank # Noqa def log_model_parallel_main_rank(*message, log_fn=logger.info): @@ -485,7 +352,7 @@ def log_model_parallel_main_rank(*message, log_fn=logger.info): def is_pipeline_parallel_main_rank(): - return is_main_rank() if _run is None else _run.is_pipeline_parallel_main_rank + return is_main_rank() if _run is None else _run._is_pipeline_parallel_main_rank # Noqa def log_pipeline_parallel_main_rank(*message, log_fn=logger.info): diff --git a/fast_llm/engine/huggingface/config.py b/fast_llm/engine/huggingface/config.py index f3a339e42..1adff8bdd 100644 --- a/fast_llm/engine/huggingface/config.py +++ b/fast_llm/engine/huggingface/config.py @@ -56,14 +56,14 @@ def _get_config_dict(cls, pretrained_model_name_or_path: str | os.PathLike | Pre # Get the pretrained config. if "pretrained" in kwargs: assert isinstance(kwargs["pretrained"], PretrainedConfig) - assert kwargs["pretrained"].pretrained_checkpoint_path == pretrained_model_name_or_path + assert kwargs["pretrained"].path == pretrained_model_name_or_path pretrained = kwargs.pop("pretrained") elif isinstance(pretrained_model_name_or_path, PretrainedConfig): pretrained = pretrained_model_name_or_path else: pretrained = PretrainedConfig( - pretrained_checkpoint_path=pathlib.Path(pretrained_model_name_or_path), - pretrained_checkpoint_type=CheckpointType.state_dict, + path=pathlib.Path(pretrained_model_name_or_path), + format=CheckpointType.state_dict, ) metadata = cls.model_config_class.load_pretrained_metadata(pretrained) updates = {} diff --git a/fast_llm/engine/huggingface/model.py b/fast_llm/engine/huggingface/model.py index 995d236f4..fa46d0e40 100644 --- a/fast_llm/engine/huggingface/model.py +++ b/fast_llm/engine/huggingface/model.py @@ -66,8 +66,8 @@ def from_pretrained( # Pretrained config. if not isinstance(pretrained_model_name_or_path, PretrainedConfig): pretrained_model_name_or_path = PretrainedCheckpointConfig( - pretrained_checkpoint_path=pathlib.Path(pretrained_model_name_or_path), - pretrained_checkpoint_type=CheckpointType.state_dict, + path=pathlib.Path(pretrained_model_name_or_path), + format=CheckpointType.state_dict, ) config_updates = {} diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index 119bd208b..e1a756c08 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -189,29 +189,25 @@ class CheckpointType(str, enum.Enum): @config_class() class PretrainedConfig(Config): - pretrained_checkpoint_path: pathlib.Path | None = Field( + path: pathlib.Path | None = Field( default=None, desc="Path to the checkpoint.", hint=FieldHint.core, ) - pretrained_checkpoint_type: CheckpointType = Field( + format: CheckpointType = Field( default=CheckpointType.distributed, desc="Format of the checkpoint.", hint=FieldHint.core, ) - imported_model_type: str | None = Field( + imported_type: str | None = Field( default=None, desc="Model type for external models (ex. Huggingace type).", hint=FieldHint.feature, ) - use_pretrained_config: bool = Field( - default=True, - desc="Load the architecture config from the pretrained checkpoint.", - hint=FieldHint.feature, - ) - ignore_pretrained_config: bool = Field( + override_architecture: bool = Field( default=False, - desc="Ignore the pretrained checkpoint architecture config, i.e., disable verification.", + desc="Ignore the base model architecture from the pretrained checkpoint and use the provided one instead." + " May have unintended consequences.", hint=FieldHint.feature, ) load_full_base_model_config: bool = Field( @@ -232,17 +228,15 @@ def _validate(self): @property def compare_log_fn(self): - return logger.warning if self.ignore_pretrained_config else ValueError + return logger.warning if self.override_architecture else ValueError @config_class() class PretrainedCheckpointConfig(PretrainedConfig): - # Load weights from pretrained_checkpoint_path (if applicable), + # Load weights from path (if applicable), # otherwise reinitialize them (i.e. load the config only.) - load_pretrained_weights: bool = Field( - default=True, desc="Load model weights from the checkpoint.", hint=FieldHint.feature - ) - load_pretrained_optimizer: bool = Field( + load_weights: bool = Field(default=True, desc="Load model weights from the checkpoint.", hint=FieldHint.feature) + load_optimizer: bool = Field( default=False, desc="Load the optimizer state from the checkpoint.", hint=FieldHint.feature ) @@ -308,7 +302,7 @@ def from_pretrained( default: "FastLLMModelConfig" = None, ): # TODO: Add *updates? - assert pretrained.pretrained_checkpoint_path is not None + assert pretrained.path is not None metadata = cls.load_pretrained_metadata(pretrained) return cls.from_metadata(pretrained, metadata, default) @@ -332,7 +326,7 @@ def from_metadata( return cls._from_metadata_v0(pretrained, metadata, default, updates) pretrained_config = cls.from_dict(metadata["fast_llm_config"]) - if not pretrained.use_pretrained_config: + if pretrained.override_architecture: assert default is not None config = default.to_copy() config.base_model.compare_architecture(pretrained_config.base_model, pretrained.compare_log_fn) @@ -367,12 +361,15 @@ def _from_metadata_v0( with NoAutoValidate(): if default is None: - assert pretrained.use_pretrained_config + assert not pretrained.override_architecture config = cls(base_model=base_model_config_cls()) else: config = default.to_copy() - if pretrained.use_pretrained_config: + if pretrained.override_architecture: + config.validate() + architecture_config.compare_architecture(default.base_model, pretrained.compare_log_fn) + else: if pretrained.load_full_base_model_config: # Replace the whole config config.base_model = base_model_config_cls.from_flat_dict(metadata["model_config"]) @@ -384,9 +381,6 @@ def _from_metadata_v0( config.distributed = DistributedConfig.from_flat_dict( metadata["distributed_config"], ) - else: - config.validate() - architecture_config.compare_architecture(default.base_model, pretrained.compare_log_fn) config.validate() if updates: @@ -398,24 +392,20 @@ def load_pretrained_metadata(cls, pretrained): import yaml base_model_config_cls = cls.get_base_model_config_cls() - if pretrained.pretrained_checkpoint_type == CheckpointType.distributed: - return yaml.safe_load((pretrained.pretrained_checkpoint_path / "metadata.yaml").open("r")) - elif pretrained.pretrained_checkpoint_type == CheckpointType.state_dict: - return json.load((pretrained.pretrained_checkpoint_path / f"state_dict.safetensors.index.json").open("r"))[ - "metadata" - ] - elif pretrained.pretrained_checkpoint_type == CheckpointType.huggingface: - converter_class = base_model_config_cls.get_converter_class(pretrained.imported_model_type) - imported_model_config = converter_class.import_config( - converter_class.load_config(pretrained.pretrained_checkpoint_path), True - ) + if pretrained.format == CheckpointType.distributed: + return yaml.safe_load((pretrained.path / "metadata.yaml").open("r")) + elif pretrained.format == CheckpointType.state_dict: + return json.load((pretrained.path / f"state_dict.safetensors.index.json").open("r"))["metadata"] + elif pretrained.format == CheckpointType.huggingface: + converter_class = base_model_config_cls.get_converter_class(pretrained.imported_type) + imported_model_config = converter_class.import_config(converter_class.load_config(pretrained.path), True) return { "fast_llm_config": {"base_model": imported_model_config.to_serialized()}, "checkpoint_type": CheckpointType.huggingface.value, "checkpoint_version": CHECKPOINT_VERSION, } else: - raise NotImplementedError(pretrained.pretrained_checkpoint_type) + raise NotImplementedError(pretrained.format) @config_class() @@ -462,7 +452,7 @@ def base_model(self): def _validate(self): assert self.model is not None self.pretrained.validate() - if self.pretrained.pretrained_checkpoint_path is not None: + if self.pretrained.path is not None: self.model = self.model.from_pretrained(self.pretrained, default=self.model) self._setup() super()._validate() diff --git a/fast_llm/engine/multi_stage/fast_llm_model.py b/fast_llm/engine/multi_stage/fast_llm_model.py index 9663450a0..1e4deaa8f 100644 --- a/fast_llm/engine/multi_stage/fast_llm_model.py +++ b/fast_llm/engine/multi_stage/fast_llm_model.py @@ -99,23 +99,21 @@ def save_checkpoint( raise NotImplementedError(checkpoint_config.checkpoint_type) def load_pretrained_checkpoint(self, pretrained_config: PretrainedCheckpointConfig): - if pretrained_config.pretrained_checkpoint_type == CheckpointType.distributed: + if pretrained_config.format == CheckpointType.distributed: # TODO: Check if same format. self._load_distributed_checkpoint(pretrained_config) - elif pretrained_config.pretrained_checkpoint_type == CheckpointType.state_dict: + elif pretrained_config.format == CheckpointType.state_dict: self._load_state_dict_checkpoint(pretrained_config) - elif pretrained_config.pretrained_checkpoint_type == CheckpointType.huggingface: + elif pretrained_config.format == CheckpointType.huggingface: self._import_checkpoint(pretrained_config) else: - raise NotImplementedError(pretrained_config.pretrained_checkpoint_type) + raise NotImplementedError(pretrained_config.format) def load_distributed_checkpoint_same_format(self, directory: pathlib.Path): # TODO: Handle barriers, ok file, etc. here # TODO: More safety checks # TODO: Integrate to load_checkpoint. - pretrained_config = PretrainedCheckpointConfig( - pretrained_checkpoint_path=directory, pretrained_checkpoint_type=CheckpointType.distributed - ) + pretrained_config = PretrainedCheckpointConfig(path=directory, format=CheckpointType.distributed) metadata = self.config_class.load_pretrained_metadata(pretrained_config) with self._LoadContext(self, safe=False, load_optimizer=True, reset_pads=False) as context: Assert.eq( @@ -168,7 +166,7 @@ def from_pretrained( model.setup(Distributed(config.distributed, use_cpu=use_cpu), mode=mode) if mode.on_device: - if pretrained_config.load_pretrained_weights: + if pretrained_config.load_weights: model.load_pretrained_checkpoint(pretrained_config) else: model.initialize_weights() @@ -575,7 +573,6 @@ def _load_distributed_checkpoint(self, pretrained_config: PretrainedCheckpointCo metadata = self.config_class.load_pretrained_metadata(pretrained_config) loaded_pretrained_config = pretrained_config.to_copy( { - "use_pretrained_config": True, "load_full_base_model_config": True, "load_full_fast_llm_config": True, }, @@ -585,7 +582,7 @@ def _load_distributed_checkpoint(self, pretrained_config: PretrainedCheckpointCo metadata, ) with self._LoadContext( - self, safe=True, load_optimizer=pretrained_config.load_pretrained_optimizer, reset_pads=True + self, safe=True, load_optimizer=pretrained_config.load_optimizer, reset_pads=True ) as context: Assert.eq(metadata["state_shard_names"][: context.num_shards], list(context.shard_names)) @@ -595,7 +592,7 @@ def _load_distributed_checkpoint(self, pretrained_config: PretrainedCheckpointCo optimizer_state_names=context.shard_names[1:], verbose=False, ) - path = pretrained_config.pretrained_checkpoint_path / f"rank_{rank}.safetensors" + path = pretrained_config.path / f"rank_{rank}.safetensors" logger.info(f"Loading from {path}") # TODO: skip shards without overlap. with safetensors.safe_open(path, framework="pt", device=str(self._distributed.device)) as f: @@ -609,15 +606,15 @@ def _load_state_dict_checkpoint(self, pretrained_config: PretrainedCheckpointCon # TODO: Verify more distributed configs. # TODO: More safety checks with self._LoadContext( - self, safe=True, load_optimizer=pretrained_config.load_pretrained_optimizer, reset_pads=True + self, safe=True, load_optimizer=pretrained_config.load_optimizer, reset_pads=True ) as context: - index_path = pretrained_config.pretrained_checkpoint_path / f"state_dict.safetensors.index.json" + index_path = pretrained_config.path / f"state_dict.safetensors.index.json" logger.info(f"Loading index from {index_path}") file_names = set(json.load(index_path.open("r"))["weight_map"].values()) for file_name in file_names: - logger.info(f"Loading from {pretrained_config.pretrained_checkpoint_path/file_name}") + logger.info(f"Loading from {pretrained_config.path / file_name}") with safetensors.safe_open( - pretrained_config.pretrained_checkpoint_path / file_name, + pretrained_config.path / file_name, framework="pt", device=str(self._distributed.device), ) as f: @@ -632,23 +629,19 @@ def _load_state_dict_checkpoint(self, pretrained_config: PretrainedCheckpointCon def _import_checkpoint(self, pretrained_config: PretrainedCheckpointConfig): # TODO: Support optimizer? - assert not pretrained_config.load_pretrained_optimizer + assert not pretrained_config.load_optimizer # TODO: Verify more distributed configs. # TODO: Safety checks - converter_class = self.base_model.architecture_cls().get_converter_class(pretrained_config.imported_model_type) - converter = converter_class.from_config( - converter_class.load_config(pretrained_config.pretrained_checkpoint_path) - ) + converter_class = self.base_model.architecture_cls().get_converter_class(pretrained_config.imported_type) + converter = converter_class.from_config(converter_class.load_config(pretrained_config.path)) self._base_model_config.compare_architecture(converter.config, pretrained_config.compare_log_fn) state_dict = {} with self._LoadContext( - self, safe=True, load_optimizer=pretrained_config.load_pretrained_optimizer, reset_pads=True + self, safe=True, load_optimizer=pretrained_config.load_optimizer, reset_pads=True ) as context: - for name, tensor in converter.load_weights( - pretrained_config.pretrained_checkpoint_path, self._distributed.device - ): + for name, tensor in converter.load_weights(pretrained_config.path, self._distributed.device): assert name not in state_dict state_dict[name] = tensor for parameter_name, fast_llm_tensor in converter.convert_state_dict(state_dict, False).items(): diff --git a/fast_llm/engine/optimizer/config.py b/fast_llm/engine/optimizer/config.py index 982875fbc..3a154c9e5 100644 --- a/fast_llm/engine/optimizer/config.py +++ b/fast_llm/engine/optimizer/config.py @@ -17,94 +17,100 @@ class LearningRateStageType: @config_class() class LearningRateScheduleConfig(Config): - lr: float = Field(default=0.0001, desc="Base learning rate for the optimizer.", hint=FieldHint.core) - lr_decay_style: str = Field(default="constant", desc="The learning rate decay formula.", hint=FieldHint.feature) - lr_decay_iters: int | None = Field( + base: float = Field(default=0.0001, desc="Base learning rate for the optimizer.", hint=FieldHint.core) + decay_style: str = Field(default="constant", desc="The learning rate decay formula.", hint=FieldHint.feature) + decay_iterations: int | None = Field( default=None, desc="Duration of the learning rate decay, in iterations.", hint=FieldHint.feature ) - lr_decay_power: float = Field( + decay_power: float = Field( default=1.0, desc="Exponent for learning rate decay, applied on the decay step..", hint=FieldHint.feature ) - lr_warmup_iters: int = Field( + warmup_iterations: int = Field( default=0, desc="Number of iteration for the learning rate warmup.", hint=FieldHint.feature ) - min_lr: float = Field(default=0.0, desc="Learning rate at the end of decay.", hint=FieldHint.feature) - lr_schedule: str | None = Field( + minimum: float = Field(default=0.0, desc="Learning rate at the end of decay.", hint=FieldHint.feature) + schedule: str | None = Field( default=None, desc="Complex learning rate schedule encoded in a string (untested, replaces the other arguments.", hint=FieldHint.wip, ) +@config_class() +class GradientScalerConfig(Config): + constant: float | None = Field( + default=None, + desc="Constant multiplier applied to the loss. Setting this disables dynamic scaling.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) + initial: float = Field( + default=2**16, + desc="Initial loss scale for dynamic scaling (fp16).", + hint=FieldHint.feature, + valid=check_field(Assert.gt, 0), + ) + minimum: float = Field( + default=1.0, + desc="Minimum loss scale for dynamic scaling (fp16).", + hint=FieldHint.feature, + valid=check_field(Assert.gt, 0), + ) + window: int = Field( + default=1000, + desc="Interval between dynamic scaling growth (fp16).", + hint=FieldHint.feature, + valid=check_field(Assert.gt, 0), + ) + hysteresis: int = Field( + default=2, + desc="Number of failed updates to tolerate before lowering the learning rate in dynamic scaling (fp16).", + hint=FieldHint.feature, + valid=check_field(Assert.gt, 0), + ) + + @config_class() class OptimizerConfig(Config): - lr_schedule: LearningRateScheduleConfig = Field( + learning_rate: LearningRateScheduleConfig = Field( default_factory=LearningRateScheduleConfig, desc="A schedule for the learning rate.", hint=FieldHint.core, ) + gradient_scaler: GradientScalerConfig = Field( + default_factory=GradientScalerConfig, + desc="Configuration for the fixed or dynamic gradient scaling.", + hint=FieldHint.feature, + ) weight_decay: float = Field( default=0.01, desc="Weight decay (Adamw).", hint=FieldHint.core, valid=check_field(Assert.geq, 0), ) - adam_beta1: float = Field( + beta_1: float = Field( default=0.9, desc="First Adam momentum.", hint=FieldHint.optional, valid=check_field(Assert.in_range_incl, 0, 1), ) - adam_beta2: float = Field( + beta_2: float = Field( default=0.999, desc="Second Adam Momentum.", hint=FieldHint.optional, valid=check_field(Assert.in_range_incl, 0, 1), ) - adam_eps: float = Field( + epsilon: float = Field( default=1e-8, desc="Regularizer for Adam.", hint=FieldHint.optional, valid=check_field(Assert.gt, 0) ) - clip_grad: float = Field( - default=1.0, - desc="Duration of the learning rate decay, in iterations.", - hint=FieldHint.feature, - valid=check_field(Assert.gt, 0), - ) - loss_scale: float | None = Field( - default=None, - desc="Constant multiplier applied to the loss (ignored in fp16).", - hint=FieldHint.feature, - valid=skip_valid_if_none(check_field(Assert.geq, 0)), - ) - initial_loss_scale: float = Field( - default=2**16, - desc="Initial loss scale for dynamic scaling (fp16).", - hint=FieldHint.feature, - valid=check_field(Assert.gt, 0), - ) - min_loss_scale: float = Field( + gradient_norm_clipping: float = Field( default=1.0, - desc="Minimum loss scale for dynamic scaling (fp16).", - hint=FieldHint.feature, - valid=check_field(Assert.gt, 0), - ) - loss_scale_window: int = Field( - default=1000, - desc="Interval between dynamic scaling growth (fp16).", + desc="Clip the gradient norm to this value.", hint=FieldHint.feature, valid=check_field(Assert.gt, 0), ) - hysteresis: int = Field( - default=2, - desc="Number of failed updates to tolerate before lowering the learning rate in dynamic scaling (fp16).", - hint=FieldHint.feature, - valid=check_field(Assert.gt, 0), - ) - lr_schedule_offset: int = Field( - default=0, desc="Offset for the learning rate schedule, in steps.", hint=FieldHint.feature - ) - default_lr_scale: float = Field( + default_learning_rate_scale: float = Field( default=1.0, desc="Default multiplier to apply to the learning rate schedule, for parameters that do not define a scale.", hint=FieldHint.feature, diff --git a/fast_llm/engine/optimizer/learning_rate.py b/fast_llm/engine/optimizer/learning_rate.py index 9a1f9e761..799918e2f 100644 --- a/fast_llm/engine/optimizer/learning_rate.py +++ b/fast_llm/engine/optimizer/learning_rate.py @@ -105,16 +105,20 @@ def __call__(self, step): def create_schedule_from_config(config: LearningRateScheduleConfig) -> LearningRateSchedule: stages = [] - if config.lr_schedule is None: - if config.lr_warmup_iters > 0: - stages.append(PowerLRStage(begin_step=0, end_step=config.lr_warmup_iters, lr=0, end_lr=config.lr)) - kwargs = {"begin_step": config.lr_warmup_iters, "end_step": config.lr_decay_iters, "lr": float(config.lr)} - if config.lr_decay_style != "constant": - kwargs.update(end_lr=config.min_lr, power=config.lr_decay_power) - stages.append(_STAGE_TYPE_MAP[config.lr_decay_style](**kwargs)) + if config.schedule is None: + if config.warmup_iterations > 0: + stages.append(PowerLRStage(begin_step=0, end_step=config.warmup_iterations, lr=0, end_lr=config.base)) + kwargs = { + "begin_step": config.warmup_iterations, + "end_step": config.decay_iterations, + "lr": float(config.base), + } + if config.decay_style != "constant": + kwargs.update(end_lr=config.minimum, power=config.decay_power) + stages.append(_STAGE_TYPE_MAP[config.decay_style](**kwargs)) else: begin_step = 0 - for stage_arg_str in config.lr_schedule.split(";"): + for stage_arg_str in config.schedule.split(";"): try: for stage_type, num_steps, lr, *stage_args in stage_arg_str.split(","): assert begin_step is not None diff --git a/fast_llm/engine/optimizer/optimizer.py b/fast_llm/engine/optimizer/optimizer.py index f2cf0a7f4..5a7598031 100644 --- a/fast_llm/engine/optimizer/optimizer.py +++ b/fast_llm/engine/optimizer/optimizer.py @@ -7,22 +7,22 @@ from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.optimizer.config import OptimizerConfig, ParamGroup +from fast_llm.engine.optimizer.config import GradientScalerConfig, OptimizerConfig, ParamGroup from fast_llm.engine.optimizer.learning_rate import create_schedule_from_config from fast_llm.utils import Assert -def get_grad_scaler(config: OptimizerConfig, distributed: Distributed) -> "GradScaler": - if config.loss_scale: +def get_grad_scaler(config: GradientScalerConfig, distributed: Distributed) -> "GradScaler": + if config.constant: return ConstantGradScaler( - initial_scale=config.loss_scale, + initial_scale=config.constant, distributed=distributed, ) elif distributed.config.training_dtype == DataType.float16: return DynamicGradScaler( - initial_scale=config.initial_loss_scale, - min_scale=config.min_loss_scale, - growth_interval=config.loss_scale_window, + initial_scale=config.initial, + min_scale=config.minimum, + growth_interval=config.window, hysteresis=config.hysteresis, distributed=distributed, ) @@ -56,15 +56,15 @@ def __init__( grads_for_norm: list[torch.Tensor], distributed: Distributed, ): - self._config = config.validate() + self._config = config self._param_groups = _merge_and_filter_groups(param_groups) self._grads_for_norm = [g for g in grads_for_norm if g.device.type != "meta" and g.numel() > 0] self._grad_norm = None if self._grads_for_norm else torch.zeros([1], device=distributed.device) self._grads = [g for group in self._param_groups for g in group.grads] - self._grad_scaler = get_grad_scaler(self._config, distributed) + self._grad_scaler = get_grad_scaler(self._config.gradient_scaler, distributed) self._noop_flag = self._grad_scaler.noop_flag self._reduce_group = distributed.world_group - self._lr_schedule = create_schedule_from_config(self._config.lr_schedule) + self._lr_schedule = create_schedule_from_config(self._config.learning_rate) def _clip_grad_norm(self): # TODO: Optimize this. @@ -78,9 +78,9 @@ def _clip_grad_norm(self): group=self._reduce_group, ) grad_norm.pow_(0.5) - if self._config.clip_grad > 0.0: + if self._config.gradient_norm_clipping > 0.0: # TODO: Use noop flag instead of clamp. - clip_coeff = torch.clamp_max_(self._config.clip_grad / (grad_norm + 1.0e-6), 1.0) + clip_coeff = torch.clamp_max_(self._config.gradient_norm_clipping / (grad_norm + 1.0e-6), 1.0) # if clip_coeff < 1.0: scale_(self._grads, self._noop_flag, clip_coeff) return grad_norm @@ -95,8 +95,12 @@ def step(self, metrics: dict | None = None): update_successful = self._grad_scaler.update_successful() if update_successful: self._optimizer_step += 1 - lr = self._lr_schedule(self._optimizer_step + self._config.lr_schedule_offset) - grad_norm = self._clip_grad_norm().item() if self._config.clip_grad > 0.0 or metrics is not None else None + lr = self._lr_schedule(self._optimizer_step) + grad_norm = ( + self._clip_grad_norm().item() + if self._config.gradient_norm_clipping > 0.0 or metrics is not None + else None + ) for group in self._param_groups: fused_adam( params=group.params, @@ -104,11 +108,11 @@ def step(self, metrics: dict | None = None): exp_avgs=group.exp_avgs, exp_avg_sqs=group.exp_avgs_sq, noop_flag=self._noop_flag, - lr=lr * (self._config.default_lr_scale if group.lr_scale is None else group.lr_scale), - beta1=self._config.adam_beta1 if group.beta1 is None else group.beta1, - beta2=self._config.adam_beta2 if group.beta2 is None else group.beta2, + lr=lr * (self._config.default_learning_rate_scale if group.lr_scale is None else group.lr_scale), + beta1=self._config.beta_1 if group.beta1 is None else group.beta1, + beta2=self._config.beta_2 if group.beta2 is None else group.beta2, wd=self._config.weight_decay if group.weight_decay is None else group.weight_decay, - eps=self._config.adam_eps if group.eps is None else group.eps, + eps=self._config.epsilon if group.eps is None else group.eps, step=self._optimizer_step, ) diff --git a/fast_llm/engine/run/__init__.py b/fast_llm/engine/run/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 5b8f8194e..6f2fce854 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -1,4 +1,7 @@ import argparse +import os +import shlex +import subprocess import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none @@ -14,22 +17,168 @@ from fast_llm.engine.training.trainer import Trainer +def get_interval_config_class(desc: str, offset_desc: str | None = None): + # Intervals are a common pattern, so we standardize them with this helper. + @config_class() + class IntervalConfig(Config): + interval: int | None = Field( + default=None, + desc=f"The number of training iterations between each {desc}. Setting to None will disable.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.gt, 0)), + ) + offset: int = Field( + default=0, + desc=f"Offset for the first {offset_desc or desc}.", + hint=FieldHint.feature, + valid=check_field(Assert.geq, 0), + ) + + def enabled(self, iteration: int | None = None): + return self.interval and (iteration is None or (iteration - self.offset) % self.interval == 0) + + def is_sub_interval(self, other: "IntervalConfig"): + if not self.enabled(): + return True + elif not other.enabled(): + return False + return self.interval % other.interval == 0 and (other.offset % other.interval) == ( + self.offset % other.interval + ) + + def assert_sub_interval(self, other: "IntervalConfig"): + assert self.is_sub_interval(other), f"{self} is not a sub-interval of {other}" + + return IntervalConfig + + @config_class() -class TrainingConfig(Config): - train_iters: int = Field( - default=0, desc="Total number of training iterations.", hint=FieldHint.core, valid=check_field(Assert.geq, 0) +class WandbAlertConfig( + get_interval_config_class( + "Wandb status post (alert). Must be a multiple of the logging interval", + "Wandb status post (alert). Must be compatible with the logging offset", ) - validation_iters: int = Field( - default=0, - desc="Number of iterations for each validation phase. Setting to 0 will disable the validation phase.", +): + status_updates: bool | None = Field( + default=None, + desc="Post wandb status updates on status changes (run begin/end). " + "The update may be posted by email and/or slack depending on the Wandb account configuration.", hint=FieldHint.feature, - valid=check_field(Assert.geq, 0), ) - validation_interval: int = Field( - default=1000, - desc="Number of training steps between each validation phase.", + + def _validate(self): + if self.status_updates is None: + self.post_alerts = self.enabled() + super()._validate() + + +@config_class() +class MetricsLogsConfig(get_interval_config_class("metric logs")): + pass + + +@config_class() +class WandbConfig(Config): + alert: WandbAlertConfig = Field( + default_factory=WandbAlertConfig, + desc="Configuration for Wandb alerts. The alerts may be posted by email and/or slack depending on the Wandb account configuration.", + hint=FieldHint.core, + ) + group_name: str = Field(default="default", desc="A group name for Wandb", hint=FieldHint.feature) + project_name: str = Field(default="fast_llm", desc="A project name for Wandb", hint=FieldHint.feature) + entity_name: str | None = Field(default=None, desc="An entity (user) name for Wandb", hint=FieldHint.feature) + + +@config_class() +class ValidationConfig(get_interval_config_class("validation")): + iterations: int | None = Field( + default=None, + desc="Number of iterations for each validation phase. Setting to None will disable.", hint=FieldHint.feature, - valid=check_field(Assert.gt, 0), + valid=skip_valid_if_none(check_field(Assert.gt, 0)), + ) + + def get_completed_iterations(self, training_iterations: int, completed_validations: int = 0): + # Number of completed validation iterations + return ( + (training_iterations // self.interval + completed_validations) * self.iterations if self.enabled() else 0 + ) + + +@config_class() +class CheckpointConfig(get_interval_config_class("checkpoint")): + keep: int | None = Field( + default=5, + desc="The maximum number of checkpoints to keep. When exceeding this value, checkpoints are deleted starting from the older ones.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.gt, 0)), + ) + + +def _validate_script(value): + if isinstance(value, str): + value = shlex.split(value) + Assert.geq(len(value), 1) + return value + + +@config_class() +class CallbackConfig(Config): + script: list[str] | None = Field( + default=None, + desc="Shell script to run after.", + hint=FieldHint.feature, + valid=skip_valid_if_none(_validate_script), + ) + environment: dict[str, str] = Field( + default_factory=dict, + desc="Environment variables to add to the script.", + hint=FieldHint.feature, + ) + + def run(self): + if self.script is not None: + environment = os.environ.copy() + environment.update(self.environment) + subprocess.Popen(self.script, env=environment) + + +@config_class() +class ExportConfig(get_interval_config_class("export")): + callback: CallbackConfig = Field( + default_factory=CallbackConfig, + desc="Callback (shell script) to run after export.", + hint=FieldHint.core, + ) + + +@config_class() +class ShutdownConfig(get_interval_config_class("automated shutdown")): + pass + + +@config_class() +class TrainingConfig(Config): + validation: ValidationConfig = Field( + default_factory=ValidationConfig, + desc="Configuration for the validation phase", + hint=FieldHint.core, + ) + logs: MetricsLogsConfig = Field( + default_factory=MetricsLogsConfig, desc="Configuration for metric logging.", hint=FieldHint.core + ) + checkpoint: CheckpointConfig = Field( + default_factory=MetricsLogsConfig, desc="Configuration for checkpoints.", hint=FieldHint.core + ) + export: ExportConfig = Field( + default_factory=MetricsLogsConfig, desc="Configuration for exports.", hint=FieldHint.core + ) + shutdown: ShutdownConfig = Field( + default_factory=ShutdownConfig, desc="Configuration for automated shutdown.", hint=FieldHint.core + ) + wandb: WandbConfig = Field(default_factory=WandbConfig, desc="Configuration for Wandb.", hint=FieldHint.core) + train_iters: int = Field( + default=0, desc="Total number of training iterations.", hint=FieldHint.core, valid=check_field(Assert.geq, 0) ) test_iters: int = Field( default=0, @@ -49,12 +198,12 @@ class TrainingConfig(Config): hint=FieldHint.performance, valid=skip_valid_if_none(check_field(Assert.gt, 0)), ) - export_callback_script: str = Field(default="", desc="Shell script to run after export.", hint=FieldHint.feature) - export_callback_env: str = Field( - default="", - desc="Environment variables to add to the export script, encoded in json format.", - hint=FieldHint.feature, - ) + + def _validate(self): + super()._validate() + self.export.assert_sub_interval(self.checkpoint) + self.shutdown.assert_sub_interval(self.checkpoint) + self.wandb.alert.assert_sub_interval(self.logs) @config_class() @@ -90,6 +239,11 @@ class TrainerConfig(PretrainedFastLLMModelConfig, ExperimentConfig): hint=FieldHint.core, ) + def _validate(self): + super()._validate() + if self.run.experiment_dir is None: + assert not self.training.checkpoint.enabled() + @classmethod def get_trainer_class(cls) -> type["Trainer"]: raise NotImplementedError diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index b4aaa862f..a17eca121 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -1,10 +1,6 @@ import abc -import json import logging import math -import os -import shlex -import subprocess import time import typing @@ -23,6 +19,7 @@ from fast_llm.engine.schedule.runner import ScheduleRunner from fast_llm.engine.schedule.schedule import Schedule from fast_llm.engine.training.config import TrainerConfig +from fast_llm.engine.training.wandb import Wandb from fast_llm.logging import format_metrics, get_memory_usage_mib, log_memory_usage from fast_llm.utils import Assert @@ -36,7 +33,9 @@ class Trainer(abc.ABC): _is_setup: bool = False _distributed: Distributed _run: Run + _wandb: Wandb _optimizer: Optimizer + _completed_steps: int def __init__(self, config: TrainerConfig): @@ -57,8 +56,9 @@ def __init__(self, config: TrainerConfig): ) steps_per_split = { PhaseType.training: self._config.training.train_iters, - PhaseType.validation: (self._config.training.train_iters // self._config.training.validation_interval + 1) - * self._config.training.validation_iters, + PhaseType.validation: self._config.training.validation.get_completed_iterations( + self._config.training.train_iters, 1 + ), PhaseType.test: self._config.training.test_iters, } self._samples_per_split = { @@ -84,6 +84,7 @@ def setup(self, distributed: Distributed, run: Run): self._is_setup = True self._distributed = distributed self._run = run + self._wandb = Wandb(self._config.training.wandb, self._run, self._config) # Setup the model. with torch.no_grad(): @@ -129,14 +130,14 @@ def _consumed_tokens(self): @property def _completed_validation_steps(self) -> int: # Number of validation steps performed before the current step - return ( - (self._completed_steps - 1) - // self._config.training.validation_interval - * self._config.training.validation_iters - ) + return self._config.training.validation.get_completed_iterations(self._completed_steps - 1) def run(self): assert self._is_setup + with self._wandb: + self._run_training() + + def _run_training(self): self._prepare_training_state() log_main_rank("done with setup ...") log_pipeline_parallel_main_rank(lambda: log_memory_usage(f"After initial setup", str)) @@ -162,9 +163,9 @@ def run(self): ) formatted_metrics = format_metrics(metrics[PhaseType.test], self._loss_defs, PhaseType.test) log_main_rank(formatted_metrics) - self._run.post_wandb_alert("Testing results", formatted_metrics, "WARN") + self._wandb.alert("Testing results", formatted_metrics, "WARN") # TODO: This may erase some metrics. - self._run.log_wandb_metrics(self._completed_steps, metrics) + self._wandb.log_metrics(self._completed_steps, metrics) def _train(self): # Tracking loss. @@ -178,9 +179,6 @@ def _train(self): distributed_config=self._config.distributed, start_step=self._completed_steps ) - # The triton compilation during the first iteration breaks parallel data loading - # https://github.com/ServiceNow/Fast-LLM/issues/101, - # so we run the first iteration without it. train_iterator = self._get_data_iterator( PhaseType.training, self._completed_steps, @@ -202,10 +200,7 @@ def _train(self): while not stop: # Iteration starts at 1, so we increment at the beginning. self._completed_steps += 1 - is_logging = ( - self._config.run.log_interval - and (self._completed_steps - self._config.run.log_offset) % self._config.run.log_interval == 0 - ) + is_logging = self._config.training.logs.enabled(self._completed_steps) # TODO: Data loader hates getting all micro-batches at once. # (Also preprocessing adds overhead) @@ -275,13 +270,8 @@ def _train(self): metrics[PhaseType.training], self._loss_defs, PhaseType.training ) logger.info(formatted_metrics) - if ( - self._config.run.wandb_status_interval - and (self._completed_steps - self._config.run.log_offset) - % self._config.run.wandb_status_interval - == 0 - ): - self._run.post_wandb_alert("Training results", formatted_metrics, "INFO") + if self._config.training.wandb.alert.enabled(self._completed_steps): + self._wandb.alert("Training results", formatted_metrics, "INFO") advanced_iters = 0 skipped_iters = 0 @@ -294,18 +284,11 @@ def _train(self): done = self._completed_steps >= self._config.training.train_iters # TODO: Signal-based stop. - stop = done or ( - self._config.run.stop_interval - and (self._completed_steps - self._config.run.stop_offset) % self._config.run.stop_interval == 0 - ) + stop = done or self._config.training.shutdown.enabled(self._completed_steps) # Evaluation # TODO: Adjust valid iterator length. if PhaseType.validation in self._samples_per_split and ( - done - or ( - self._config.training.validation_interval - and self._completed_steps % self._config.training.validation_interval == 0 - ) + done or self._config.training.validation.enabled(self._completed_steps) ): if valid_iterator is None: valid_iterator = self._get_data_iterator( @@ -314,42 +297,23 @@ def _train(self): metrics[PhaseType.validation] = self._evaluate( data_iterator=valid_iterator, phase=PhaseType.validation, - num_iters=self._config.training.validation_iters, + num_iters=self._config.training.validation.iterations, begin_iter=self._completed_validation_steps, ) formatted_metrics = format_metrics( metrics[PhaseType.validation], self._loss_defs, PhaseType.validation ) log_main_rank(formatted_metrics) - if ( - self._config.run.wandb_status_interval - and (self._completed_steps - self._config.run.log_offset) - % self._config.run.wandb_status_interval - == 0 - ): - self._run.post_wandb_alert("Validation results", formatted_metrics, "INFO") + if self._config.training.wandb.alert.enabled(self._completed_steps): + self._wandb.alert("Validation results", formatted_metrics, "INFO") if is_main_rank() and metrics: - self._run.log_wandb_metrics(self._completed_steps, metrics) - - if self._config.run.checkpoint_interval and ( - stop - or ( - self._config.run.checkpoint_interval - and (self._completed_steps - self._config.run.checkpoint_offset) - % self._config.run.checkpoint_interval - == 0 - ) - ): + self._wandb.log_metrics(self._completed_steps, metrics) + + if self._config.training.checkpoint.enabled(None if stop else self._completed_steps): self._save_checkpoint( metrics, - export=self._config.run.export_interval - and ( - done - or (self._completed_steps - self._config.run.checkpoint_offset) - % self._config.run.export_interval - == 0 - ), + export=self._config.training.export.enabled(None if done else self._completed_steps), ) return done, metrics @@ -401,12 +365,10 @@ def _evaluate( def _prepare_training_state(self): # Setup the training state. if (last_iteration := self._run.get_last_checkpoint()) is None: - if ( - path := self._config.pretrained.pretrained_checkpoint_path - ) is not None and self._config.pretrained.load_pretrained_weights: + if (path := self._config.pretrained.path) is not None and self._config.pretrained.load_weights: log_main_rank( f"Initializing training state from pretrained checkpoint at {path}" - f" ({'loading' if self._config.pretrained.load_pretrained_optimizer else 'resetting'}" + f" ({'loading' if self._config.pretrained.load_optimizer else 'resetting'}" f" optimizer state)..." ) self._multi_stage.load_pretrained_checkpoint(self._config.pretrained) @@ -440,7 +402,9 @@ def _get_data_iterator(self, phase, completed_steps: int = 0, prefetch_factor: i def _save_checkpoint(self, metrics: dict[PhaseType, dict[str, float | int]] | None, export: bool = False): assert self._is_setup - with self._run.get_save_checkpoint_context(self._completed_steps, export) as checkpoint: + with self._run.get_save_checkpoint_context( + self._completed_steps, export, self._config.training.checkpoint.keep + ) as checkpoint: metadata = { "optimizer": self._optimizer.save(), "completed_steps": self._completed_steps, @@ -451,11 +415,8 @@ def _save_checkpoint(self, metrics: dict[PhaseType, dict[str, float | int]] | No CheckpointConfig(checkpoint_type=CheckpointType.distributed, checkpoint_path=checkpoint.directory), metadata, ) - if export and self._run.is_main_rank and self._config.training.export_callback_script: # noqa - custom_env = os.environ.copy() - if self._config.training.export_callback_env: - custom_env.update(json.loads(self._config.training.export_callback_env)) - subprocess.Popen(shlex.split(self._config.training.export_callback_script), env=custom_env) + if export and self._run.is_main_rank: # noqa + self._config.training.export.callback.run() @abc.abstractmethod def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration) -> tuple[int, int]: diff --git a/fast_llm/engine/training/wandb.py b/fast_llm/engine/training/wandb.py new file mode 100644 index 000000000..d00c0e811 --- /dev/null +++ b/fast_llm/engine/training/wandb.py @@ -0,0 +1,73 @@ +import os +import pathlib + +import yaml + +from fast_llm.config import Config +from fast_llm.engine.config_utils.run import Run +from fast_llm.engine.training.config import WandbConfig + + +class Wandb: + def __init__(self, config: WandbConfig, run: Run, experiment_config: Config): + self._config = config + self._is_setup = True + self._run = run + if self._config.entity_name is not None and self._run.is_main_rank: + import wandb.sdk.lib.runid + + # Wandb login from file + api_key_path = os.environ.get("WANDB_API_KEY_PATH") + if api_key_path: + os.environ["WANDB_API_KEY"] = pathlib.Path(api_key_path).open("r").read().strip() + wandb_path = ( + None + if self._run.experiment_directory is None + else self._run.experiment_directory / "wandb_config.yaml" + ) + if wandb_path is not None and wandb_path.is_file(): + wandb_config = yaml.safe_load(wandb_path.open("r")) + else: + wandb_config = { + "id": wandb.sdk.lib.runid.generate_id(16), + "project": self._config.project_name, + "name": self._run.experiment_name, + "entity": self._config.entity_name, + "group": self._config.group_name, + "save_code": False, + "resume": "allow", + } + if wandb_path is not None: + yaml.safe_dump(wandb_config, wandb_path.open("w")) + # TODO: Does wandb work with nested configs? + self._wandb = wandb.init(config=experiment_config.to_serialized(), **wandb_config) + else: + self._wandb = None + + def log_metrics(self, completed_steps: int, metrics: dict[str, dict[str, float | int]]): + # Note: metrics modified in-place + if self._wandb is not None: + import wandb + + wandb.log(metrics, step=completed_steps) # noqa + + def alert(self, title, text, level="INFO", wait=0.001): + if self._wandb is not None and self._config.alert.post_alerts: + pass + + self._wandb.alert( # noqa + title=title() if callable(title) else title, + text=f"[{self._config.project_name}/{self._run.experiment_name}, run {self._run.index}]" + f" {text() if callable(text) else text}", + level=level, + wait_duration=wait, + ) + + def __enter__(self): + self.alert(f"Run started!", "", "ERROR") + + def __exit__(self, exc_type, exc_val: OSError, exc_tb): + if exc_val: + self.alert(f"Run crashed!", (lambda: ", ".join(exc_val.args)), "ERROR") + else: + self.alert(f"Run ended!", "", "INFO") diff --git a/fast_llm/functional/triton/adam.py b/fast_llm/functional/triton/adam.py index 560a19f7d..ab91d45b8 100644 --- a/fast_llm/functional/triton/adam.py +++ b/fast_llm/functional/triton/adam.py @@ -5,11 +5,11 @@ """ import torch -import triton from torch.optim.adamw import adamw # noqa -from triton import language as tl +import triton from fast_llm.functional.config import TritonConfig +from triton import language as tl @triton.jit diff --git a/fast_llm/functional/triton/cross_entropy.py b/fast_llm/functional/triton/cross_entropy.py index 7867bf122..f84bdb7ac 100644 --- a/fast_llm/functional/triton/cross_entropy.py +++ b/fast_llm/functional/triton/cross_entropy.py @@ -1,7 +1,7 @@ import torch + import triton import triton.language as tl - from fast_llm.functional.config import TritonConfig diff --git a/fast_llm/functional/triton/mlp.py b/fast_llm/functional/triton/mlp.py index 91098ded1..db8188d77 100644 --- a/fast_llm/functional/triton/mlp.py +++ b/fast_llm/functional/triton/mlp.py @@ -1,9 +1,8 @@ import math import torch -import triton -from triton import language as tl +import triton from fast_llm.core.distributed import ProcessGroup from fast_llm.core.ops import gather_op from fast_llm.functional.autograd import wrap_forward_backward @@ -24,6 +23,7 @@ ) from fast_llm.functional.triton.sparse_linear import output_sparse_matmul from fast_llm.tensor import param_get_and_unset_is_zero +from triton import language as tl @triton.jit diff --git a/fast_llm/functional/triton/normalization.py b/fast_llm/functional/triton/normalization.py index 28421a759..431d7d437 100644 --- a/fast_llm/functional/triton/normalization.py +++ b/fast_llm/functional/triton/normalization.py @@ -1,7 +1,7 @@ import torch + import triton import triton.language as tl - from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.functional.config import TritonConfig from fast_llm.tensor import param_get_and_unset_is_zero diff --git a/fast_llm/functional/triton/pointwise.py b/fast_llm/functional/triton/pointwise.py index e72d496a2..bd88326bc 100644 --- a/fast_llm/functional/triton/pointwise.py +++ b/fast_llm/functional/triton/pointwise.py @@ -4,11 +4,11 @@ """ import torch -import triton -from triton import language as tl +import triton from fast_llm.engine.config_utils.data_type import DataType from fast_llm.functional.config import TritonConfig +from triton import language as tl @triton.jit diff --git a/fast_llm/functional/triton/rotary.py b/fast_llm/functional/triton/rotary.py index 616ae9142..a25f29bb2 100644 --- a/fast_llm/functional/triton/rotary.py +++ b/fast_llm/functional/triton/rotary.py @@ -1,10 +1,10 @@ import torch -import triton -from triton import language as tl +import triton from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.functional.config import TritonConfig from fast_llm.utils import div +from triton import language as tl @triton.jit diff --git a/fast_llm/functional/triton/sparse_copy.py b/fast_llm/functional/triton/sparse_copy.py index 04df2f3eb..2583ca413 100644 --- a/fast_llm/functional/triton/sparse_copy.py +++ b/fast_llm/functional/triton/sparse_copy.py @@ -1,9 +1,9 @@ import dataclasses import torch + import triton import triton.language as tl - from fast_llm.engine.config_utils.data_type import DataType from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.functional.config import MAX_DROPLESS_BLOCK_SIZE_ROW, TritonConfig diff --git a/fast_llm/functional/triton/sparse_linear.py b/fast_llm/functional/triton/sparse_linear.py index 7b724bcf5..b148193a6 100644 --- a/fast_llm/functional/triton/sparse_linear.py +++ b/fast_llm/functional/triton/sparse_linear.py @@ -1,7 +1,7 @@ import torch + import triton import triton.language as tl - from fast_llm.functional.triton.sparse_copy import SparseMap from fast_llm.utils import Assert, div diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 8fbdf78da..74684dfe8 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -36,31 +36,43 @@ class NormalizationArchitectureConfig(BaseModelArchitectureConfig): _abstract = False # TODO: Remove "normalization" from names once we have fully nested configs? # Normalization type - normalization_type: NormalizationType = Field( + type: NormalizationType = Field( default=NormalizationType.layer_norm, desc="The type of normalization to use, for example Layer Norm or RMS Norm.", hint=FieldHint.core, ) # TODO: Rename to normalization_epsilon - layer_norm_eps: float = Field( + epsilon: float = Field( default=1e-5, desc="Regularizer for the division.", hint=FieldHint.stability, valid=check_field(Assert.gt, 0) ) - zero_centered_normalization: bool = Field( + zero_centered: bool = Field( default=False, desc="Write the normalization weight as `w = 1 + w'`, to improve numerical accuracy when close to one.", hint=FieldHint.stability, ) + @classmethod + def _from_dict( + cls, + default: dict[str], + strict: bool = True, + flat: bool = False, + ): + cls._handle_renamed_field(default, "normalization_type", "type") + cls._handle_renamed_field(default, "layer_norm_eps", "epsilon") + cls._handle_renamed_field(default, "zero_centered_normalization", "zero_centered") + return super()._from_dict(default, strict, flat) + @config_class() class NormalizationConfig(NormalizationArchitectureConfig, BaseModelConfig): - normalization_implementation: NormalizationImplementation = Field( + implementation: NormalizationImplementation = Field( default=NormalizationImplementation.auto, desc="The implementation to use for the normalization layer.", hint=FieldHint.performance, ) # TODO: Rename to normalization_init_range - layer_norm_init_range: float = Field( + initialization_range: float = Field( default=0.0, desc="Randomize the initialization with a uniform noise. Used to test for issues that may not be visible with the default initialization.", hint=FieldHint.testing, @@ -73,20 +85,31 @@ def get_layer(self, hidden_dim: "TensorDim"): kwargs = { "hidden_dim": hidden_dim, - "eps": self.layer_norm_eps, - "implementation": self.normalization_implementation, - "zero_centered": self.zero_centered_normalization, + "eps": self.epsilon, + "implementation": self.implementation, + "zero_centered": self.zero_centered, } - if self.layer_norm_init_range: - mean = 0 if self.zero_centered_normalization else 1 + if self.initialization_range: + mean = 0 if self.zero_centered else 1 kwargs["weight_init_method"] = init_uniform_( - mean - self.layer_norm_init_range, mean + self.layer_norm_init_range + mean - self.initialization_range, mean + self.initialization_range ) - if self.normalization_type == NormalizationType.layer_norm: - if self.layer_norm_init_range: - kwargs["bias_init_method"] = init_uniform_(-self.layer_norm_init_range, self.layer_norm_init_range) + if self.type == NormalizationType.layer_norm: + if self.initialization_range: + kwargs["bias_init_method"] = init_uniform_(-self.initialization_range, self.initialization_range) return LayerNorm(**kwargs) - elif self.normalization_type == NormalizationType.rms_norm: + elif self.type == NormalizationType.rms_norm: return RMSNorm(**kwargs) else: - raise ValueError(self.normalization_type) + raise ValueError(self.type) + + @classmethod + def _from_dict( + cls, + default: dict[str], + strict: bool = True, + flat: bool = False, + ): + cls._handle_renamed_field(default, "normalization_implementation", "implementation") + cls._handle_renamed_field(default, "layer_norm_init_range", "initialization_range") + return super()._from_dict(default, strict, flat) diff --git a/fast_llm/logging.py b/fast_llm/logging.py index 50be8dae9..329f4ada8 100644 --- a/fast_llm/logging.py +++ b/fast_llm/logging.py @@ -129,7 +129,7 @@ def log_tensor( ): if level < 1: return - save_stats = TensorLogs.enabled() + save_stats = TensorLogs.config.save shape = tuple(tensor.shape) _, dtype = str(tensor.dtype).split("torch.") txt = [ @@ -189,7 +189,7 @@ def log_tensor( samples = tensor.flatten()[: target_samples * step : step].cpu() stats.update(samples=samples, step=step) # Crop the list in the logs. The full tensor is still in stats. - samples = [format_number(x) for x in samples.tolist()[: TensorLogs.max_logged_elements]] + samples = [format_number(x) for x in samples.tolist()[: TensorLogs.config.max_elements]] num_logged_elements = len(samples) samples = ",".join(f"{sample:10s}" for sample in samples) txt.append((f"{f'samples (step={step})':21s}", f" ({samples})", num_logged_elements * 11 + 3)) @@ -204,7 +204,7 @@ def log_tensor( prefix = "" if prefix is None else f" {prefix}=" len_ += col_len + len(prefix) + 1 out = f"{f'{out}{prefix}{str(val)}':{len_}s}" - if TensorLogs.verbose: + if TensorLogs.config.show: return log_fn(out) diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index d362a9bbf..d33d0daae 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -126,7 +126,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: def _create_weight_converters(self) -> list[WeightConverter]: converters = [] num_layers = self.config.transformer.num_layers - norm_bias: bool = self.config.transformer.normalization.normalization_type == NormalizationType.layer_norm + norm_bias: bool = self.config.transformer.normalization.type == NormalizationType.layer_norm linear_bias: bool = self.config.transformer.add_linear_biases # Embedding and output @@ -209,10 +209,8 @@ class Starcoder2HuggingfaceConverter(CommonHuggingfaceConverter): def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ ConstantExportParamConverter(None, "architectures", ["Starcoder2ForCausalLM"]), - ConstantImportParamConverter( - ("transformer", "normalization", "normalization_type"), None, NormalizationType.layer_norm - ), - ParamConverter(("transformer", "normalization", "layer_norm_eps"), "norm_epsilon"), + ConstantImportParamConverter(("transformer", "normalization", "type"), None, NormalizationType.layer_norm), + ParamConverter(("transformer", "normalization", "epsilon"), "norm_epsilon"), ConstantImportParamConverter(("transformer", "gated"), None, False), ConstantImportParamConverter(("transformer", "add_linear_biases"), None, True), ] @@ -233,10 +231,8 @@ class CommonLlamaHuggingfaceConverter(CommonHuggingfaceConverter, abc.ABC): @classmethod def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ - ConstantImportParamConverter( - ("transformer", "normalization", "normalization_type"), None, NormalizationType.rms_norm - ), - ParamConverter(("transformer", "normalization", "layer_norm_eps"), "rms_norm_eps"), + ConstantImportParamConverter(("transformer", "normalization", "type"), None, NormalizationType.rms_norm), + ParamConverter(("transformer", "normalization", "epsilon"), "rms_norm_eps"), ConstantImportParamConverter(("transformer", "gated"), None, True), ConstantImportParamConverter(("transformer", "add_linear_biases"), None, False), ] diff --git a/fast_llm/profile.py b/fast_llm/profile.py index ad0f3b482..88e1f85c8 100644 --- a/fast_llm/profile.py +++ b/fast_llm/profile.py @@ -27,63 +27,59 @@ def step(self): @config_class() class ProfilingConfig(Config): - profile_cpu: bool = Field( - default=False, desc="Profile the CUDA operations on the CPU side.", hint=FieldHint.feature - ) - profile_cuda: bool = Field(default=False, desc="Profile the CUDA operations on the CPU side.", hint=FieldHint.core) - profile_skip: int = Field( + cpu: bool = Field(default=False, desc="Profile the CUDA operations on the CPU side.", hint=FieldHint.feature) + cuda: bool = Field(default=False, desc="Profile the CUDA operations on the CPU side.", hint=FieldHint.core) + skip: int = Field( default=1, desc="Skip this many iterations before starting the profiler for the first time.", hint=FieldHint.optional, valid=check_field(Assert.geq, 0), ) # Skip on every cycle (profiler disabled) - profile_wait: int = Field( + wait: int = Field( default=0, desc="Wait this many iterations before each profiling cycle.", hint=FieldHint.optional, valid=check_field(Assert.geq, 0), ) # Warmup on every cycle (profiler enabled, results ignored) - profile_warmup: int = Field( + warmup: int = Field( default=1, desc="Warmup the profiler for this many iterations before each profiling cycle, i.e., enable the profiler but discard the results.", hint=FieldHint.optional, valid=check_field(Assert.geq, 0), ) # Profile on every cycle (profiler enabled, results kept) - profile_cycles: int = Field( + cycles: int = Field( default=1, desc="Profile this many iterations in each profiling cycle.", hint=FieldHint.optional, valid=check_field(Assert.gt, 0), ) - profile_averages: bool = Field( + averages: bool = Field( default=False, desc="Log a table of average and total properties for each CUDA operation.", hint=FieldHint.logging, ) - profile_trace: bool = Field( + trace: bool = Field( default=False, desc="Log a table of every CUDA operation in chronological order.", hint=FieldHint.logging ) - profile_column_width: int = Field( + table_width: int = Field( default=80, desc="Target width for logged tables, in characters.", hint=FieldHint.logging, valid=check_field(Assert.geq, 40), ) # The ranks to profile (all by default) - profile_ranks: set[int] = Field( - default_factory=set, desc="Profile only on the specified ranks.", hint=FieldHint.feature - ) + ranks: set[int] = Field(default_factory=set, desc="Profile only on the specified ranks.", hint=FieldHint.feature) # Print the profile table(s), otherwise save to file. - profile_log: bool = Field( + log: bool = Field( default=False, desc="Log the profile tables to stdout, otherwise save them as artifacts.", hint=FieldHint.logging, ) # Export for chrome/tensorboard - profile_export: bool = Field( + export: bool = Field( default=False, desc="Export the raw profile as an artifact in chrome trace format.", doc="The profile is saved to profile_chrome_step_{step}. " @@ -92,35 +88,33 @@ class ProfilingConfig(Config): ) def _validate(self): - if isinstance(self.profile_ranks, str): + if isinstance(self.ranks, str): # This happens with yaml serialization - Assert.eq(self.profile_ranks, "set()") + Assert.eq(self.ranks, "set()") self.global_attention_layers = set() - profile_ranks = set(self.profile_ranks or []) - Assert.eq(len(profile_ranks), len(self.profile_ranks or [])) - self.profile_ranks = profile_ranks # noqa + profile_ranks = set(self.ranks or []) + Assert.eq(len(profile_ranks), len(self.ranks or [])) + self.ranks = profile_ranks # noqa def get_profiler( self, *, distributed_config: DistributedConfig | None = None, start_step: int = 0 ) -> typing.Union["torch.profiler.profile", NoProfiler]: import torch - activities = ([torch.profiler.ProfilerActivity.CPU] if self.profile_cpu else []) + ( - [torch.profiler.ProfilerActivity.CUDA] if self.profile_cuda else [] + activities = ([torch.profiler.ProfilerActivity.CPU] if self.cpu else []) + ( + [torch.profiler.ProfilerActivity.CUDA] if self.cuda else [] ) if ( not activities - or not (self.profile_averages or self.profile_trace or self.profile_export) - or not ( - distributed_config is None or not self.profile_ranks or distributed_config.rank in self.profile_ranks - ) + or not (self.averages or self.trace or self.export) + or not (distributed_config is None or not self.ranks or distributed_config.rank in self.ranks) ): return NoProfiler() schedule = torch.profiler.schedule( - skip_first=self.profile_skip, - warmup=self.profile_warmup, - wait=self.profile_wait, - active=self.profile_cycles, + skip_first=self.skip, + warmup=self.warmup, + wait=self.wait, + active=self.cycles, ) return torch.profiler.profile( schedule=schedule, @@ -140,34 +134,34 @@ def trace_fn( try: step = start_step + profiler.step_num - f"self_{'cuda' if config.profile_cuda else 'cpu'}_time_total" - if config.profile_trace: + f"self_{'cuda' if config.cuda else 'cpu'}_time_total" + if config.trace: table = build_trace_table( profiler, - cuda=config.profile_cuda, - cpu=config.profile_cpu, - column_width=config.profile_column_width, + cuda=config.cuda, + cpu=config.cpu, + column_width=config.table_width, header=f"Trace for step {step}", ) - if config.profile_log: + if config.log: logger.info(table) else: run.open_artifact(f"profile_trace_step_{step}").write(table) - if config.profile_averages: + if config.averages: table = build_average_table( profiler, - cuda=config.profile_cuda, - cpu=config.profile_cpu, - column_width=config.profile_column_width, + cuda=config.cuda, + cpu=config.cpu, + column_width=config.table_width, header=f"Averages for step {step}", ) - if config.profile_log: + if config.log: logger.info(table) else: run.open_artifact(f"profile_averages_step_{step}").write(table) - if config.profile_export: + if config.export: profiler.export_chrome_trace(str(run.open_artifact(f"profile_chrome_step_{step}", mode=None))) # Store results for future use. diff --git a/fast_llm/tools/convert.py b/fast_llm/tools/convert.py index 17f75142d..d2305ade7 100644 --- a/fast_llm/tools/convert.py +++ b/fast_llm/tools/convert.py @@ -63,9 +63,9 @@ def _convert_model_partial( logger.info(f"Loading {self.input_type} checkpoint from {self.input_path}...") model = model_class.from_pretrained( PretrainedCheckpointConfig( - pretrained_checkpoint_path=self.input_path, - pretrained_checkpoint_type=self.input_type, - imported_model_type=self.model_type, + path=self.input_path, + format=self.input_type, + imported_type=self.model_type, ), mode=StageMode.weights, use_cpu=self.use_cpu, @@ -111,9 +111,9 @@ def run(self, model_config_class: type["FastLLMModelConfig"] | str): # Create a dummy version to determine the stage split. model = model_class.from_pretrained( PretrainedCheckpointConfig( - pretrained_checkpoint_path=self.input_path, - pretrained_checkpoint_type=self.input_type, - imported_model_type=self.model_type, + path=self.input_path, + format=self.input_type, + imported_type=self.model_type, load_pretrained_weights=False, ), mode=StageMode.off_device, diff --git a/fast_llm/utils.py b/fast_llm/utils.py index 8b72864c0..4d61dffdc 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -200,3 +200,11 @@ def __contains__(self, item): class LazyRegistry(Registry): def __getitem__(self, key): return super().__getitem__(key)() + + +def log(*message, log_fn: typing.Union[BaseException, typing.Callable] = logger.info, join: str = ", "): + message = join.join([str(m() if callable(m) else m) for m in message]) + if isinstance(log_fn, BaseException): + raise log_fn(message) + else: + return log_fn(message) diff --git a/tests/common.py b/tests/common.py index dee9d1143..127dfb731 100644 --- a/tests/common.py +++ b/tests/common.py @@ -33,9 +33,9 @@ CONFIG_BASE_FAST_LLM = [ - "run.log_interval=1", - "run.save_tensor_logs=True", - "run.show_tensor_logs=False", + "training.logs.interval=1", + "run.tensor_logs.save=True", + "run.tensor_logs.show=False", "model.base_model.transformer.num_layers=2", "model.base_model.transformer.hidden_size=1024", "model.base_model.transformer.num_attention_heads=8", @@ -51,8 +51,8 @@ "training.num_workers=4", "batch.batch_size=8", "batch.sequence_length=2048", - f"data.data_path={DATASET_PREFIX}", - "optimizer.lr_schedule.lr=0.0001", + f"data.path={DATASET_PREFIX}", + "optimizer.learning_rate.base=0.0001", ] CONFIG_BASE_MEGATRON = [ "--num-layers=2", @@ -116,7 +116,7 @@ "model.base_model.transformer.gated=True", "model.base_model.transformer.activation_type=silu", "model.base_model.transformer.add_linear_biases=False", - "model.base_model.transformer.normalization.normalization_type=rms_norm", + "model.base_model.transformer.normalization.type=rms_norm", "model.base_model.transformer.ffn_hidden_size=4096", "model.base_model.tie_word_embeddings=False", ] diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 098698331..50ba105fa 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -34,7 +34,8 @@ def test_checkpoint_and_eval(): # A baseline config (single-gpu, bf16, flash-attn). run_test_script( f"test_{TEST_MODEL}_checkpoint_and_eval", - CONFIG_COMMON + ["run.checkpoint_interval=1", "training.validation_interval=2", "training.validation_iters=1"], + CONFIG_COMMON + + ["training.checkpoint.interval=1", "training.validation.interval=2", "training.validation.iterations=1"], ) @@ -62,7 +63,8 @@ def _compare_resume_fn(test_path: pathlib.Path, compare_path: pathlib.Path): def test_resume(): run_test_script( f"test_{TEST_MODEL}_resume", - CONFIG_COMMON + ["run.checkpoint_interval=1", "training.validation_interval=2", "training.validation_iters=1"], + CONFIG_COMMON + + ["training.checkpoint.interval=1", "training.validation.interval=2", "training.validation.iterations=1"], compare=f"test_{TEST_MODEL}_checkpoint_and_eval", prepare_fn=_prepare_resume_fn, compare_fn=_compare_resume_fn, @@ -205,9 +207,9 @@ def test_load_pretrained_distributed_checkpoint(): yaml.safe_load((_CKPT_PATH / ".." / ".." / "config.yaml").open("r")), strict=False ) pretrained_config_ref = PretrainedCheckpointConfig( - pretrained_checkpoint_path=_CKPT_PATH, - pretrained_checkpoint_type=CheckpointType.distributed, - load_pretrained_optimizer=True, + path=_CKPT_PATH, + format=CheckpointType.distributed, + load_optimizer=True, load_full_base_model_config=True, load_full_fast_llm_config=True, ) @@ -221,16 +223,14 @@ def test_load_pretrained_distributed_checkpoint(): @pytest.mark.depends(on=["test_load_pretrained_distributed_checkpoint"]) def test_load_converted_distributed_checkpoint(): - pretrained_config_ref = PretrainedCheckpointConfig( - pretrained_checkpoint_path=_CKPT_PATH, pretrained_checkpoint_type=CheckpointType.distributed - ) + pretrained_config_ref = PretrainedCheckpointConfig(path=_CKPT_PATH, format=CheckpointType.distributed) pretrained_config_0 = PretrainedCheckpointConfig( - pretrained_checkpoint_path=_CONVERT_PATH / "distributed_0", - pretrained_checkpoint_type=CheckpointType.distributed, + path=_CONVERT_PATH / "distributed_0", + format=CheckpointType.distributed, ) pretrained_config_1 = PretrainedCheckpointConfig( - pretrained_checkpoint_path=_CONVERT_PATH / "distributed_1", - pretrained_checkpoint_type=CheckpointType.distributed, + path=_CONVERT_PATH / "distributed_1", + format=CheckpointType.distributed, ) config = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_ref) model = TEST_MODEL_CLS.from_pretrained(pretrained_config_0) @@ -245,14 +245,12 @@ def test_load_converted_distributed_checkpoint(): @pytest.mark.depends(on=["test_converted_state_dict", "test_load_pretrained_distributed_checkpoint"]) def test_load_converted_state_dict_checkpoint(): - pretrained_config_ref = PretrainedCheckpointConfig( - pretrained_checkpoint_path=_CKPT_PATH, pretrained_checkpoint_type=CheckpointType.distributed - ) + pretrained_config_ref = PretrainedCheckpointConfig(path=_CKPT_PATH, format=CheckpointType.distributed) pretrained_config_0 = PretrainedCheckpointConfig( - pretrained_checkpoint_path=_CONVERT_PATH / "state_dict_0", pretrained_checkpoint_type=CheckpointType.state_dict + path=_CONVERT_PATH / "state_dict_0", format=CheckpointType.state_dict ) pretrained_config_1 = PretrainedCheckpointConfig( - pretrained_checkpoint_path=_CONVERT_PATH / "state_dict_1", pretrained_checkpoint_type=CheckpointType.state_dict + path=_CONVERT_PATH / "state_dict_1", format=CheckpointType.state_dict ) config = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_ref) model = TEST_MODEL_CLS.from_pretrained(pretrained_config_0) @@ -268,16 +266,16 @@ def test_load_converted_state_dict_checkpoint(): @pytest.mark.depends(on=["test_converted_state_dict", "test_load_pretrained_distributed_checkpoint"]) def test_load_converted_huggingface_checkpoint(): pretrained_config_ref = PretrainedCheckpointConfig( - pretrained_checkpoint_path=_CKPT_PATH, - pretrained_checkpoint_type=CheckpointType.distributed, + path=_CKPT_PATH, + format=CheckpointType.distributed, ) pretrained_config_0 = PretrainedCheckpointConfig( - pretrained_checkpoint_path=_CONVERT_PATH / "huggingface_0", - pretrained_checkpoint_type=CheckpointType.huggingface, + path=_CONVERT_PATH / "huggingface_0", + format=CheckpointType.huggingface, ) pretrained_config_1 = PretrainedCheckpointConfig( - pretrained_checkpoint_path=_CONVERT_PATH / "huggingface_1", - pretrained_checkpoint_type=CheckpointType.huggingface, + path=_CONVERT_PATH / "huggingface_1", + format=CheckpointType.huggingface, ) config = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_ref) model = TEST_MODEL_CLS.from_pretrained(pretrained_config_0, mode=StageMode.weights) @@ -294,8 +292,8 @@ def test_load_converted_huggingface_checkpoint(): def test_run_converted_model(): model_ref = TEST_MODEL_HF_CLS.from_pretrained( PretrainedCheckpointConfig( - pretrained_checkpoint_path=_CKPT_PATH, - pretrained_checkpoint_type=CheckpointType.distributed, + path=_CKPT_PATH, + format=CheckpointType.distributed, ) ) test_input = torch.randint( @@ -305,8 +303,8 @@ def test_run_converted_model(): model_from_state_dict = TEST_MODEL_HF_CLS.from_pretrained(_CONVERT_PATH / "state_dict_0") model_from_hf = TEST_MODEL_HF_CLS.from_pretrained( PretrainedCheckpointConfig( - pretrained_checkpoint_path=_CONVERT_PATH / "huggingface_0", - pretrained_checkpoint_type=CheckpointType.huggingface, + path=_CONVERT_PATH / "huggingface_0", + format=CheckpointType.huggingface, ) ) errors = [] @@ -340,9 +338,9 @@ def test_load_pretrained_distributed_in_dp2(): f"test_{TEST_MODEL}_load_pretrained_distributed_in_dp2", CONFIG_COMMON + [ - "run.checkpoint_interval=1", + "training.checkpoint.interval=1", "training.train_iters=1", - f"pretrained.pretrained_checkpoint_path={_CONVERT_PATH / 'distributed_0'}", + f"pretrained.path={_CONVERT_PATH / 'distributed_0'}", "schedule.skip_step=True", ], num_gpus=2, @@ -355,9 +353,9 @@ def test_load_pretrained_distributed_with_config(): f"test_{TEST_MODEL}_load_pretrained_distributed_with_config", CONFIG_COMMON + [ - "run.checkpoint_interval=1", + "training.checkpoint.interval=1", "training.train_iters=1", - f"pretrained.pretrained_checkpoint_path={_CONVERT_PATH / 'distributed_0'}", + f"pretrained.path={_CONVERT_PATH / 'distributed_0'}", "schedule.skip_step=True", ], ) @@ -367,13 +365,13 @@ def test_load_pretrained_distributed_with_config(): def test_load_pretrained_in_dp2_match_checkpoint(): test_ckpt_path = TEST_RESULTS_PATH / f"test_{TEST_MODEL}_load_pretrained_distributed_in_dp2" / "checkpoints" / "1" pretrained_config_ref = PretrainedCheckpointConfig( - pretrained_checkpoint_path=_CKPT_PATH, - pretrained_checkpoint_type=CheckpointType.distributed, + path=_CKPT_PATH, + format=CheckpointType.distributed, load_full_fast_llm_config=True, ) pretrained_config_test = PretrainedCheckpointConfig( - pretrained_checkpoint_path=test_ckpt_path, - pretrained_checkpoint_type=CheckpointType.distributed, + path=test_ckpt_path, + format=CheckpointType.distributed, load_full_fast_llm_config=True, ) config_ref = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_ref) @@ -409,17 +407,14 @@ def test_load_pretrained_in_dp2_match_checkpoint(): def test_load_distributed_checkpoint_dp2(): # This also tests conversion which uses `FastLLMModel.from_checkpoint` pretrained_config_ref = PretrainedCheckpointConfig( - pretrained_checkpoint_path=_CKPT_PATH, - pretrained_checkpoint_type=CheckpointType.distributed, + path=_CKPT_PATH, + format=CheckpointType.distributed, load_full_base_model_config=True, load_full_fast_llm_config=True, ) pretrained_config_test = PretrainedCheckpointConfig( - pretrained_checkpoint_path=TEST_RESULTS_PATH - / f"test_{TEST_MODEL}_load_pretrained_distributed_in_dp2" - / "checkpoints" - / "1", - pretrained_checkpoint_type=CheckpointType.distributed, + path=TEST_RESULTS_PATH / f"test_{TEST_MODEL}_load_pretrained_distributed_in_dp2" / "checkpoints" / "1", + format=CheckpointType.distributed, ) config = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_ref) model = TEST_MODEL_CLS.from_pretrained(pretrained_config_test, mode=StageMode.weights) @@ -436,10 +431,10 @@ def test_load_pretrained_state_dict_in_dp2(): f"test_{TEST_MODEL}_load_pretrained_state_dict_in_dp2", CONFIG_COMMON + [ - "run.checkpoint_interval=1", + "training.checkpoint.interval=1", "training.train_iters=1", - f"pretrained.pretrained_checkpoint_path={_CONVERT_PATH / 'state_dict_0'}", - f"pretrained.pretrained_checkpoint_type=state_dict", + f"pretrained.path={_CONVERT_PATH / 'state_dict_0'}", + f"pretrained.format=state_dict", "schedule.skip_step=True", ], num_gpus=2, @@ -468,10 +463,10 @@ def test_load_pretrained_huggingface_in_dp2(): f"test_{TEST_MODEL}_load_pretrained_huggingface_in_dp2", CONFIG_COMMON + [ - "run.checkpoint_interval=1", + "training.checkpoint.interval=1", "training.train_iters=1", - f"pretrained.pretrained_checkpoint_path={_CONVERT_PATH / 'huggingface_0'}", - f"pretrained.pretrained_checkpoint_type=huggingface", + f"pretrained.path={_CONVERT_PATH / 'huggingface_0'}", + f"pretrained.format=huggingface", "schedule.skip_step=True", ], num_gpus=2,