From 123096e646a6de05d21d8f7ddefef265b44e3cd5 Mon Sep 17 00:00:00 2001 From: fra Date: Thu, 7 Apr 2022 21:12:47 +0200 Subject: [PATCH] updated modeling_utils from main --- src/transformers/modeling_utils.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 471d785aa68856..c81ef06ebed853 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1487,12 +1487,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P Please refer to the mirror site for more information. _fast_init(`bool`, *optional*, defaults to `True`): Whether or not to disable fast initialization. - low_cpu_mem_usage(`bool``, *optional*, defaults to `False`): - Tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. - This is an experimental feature and a subject to change at any moment. - torch_dtype (`str` or `torch.dtype`, *optional*): - Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype - will be automatically derived from the model's weights. @@ -1502,6 +1496,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P + low_cpu_mem_usage(`bool`, *optional*, defaults to `False`): + Tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + This is an experimental feature and a subject to change at any moment. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype + will be automatically derived from the model's weights. kwargs (remaining dictionary of keyword arguments, *optional*): Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., `output_attentions=True`). Behaves differently depending on whether a `config` is provided or @@ -1823,7 +1823,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P if is_sharded: loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] else: - state_dict = load_state_dict(resolved_archive_file) loaded_state_dict_keys = [k for k in state_dict.keys()] del state_dict # free CPU memory - will reload again later @@ -2162,14 +2161,15 @@ def find_submodule_and_param_name(model, long_key): state_dict = torch.load(archive_file, map_location="cpu") # materialize state_dict entries one by one on CPU - for k in state_dict.keys(): - submodule, param_name = find_submodule_and_param_name(model, k) - if submodule is not None: - param_dtype = getattr(submodule, param_name).dtype - new_val = state_dict[k].to(param_dtype) - if isinstance(getattr(submodule, param_name), torch.nn.Parameter): - new_val = torch.nn.Parameter(new_val) - setattr(submodule, param_name, new_val) + for k in loaded_state_dict_keys: + if k in state_dict: + submodule, param_name = find_submodule_and_param_name(model, k) + if submodule is not None: + param_dtype = getattr(submodule, param_name).dtype + new_val = state_dict[k].to(param_dtype) + if isinstance(getattr(submodule, param_name), torch.nn.Parameter): + new_val = torch.nn.Parameter(new_val) + setattr(submodule, param_name, new_val) del state_dict