Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 35 additions & 14 deletions test/spmd/test_xla_distributed_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import test_xla_sharding_base

import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dist_cp
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
Expand Down Expand Up @@ -51,20 +52,22 @@ def _same_shard_data(self, shards, others) -> bool:
return True


class ReshardingTest(DistributedCheckpointTestBase):
class EndToEndCheckpointTest(DistributedCheckpointTestBase):

def _save_and_restore(self,
model_in,
model_out,
save_planner=None,
load_planner=None,
is_sharded_cpu_state_dict=False):
is_sharded_cpu_state_dict=False,
no_dist=True,
chkpt_path=None):
"""
Checkpoint model_in using the provided save_planner and load into model_out
using the provided load_planner, and assert model_out equals model_in after
the load. If either planner is not specified, the DefaultPlanner is used.
"""
tmpdir = tempfile.mkdtemp()
chkpt_path = chkpt_path or tempfile.mkdtemp()

# Save an unsharded model using the provided save planner
model_in_state_dict = model_in.state_dict()
Expand All @@ -73,33 +76,31 @@ def _save_and_restore(self,
model_out_state_dict = model_out.state_dict()
dist_cp.save_state_dict(
state_dict=model_in_state_dict,
storage_writer=dist_cp.FileSystemWriter(tmpdir),
storage_writer=dist_cp.FileSystemWriter(chkpt_path),
planner=save_planner,
no_dist=True, # Single-host checkpoint doesn't require a process group
no_dist=no_dist,
)
# Load the checkpoint using the provided load planner
for p1, p2 in zip(model_in.parameters(), model_out.parameters()):
self.assertFalse(torch.allclose(p1, p2))

dist_cp.load_state_dict(
state_dict=model_out_state_dict,
storage_reader=dist_cp.FileSystemReader(tmpdir),
storage_reader=dist_cp.FileSystemReader(chkpt_path),
planner=load_planner,
no_dist=True, # Single-host checkpoint doesn't require a process group
no_dist=no_dist,
)
for p1, p2 in zip(model_in.parameters(), model_out.parameters()):
self.assertTrue(torch.allclose(p1, p2))

def test_unsharded_to_sharded(self):
def test_resharding_unsharded_to_sharded(self):
# Save an unsharded model using the DefaultSavePlanner and load into a
# sharded model using the SPMDLoadPlanner
model = self.SimpleLinear().to(xm.xla_device())
sharded_model = self._get_sharded_model()
self._save_and_restore(model, sharded_model, load_planner=SPMDLoadPlanner())

# TODO(jonbolin): Enable tests for resharding into coarser meshes
@unittest.skip("View assignment with virtual device is not yet supported")
def test_sharded_to_unsharded(self):
def test_resharding_sharded_to_unsharded(self):
for chkpt_on_cpu in [True, False]:
with self.subTest(chkpt_on_cpu):
model = self.SimpleLinear().to(xm.xla_device())
Expand All @@ -110,11 +111,9 @@ def test_sharded_to_unsharded(self):
save_planner=SPMDSavePlanner(),
is_sharded_cpu_state_dict=chkpt_on_cpu)

# TODO(jonbolin): Enable tests for resharding into coarser meshes
@unittest.skip("View assignment with virtual device is not yet supported")
@unittest.skipIf(xr.global_runtime_device_count() == 1,
"Multiple devices needed to change mesh")
def test_different_device_mesh(self):
def test_resharding_different_device_mesh(self):
dim = self.n_devices // 2
model1 = self._get_sharded_model(mesh_shape=(dim, self.n_devices // dim))
model2 = self._get_sharded_model(mesh_shape=(self.n_devices, 1))
Expand All @@ -124,6 +123,28 @@ def test_different_device_mesh(self):
save_planner=SPMDSavePlanner(),
load_planner=SPMDLoadPlanner())

@unittest.skipUnless(
{'CHKPT_PATH', 'MASTER_ADDR', 'MASTER_PORT', 'RANK', 'WORLD_SIZE'
} <= os.environ.keys(),
'CHKPT_PATH and distributed config must be set for multihost checkpoint')
def test_multihost_checkpoint(self):
# Initialize the default CPU process group from the environment.
dist.init_process_group()

model1 = self._get_sharded_model(mesh_shape=(1, self.n_devices))
model2 = self._get_sharded_model(mesh_shape=(self.n_devices, 1))
# Take the checkpoint, writing to the path configured in the environment.
self._save_and_restore(
model1,
model2,
save_planner=SPMDSavePlanner(),
load_planner=SPMDLoadPlanner(),
no_dist=False,
chkpt_path=os.environ['CHKPT_PATH'])

# Destroy the CPU process group after the test
dist.destroy_process_group()


class SPMDLoadPlannerTest(DistributedCheckpointTestBase):

Expand Down
1 change: 1 addition & 0 deletions test/tpu/xla_test_job.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ spec:
python3 /src/pytorch/xla/test/pjrt/test_runtime_tpu.py
python3 /src/pytorch/xla/test/spmd/test_xla_sharding.py
python3 /src/pytorch/xla/test/spmd/test_xla_virtual_device.py
python3 /src/pytorch/xla/test/spmd/test_xla_distributed_checkpoint.py
python3 /src/pytorch/xla/test/spmd/test_train_spmd_linear_model.py
python3 /src/pytorch/xla/test/spmd/test_spmd_xla_model_api.py
XLA_EXPERIMENTAL=nonzero:masked_select python3 /src/pytorch/xla/test/ds/test_dynamic_shape_models.py -v
Expand Down
22 changes: 13 additions & 9 deletions torch_xla/experimental/distributed_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ def set_up_planner(self, state_dict: STATE_DICT_TYPE,
for k, v in state_dict.items()
if _is_sharded_tensor(v) or isinstance(v, _CpuShards)
}
unsharded = dict(state_dict.items() - self.sharded_state_dict.items())
unsharded = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious on why you made this change?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to support _CpuShard. The old approach tries to hash the values of the dict as well, since .items() is (key, value) pairs. When we're checkpointing on CPU, the _CpuShards contains List[XLAShard] which isn't hashable.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, but unsharded is still a KV pair?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is with state_dict.items() - self.sharded_state_dict.items() - the - operator between two dict_items is a set difference, which will hash the entire (k, v) tuple. Using dict comprehension, only k is hashed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's interesting. Thanks for the python lecture!

k: v for k, v in state_dict.items() if k not in self.sharded_state_dict
}
self.unsharded_state_dict = tree_map(_unwrap_xla_sharded_tensor, unsharded)

def create_local_plan(self) -> SavePlan:
Expand All @@ -112,7 +114,8 @@ def create_global_plan(
# Deduplicate write items across plans
all_plans = dedup_tensors(all_plans)

global_plan, metadata = create_default_global_save_plan(all_plans)
global_plan, metadata = create_default_global_save_plan(
all_plans, rewrite_index_hints=False)

# Combine mappings from all plans
planner_data_dict = [p.planner_data for p in global_plan]
Expand Down Expand Up @@ -220,7 +223,9 @@ def set_up_planner(
self.sharded_state_dict = {
k: v for k, v in state_dict.items() if _is_sharded_tensor(v)
}
unsharded = dict(state_dict.items() - self.sharded_state_dict.items())
unsharded = {
k: v for k, v in state_dict.items() if k not in self.sharded_state_dict
}
self.unsharded_state_dict = tree_map(_unwrap_xla_sharded_tensor, unsharded)

def create_local_plan(self) -> LoadPlan:
Expand Down Expand Up @@ -340,13 +345,12 @@ def _create_write_items_for_xla_sharded_tensor(
def _create_write_items_for_cpu_shards(
fqn: str, cpu_shards: _CpuShards) -> List[WriteItem]:
items = []
for xla_shard in cpu_shards.shards:
for shard_ind, xla_shard in enumerate(cpu_shards.shards):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the bug you are fixing?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, what we had originally was incorrect. It looks like the CPU tests didn't actually hit this codepath, and we weren't running on TPU CI.

prop = TensorProperties.create_from_tensor(xla_shard.data)
for shard_ind, indices in enumerate(xla_shard.indices):
write_item = _create_write_item_from_indices(fqn, shard_ind, indices,
cpu_shards.global_shape,
prop)
items.append(write_item)
write_item = _create_write_item_from_indices(fqn, shard_ind,
xla_shard.indices,
cpu_shards.global_shape, prop)
items.append(write_item)
return items


Expand Down