From 01d411036857e685d9d9bc387574d54dd7437e40 Mon Sep 17 00:00:00 2001 From: bzantium Date: Wed, 4 Jun 2025 08:30:20 +0900 Subject: [PATCH 1/7] implement multi-source blending for arrayrecord --- .../input_pipeline/_grain_data_processing.py | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/MaxText/input_pipeline/_grain_data_processing.py b/MaxText/input_pipeline/_grain_data_processing.py index dd055ed988..d5ccac93c7 100644 --- a/MaxText/input_pipeline/_grain_data_processing.py +++ b/MaxText/input_pipeline/_grain_data_processing.py @@ -33,6 +33,12 @@ from MaxText import tokenizer +def find_data_files(data_file_pattern): + data_files = glob.glob(str(Path(data_file_pattern).expanduser().resolve())) + assert len(data_files) > 0, f"No file found with pattern {data_file_pattern}." + return data_files + + def get_datasets( data_file_pattern, data_file_type, @@ -44,17 +50,23 @@ def get_datasets( grain_worker_count, ): """Load dataset from array_record files for using with grain""" - data_files = glob.glob(str(Path(data_file_pattern).expanduser().resolve())) - assert len(data_files) > 0, f"No file found with pattern {data_file_pattern}." - max_logging.log(f"Found {len(data_files)} files for train/eval with grain") if data_file_type == "arrayrecord": - dataset = grain.MapDataset.source(grain.ArrayRecordDataSource(data_files)) + if ";" in data_file_pattern: + data_file_patterns, weights = zip(*[pattern.split(":") for pattern in data_file_pattern.split(";")]) + assert len(data_file_patterns) == len(weights), "Number of data file patterns and weights must match" + weights = [round(weight / sum(weights), 4) for weight in weights] + dataset_list = [grain.MapDataset.source(grain.ArrayRecordDataSource(find_data_files(pattern))) for pattern in data_file_patterns] + dataset = grain.MapDataset.mix(dataset_list, weights) + else: + data_files = find_data_files(data_file_pattern) + dataset = grain.MapDataset.source(grain.ArrayRecordDataSource(data_files)) if shuffle: dataset = dataset.shuffle(seed=shuffle_seed) dataset = dataset.repeat(num_epoch) dataset = dataset[dataloading_host_index::dataloading_host_count] # sharding dataset = dataset.to_iter_dataset() elif data_file_type == "parquet": + data_files = find_data_files(data_file_pattern) dataset = grain.MapDataset.source(data_files) if shuffle: dataset = dataset.shuffle(seed=shuffle_seed) From 88ae1e5d8a4a800a6d0158e69f1d82c8879b6486 Mon Sep 17 00:00:00 2001 From: bzantium Date: Wed, 4 Jun 2025 08:30:31 +0900 Subject: [PATCH 2/7] fix typo --- MaxText/input_pipeline/_input_pipeline_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MaxText/input_pipeline/_input_pipeline_utils.py b/MaxText/input_pipeline/_input_pipeline_utils.py index d4e7ecb58c..7e08541a5c 100644 --- a/MaxText/input_pipeline/_input_pipeline_utils.py +++ b/MaxText/input_pipeline/_input_pipeline_utils.py @@ -332,7 +332,7 @@ def map(self, element): @dataclasses.dataclass class Rekey(grain.MapTransform): - """Rname keys according to a mappign dict""" + """Rename keys according to a mapping dict""" def __init__(self, mapping_dict, keep_old_keys=False): self.mapping_dict = mapping_dict From 26852384f4e5aaa1a03b6a3f3a34f6ac25878eb4 Mon Sep 17 00:00:00 2001 From: bzantium Date: Thu, 12 Jun 2025 08:46:44 +0900 Subject: [PATCH 3/7] add testing and documentation --- MaxText/configs/base.yml | 5 +++ MaxText/tests/grain_data_processing_test.py | 37 +++++++++++++++++++++ getting_started/Data_Input_Pipeline.md | 15 +++++++-- 3 files changed, 55 insertions(+), 2 deletions(-) diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index 81ec1d544e..b01a874549 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -438,6 +438,11 @@ hf_eval_split: '' hf_eval_files: '' hf_access_token: '' # for Grain input pipeline (dataset_type=grain) +# Path to grain data files. Can be a single pattern or multiple patterns with weights. +# For multiple patterns, use semicolon (;) to separate and colon (:) to specify weights. +# Example: "path/to/data1.array_record*:0.3;path/to/data2.array_record*:0.7" +# Note: When using multiple files (separated by ';'), only ArrayRecord format is supported. +# For more details, see https://github.com/google/maxtext/blob/main/getting_started/Data_Input_Pipeline.md#grain-input-pipeline grain_train_files: '' grain_eval_files: '' grain_file_type: 'arrayrecord' # arrayrecord or parquet diff --git a/MaxText/tests/grain_data_processing_test.py b/MaxText/tests/grain_data_processing_test.py index 262c4d8413..bbb4cb9b97 100644 --- a/MaxText/tests/grain_data_processing_test.py +++ b/MaxText/tests/grain_data_processing_test.py @@ -105,6 +105,43 @@ def get_first_batch(iterator): self.assertTrue((train_batch1["targets"] == train_batch2["targets"]).all()) +class GrainArrayRecordProcessingTestWithMultiSourceBlending(GrainArrayRecordProcessingTest): + def setUp(self): + super().setUp() + temp_dir = tempfile.gettempdir() + # We use the same dataset for testing, but you can use different datasets by changing the file patterns. + grain_train_files = [ + f"{temp_dir}/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record*:0.3", + f"{temp_dir}/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record*:0.7", + ] + grain_train_files = ";".join(grain_train_files) + self.config = pyconfig.initialize( + [sys.argv[0], os.path.join(PKG_DIR, "configs", "base.yml")], + per_device_batch_size=1, + run_name="test", + mesh_axes=["data"], + logical_axis_rules=[["batch", "data"]], + data_sharding=["data"], + base_output_directory="gs://max-experiments/", + dataset_type="grain", + grain_train_files=grain_train_files, + tokenizer_path=os.path.join(os.path.dirname(PKG_DIR), "assets", "tokenizer"), + enable_checkpointing=False, + ) + self.mesh_shape_1d = (len(jax.devices()),) + self.mesh = Mesh(mesh_utils.create_device_mesh(self.mesh_shape_1d), self.config.mesh_axes) + self.process_indices = input_pipeline_interface.get_process_loading_real_data( + self.config.data_sharding, + self.config.global_batch_size_to_load, + self.config.global_batch_size_to_train_on, + self.config.max_target_length, + self.mesh, + ) + self.train_iter = _grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices) + + + + class GrainParquetProcessingTest(unittest.TestCase): @classmethod diff --git a/getting_started/Data_Input_Pipeline.md b/getting_started/Data_Input_Pipeline.md index fc4c13c674..a9140c4e4d 100644 --- a/getting_started/Data_Input_Pipeline.md +++ b/getting_started/Data_Input_Pipeline.md @@ -102,7 +102,18 @@ bash setup_gcsfuse.sh DATASET_GCS_BUCKET=$BUCKET_NAME MOUNT_PATH=$MOUNT_PATH [FI ``` 3. Set `dataset_type=grain` and set `grain_train_files` to match the ArrayRecord files via a local path since the bucket has been mounted. 4. Tune `grain_worker_count` for performance. This parameter controls the number of child process used by Grain (more details in [behind_the_scene](https://github.com/google/grain/blob/main/docs/behind_the_scenes.md), [code](https://github.com/google/grain/blob/main/grain/_src/python/grain_pool.py)). If you use a large number of workers, please check your config for gcsfuse in [setup_gcsfuse.sh](https://github.com/google/maxtext/blob/main/setup_gcsfuse.sh) to avoid gcsfuse throttling. -5. Example command: + +5. For multi-source blending, you can specify multiple data sources with their respective weights using semicolon (;) as separator and colon (:) for weights. The weights will be automatically normalized to sum to 1.0. For example: +``` +# Blend two data sources with 30% from first source and 70% from second source +grain_train_files=/tmp/gcsfuse/dataset1.array_record*:0.3;/tmp/gcsfuse/dataset2.array_record*:0.7 + +# Blend three data sources with equal weights (will be normalized to 0.33 each) +grain_train_files=/tmp/gcsfuse/dataset1.array_record*:1;/tmp/gcsfuse/dataset2.array_record*:1;/tmp/gcsfuse/dataset3.array_record*:1 +``` +Note: When using multiple data sources, only ArrayRecord format is supported. + +6. Example command: ``` bash setup_gcsfuse.sh \ DATASET_GCS_BUCKET=maxtext-dataset \ @@ -114,7 +125,7 @@ grain_file_type=arrayrecord \ grain_train_files=/tmp/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record* \ grain_worker_count=2 ``` -6. Using validation set for eval +7. Using validation set for eval When setting eval_interval > 0, eval will be run with a specified eval dataset. Example config: ``` eval_interval: 10000 From 2b694b988130259e8354e734df4c899aba4782af Mon Sep 17 00:00:00 2001 From: bzantium Date: Fri, 13 Jun 2025 09:12:58 +0900 Subject: [PATCH 4/7] rename test name --- MaxText/tests/grain_data_processing_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MaxText/tests/grain_data_processing_test.py b/MaxText/tests/grain_data_processing_test.py index bbb4cb9b97..0be94cbe12 100644 --- a/MaxText/tests/grain_data_processing_test.py +++ b/MaxText/tests/grain_data_processing_test.py @@ -105,7 +105,7 @@ def get_first_batch(iterator): self.assertTrue((train_batch1["targets"] == train_batch2["targets"]).all()) -class GrainArrayRecordProcessingTestWithMultiSourceBlending(GrainArrayRecordProcessingTest): +class GrainArrayRecordProcessingWithMultiSourceBlendingTest(GrainArrayRecordProcessingTest): def setUp(self): super().setUp() temp_dir = tempfile.gettempdir() From 4354f38a15b2a9e4c9da34df0aa47cb50d9c2009 Mon Sep 17 00:00:00 2001 From: bzantium Date: Sat, 14 Jun 2025 10:25:54 +0900 Subject: [PATCH 5/7] make weights float --- MaxText/input_pipeline/_grain_data_processing.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/MaxText/input_pipeline/_grain_data_processing.py b/MaxText/input_pipeline/_grain_data_processing.py index d5ccac93c7..106cb71200 100644 --- a/MaxText/input_pipeline/_grain_data_processing.py +++ b/MaxText/input_pipeline/_grain_data_processing.py @@ -54,8 +54,11 @@ def get_datasets( if ";" in data_file_pattern: data_file_patterns, weights = zip(*[pattern.split(":") for pattern in data_file_pattern.split(";")]) assert len(data_file_patterns) == len(weights), "Number of data file patterns and weights must match" + weights = [float(weight) for weight in weights] weights = [round(weight / sum(weights), 4) for weight in weights] - dataset_list = [grain.MapDataset.source(grain.ArrayRecordDataSource(find_data_files(pattern))) for pattern in data_file_patterns] + dataset_list = [ + grain.MapDataset.source(grain.ArrayRecordDataSource(find_data_files(pattern))) for pattern in data_file_patterns + ] dataset = grain.MapDataset.mix(dataset_list, weights) else: data_files = find_data_files(data_file_pattern) From 2011c8f8e6ed8d5423b95a4e08eb07ea4da50d1c Mon Sep 17 00:00:00 2001 From: bzantium Date: Sat, 14 Jun 2025 10:26:50 +0900 Subject: [PATCH 6/7] apply style --- MaxText/layers/attentions.py | 16 ++++++++-------- MaxText/layers/models.py | 12 ++++-------- MaxText/pyconfig.py | 8 ++++---- MaxText/tests/grain_data_processing_test.py | 7 +++---- MaxText/tests/integration_tests/train_tests.py | 1 + MaxText/tests/train_compile_test.py | 6 +++--- 6 files changed, 23 insertions(+), 27 deletions(-) diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 721fb7d738..e487589752 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -874,13 +874,12 @@ def cudnn_jax_flash_attention( decoder_segment_ids: Array | None, model_mode: str = MODEL_MODE_TRAIN, ) -> Array: - """CUDNN Flash Attention with JAX SDPA API. - """ + """CUDNN Flash Attention with JAX SDPA API.""" # These imports are only meant to work in a GPU build. # pylint: disable=import-outside-toplevel from jax._src.cudnn.fused_attention_stablehlo import ( - dot_product_attention, - MaskType, + dot_product_attention, + MaskType, ) _, _, _, head_dim = query.shape # pylint: disable=unused-variable @@ -898,7 +897,7 @@ def cudnn_jax_flash_attention( scale=1.0, dropout_rate=self.dropout_rate, qkv_layout="BTNH", - return_residual=True + return_residual=True, ) else: return dot_product_attention( @@ -909,7 +908,7 @@ def cudnn_jax_flash_attention( scale=1.0 / math.sqrt(head_dim), dropout_rate=self.dropout_rate, qkv_layout="BTNH", - return_residual=True + return_residual=True, ) def compute_local_attention( @@ -1124,8 +1123,9 @@ def normalize_cudnn_attention(self, local_outs, local_stats): stat1 = local_stats[1].reshape((*local_stats[1].shape, 1)) global_stat = jnp.log(jnp.exp(stat0) + jnp.exp(stat1)) # # transpose stat to have shape [b, t, n, 1] for elemenwise multiplication - attn_out = local_outs[0].astype(jnp.float32) * jnp.exp(stat0 - global_stat).transpose((0, 2, 1, 3)) \ - + local_outs[1].astype(jnp.float32) * jnp.exp(stat1 - global_stat).transpose((0, 2, 1, 3)) + attn_out = local_outs[0].astype(jnp.float32) * jnp.exp(stat0 - global_stat).transpose((0, 2, 1, 3)) + local_outs[ + 1 + ].astype(jnp.float32) * jnp.exp(stat1 - global_stat).transpose((0, 2, 1, 3)) return attn_out.astype(local_stats[0].dtype) def normalize_attention(self, local_outs, local_maxes, local_sums): diff --git a/MaxText/layers/models.py b/MaxText/layers/models.py index f542c5815f..16a3677454 100644 --- a/MaxText/layers/models.py +++ b/MaxText/layers/models.py @@ -419,9 +419,10 @@ def scan_decoder_layers(self, cfg, decoder_layer, length, metdata_axis_name, mes def get_pipeline_stage_module(self, decoder_blocks): """get pipeline stage module""" + def get_layer_to_pipeline(blocks, cfg): if cfg.decoder_block == DecoderBlockType.DEEPSEEK: - return blocks[1] # return the sparse block + return blocks[1] # return the sparse block else: return blocks[0] @@ -530,14 +531,9 @@ def __call__( model_mode, ) y = self.pipeline_module( - y, - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - partition_spec=partition_spec + y, decoder_segment_ids, decoder_positions, deterministic, model_mode, partition_spec=partition_spec ) - else: # Not DeepSeek + else: # Not DeepSeek y = self.pipeline_module( y, decoder_segment_ids, decoder_positions, deterministic, model_mode, partition_spec=partition_spec ) diff --git a/MaxText/pyconfig.py b/MaxText/pyconfig.py index 1b59d3f364..45c32a3e6a 100644 --- a/MaxText/pyconfig.py +++ b/MaxText/pyconfig.py @@ -829,17 +829,17 @@ def pipeline_first_axis(raw_keys): raw_keys = pipeline_first_axis(raw_keys) num_stages = int(raw_keys["ici_pipeline_parallelism"] * raw_keys["dcn_pipeline_parallelism"]) if raw_keys["pipeline_parallel_layers"] == -1: - if raw_keys["decoder_block"]=="deepseek": + if raw_keys["decoder_block"] == "deepseek": moe_layers = raw_keys["num_decoder_layers"] - raw_keys["first_num_dense_layers"] raw_keys["pipeline_parallel_layers"] = moe_layers else: raw_keys["pipeline_parallel_layers"] = raw_keys["num_decoder_layers"] else: - if raw_keys["decoder_block"]=="deepseek": + if raw_keys["decoder_block"] == "deepseek": moe_layers = raw_keys["num_decoder_layers"] - raw_keys["first_num_dense_layers"] assert ( - raw_keys["pipeline_parallel_layers"] <= moe_layers - ), f"You can only pipeline a subset of the moe decoder layers for deepseek, but you requested to pipeline {raw_keys['pipeline_parallel_layers']} with pipeline_parallel_layers and there are only {moe_layers} decoder layers." + raw_keys["pipeline_parallel_layers"] <= moe_layers + ), f"You can only pipeline a subset of the moe decoder layers for deepseek, but you requested to pipeline {raw_keys['pipeline_parallel_layers']} with pipeline_parallel_layers and there are only {moe_layers} decoder layers." else: assert ( raw_keys["pipeline_parallel_layers"] <= raw_keys["num_decoder_layers"] diff --git a/MaxText/tests/grain_data_processing_test.py b/MaxText/tests/grain_data_processing_test.py index 0be94cbe12..87f3f8088c 100644 --- a/MaxText/tests/grain_data_processing_test.py +++ b/MaxText/tests/grain_data_processing_test.py @@ -106,13 +106,14 @@ def get_first_batch(iterator): class GrainArrayRecordProcessingWithMultiSourceBlendingTest(GrainArrayRecordProcessingTest): + def setUp(self): super().setUp() temp_dir = tempfile.gettempdir() # We use the same dataset for testing, but you can use different datasets by changing the file patterns. grain_train_files = [ - f"{temp_dir}/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record*:0.3", - f"{temp_dir}/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record*:0.7", + f"{temp_dir}/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record*:0.3", + f"{temp_dir}/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record*:0.7", ] grain_train_files = ";".join(grain_train_files) self.config = pyconfig.initialize( @@ -140,8 +141,6 @@ def setUp(self): self.train_iter = _grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices) - - class GrainParquetProcessingTest(unittest.TestCase): @classmethod diff --git a/MaxText/tests/integration_tests/train_tests.py b/MaxText/tests/integration_tests/train_tests.py index 36a5b89716..db567ad1a5 100644 --- a/MaxText/tests/integration_tests/train_tests.py +++ b/MaxText/tests/integration_tests/train_tests.py @@ -334,5 +334,6 @@ def test_gpu_cudnn_flash_jax(self): ] train_main(cudnn_flash_jax) + if __name__ == "__main__": absltest.main() diff --git a/MaxText/tests/train_compile_test.py b/MaxText/tests/train_compile_test.py index 712aeb2d77..36ea3deb2e 100644 --- a/MaxText/tests/train_compile_test.py +++ b/MaxText/tests/train_compile_test.py @@ -572,14 +572,14 @@ def test_moe_deepseek_pipeline_subset(self): "compile_topology_num_slices=8", "use_iota_embed=true", "model_name=deepseek3-671b", - "megablox=False", # dropless not yet supported (b/418313093) - "sparse_matmul=False", + "megablox=False", # dropless not yet supported (b/418313093) + "sparse_matmul=False", "capacity_factor=1", "per_device_batch_size=1", "max_target_length=2048", "pipeline_parallel_layers=56", "ici_expert_parallelism=16", - "dcn_pipeline_parallelism=8" + "dcn_pipeline_parallelism=8", ) ) From 2fa06ebaa6d0f51cf9dc18053679d42267d6a667 Mon Sep 17 00:00:00 2001 From: bzantium Date: Tue, 17 Jun 2025 09:40:58 +0900 Subject: [PATCH 7/7] add logging for files --- MaxText/input_pipeline/_grain_data_processing.py | 1 + 1 file changed, 1 insertion(+) diff --git a/MaxText/input_pipeline/_grain_data_processing.py b/MaxText/input_pipeline/_grain_data_processing.py index 106cb71200..47b8e0539c 100644 --- a/MaxText/input_pipeline/_grain_data_processing.py +++ b/MaxText/input_pipeline/_grain_data_processing.py @@ -36,6 +36,7 @@ def find_data_files(data_file_pattern): data_files = glob.glob(str(Path(data_file_pattern).expanduser().resolve())) assert len(data_files) > 0, f"No file found with pattern {data_file_pattern}." + max_logging.log(f"Found {len(data_files)} files for train/eval with grain") return data_files