-
Notifications
You must be signed in to change notification settings - Fork 28.7k
fix loading from pretrained for sharded model with `torch_dtype="auto" #18061
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The documentation is not available anymore as the PR was closed or merged. |
Hey @NouamaneTazi, do you have a code example that failed before and that doesn't fail anymore with your PR? |
Yes @LysandreJik, the script I provided did fail for me when I tried it: model = AutoModelForCausalLM.from_pretrained("bigscience/bloom", torch_dtype="auto") # this should fail for any sharded models The issue was that |
Understood! It's a bit hard to play with such a large model, so I'm reproducing with >>> from transformers import AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained("lysandre/test-bert-sharded", torch_dtype="auto")
File ~/Workspaces/Python/transformers/src/transformers/models/auto/auto_factory.py:446, in _BaseAutoModelClass.from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
444 elif type(config) in cls._model_mapping.keys():
445 model_class = _get_model_class(config, cls._model_mapping)
--> 446 return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
447 raise ValueError(
448 f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
449 f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
450 )
File ~/Workspaces/Python/transformers/src/transformers/modeling_utils.py:2040, in PreTrainedModel.from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
2038 torch_dtype = get_state_dict_dtype(state_dict)
2039 else:
-> 2040 one_state_dict = load_state_dict(resolved_archive_file)
2041 torch_dtype = get_state_dict_dtype(one_state_dict)
2042 del one_state_dict # free CPU memory
File ~/Workspaces/Python/transformers/src/transformers/modeling_utils.py:359, in load_state_dict(checkpoint_file)
357 except Exception as e:
358 try:
--> 359 with open(checkpoint_file) as f:
360 if f.read().startswith("version"):
361 raise OSError(
362 "You seem to have cloned a repository without having git-lfs installed. Please install "
363 "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
364 "you cloned."
365 )
TypeError: expected str, bytes or os.PathLike object, not list |
This is exactly the error I got before the fix. And from your traceback it seems that the patch wasn't applied -> 2040 one_state_dict = load_state_dict(resolved_archive_file) When it should be -> 2040 one_state_dict = load_state_dict(resolved_archive_file[0]) From my side, testing with the patch did succeed in loading your model |
Ah, great catch; I had too many |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this works. Note that this is not recommended in terms of speed as you load a shard and then discard it immediately, so it's more efficient to just the torch dtype to the value you want.
Should I raise a warning when this method is used @sgugger? |
I don't think Stas will like the extra warning, so I'd say no ;-) |
Fixes the following script which failed because
resolved_archive_file
is a list for sharded models andload_state_dict
expects a path to a single file