Skip to content

Commit

Permalink
updated modeling_utils from main
Browse files Browse the repository at this point in the history
  • Loading branch information
FrancescoSaverioZuppichini committed Apr 7, 2022
1 parent 9e244da commit 123096e
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions src/transformers/modeling_utils.py
Expand Up @@ -1487,12 +1487,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
Please refer to the mirror site for more information.
_fast_init(`bool`, *optional*, defaults to `True`):
Whether or not to disable fast initialization.
low_cpu_mem_usage(`bool``, *optional*, defaults to `False`):
Tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
This is an experimental feature and a subject to change at any moment.
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
will be automatically derived from the model's weights.
<Tip warning={true}>
Expand All @@ -1502,6 +1496,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
</Tip>
low_cpu_mem_usage(`bool`, *optional*, defaults to `False`):
Tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
This is an experimental feature and a subject to change at any moment.
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
will be automatically derived from the model's weights.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or
Expand Down Expand Up @@ -1823,7 +1823,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
if is_sharded:
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
else:
state_dict = load_state_dict(resolved_archive_file)
loaded_state_dict_keys = [k for k in state_dict.keys()]
del state_dict # free CPU memory - will reload again later

Expand Down Expand Up @@ -2162,14 +2161,15 @@ def find_submodule_and_param_name(model, long_key):
state_dict = torch.load(archive_file, map_location="cpu")

# materialize state_dict entries one by one on CPU
for k in state_dict.keys():
submodule, param_name = find_submodule_and_param_name(model, k)
if submodule is not None:
param_dtype = getattr(submodule, param_name).dtype
new_val = state_dict[k].to(param_dtype)
if isinstance(getattr(submodule, param_name), torch.nn.Parameter):
new_val = torch.nn.Parameter(new_val)
setattr(submodule, param_name, new_val)
for k in loaded_state_dict_keys:
if k in state_dict:
submodule, param_name = find_submodule_and_param_name(model, k)
if submodule is not None:
param_dtype = getattr(submodule, param_name).dtype
new_val = state_dict[k].to(param_dtype)
if isinstance(getattr(submodule, param_name), torch.nn.Parameter):
new_val = torch.nn.Parameter(new_val)
setattr(submodule, param_name, new_val)

del state_dict

Expand Down

0 comments on commit 123096e

Please sign in to comment.