-
Notifications
You must be signed in to change notification settings - Fork 561
Don't rewrite index hints in global save planning #5348
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -92,7 +92,9 @@ def set_up_planner(self, state_dict: STATE_DICT_TYPE, | |
| for k, v in state_dict.items() | ||
| if _is_sharded_tensor(v) or isinstance(v, _CpuShards) | ||
| } | ||
| unsharded = dict(state_dict.items() - self.sharded_state_dict.items()) | ||
| unsharded = { | ||
| k: v for k, v in state_dict.items() if k not in self.sharded_state_dict | ||
| } | ||
| self.unsharded_state_dict = tree_map(_unwrap_xla_sharded_tensor, unsharded) | ||
|
|
||
| def create_local_plan(self) -> SavePlan: | ||
|
|
@@ -112,7 +114,8 @@ def create_global_plan( | |
| # Deduplicate write items across plans | ||
| all_plans = dedup_tensors(all_plans) | ||
|
|
||
| global_plan, metadata = create_default_global_save_plan(all_plans) | ||
| global_plan, metadata = create_default_global_save_plan( | ||
| all_plans, rewrite_index_hints=False) | ||
|
|
||
| # Combine mappings from all plans | ||
| planner_data_dict = [p.planner_data for p in global_plan] | ||
|
|
@@ -220,7 +223,9 @@ def set_up_planner( | |
| self.sharded_state_dict = { | ||
| k: v for k, v in state_dict.items() if _is_sharded_tensor(v) | ||
| } | ||
| unsharded = dict(state_dict.items() - self.sharded_state_dict.items()) | ||
| unsharded = { | ||
| k: v for k, v in state_dict.items() if k not in self.sharded_state_dict | ||
| } | ||
| self.unsharded_state_dict = tree_map(_unwrap_xla_sharded_tensor, unsharded) | ||
|
|
||
| def create_local_plan(self) -> LoadPlan: | ||
|
|
@@ -340,13 +345,12 @@ def _create_write_items_for_xla_sharded_tensor( | |
| def _create_write_items_for_cpu_shards( | ||
| fqn: str, cpu_shards: _CpuShards) -> List[WriteItem]: | ||
| items = [] | ||
| for xla_shard in cpu_shards.shards: | ||
| for shard_ind, xla_shard in enumerate(cpu_shards.shards): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this the bug you are fixing?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, what we had originally was incorrect. It looks like the CPU tests didn't actually hit this codepath, and we weren't running on TPU CI. |
||
| prop = TensorProperties.create_from_tensor(xla_shard.data) | ||
| for shard_ind, indices in enumerate(xla_shard.indices): | ||
| write_item = _create_write_item_from_indices(fqn, shard_ind, indices, | ||
| cpu_shards.global_shape, | ||
| prop) | ||
| items.append(write_item) | ||
| write_item = _create_write_item_from_indices(fqn, shard_ind, | ||
| xla_shard.indices, | ||
| cpu_shards.global_shape, prop) | ||
| items.append(write_item) | ||
| return items | ||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious on why you made this change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is to support
_CpuShard. The old approach tries to hash the values of the dict as well, since.items()is(key, value)pairs. When we're checkpointing on CPU, the_CpuShardscontainsList[XLAShard]which isn't hashable.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, but unsharded is still a KV pair?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue is with
state_dict.items() - self.sharded_state_dict.items()- the-operator between twodict_itemsis a set difference, which will hash the entire(k, v)tuple. Using dict comprehension, onlykis hashed.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's interesting. Thanks for the python lecture!