diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index 81ec1d544e..b01a874549 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -438,6 +438,11 @@ hf_eval_split: '' hf_eval_files: '' hf_access_token: '' # for Grain input pipeline (dataset_type=grain) +# Path to grain data files. Can be a single pattern or multiple patterns with weights. +# For multiple patterns, use semicolon (;) to separate and colon (:) to specify weights. +# Example: "path/to/data1.array_record*:0.3;path/to/data2.array_record*:0.7" +# Note: When using multiple files (separated by ';'), only ArrayRecord format is supported. +# For more details, see https://github.com/google/maxtext/blob/main/getting_started/Data_Input_Pipeline.md#grain-input-pipeline grain_train_files: '' grain_eval_files: '' grain_file_type: 'arrayrecord' # arrayrecord or parquet diff --git a/MaxText/input_pipeline/_grain_data_processing.py b/MaxText/input_pipeline/_grain_data_processing.py index dd055ed988..47b8e0539c 100644 --- a/MaxText/input_pipeline/_grain_data_processing.py +++ b/MaxText/input_pipeline/_grain_data_processing.py @@ -33,6 +33,13 @@ from MaxText import tokenizer +def find_data_files(data_file_pattern): + data_files = glob.glob(str(Path(data_file_pattern).expanduser().resolve())) + assert len(data_files) > 0, f"No file found with pattern {data_file_pattern}." + max_logging.log(f"Found {len(data_files)} files for train/eval with grain") + return data_files + + def get_datasets( data_file_pattern, data_file_type, @@ -44,17 +51,26 @@ def get_datasets( grain_worker_count, ): """Load dataset from array_record files for using with grain""" - data_files = glob.glob(str(Path(data_file_pattern).expanduser().resolve())) - assert len(data_files) > 0, f"No file found with pattern {data_file_pattern}." - max_logging.log(f"Found {len(data_files)} files for train/eval with grain") if data_file_type == "arrayrecord": - dataset = grain.MapDataset.source(grain.ArrayRecordDataSource(data_files)) + if ";" in data_file_pattern: + data_file_patterns, weights = zip(*[pattern.split(":") for pattern in data_file_pattern.split(";")]) + assert len(data_file_patterns) == len(weights), "Number of data file patterns and weights must match" + weights = [float(weight) for weight in weights] + weights = [round(weight / sum(weights), 4) for weight in weights] + dataset_list = [ + grain.MapDataset.source(grain.ArrayRecordDataSource(find_data_files(pattern))) for pattern in data_file_patterns + ] + dataset = grain.MapDataset.mix(dataset_list, weights) + else: + data_files = find_data_files(data_file_pattern) + dataset = grain.MapDataset.source(grain.ArrayRecordDataSource(data_files)) if shuffle: dataset = dataset.shuffle(seed=shuffle_seed) dataset = dataset.repeat(num_epoch) dataset = dataset[dataloading_host_index::dataloading_host_count] # sharding dataset = dataset.to_iter_dataset() elif data_file_type == "parquet": + data_files = find_data_files(data_file_pattern) dataset = grain.MapDataset.source(data_files) if shuffle: dataset = dataset.shuffle(seed=shuffle_seed) diff --git a/MaxText/input_pipeline/_input_pipeline_utils.py b/MaxText/input_pipeline/_input_pipeline_utils.py index d4e7ecb58c..7e08541a5c 100644 --- a/MaxText/input_pipeline/_input_pipeline_utils.py +++ b/MaxText/input_pipeline/_input_pipeline_utils.py @@ -332,7 +332,7 @@ def map(self, element): @dataclasses.dataclass class Rekey(grain.MapTransform): - """Rname keys according to a mappign dict""" + """Rename keys according to a mapping dict""" def __init__(self, mapping_dict, keep_old_keys=False): self.mapping_dict = mapping_dict diff --git a/MaxText/tests/grain_data_processing_test.py b/MaxText/tests/grain_data_processing_test.py index 262c4d8413..87f3f8088c 100644 --- a/MaxText/tests/grain_data_processing_test.py +++ b/MaxText/tests/grain_data_processing_test.py @@ -105,6 +105,42 @@ def get_first_batch(iterator): self.assertTrue((train_batch1["targets"] == train_batch2["targets"]).all()) +class GrainArrayRecordProcessingWithMultiSourceBlendingTest(GrainArrayRecordProcessingTest): + + def setUp(self): + super().setUp() + temp_dir = tempfile.gettempdir() + # We use the same dataset for testing, but you can use different datasets by changing the file patterns. + grain_train_files = [ + f"{temp_dir}/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record*:0.3", + f"{temp_dir}/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record*:0.7", + ] + grain_train_files = ";".join(grain_train_files) + self.config = pyconfig.initialize( + [sys.argv[0], os.path.join(PKG_DIR, "configs", "base.yml")], + per_device_batch_size=1, + run_name="test", + mesh_axes=["data"], + logical_axis_rules=[["batch", "data"]], + data_sharding=["data"], + base_output_directory="gs://max-experiments/", + dataset_type="grain", + grain_train_files=grain_train_files, + tokenizer_path=os.path.join(os.path.dirname(PKG_DIR), "assets", "tokenizer"), + enable_checkpointing=False, + ) + self.mesh_shape_1d = (len(jax.devices()),) + self.mesh = Mesh(mesh_utils.create_device_mesh(self.mesh_shape_1d), self.config.mesh_axes) + self.process_indices = input_pipeline_interface.get_process_loading_real_data( + self.config.data_sharding, + self.config.global_batch_size_to_load, + self.config.global_batch_size_to_train_on, + self.config.max_target_length, + self.mesh, + ) + self.train_iter = _grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices) + + class GrainParquetProcessingTest(unittest.TestCase): @classmethod diff --git a/getting_started/Data_Input_Pipeline.md b/getting_started/Data_Input_Pipeline.md index fc4c13c674..a9140c4e4d 100644 --- a/getting_started/Data_Input_Pipeline.md +++ b/getting_started/Data_Input_Pipeline.md @@ -102,7 +102,18 @@ bash setup_gcsfuse.sh DATASET_GCS_BUCKET=$BUCKET_NAME MOUNT_PATH=$MOUNT_PATH [FI ``` 3. Set `dataset_type=grain` and set `grain_train_files` to match the ArrayRecord files via a local path since the bucket has been mounted. 4. Tune `grain_worker_count` for performance. This parameter controls the number of child process used by Grain (more details in [behind_the_scene](https://github.com/google/grain/blob/main/docs/behind_the_scenes.md), [code](https://github.com/google/grain/blob/main/grain/_src/python/grain_pool.py)). If you use a large number of workers, please check your config for gcsfuse in [setup_gcsfuse.sh](https://github.com/google/maxtext/blob/main/setup_gcsfuse.sh) to avoid gcsfuse throttling. -5. Example command: + +5. For multi-source blending, you can specify multiple data sources with their respective weights using semicolon (;) as separator and colon (:) for weights. The weights will be automatically normalized to sum to 1.0. For example: +``` +# Blend two data sources with 30% from first source and 70% from second source +grain_train_files=/tmp/gcsfuse/dataset1.array_record*:0.3;/tmp/gcsfuse/dataset2.array_record*:0.7 + +# Blend three data sources with equal weights (will be normalized to 0.33 each) +grain_train_files=/tmp/gcsfuse/dataset1.array_record*:1;/tmp/gcsfuse/dataset2.array_record*:1;/tmp/gcsfuse/dataset3.array_record*:1 +``` +Note: When using multiple data sources, only ArrayRecord format is supported. + +6. Example command: ``` bash setup_gcsfuse.sh \ DATASET_GCS_BUCKET=maxtext-dataset \ @@ -114,7 +125,7 @@ grain_file_type=arrayrecord \ grain_train_files=/tmp/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record* \ grain_worker_count=2 ``` -6. Using validation set for eval +7. Using validation set for eval When setting eval_interval > 0, eval will be run with a specified eval dataset. Example config: ``` eval_interval: 10000