Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions fast_llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
import traceback
import types
import typing
import warnings

import yaml

from fast_llm.engine.config_utils.logging import log
from fast_llm.utils import Assert, Tag, get_type_name, header
from fast_llm.utils import Assert, Tag, get_type_name, header, log

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -270,7 +270,7 @@ class Config:
__class_validated__: typing.ClassVar[bool] = True
_abstract: typing.ClassVar[bool] = False
_validated: bool = Field(init=False, repr=False)
_unknown_fields: tuple = Field(init=False, repr=False)
_unknown_fields: dict[str] = Field(init=False, repr=False)

def __post_init__(self):
"""
Expand Down Expand Up @@ -335,7 +335,7 @@ def _validate(self):
value = getattr(self, name)
new_value = self._validate_nested(value, field.type, field.name, field.valid, errors, False)
setattr(self, name, new_value)
for name in getattr(self, "_unknown_fields", ()):
for name in getattr(self, "_unknown_fields", {}):
errors.append(f"Unknown field `{name}` in class {self._get_class_name()}")
if errors:
# TODO: Option to show traceback for errors.
Expand Down Expand Up @@ -621,7 +621,7 @@ def _get_class_name(cls):
@classmethod
def from_dict(
cls,
default: typing.Union["Config", dict],
default: typing.Union["Config", dict[str]],
*updates: typing.Union["Config", dict[str | tuple[str, ...], typing.Any]],
strict: bool = True,
):
Expand All @@ -646,7 +646,7 @@ def from_dict(
@classmethod
def from_flat_dict(
cls,
default: dict,
default: dict[str],
strict: bool = True,
):
# TODO v0.2: Remove flat format
Expand All @@ -655,7 +655,7 @@ def from_flat_dict(
@classmethod
def _from_dict(
cls,
default: dict,
default: dict[str],
strict: bool = True,
flat: bool = False,
):
Expand Down Expand Up @@ -691,9 +691,9 @@ def _from_dict(
f"Invalid field type `{get_type_name(field.type)}` in class {cls._get_class_name()}: "
+ ", ".join(e.args)
)
out = cls(**out_arg_dict) # noqa
if strict and default:
out._unknown_fields = tuple(default)
out = cls(**out_arg_dict) # noqa
if strict and default:
out._unknown_fields = default.copy()
if _AUTO_VALIDATE:
out.validate()
return out
Expand Down Expand Up @@ -767,6 +767,12 @@ def _from_dict_dict(cls, value, type_, strict: bool):
# Keys can't include configs so we only recurse on values.
return {key: cls._from_dict_nested(value_, args[1], strict) for key, value_ in value.items()}

@classmethod
def _handle_renamed_field(cls, default: dict[str], old_name: str, new_name: str):
if old_name in default:
warnings.warn(f"Field `{old_name}` is deprecated in class {get_type_name(cls)}, use `{new_name}` instead.")
default[new_name] = default.pop(old_name)

def compare(self, other: "Config", log_fn: typing.Union[BaseException, typing.Callable] = ValueError):
# TODO: Check classes?
self_dict = self._to_dict(format_=_ConfigDictFormat.tuple, serializable=True)
Expand Down
37 changes: 12 additions & 25 deletions fast_llm/data/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,6 @@
from fast_llm.engine.distributed.distributed import Distributed


class DatasetType(str, enum.Enum):
"""
Placeholder for future generalization to other data types.
"""

gpt = "gpt"


class DatasetSource(str, enum.Enum):
"""
An enum for the different ways to load datasets.
Expand Down Expand Up @@ -61,49 +53,49 @@ class FimConfig(Config):
Configuration for FIM.
"""

fim_rate: float = Field(
rate: float = Field(
default=0.0,
desc="FIM rate for each sample.",
hint=FieldHint.core,
valid=check_field(Assert.in_range_incl, 0, 1),
)
fim_max_middle_len: int | None = Field(
max_middle_len: int | None = Field(
default=None,
desc="Maximum length of the middle segment in FIM.",
hint=FieldHint.feature,
valid=skip_valid_if_none(check_field(Assert.gt, 0)),
)
fim_split_sample: str | None = Field(
split_sample: str | None = Field(
default=None,
desc="Split samples on this token and permute each fragment separately.",
hint=FieldHint.feature,
)
fim_fragment_rate: float = Field(
fragment_rate: float = Field(
default=0.0,
desc="FIM rate for each fragment when using fim_split_sample.",
hint=FieldHint.feature,
valid=check_field(Assert.in_range_incl, 0, 1),
)
fim_ignore_prefix: str | None = Field(
ignore_prefix: str | None = Field(
default=None,
desc="Do not apply FIM to fragments that start with this prefix.",
hint=FieldHint.feature,
)
fim_spm_rate: float = Field(
spm_rate: float = Field(
default=0.5,
desc="TODO.",
hint=FieldHint.feature,
valid=check_field(Assert.in_range_incl, 0, 1),
)
fim_truncate_or_pad: bool = Field(
truncate_or_pad: bool = Field(
default=False,
desc="TODO.",
hint=FieldHint.feature,
)

def _validate(self):
super()._validate()
Assert.in_range_incl(self.fim_rate, 0, 1)
Assert.in_range_incl(self.rate, 0, 1)


EOD = "<|endoftext|>"
Expand All @@ -117,13 +109,13 @@ class TokenizerConfig(Config):
Currently, the tokenizer is only needed for FIM.
"""

tokenizer_type: str = Field(
format: str = Field(
default="TokenizerFromFile",
desc="Unused.",
hint=FieldHint.deprecated,
valid=check_field(Assert.eq, TokenizerFromFile),
)
tokenizer_file: str | None = Field(
path: str | None = Field(
default=None,
desc="Path to the tokenizer file.",
hint=FieldHint.core,
Expand Down Expand Up @@ -181,17 +173,12 @@ class DataConfig(AbstractDataConfig):
hint=FieldHint.core,
valid=_validate_split,
)
dataset_type: DatasetType = Field(
default=DatasetType.gpt,
desc="Unused.",
hint=FieldHint.wip,
)
dataset_source: DatasetSource = Field(
format: DatasetSource = Field(
default=DatasetSource.list,
desc="Format for the dataset definition.",
hint=FieldHint.core,
)
data_path: list[str] = Field(
path: list[str] = Field(
default_factory=list,
desc="Path or list of paths and weights.",
hint=FieldHint.core,
Expand Down
37 changes: 18 additions & 19 deletions fast_llm/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
import torch.utils.data

from fast_llm.data.config import AbstractData, DataConfig, DatasetSource, DatasetType
from fast_llm.data.config import AbstractData, DataConfig, DatasetSource
from fast_llm.data.dataset import BlendedDataset, SampledDataset, Sampler
from fast_llm.data.gpt import DummyGPTDataset, GPTDataset, GPTSampledDataset
from fast_llm.data.mmap import MMapIndexedDataset
Expand Down Expand Up @@ -67,35 +67,34 @@ def __init__(
}

data_base_path = None
if self._config.dataset_source == DatasetSource.file:
Assert.eq(len(self._config.data_path), 1)
data_path = pathlib.Path(self._config.data_path[0])
if self._config.format == DatasetSource.file:
Assert.eq(len(self._config.path), 1)
data_path = pathlib.Path(self._config.path[0])
dataset_defs = json.load(data_path.open("r"))
data_base_path = data_path.parent
dataset_prefixes = [dataset_def["prefix"] for dataset_def in dataset_defs["datasets"]]
dataset_weights = normalize_probs([dataset_def["weight"] for dataset_def in dataset_defs["datasets"]])
self._build_and_sample_dataset = self._build_and_sample_gpt_dataset
elif self._config.dataset_source == DatasetSource.list:
Assert.geq(len(self._config.data_path), 1)
if len(self._config.data_path) == 1:
dataset_prefixes, dataset_weights = [self._config.data_path[0].strip()], [1.0]
elif self._config.format == DatasetSource.list:
Assert.geq(len(self._config.path), 1)
if len(self._config.path) == 1:
dataset_prefixes, dataset_weights = [self._config.path[0].strip()], [1.0]
else:
Assert.eq(self._config.dataset_type, DatasetType.gpt)
Assert.custom(lambda x: x % 2 == 0, len(self._config.data_path))
dataset_prefixes = [x.strip() for x in self._config.data_path[1::2]]
Assert.custom(lambda x: x % 2 == 0, len(self._config.path))
dataset_prefixes = [x.strip() for x in self._config.path[1::2]]
assert len(dataset_prefixes) == len(set(dataset_prefixes))
dataset_weights = normalize_probs([float(x) for x in self._config.data_path[::2]])
dataset_weights = normalize_probs([float(x) for x in self._config.path[::2]])
self._build_and_sample_dataset = self._build_and_sample_gpt_dataset
elif self._config.dataset_source == DatasetSource.sample:
Assert.eq(len(self._config.data_path), 1)
dataset_prefixes, dataset_weights = [self._config.data_path[0].strip()], [1.0]
elif self._config.format == DatasetSource.sample:
Assert.eq(len(self._config.path), 1)
dataset_prefixes, dataset_weights = [self._config.path[0].strip()], [1.0]
self._build_and_sample_dataset = self._build_and_sample_dummy_dataset
elif self._config.dataset_source == DatasetSource.random:
Assert.eq(len(self._config.data_path), 0)
elif self._config.format == DatasetSource.random:
Assert.eq(len(self._config.path), 0)
dataset_prefixes, dataset_weights = [None], [1.0]
self._build_and_sample_dataset = self._build_and_sample_dummy_dataset
else:
raise NotImplementedError(self._config.dataset_source)
raise NotImplementedError(self._config.format)

dataset_names = [
f"dataset_{i}_{'dummy' if prefix is None else prefix.replace('/','__')}"
Expand Down Expand Up @@ -124,7 +123,7 @@ def setup(self, distributed: Distributed, samples_per_phase: dict[PhaseType, int
run = get_run()
Assert.leq(set(samples_per_phase), set(self._phase_split))
log_main_rank(f"Preparing {self._num_datasets} datasets. This may take several minutes.")
self._tokenizer = Tokenizer(self._config.tokenizer) if self._config.fim.fim_rate > 0 else None
self._tokenizer = Tokenizer(self._config.tokenizer) if self._config.fim.rate > 0 else None
self._distributed = distributed
self._cache_dir = run.dataset_cache_dir
self._samples_per_phase = samples_per_phase
Expand Down
22 changes: 11 additions & 11 deletions fast_llm/data/fim.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(self, config: FimConfig, tokenizer: Tokenizer):
self._tokenizer.vocab[tok] for tok in [FIM_SUFFIX, FIM_PREFIX, FIM_MIDDLE, FIM_PAD]
)
self.fim_split_sample = (
self._tokenizer.vocab[self._config.fim_split_sample] if self._config.fim_split_sample is not None else None
self._tokenizer.vocab[self._config.split_sample] if self._config.split_sample is not None else None
)

def __call__(self, sample, np_rng):
Expand Down Expand Up @@ -61,15 +61,15 @@ def _fim_split_and_permute_sequence(self, sequence, np_rng):
fragment_fim_rate: if set, apply fim with this rate to each fragment.
"""
if self.fim_split_sample is None:
return self._fim_permute_sequence(sequence, np_rng, self._config.fim_rate)
return self._fim_permute_sequence(sequence, np_rng, self._config.rate)
# fim_split_sample is set: split the sample on this token and permute each fragment separately.
# Typically, if each sample is a repository, then we split again on the file level.
# Each fragment is a file, and we permute the files.
fragment_breaks = np.argwhere(sequence == self.fim_split_sample)
if fragment_breaks.shape == (0, 1):
# no split token in this sample
return self._fim_permute_sequence(sequence, np_rng, self._config.fim_rate)
if not np_rng.binomial(1, self._config.fim_rate):
return self._fim_permute_sequence(sequence, np_rng, self._config.rate)
if not np_rng.binomial(1, self._config.rate):
# don't do FIM preproc
return sequence
# Do FIM on each fragment
Expand All @@ -79,12 +79,12 @@ def _fim_split_and_permute_sequence(self, sequence, np_rng):
for loc in np.nditer(fragment_breaks):
if loc - curr_start_position > 0:
permuted = self._fim_permute_sequence(
sequence[curr_start_position:loc], np_rng, self._config.fim_fragment_rate
sequence[curr_start_position:loc], np_rng, self._config.fragment_rate
)
new_samples += [permuted, [self.fim_split_sample]]
curr_start_position = loc + 1 # Jump over the split token
# Permute the segment after the last split token
permuted = self._fim_permute_sequence(sequence[curr_start_position:], np_rng, self._config.fim_fragment_rate)
permuted = self._fim_permute_sequence(sequence[curr_start_position:], np_rng, self._config.fragment_rate)
new_samples.append(permuted)
return np.concatenate(new_samples)

Expand All @@ -106,10 +106,10 @@ def _fim_permute_sequence(
contents = self._tokenizer.detokenize(sequence)

# Do not apply FIM if the sample starts with no_fim_prefix
if self._config.fim_ignore_prefix is not None and contents.startswith(self._config.fim_ignore_prefix):
if self._config.ignore_prefix is not None and contents.startswith(self._config.ignore_prefix):
return sequence

if self._config.fim_max_middle_len is None:
if self._config.max_middle_len is None:
# Sample the two boundaries uniformly at random
# A boundary can be =0 (prefix will be empty)
# a boundary can be =len(contents) (suffix will be empty)
Expand All @@ -118,7 +118,7 @@ def _fim_permute_sequence(
boundaries.sort()
else:
# Sample a window-length
middle_length = np_rng.randint(low=0, high=min(self._config.fim_max_middle_len, len(contents)) + 1)
middle_length = np_rng.randint(low=0, high=min(self._config.max_middle_len, len(contents)) + 1)
first_boundary = np_rng.randint(low=0, high=len(contents) - middle_length + 1)
# middle_length <= Second-boundary <= len(contents)
boundaries = [first_boundary, first_boundary + middle_length]
Expand All @@ -134,7 +134,7 @@ def _fim_permute_sequence(
# here we truncate each given segment to fit the same length as it was before
# A consequence is that we never reach the end of a file?
# we should rather truncate at the context-level
if self._config.fim_truncate_or_pad:
if self._config.truncate_or_pad:
# need to make same length as the input. Take the 3 sentinel tokens into account
new_length = suffix.shape[0] + prefix.shape[0] + middle.shape[0] + 3
diff = new_length - sequence.shape[0]
Expand All @@ -145,7 +145,7 @@ def _fim_permute_sequence(
elif diff < 0: # too short
suffix = np.concatenate([suffix, np.full((-1 * diff), self._pad_tok_id)])

if np_rng.binomial(1, self._config.fim_spm_rate):
if np_rng.binomial(1, self._config.spm_rate):
# SPM (variant 2 from FIM paper)
new_sample = np.concatenate(
[[self._prefix_tok_id, self._suffix_tok_id], suffix, [self._middle_tok_id], prefix, middle] # noqa
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/data/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def __init__(
):
self._dataset = dataset

if config.fim.fim_rate > 0:
if config.fim.rate > 0:
assert tokenizer is not None
self._fim = Fim(config.fim, tokenizer)
else:
Expand Down
4 changes: 2 additions & 2 deletions fast_llm/data/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ class Tokenizer:
"""

def __init__(self, config: TokenizerConfig):
log_main_rank(f"> loading tokenizer from {config.tokenizer_file} ...")
log_main_rank(f"> loading tokenizer from {config.path} ...")
special_tokens = [EOD]
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=config.tokenizer_file, errors="replace", max_len=None)
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=config.path, errors="replace", max_len=None)
self.tokenizer.add_special_tokens({"additional_special_tokens": special_tokens})
self.eod_id = self.tokenizer.vocab[EOD]
# Token->id mapping for additional special-tokens
Expand Down
1 change: 1 addition & 0 deletions fast_llm/engine/config_utils/data_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
if typing.TYPE_CHECKING:
import numpy as np
import torch

from triton import language as tl


Expand Down
Loading