Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama-3 Inference and Uploading to Huggingface #931

Open
fabriceyhc opened this issue May 3, 2024 · 7 comments
Open

Llama-3 Inference and Uploading to Huggingface #931

fabriceyhc opened this issue May 3, 2024 · 7 comments
Assignees

Comments

@fabriceyhc
Copy link

I'm trying to fine-tune Llama-3-8B and 70B with LoRA on a custom drug detection dataset and upload it to Huggingface so that it fits nicely into an existing evaluation zero-shot evaluation pipeline. My current challenge lies in converting the different checkpoints - 8B uses FullModelMetaCheckpointer that outputs a meta_model_{i}.pt, whereas 70B uses FullModelHFCheckpointer and outputs hf_model_{idx}_{i}.pt where i is for each epoch (over 5 epochs).

Question 1: Is it safe to assume that we only need the checkpoint with the highest i index and can delete the intermediate ones? If so, it would be handy to have a config option to conserve diskspace by keeping the most recent checkpoint only.

As discussed in #832, we need to convert the 8B model in llama format to hf format. To do this, I've had to move a lot of the contents from the original folder into a the output_dir (e.g. tokenzier.model, etc) and then run a script from transformers (here).

python src/transformers/models/llama/convert_llama_weights_to_hf.py \
--input_dir <checkpoint_dir> \
--llama_version 3 \
--model_size 8B \
--output_dir <hf_staging_dir> 

This script is specifically looking for model weights from a file called consolidated.00.pth (L161), which is the original, untrained 8B Llama-3. It's not clear to me how to have it use the lora merged meta_model_5.pt. When I tried to follow the instructions from the e2e example and just upload <checkpoint_dir> directly, HF errors out saying it can't find a file with a suitable format (e.g. pytorch_model.bin, etc).

Question 2: How can we convert both the 8B and 70B versions of the lora fine-tuned LLama-3s so that they are suitable for inference via HF?

@monk1337
Copy link

monk1337 commented May 6, 2024

@fabriceyhc You have to change the file name in the script to

        loaded = [
            torch.load(os.path.join(input_base_path, f"hf_model_{i:04d}_0.pt"), map_location="cpu")
            for i in range(1, num_shards + 1)  # Start from 1 and go up to num_shards
           ]

+1 because the naming convention start with 0001 for 70B

Hi @joecummings , could you please take a look at the script that I've been working on? I made some changes to the names, but now I'm encountering new errors when I try to convert it into hf weights for 70B. However, the conversion seems to work fine for 8B. At the moment, the issue appears to be specific to torchtune. Once we fine-tune the model, it would be great if there's a command line option available to convert the pt weights to hf and upload them to Hugging Face. From there, we can proceed with other tasks as the entire ecosystem is based around Hugging Face.

https://gist.github.com/monk1337/925a5a44c431ed1f1d3927141f31b6d2

@optimass
Copy link

Once we fine-tune the model, it would be great if there's a command line option available to convert the pt weights to hf and upload them to Hugging Face. From there, we can proceed with other tasks as the entire ecosystem is based around Hugging Face.

totally agree that we need this!

@optimass
Copy link

optimass commented May 10, 2024

https://gist.github.com/monk1337/925a5a44c431ed1f1d3927141f31b6d2
I tried this w/ for LLAMA-3-8b and got the following error:

File "/home/toolkit/ui-copilot/finetuning/utils/convert_llama_weights_to_hf2.py", line 447, in main
    write_model(
  File "/home/toolkit/ui-copilot/finetuning/utils/convert_llama_weights_to_hf2.py", line 195, in write_model
    f"model.layers.{layer_i}.input_layernorm.weight": loaded[0][
    KeyError: 'layers.0.attention_norm.weight'

@kartikayk
Copy link
Contributor

Hey! Sorry you're running into issues here. I didn't realize there are differences in how the 8B and 70B models are converted. Let me take a look into this in a bit.

@kartikayk
Copy link
Contributor

We have this function in the checkpointer, but seems like this isn't getting the job done. So need to figure out why that is.

@kartikayk
Copy link
Contributor

@fabriceyhc Some thoughts on the questions you asked above:

My current challenge lies in converting the different checkpoints - 8B uses FullModelMetaCheckpointer that outputs a meta_model_{i}.pt, whereas 70B uses FullModelHFCheckpointer and outputs hf_model_{idx}_{i}.pt where i is for each epoch (over 5 epochs)

The checkpointer used depends on the input checkpoint format. The 8B model makes uses of the consolidated.00.pth file which is in Meta format. But you can update this to use the safetensors checkpoints and use the HFCheckpointer instead. This should address the discrepancy in configs between 8B and 70B.

Is it safe to assume that we only need the checkpoint with the highest i index and can delete the intermediate ones? If so, it would be handy to have a config option to conserve diskspace by keeping the most recent checkpoint only.

yes this is the right understanding. Adding this flag has been on our TODO list for a while now. If you'd be open to contributing this as a PR, I'd be happy to collaborate with you on the review.

How can we convert both the 8B and 70B versions of the lora fine-tuned LLama-3s so that they are suitable for inference via HF?

As I commented above, the FullModelCheckpointer does this conversion for you. But seems like you're still running into issues?

@optimass
Copy link

As I commented above, the FullModelCheckpointer does this conversion for you. But seems like you're still running into issues?

Yes, it's still unclear to me how to use the FullModelHFCheckpointer's outputs to HF's APIs, in particular Text-Generation Inference (TGI).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants