Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding GPT-NeoX-20B #16659

Merged
merged 42 commits into from May 24, 2022
Merged

Adding GPT-NeoX-20B #16659

merged 42 commits into from May 24, 2022

Conversation

zphang
Copy link
Contributor

@zphang zphang commented Apr 7, 2022

What does this PR do?

Adds GPT-NeoX-20B model and tokenizers.

Fixes #15642

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@LysandreJik

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@zphang zphang mentioned this pull request Apr 7, 2022
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 7, 2022

The documentation is not available anymore as the PR was closed or merged.

@ViktorThink
Copy link

Incredible work!

I have tested the model and seems to work as intended. I did discover one problem with the tokenizer though:

Here is the full script:

!git clone https://github.com/zphang/transformers
!cd transformers
!git checkout neox20b
!pip install -e .
!cd ..

from transformers import AutoModelForCausalLM, GPTNeoXTokenizer


model_name = r"EleutherAI/gpt-neox-20b"

model = AutoModelForCausalLM.from_pretrained(model_name)

tokenizer = GPTNeoXTokenizer.from_pretrained(model_name)

input_ids=tokenizer.encode("This is the input text", return_tensors="pt",add_special_tokens=False)
beam_output = model.generate(
      input_ids=input_ids,
      max_length=input_ids.shape[1]+30,
      min_length=input_ids.shape[1]+5,
      early_stopping=True,
      num_return_sequences=4,
      do_sample=True
      )

for j in range(4):
        output = tokenizer.decode(beam_output[j][input_ids.shape[1]:], skip_special_tokens=False)

I got the following error:

File "testing/testDecoderOnly.py", line 104, in testModelSample
ran = tokenizer.decode(beam_output[j][input_ids.shape[1]:], skip_special_tokens=False)
File "/home/ec2-user/t5-regression3/transformers/src/transformers/tokenization_utils_base.py", line 3308, in decode
**kwargs,
File "/home/ec2-user/t5-regression3/transformers/src/transformers/tokenization_utils.py", line 946, in _decode
sub_texts.append(self.convert_tokens_to_string(current_sub_text))
File "/home/ec2-user/t5-regression3/transformers/src/transformers/models/gpt2/tokenization_gpt2.py", line 266, in convert_tokens_to_string
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
File "/home/ec2-user/t5-regression3/transformers/src/transformers/models/gpt2/tokenization_gpt2.py", line 266, in
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
KeyError: ' '

@zphang
Copy link
Contributor Author

zphang commented Apr 10, 2022

Hm yea, I can replicate that issue too. I'm not too familiar with the tokenization code. The fast tokenizer seems to work just fine, but the Python one (which I'm basing of GPT-2's tokenizer) seems to have some issues.

Here's the a minimal reproducible version:

import transformers
model_name = "EleutherAI/gpt-neox-20b"
tokenizer_slow = transformers.GPTNeoXTokenizer.from_pretrained(model_name)
tokenizer_fast = transformers.GPTNeoXTokenizerFast.from_pretrained(model_name)
print("Fast", repr(tokenizer_fast.decode([50274])))
print("Slow", repr(tokenizer_slow.decode([50274])))
Fast '    '
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Input In [78], in <cell line: 2>()
      1 print("Fast", repr(tokenizer_fast.decode([50274])))
----> 2 print("Slow", repr(tokenizer_slow.decode([50274])))

File ~/code/transformers/src/transformers/tokenization_utils_base.py:3304, in PreTrainedTokenizerBase.decode(self, token_ids, skip_special_tokens, clean_up_tokenization_spaces, **kwargs)
   3301 # Convert inputs to python lists
   3302 token_ids = to_py_obj(token_ids)
-> 3304 return self._decode(
   3305     token_ids=token_ids,
   3306     skip_special_tokens=skip_special_tokens,
   3307     clean_up_tokenization_spaces=clean_up_tokenization_spaces,
   3308     **kwargs,
   3309 )

File ~/code/transformers/src/transformers/tokenization_utils.py:946, in PreTrainedTokenizer._decode(self, token_ids, skip_special_tokens, clean_up_tokenization_spaces, spaces_between_special_tokens, **kwargs)
    944         current_sub_text.append(token)
    945 if current_sub_text:
--> 946     sub_texts.append(self.convert_tokens_to_string(current_sub_text))
    948 if spaces_between_special_tokens:
    949     text = " ".join(sub_texts)

File ~/code/transformers/src/transformers/models/gpt2/tokenization_gpt2.py:266, in GPT2Tokenizer.convert_tokens_to_string(self, tokens)
    264 """Converts a sequence of tokens (string) in a single string."""
    265 text = "".join(tokens)
--> 266 text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
    267 return text

File ~/code/transformers/src/transformers/models/gpt2/tokenization_gpt2.py:266, in <listcomp>(.0)
    264 """Converts a sequence of tokens (string) in a single string."""
    265 text = "".join(tokens)
--> 266 text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
    267 return text

KeyError: ' '

I believe the NeoX tokenizer handles spaces a little differently (it has special tokens for single, double, triple spaces, etc). Do you know if someone who's more familiar with tokenization code might be able to chime in?

@ViktorThink
Copy link

Great that the fast version works. Gently pinging @SaulLu and @Narsil if they have any answers.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your PR, @zphang, this is great! There are a few tests failing, let me give you pointers on how to solve them:

  • The check_code_quality run fails because the quality checks weren't applied. I recommend doing the following from the root of your fork: pip install -e .[quality], followed by make fixup. This should fix most of the issues, and tell you which issues remain to be solved manually.
  • There's a missing mention of GPT-NeoX-20B in the index.mdx file of the doc. Running make fix-copies from the root of your clone should solve this issue.

The rest of the issues seem to be linked to you importing many different models, most of which do not exist, in both src/transformers/models/gpt_neox/__init__.py and src/transformers/models/auto/modeling_auto.py. Left some comments where that applies.

Did you use the add-new-model-like command to add this model? What was your experience like using the script? Thanks again for your contributions!

src/transformers/models/auto/modeling_auto.py Outdated Show resolved Hide resolved
src/transformers/models/auto/modeling_auto.py Outdated Show resolved Hide resolved
src/transformers/models/auto/modeling_auto.py Outdated Show resolved Hide resolved
src/transformers/models/auto/modeling_auto.py Outdated Show resolved Hide resolved
src/transformers/models/auto/modeling_auto.py Outdated Show resolved Hide resolved
src/transformers/models/auto/modeling_auto.py Outdated Show resolved Hide resolved
src/transformers/models/gpt_neox/__init__.py Outdated Show resolved Hide resolved
@zphang
Copy link
Contributor Author

zphang commented Apr 11, 2022

Hey @LysandreJik, thanks for taking a look! I'll look into getting the tests to pass today.

Re: the model script, I did use the new model templating script, but many parts of it seemed to make the assumption that the model with be an encoder-decoder model (e.g. mentioning cross attention). I removed most of other model implementations aside from CasualLM as that's the primary format that NeoX-20B would be used for, but it looks like I missed out some other references to the other model implementations. Other than that, the script was very useful in setting up the boilerplate.

@aalok-sathe
Copy link

aalok-sathe commented Apr 21, 2022

added a PR to the PR to support AutoTokenizer initialization from pretrained_model_name_or_path:
zphang#1

@StellaAthena
Copy link
Contributor

I have resolved the merge conflicts in the config files, but I am not confidant in my understanding of how these various configs are supposed to work. I would appreciate it if someone double checked that I didn't do anything stupid.

@zphang
Copy link
Contributor Author

zphang commented May 21, 2022

Are there any further blockers to merging? It would be nice to have this merged in time for ACL next week :)

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @zphang. Many of the comments/suggestions on the previous reviews were just ignored. I have a few more suggestions on the style.
We will merge the model as soon as they are resolved, let us know if you need any help.

src/transformers/__init__.py Outdated Show resolved Hide resolved
src/transformers/__init__.py Outdated Show resolved Hide resolved
src/transformers/__init__.py Outdated Show resolved Hide resolved
src/transformers/models/gpt_neox/modeling_gpt_neox.py Outdated Show resolved Hide resolved
src/transformers/models/gpt_neox/modeling_gpt_neox.py Outdated Show resolved Hide resolved
src/transformers/models/gpt_neox/modeling_gpt_neox.py Outdated Show resolved Hide resolved
tests/models/gpt_neox/test_modeling_gpt_neox.py Outdated Show resolved Hide resolved
tests/models/gpt_neox/test_modeling_gpt_neox.py Outdated Show resolved Hide resolved
tests/models/gpt_neox/test_modeling_gpt_neox.py Outdated Show resolved Hide resolved
tests/models/gpt_neox/test_modeling_gpt_neox.py Outdated Show resolved Hide resolved
@zphang
Copy link
Contributor Author

zphang commented May 23, 2022

Apologies, I must have missed the previous comments. I've pushed an update with the desired changes.

@sgugger
Copy link
Collaborator

sgugger commented May 23, 2022

There are still four open comment on the modeling file, if you could have a look.

@zphang
Copy link
Contributor Author

zphang commented May 23, 2022

I think I got to all of them now (is there an easy way to check on the GitHub web interface?), let me know if I'm missing any.

@sgugger
Copy link
Collaborator

sgugger commented May 23, 2022

I see there are closed but not addressed, maybe you forgot to push your commit?

@zphang
Copy link
Contributor Author

zphang commented May 23, 2022

Terribly sorry! Pushed now.

@sgugger sgugger merged commit 71e6027 into huggingface:main May 24, 2022
@sgugger
Copy link
Collaborator

sgugger commented May 24, 2022

Thanks again for all your wok on this model!

@sgugger sgugger changed the title [WIP] Adding GPT-NeoX-20B Adding GPT-NeoX-20B May 24, 2022
elusenji pushed a commit to elusenji/transformers that referenced this pull request Jun 12, 2022
* initial

* first try

* working 20B

* 20B tokenizers

* Docs

* Import fixes for missing classes

* Update docs, fixup

* black formatting

* isort

* flake

* dummy objects

* documentation

* Documentation yml

* more docs

* tweaks for tests

* tokenization auto

* fix neox tests

* test

* test

* einsum

* address PR feedback

* Documentation

* Update README.md

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/gpt_neox/__init__.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/gpt_neox/configuration_gpt_neox.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Remove undefined LaTeX syntax

* Update to full url to avoid confusion about if that's supposed to refer to the Hub

* fix auto

* move tests

* documentation fix

* more doc fixes

* test refactor

* fix import

* fix import

* fix import

* fix import

* fix import

* style fixes

* More modeling fixes

Co-authored-by: Jason Phang <zp489@gr057.hpc.nyu.edu>
Co-authored-by: Stella Biderman <stellabiderman@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

GPT-NeoX-20B Integration
8 participants