Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add OLMo #927

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open

[WIP] Add OLMo #927

wants to merge 9 commits into from

Conversation

rasbt
Copy link
Collaborator

@rasbt rasbt commented Feb 12, 2024

Adds the popular and fully open-source OLMo models by Allen AI.

  • Implement model download
  • Test tokenizer
  • Implement HF checkpoint conversion
  • clean up HF checkpoint conversion
  • Make sure to use the right layer normalization
  • Make sure generate.py produces reasonable outputs
  • Update download and finetuning docs
  • Test pretraining
  • Test finetuning
    • Full finetuning
    • LoRA
    • Adapter
  • Add tests
  • Update README

Fixes #925

lit_gpt/config.py Outdated Show resolved Hide resolved
@rasbt
Copy link
Collaborator Author

rasbt commented Feb 13, 2024

I'm a big stuck with the conversion and would appreciate your advice and ideas @carmocca or @Andrei-Aksionov!

So, here are 3 special things about Olmo:

  1. they used weight tying like in GPT-2: They reuse the WTE weight as the output projection weight. The way they saved the tensors on the Hub though they simply duplicated that tensor so there shouldn't be any action required. When loading the model in HuggingFace, I checked that olmo.model.transformer.wte.weight and olmo.model.transformer.ff_out.weight contain the same tensor. That should be all good here.

  2. They use a non-parametric LayerNorm. I.e., their LayerNorm doesn't have the scale (weight) and shift (bias) parameters. To avoid any code changes just for that model, my workaround is to just use zeros and ones so that these have no effect:

        state_dict[f"transformer.h.{l}.norm_1.weight"] = torch.ones(config.n_embd)
        state_dict[f"transformer.h.{l}.norm_2.weight"] = torch.ones(config.n_embd)
        state_dict[f"transformer.h.{l}.norm_1.bias"] = torch.zeros(config.n_embd)
        state_dict[f"transformer.h.{l}.norm_2.bias"] = torch.zeros(config.n_embd)
  1. The problem is that I'm missing weights ...

The HF version is like this, which is confusing, because the ops are not applied in that "sequential" order as far as I can tell:

OLMoForCausalLM(
  (model): Olmo(
    (transformer): ModuleDict(
      (wte): Embedding(50304, 2048)
      (emb_drop): Dropout(p=0.0, inplace=False)
      (ln_f): LayerNorm()
      (blocks): ModuleList(
        (0-15): 16 x OlmoSequentialBlock(
          (dropout): Dropout(p=0.0, inplace=False)
          (act): SwiGLU()
          (attn_out): Linear(in_features=2048, out_features=2048, bias=False)
          (ff_out): Linear(in_features=8192, out_features=2048, bias=False)
          (rotary_emb): RotaryEmbedding()
          (attn_norm): LayerNorm()
          (ff_norm): LayerNorm()
          (att_proj): Linear(in_features=2048, out_features=6144, bias=False)
          (ff_proj): Linear(in_features=2048, out_features=16384, bias=False)
        )
      )
      (ff_out): Embedding(50304, 2048)
    )
  )
)

Unless I'm wrong, I think what happens is that ff_proj is a placeholder for the mlp F1 and FC2 layers. I.e., the first half is FC1 and the second half is FC2. It's kind of confusing though.

What I am thinking is that we have to split the fc weights, which would avoid us having to write some custom code in the GPT model class:

    weight_map = {
        "model.transformer.wte.weight": "transformer.wte.weight",
        "model.transformer.ff_out.weight": "lm_head.weight",
        "model.transformer.blocks.{}.attn_out.weight": "transformer.h.{}.attn.proj.weight",
        "model.transformer.blocks.{}.ff_proj.weight": "transformer.h.{}.mlp.fc_1.weight", # split into fc1 and fc2
        "model.transformer.blocks.{}.att_proj.weight": "transformer.h.{}.attn.attn.weight",
        "model.transformer.blocks.{}.ff_out.weight": "transformer.h.{}.mlp.proj.weight",
    }
...

    for l in range(config.n_layer):
        state_dict[f"transformer.h.{l}.mlp.fc_2.weight"] = state_dict[f"transformer.h.{l}.mlp.fc_1.weight"][config.n_embd:]
        state_dict[f"transformer.h.{l}.mlp.fc_1.weight"] = state_dict[f"transformer.h.{l}.mlp.fc_1.weight"][:config.n_embd]

is this somehow possible with the lit-gpt SaveProxyTensor?

@carmocca
Copy link
Contributor

Hey! All your suggestions make sense to me. You should be able to split the combined ff linear as you suggest, especially if load_param has ben called already. We also manipulate the qkv linears for llama2 checkpoints in a similar way.

However, note that your workarounds will only work for inference. During training, wte and ff_out will not be tied and the layernorm parameters wont be frozen.

@Andrei-Aksionov
Copy link
Collaborator

Hello @rasbt

Looks like you are correct. I just wanted to add a couple of things that I've noticed while reviewing their code. For posterity, sorta speak.

I also don't like when the layers are initialized not in the order they are executed. Lit-GPT also does it: first we create lm_head and only then transformer layers 🙃.

So the order of execution should be as such:

# Attention
1. attn_norm
2. attn_proj (2048, 6144) <- combined QKV
3. rotary_emb
4. attn_out (2048, 2048)
5. dropout
# MLP
1. ff_norm
2. ff_proj (2048, 16384) <- combined [fc_2, fc_1] / [up, gate] in LlaMA notation
3. act
4. ff_out (8192, 2048)
5. dropout
  1. Yes, they use weight_tying. It's configurable and they decided to use it. And yes, it won't work during training. Although it's not difficult to add if more models will use it.
  2. Their LayerNorm class supports weight and bias parameters, but it's controlled by the config. It looks like they turned off .weight and .bias per config.
  3. This deserves a bit more explanation.
    In LlaMAMLP class we have fc_1, fc_2 and proj. During the forward pass we apply fc_1 and fc_2 on the input separately:
    https://github.com/Lightning-AI/lit-gpt/blob/f5d68065ff621fc2cc190c05dcc4ab2cda1d1f57/lit_gpt/model.py#L286-L290

Olmo has only two layer: ff_proj and ff_out. They decided to take an approach that is similar to a combined QKV matrix and created ff_proj layer that does this matmul op in one go. But then, the way they split the result is I would say an unexpected - in the activation function:

def forward(self, x: torch.Tensor) -> torch.Tensor:
        x, gate = x.chunk(2, dim=-1)
        return F.silu(gate) * x

and then they apply ff_out to it.

Important to note the way they split and then apply activation function on a chunk. That means that:
ff_proj == [fc_2, fc_1]
ff_out == proj

@rasbt
Copy link
Collaborator Author

rasbt commented Feb 14, 2024

Thanks so much for the feedback @carmocca and @Andrei-Aksionov , this was super helpful! After more tinkering, I went with a custom OLMoMLP (analogous to LLaMALMLP) because I thought this was easier than the other workarounds -- both from an implementation perspective but also code-readability in the future.

The weights load ok now, but for some reason, the results are garbage. E.g., for

python generate/base.py --checkpoint_dir ./checkpoints/allenai/OLMo-1b/

What food do llamas eat?lerslerslerslerslerslerslerslerslerslerslerslerslerslerslerslerslerslerslerslerslerslerslerslerslerslerslerslerslerslerslerslerslerslerslerslerslerslerslerslerslerslerslerslerslerslerslerslerslerslers

And

python generate/base.py --checkpoint_dir ./checkpoints/allenai/OLMo-7b/

What food do llamas eat? nic except ' Up Area has , climate * new area even county bun dressingall Bul Index millions Di withdrawal intent except / bun ID tonnes approve welcome St/ regimes health ng est African worse Multiple; p; ques Up ( IL'' Area / p

@rasbt
Copy link
Collaborator Author

rasbt commented Feb 14, 2024

Yes, they use weight_tying. It's configurable and they decided to use it. And yes, it won't work during training. Although it's not difficult to add if more models will use it.

Actually, upon further inspection they only use weight tying for the 1B (https://huggingface.co/allenai/OLMo-1B/blob/main/config.json#L42) model not for the 7B model (https://huggingface.co/allenai/OLMo-7B/blob/main/config.json#L42). I adjusted the code accordingly. Still not working well though.

@carmocca
Copy link
Contributor

I would strongly prefer that we don't add this new MLP class.

To debug the output, you'll have to inspect the activations for both models layer by layer to see where they diverge

@rasbt
Copy link
Collaborator Author

rasbt commented Feb 14, 2024

I would strongly prefer that we don't add this new MLP class.

Ok! Maybe let's leave it in there until we got it to work, and then we can refactor it into one of the existing classes somehow.

@rasbt
Copy link
Collaborator Author

rasbt commented Feb 14, 2024

Just to add a note about pinpointing the difference. With Carlos's help, we found that the difference currently is in how the QKV matrix is split into queries, keys, and values.

https://github.com/Lightning-AI/lit-gpt/blob/main/lit_gpt/model.py#L195-L202

and

https://github.com/allenai/OLMo/blob/main/olmo/model.py#L687
https://github.com/allenai/OLMo/blob/main/olmo/model.py#L559-L571

In Lit-GPT, the Q, K, and V are interleaved (to also support MQA) whereas in OLMo, QKV are not interleaved.

We could potentially accommodate OLMo in Lit-GPT if we apply the steps here from Llama in the conversion script but in reverse: https://github.com/Lightning-AI/lit-gpt/blob/main/scripts/convert_hf_checkpoint.py#L182-L186

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Adding OLMo 1B and 7B
3 participants