Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Failed to load the finetuned model with AutoModelForCausalLM.from_pretrained(name, state_dict=state_dict) #1362

Open
zhaosheng-thu opened this issue Apr 27, 2024 · 4 comments
Assignees
Labels
3rd party bug Something isn't working

Comments

@zhaosheng-thu
Copy link

I fine-tuned llama3-8b with Lora and followed the tutorial in the repository to convert the final result into model.pth. However, when I try to load the fine-tuned weights into the model using AutoModelForCausalLM.from_pretrained, I am unable to do so correctly. Below is my test:

state_dict = torch.load('out/convert/hf-llama3-instruct-esconv/model.pth')
print("state_dict: ", state_dict)
model = AutoModelForCausalLM.from_pretrained('checkpoints/meta-llama/Meta-Llama-3-8B',
                                  device_map=device_map, torch_dtype=torch.float16, 
                                  state_dict=state_dict)

print("model.weights", model.state_dict())

But I found that the state_dict of torch.load doesn't equal to the model.state_dict(), as shown following:
torch.load:
c62077774b213ae19704e33b6fb8ee1
model.state_dict()
e24af415a1cd401e1743546b0a5314b

I noticed that even though I passed the state_dict, from_pretrained still returns the weights of the model loaded by name. Did I make any mistakes in my code, and how can I solve this? Thanks!

@zhaosheng-thu
Copy link
Author

I can load the weight using the model.load_state_dict(), and then everything will go smoothly, but I really want to know why from_pretrained(state_dict=state_dict) can't work.

@rasbt
Copy link
Collaborator

rasbt commented Apr 29, 2024

Thanks for raising that. Maybe it's a HF thing. I will have to investigate.

@rasbt rasbt self-assigned this Apr 29, 2024
@carmocca carmocca added bug Something isn't working 3rd party labels Apr 29, 2024
@rasbt
Copy link
Collaborator

rasbt commented Apr 29, 2024

I could not reproduce it for another model yet when I gave it a quick try.

I am not sure if it's related because the differences are so big, but I wonder what the precision of the tensors in your current state dict are. Could you print the precision of the state dict, and could you also try to load it without torch_dtype=torch.float16?

EDIT: Nevermind, I can see that the precision is bfloat16 in your screenshot.

Screenshot 2024-04-29 at 12 18 03 PM

@rasbt
Copy link
Collaborator

rasbt commented Apr 29, 2024

I tried this also with Llama 3 and it seemed to work fine for me there as well. Here are my steps:

litgpt download --repo_id meta-llama/Meta-Llama-3-8B-Instruct --access_token ...


litgpt finetune \
    --checkpoint_dir checkpoints/meta-llama/Meta-Llama-3-8B-Instruct \
    --out_dir my_llama_model \
    --train.max_steps 1 \
    --eval.max_iter 1

litgpt convert from_litgpt \
    --checkpoint_dir my_llama_model/final \
    --output_dir out/converted_llama_model/

And then in a python session:
Screenshot 2024-04-29 at 1 17 29 PM

and

Screenshot 2024-04-29 at 1 21 32 PM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
3rd party bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants