Honor model dtype in load_checkpoint#920
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
muellerzr
left a comment
There was a problem hiding this comment.
Thanks, that makes sense. I don't particularly think there is a "harm" in silently pushing it out (i.e. don't advertise the bad behavior but let it still pass) in this particular case. If we do care about phasing that out perhaps leave it for a 1.0.0? (Similar to some optimizer bits we have)
|
Actually before merging, could it maybe be better to handle this in |
| break | ||
|
|
||
| if old_param is not None: | ||
| param = param.to(old_param.dtype) |
There was a problem hiding this comment.
should this not be better done in set_module_tensor_to_device ? Or maybe additionally add a torch_dtype arg to set_module_tensor_to_device that handles the param correctly if value=param is used?
| else: | ||
| for param_name, param in checkpoint.items(): | ||
| module_name = param_name | ||
| if dtype is not None and not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")): |
There was a problem hiding this comment.
This is moved to set_module_tensor_to_device.
patrickvonplaten
left a comment
There was a problem hiding this comment.
Thanks a lot for adapting!
- After #285, `load_pretrained_block()` uses `accelerate.utils.set_module_tensor_to_device()` - In accelerate>=0.16.0, it saves the tensor in the dtype previously used by the model instead of dtype of the weights (huggingface/accelerate#920) - Because of that, blocks and attention caches used float32, which caused OOMs - This PR makes `load_pretrained_block()` respect `torch_dtype` (default: `"auto"`, which means reading `torch_dtype` from `config.json`)
…op#310) - After bigscience-workshop#285, `load_pretrained_block()` uses `accelerate.utils.set_module_tensor_to_device()` - In accelerate>=0.16.0, it saves the tensor in the dtype previously used by the model instead of dtype of the weights (huggingface/accelerate#920) - Because of that, blocks and attention caches used float32, which caused OOMs - This PR makes `load_pretrained_block()` respect `torch_dtype` (default: `"auto"`, which means reading `torch_dtype` from `config.json`)
This PR fixes a standing bug where we have a different behavior than PyTorch. In torch, loading a
state_dictinside a model will never change the model's dtype:Currently in Accelerate,
load_checkpointdoes the opposite and when loading a model, it converts it to the dtype of the state dict. This PR addresses that.This PR only contains the fix for now, we have to discuss how to maybe maintain backward compatibility (even if this is a bug fix), because
diffusersmight be relying on this behavior, cc @patrickvonplaten