Skip to content

Commit 67ab975

Browse files
shahyash10will-cromarqihqicowanmegmateuszlewko
committed
Add _sharded_cpu_state_dict for distributed checkpointing (#5288)
* initiak commit * Add test workflow for `xrt` branch (#5241) * Add test workflow for `xrt` branch * Only run for PRs targeting XRT branch * Add function to generate stablehlo based callable from pytorch model (#5216) * Add function to generate stablehlo based callable from pytorch model Added function `torch_xla.experimental.stablehlo_saved_model.export_pytorch_model`. This function will take a pytorch Module and convert it into stablehlo bytecode. * Only run the main CI workflow on PRs targeting master and release branches (#5244) * Only run main CI for master and release branches. * Disabling XRT tests on main CI * AMP for TPUs v3 (#5161) * remove duplicate autocast_test (#5246) * Remove `test_experimental_pjrt_tpu.py` from TPU CI (#5247) * Install `expecttest` in xla_test_job.yaml (#5252) * Add IAM roles for cloudbuild_editors (#5251) * [Functionalization] Remove view in view_symint (#5231) * [Functionalization] Remove view in view_symint Summary: This pull request removes views in tensor_method::view_symint. Test Plan: XLA_DISABLE_FUNCTIONALIZATION=1 PJRT_DEVICE=TPU python ../test/test_view_ops.py -v -k TestViewOpsXLA.test_view_view PJRT_DEVICE=TPU python ../test/test_view_ops.py -v -k TestViewOpsXLA.test_view_view * Fix linters * fixed the test * ran the linter --------- Co-authored-by: Xiongfei Wei <[email protected]> * Delete XRT from the main branch (#5240) * Delete XRT from the main branch * Remove dead import * formatting * Remove disable_xrt build option * Fix runtime init * Revert "Remove disable_xrt build option" This reverts commit ba312e7. * Add disable XRT option back * formatting * Prune mesh service * Remove obsolete test * Remove other run server script * Remove XRT config * Update PJRT default device test * Add a file I forgot to save * if using_pjrt -> @requires_pjrt * Remove irrelevant test case * Remove XRT env vars * fix md link * formatting * Remove extra `requires_pjrt` * merge conflicts * Add other autocast back * Add nightly build for cuda 12 (#5253) * Fix the linter command in the CI (#5254) * fix linter command * ran linter * Jack cao g/fix spmd buff is null (#5256) * Fix that non-tensor scalar can't be handled by virtual device * add test * comment * Skip calling as_strided in empty_strided_symint if the input has dynamic dimensions. (#5239) * Skip calling as_strided in empty_strided_symint. * only return empty_symint conditionally. * add a comment * Add XRT nightly builds (#5261) * Add XRT nightly builds * remove space * [OpenXLA] Migrate to pull XLA from OpenXLA (#5202) PyTorch/XLA migrate to pull XLA from OpenXLA by replacing TensorFlow with OpenXLA after deprecating XRT usage, and replace TensorFlow-pin with OpenXLA-pin to May09 * Add ToString method for both PjrtData and PjrtShardedData (#5265) * Add ToString method for both PjrtData and PjrtShardedData * on cpu same config will become replicated, dont't check actual op sharding type * Update Sharded graph HLO dumping (#5266) * Enable PjRt Client Compilation with StableHLO (#5233) * Enable xla PjRt client compilation with StableHLO * add XLA_STABLEHLO_COMPILE to configuration.yaml * fix merge conflict * dummy commit to trigger ci * Revert "dummy commit to trigger ci" This reverts commit f7aec23. * Disable Bazel remote cache for forked PR (#5259) * disable bazel remote cache if gcloud key is empty * remove remote cache from setup.py * experiment with debug msg * fix flag * add more logs * skip remote chache if credential file is empty * add comment * add logs * add check in test and coverage script * fix condition in coverage test * advance branch pr * allow remote cache if gloud file isn't specified explicitly * remove dummy comment * Suppress debug symbols in OpenXLA code (#5269) * [SPMD] Sharding n-d tensor on (n+1)-d Mesh (#5268) * Make TPU detection more robust (#5271) * Clean bazel stuff on distutils clean. (#5274) * Clean bazel stuff on distutils clean * Fix python formatting * Delete unused .so file, and .lds files (#5275) * [OpenXLA] Delete unused .so file and .lds files * Fix the error when export_torch_model is given a non-tensor (#5277) However the generated StableHLO graph still hardcodes the non-tensor value. this is not correct, will fix later. * Dsiable test_simple_model_with_different_input_shape since it is curretnly broken by pytorch (#5282) * Always do build_ext in python setup.py develop (#5273) Bazel should figure out that _XLAC.so is current or not, and trigger rebuild if any cpp files changed. * Remove or improve several hardcoded TPU test conditions (#5272) * Remove or improve several hardcoded TPU test conditions * Fix test condition * Add `runtime.host_index` (#5283) * Make it an error if calling sizes() on a dynamic tensor. (#4998) * Err if calling sizes() on dynamic tensor * try to set has_symbolic_sizes_strides_ * resolve merge conflict * enable CONTINUE_ON_ERROR * fixed the python test test_SizeEq_should_not_compile_for_identical_symints * fix test_index_types * set CONTINUE_ON_ERROR to true * remove some unwanted code. * add a print * directly set has_symbolic_sizes_strides_ = true * make some fixes. * fix empty_strided_symint * ran linter * change error type in the test. * fix comments * ran linter * Fix the error where mark_step does not materalize tensors on SPMD:0 (#5281) * Fix the error where mark_step does not materalize tensors on SPMD:0 * typo * fix test_non_tensor_scalar * Disable torch._dynamo.config.automatic_dynamic_shapes (#5285) * Set torch._dynamo.config.automatic_dynamic_shapes to False * Enable DynamoInferenceBasicTest.test_simple_model_with_different_input_shape * run linter * wrap only if sharding type is non-replicated * Handle non-tensors * run linter * Call wrap_if_sharded first * Add exception in test for unsharded tensor * fix test * Use torch.Tensor instead of torch.tensor * use .cpu() only for tensors --------- Co-authored-by: Will Cromar <[email protected]> Co-authored-by: qihqi <[email protected]> Co-authored-by: Meghan Cowan <[email protected]> Co-authored-by: Mateusz Lewko <[email protected]> Co-authored-by: Jiewen Tan <[email protected]> Co-authored-by: Xiongfei Wei <[email protected]> Co-authored-by: Wonjoo Lee <[email protected]> Co-authored-by: JackCaoG <[email protected]> Co-authored-by: Manfei <[email protected]> Co-authored-by: Siyuan Liu <[email protected]> Co-authored-by: stgpetrovic <[email protected]> Co-authored-by: Mohit Khatwani <[email protected]>
1 parent 21784ce commit 67ab975

File tree

3 files changed

+66
-18
lines changed

3 files changed

+66
-18
lines changed

test/spmd/test_xla_distributed_checkpoint.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import sys
23
import tempfile
34
import unittest
45
import test_xla_sharding_base
@@ -14,6 +15,8 @@
1415
create_default_global_save_plan,
1516
)
1617
from torch_xla.experimental.distributed_checkpoint import SPMDLoadPlanner, SPMDSavePlanner
18+
from torch_xla.experimental._distributed_checkpoint_helpers import (
19+
_sharded_cpu_state_dict, _CpuShards, _is_sharded_tensor)
1720

1821

1922
class DistributedCheckpointTestBase(test_xla_sharding_base.XlaShardingTest):
@@ -244,6 +247,24 @@ def test_resolve_shard_data(self):
244247
self.assertTrue(torch.allclose(shard.data, resolved_data))
245248

246249

250+
class DistributedCheckpointHelpersTest(DistributedCheckpointTestBase):
251+
252+
def test_sharded_cpu_state_dict(self):
253+
model = self.SimpleLinear().to(xm.xla_device())
254+
state_dict = model.state_dict()
255+
sharded_cpu_state_dict = _sharded_cpu_state_dict(state_dict)
256+
self.assertCountEqual(sharded_cpu_state_dict,
257+
['fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias'])
258+
for name, param in sharded_cpu_state_dict.items():
259+
if name == 'fc1.weight':
260+
# _sharded_cpu_state_dict returns _CpuShards only for sharded tensors
261+
if _is_sharded_tensor(param):
262+
self.assertTrue(isinstance(param, _CpuShards))
263+
else:
264+
self.assertTrue(isinstance(param, torch.Tensor))
265+
self.assertTrue(param.device == torch.device("cpu"))
266+
267+
247268
if __name__ == '__main__':
248269
test = unittest.main()
249270
sys.exit(0 if test.result.wasSuccessful() else 1)

torch_xla/experimental/_distributed_checkpoint_helpers.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,14 @@
22
# stable. Once the upstream makes these stable, we should take a dependency on
33
# their APIs.
44

5+
import dataclasses
6+
57
import torch
8+
import torch_xla.experimental.xla_sharding as xs
69

710
from torch.distributed.checkpoint.planner import SavePlan
811
from typing import (
12+
Any,
913
Callable,
1014
Collection,
1115
Dict,
@@ -14,12 +18,13 @@
1418
MutableMapping,
1519
Sequence,
1620
Tuple,
17-
TypeVar,
1821
Union,
1922
cast,
2023
)
21-
from torch.distributed.checkpoint.metadata import (
22-
STATE_DICT_TYPE,)
24+
from torch.distributed.checkpoint.metadata import (MetadataIndex,
25+
STATE_DICT_TYPE)
26+
from torch_xla.experimental.xla_sharding import XLAShardedTensor, ShardingType
27+
from torch.utils._pytree import tree_map
2328

2429
PATH_ITEM = Union[str, int]
2530
OBJ_PATH = Tuple[PATH_ITEM, ...]
@@ -186,4 +191,37 @@ def narrow_tensor_by_index(tensor: torch.Tensor, offsets: Sequence[int],
186191
# recording here for the narrow op and 'local_shard' should be a
187192
# leaf variable in the autograd graph.
188193
narrowed_tensor = narrowed_tensor.narrow(idx, offset, size)
189-
return narrowed_tensor
194+
return narrowed_tensor
195+
196+
197+
def _is_sharded_tensor(x: Any) -> bool:
198+
"""Return true if the tensor's data is sharded across multiple devices"""
199+
return isinstance(
200+
x, XLAShardedTensor) and x.sharding_type != ShardingType.REPLICATED
201+
202+
203+
def _unwrap_xla_sharded_tensor(x: Any) -> Any:
204+
if isinstance(x, XLAShardedTensor):
205+
return x.global_tensor
206+
return x
207+
208+
209+
@dataclasses.dataclass
210+
class _CpuShards:
211+
shards: List[xs.XLAShard]
212+
global_shape: torch.Size
213+
214+
215+
def _sharded_cpu_state_dict(state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE:
216+
"""
217+
Converts a state_dict on XLA device to a sharded state_dict on CPU.
218+
"""
219+
220+
def move_state_dict_to_cpu(v):
221+
v = xs.wrap_if_sharded(v)
222+
if not _is_sharded_tensor(v):
223+
v = _unwrap_xla_sharded_tensor(v)
224+
return v.cpu() if isinstance(v, torch.Tensor) else v
225+
return _CpuShards(shards=v.local_shards, global_shape=v.global_tensor.shape)
226+
227+
return tree_map(move_state_dict_to_cpu, state_dict)

torch_xla/experimental/distributed_checkpoint.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,15 @@
3333
)
3434
from torch.distributed.checkpoint.utils import find_state_dict_object
3535
from torch.utils._pytree import tree_map
36-
from torch_xla.experimental.xla_sharding import (XLAShardedTensor, XLAShard,
37-
ShardingType)
36+
from torch_xla.experimental.xla_sharding import XLAShardedTensor, XLAShard
3837
from torch_xla.experimental._distributed_checkpoint_helpers import (
3938
FLATTEN_MAPPING,
4039
flatten_state_dict,
4140
dedup_tensors,
41+
_is_sharded_tensor,
4242
set_element,
4343
narrow_tensor_by_index,
44+
_unwrap_xla_sharded_tensor,
4445
)
4546
from typing import Any, Dict, List, Tuple, Union
4647

@@ -373,15 +374,3 @@ def _create_xla_read_items(sharded_state_dict: STATE_DICT_TYPE,
373374
chunks = [_create_chunk_from_shard_index(index) for index in shard_indices]
374375
items.extend(create_read_items_for_chunk_list(fqn, md, chunks))
375376
return items
376-
377-
378-
def _is_sharded_tensor(x: Any) -> bool:
379-
"""Return true if the tensor's data is sharded across multiple devices"""
380-
return isinstance(
381-
x, XLAShardedTensor) and x.sharding_type != ShardingType.REPLICATED
382-
383-
384-
def _unwrap_xla_sharded_tensor(x: Any) -> Any:
385-
if isinstance(x, XLAShardedTensor):
386-
return x.global_tensor
387-
return x

0 commit comments

Comments
 (0)