Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix loading from pretrained for sharded model with `torch_dtype="auto" #18061

Merged
merged 1 commit into from Jul 27, 2022

Conversation

NouamaneTazi
Copy link
Member

Fixes the following script which failed because resolved_archive_file is a list for sharded models and load_state_dictexpects a path to a single file

model = AutoModelForCausalLM.from_pretrained("bigscience/bloom", torch_dtype="auto")

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jul 7, 2022

The documentation is not available anymore as the PR was closed or merged.

@LysandreJik
Copy link
Member

LysandreJik commented Jul 11, 2022

Hey @NouamaneTazi, do you have a code example that failed before and that doesn't fail anymore with your PR?

@NouamaneTazi
Copy link
Member Author

NouamaneTazi commented Jul 11, 2022

Yes @LysandreJik, the script I provided did fail for me when I tried it:

model = AutoModelForCausalLM.from_pretrained("bigscience/bloom", torch_dtype="auto") # this should fail for any sharded models

The issue was that load_state_dict expects a str or a Pathlike while resolved_archive_file is a list for sharded models.

@LysandreJik
Copy link
Member

LysandreJik commented Jul 12, 2022

Understood! It's a bit hard to play with such a large model, so I'm reproducing with lysandre/test-bert-sharded. However, it seems that it doesn't entirely fix the issue:

>>> from transformers import AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained("lysandre/test-bert-sharded", torch_dtype="auto")

File ~/Workspaces/Python/transformers/src/transformers/models/auto/auto_factory.py:446, in _BaseAutoModelClass.from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
    444 elif type(config) in cls._model_mapping.keys():
    445     model_class = _get_model_class(config, cls._model_mapping)
--> 446     return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
    447 raise ValueError(
    448     f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
    449     f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
    450 )

File ~/Workspaces/Python/transformers/src/transformers/modeling_utils.py:2040, in PreTrainedModel.from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
   2038     torch_dtype = get_state_dict_dtype(state_dict)
   2039 else:
-> 2040     one_state_dict = load_state_dict(resolved_archive_file)
   2041     torch_dtype = get_state_dict_dtype(one_state_dict)
   2042     del one_state_dict  # free CPU memory

File ~/Workspaces/Python/transformers/src/transformers/modeling_utils.py:359, in load_state_dict(checkpoint_file)
    357 except Exception as e:
    358     try:
--> 359         with open(checkpoint_file) as f:
    360             if f.read().startswith("version"):
    361                 raise OSError(
    362                     "You seem to have cloned a repository without having git-lfs installed. Please install "
    363                     "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
    364                     "you cloned."
    365                 )

TypeError: expected str, bytes or os.PathLike object, not list

@NouamaneTazi
Copy link
Member Author

This is exactly the error I got before the fix. And from your traceback it seems that the patch wasn't applied
You have

-> 2040     one_state_dict = load_state_dict(resolved_archive_file)

When it should be

-> 2040     one_state_dict = load_state_dict(resolved_archive_file[0])

From my side, testing with the patch did succeed in loading your model

@LysandreJik
Copy link
Member

Ah, great catch; I had too many patch-1 branches locally. Your patch seems to work, pinging @sgugger for additional verification.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this works. Note that this is not recommended in terms of speed as you load a shard and then discard it immediately, so it's more efficient to just the torch dtype to the value you want.

@NouamaneTazi
Copy link
Member Author

Should I raise a warning when this method is used @sgugger?

@sgugger
Copy link
Collaborator

sgugger commented Jul 12, 2022

I don't think Stas will like the extra warning, so I'd say no ;-)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants