Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add phi-3 checkpoint #1341

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
Draft

Add phi-3 checkpoint #1341

wants to merge 12 commits into from

Conversation

rasbt
Copy link
Collaborator

@rasbt rasbt commented Apr 23, 2024

  • Verify Phi-3-mini-4k-instruct configs
  • Add prompt style
  • Add other config files
  • Add test_model.py
  • Add to test_prompts.py
  • Update 2 tables in README
  • Update download_model_weights.md

@rasbt rasbt marked this pull request as draft April 23, 2024 15:14
@Andrei-Aksionov
Copy link
Collaborator

There is a modeling_*.py file.
Good luck 🙂.

@rasbt
Copy link
Collaborator Author

rasbt commented Apr 23, 2024

There is a modeling_*.py file.
Good luck 🙂.

Haha, I finally get the weights loaded but of course it's never easy ... of course it's generating gibberish

⚡ phi-3-checkpoint ~/litgpt litgpt chat --checkpoint_dir checkpoints/microsoft/Phi-3-mini-4k-instruct
Now chatting with Phi-3-mini-4k-instruct.
To exit, press 'Enter' on an empty prompt.

Seed set to 1234
>> Prompt: What do llamas eat?
>> Reply: epsonniformes }).selves }).SSIONunicívo }). EverythingFormsћassaiejalphutureievediennesenticaciónicaciónMilMinigh ninassaselvesselves exhaustselvesonnselvesktionΗracheracheionedΗ Avenoted Bij_+versionsmastevosepsselvesmobileselvesilleryassaucealphasseestoreselvesférFormsiej Mu Kaiser oppienngnatteversionsionedionedversionsSSIONectionaccoossFormassaselves_+uminatesonoSSIONológissancecenteecause_+ienn选uraleʋ Stepalphigosionaliilonverte }).ienn }).ativo Sternsonoiejuralassawnkademselves│uraleativaionedvos_+utschversionsponiej_+icacióniejiewerológvoasonverte shoutioned位ionedIdentmobi

Let the easter egg hunt begin 😭

@rasbt
Copy link
Collaborator Author

rasbt commented Apr 24, 2024

Some more tidbits via Daniel Han:

Phi 3 (3.8B) got released! The paper said it was just a Llama arch, but I found some quirks while adding this to
@UnslothAI
:

  1. Sliding window of 2047? Mistral v1 4096. So does Phi mini have SWA? (And odd num?) Max RoPE position is 4096?
  2. Upcasted RoPE? Like Gemma?
  3. Dynamic RoPE for 128K context lengths
  4. Fused MLP & QKV - need to unfuse
  5. MMLU evals are very different betw the Phi team Llama-3 team - why?

@Andrei-Aksionov
Copy link
Collaborator

Ok, it's becoming more interesting.
Somewhat I expected from LlaMA 3, but it didn't deliver.

@@ -298,6 +298,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.proj(x)


class Phi3MLP(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be possible to not need this class at all and instead reshape the weights for LLaMAMLP in the checkpoint conversion.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

100% agree. I was thinking the same thing. Similar to OLMo, I was hoping to get it working first and then simplify from there.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New models by Apple have if-else statements for this case: https://huggingface.co/apple/OpenELM-270M-Instruct/blob/main/modeling_openelm.py#L405-L462

For simplicity we definitely shouldn't do the same.

@rasbt
Copy link
Collaborator Author

rasbt commented Apr 24, 2024

@Andrei-Aksionov
Copy link
Collaborator

Current code is an ugly state, but at least the model produces the same output as HF one.
The most notable change is that now Phi3 model doesn't use parallel_residual in contrast to Phi1.5 and Phi2.

The missing piece is the Tokenizer: it has a smaller vocab size (32k vs 50k) that was extended by 64 special tokens.
If I'm not mistaken, the current code doesn't add these tokens.

litgpt/model.py Show resolved Hide resolved
@rasbt
Copy link
Collaborator Author

rasbt commented Apr 25, 2024

The missing piece is the Tokenizer: it has a smaller vocab size (32k vs 50k) that was extended by 64 special tokens.
If I'm not mistaken, the current code doesn't add these tokens.

Yeah, that sounds about right based on the Phi-3 paper:

To best benefit the open source community, phi-3-mini is built upon a similar block structure as Llama-2 [TLI+23] and uses the same tokenizer with vocabulary size of 320641

litgpt/prompts.py Outdated Show resolved Hide resolved
litgpt/prompts.py Outdated Show resolved Hide resolved
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants