diff --git a/README.md b/README.md index c046393239091..d738a767c51a6 100644 --- a/README.md +++ b/README.md @@ -274,6 +274,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih. 1. **[GPT-2](https://huggingface.co/docs/transformers/model_doc/gpt2)** (from OpenAI) released with the paper [Language Models are Unsupervised Multitask Learners](https://blog.openai.com/better-language-models/) by Alec Radford*, Jeffrey Wu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever**. 1. **[GPT-J](https://huggingface.co/docs/transformers/model_doc/gptj)** (from EleutherAI) released in the repository [kingoflolz/mesh-transformer-jax](https://github.com/kingoflolz/mesh-transformer-jax/) by Ben Wang and Aran Komatsuzaki. 1. **[GPT Neo](https://huggingface.co/docs/transformers/model_doc/gpt_neo)** (from EleutherAI) released in the repository [EleutherAI/gpt-neo](https://github.com/EleutherAI/gpt-neo) by Sid Black, Stella Biderman, Leo Gao, Phil Wang and Connor Leahy. +1. **[GPT NeoX](https://huggingface.co/docs/transformers/main/model_doc/gpt_neox)** (from EleutherAI) released with the paper [GPT-NeoX-20B: An Open-Source Autoregressive Language Model](https://arxiv.org/abs/2204.06745) by Sid Black, Stella Biderman, Eric Hallahan, Quentin Anthony, Leo Gao, Laurence Golding, Horace He, Connor Leahy, Kyle McDonell, Jason Phang, Michael Pieler, USVSN Sai Prashanth, Shivanshu Purohit, Laria Reynolds, Jonathan Tow, Ben Wang, Samuel Weinbach 1. **[Hubert](https://huggingface.co/docs/transformers/model_doc/hubert)** (from Facebook) released with the paper [HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units](https://arxiv.org/abs/2106.07447) by Wei-Ning Hsu, Benjamin Bolte, Yao-Hung Hubert Tsai, Kushal Lakhotia, Ruslan Salakhutdinov, Abdelrahman Mohamed. 1. **[I-BERT](https://huggingface.co/docs/transformers/model_doc/ibert)** (from Berkeley) released with the paper [I-BERT: Integer-only BERT Quantization](https://arxiv.org/abs/2101.01321) by Sehoon Kim, Amir Gholami, Zhewei Yao, Michael W. Mahoney, Kurt Keutzer. 1. **[ImageGPT](https://huggingface.co/docs/transformers/main/model_doc/imagegpt)** (from OpenAI) released with the paper [Generative Pretraining from Pixels](https://openai.com/blog/image-gpt/) by Mark Chen, Alec Radford, Rewon Child, Jeffrey Wu, Heewoo Jun, David Luan, Ilya Sutskever. diff --git a/README_ko.md b/README_ko.md index aab7a7c4bc2d6..972f9688517b5 100644 --- a/README_ko.md +++ b/README_ko.md @@ -251,6 +251,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는 1. **[GLPN](https://huggingface.co/docs/transformers/main/model_doc/glpn)** (from KAIST) released with the paper [Global-Local Path Networks for Monocular Depth Estimation with Vertical CutDepth](https://arxiv.org/abs/2201.07436) by Doyeon Kim, Woonghyun Ga, Pyungwhan Ahn, Donggyu Joo, Sehwan Chun, Junmo Kim. 1. **[GPT](https://huggingface.co/docs/transformers/model_doc/openai-gpt)** (from OpenAI) released with the paper [Improving Language Understanding by Generative Pre-Training](https://blog.openai.com/language-unsupervised/) by Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever. 1. **[GPT Neo](https://huggingface.co/docs/transformers/model_doc/gpt_neo)** (from EleutherAI) released in the repository [EleutherAI/gpt-neo](https://github.com/EleutherAI/gpt-neo) by Sid Black, Stella Biderman, Leo Gao, Phil Wang and Connor Leahy. +1. **[GPT NeoX](https://huggingface.co/docs/transformers/main/model_doc/gpt_neox)** (from EleutherAI) released with the paper [GPT-NeoX-20B: An Open-Source Autoregressive Language Model](https://arxiv.org/abs/2204.06745) by Sid Black, Stella Biderman, Eric Hallahan, Quentin Anthony, Leo Gao, Laurence Golding, Horace He, Connor Leahy, Kyle McDonell, Jason Phang, Michael Pieler, USVSN Sai Prashanth, Shivanshu Purohit, Laria Reynolds, Jonathan Tow, Ben Wang, Samuel Weinbach 1. **[GPT-2](https://huggingface.co/docs/transformers/model_doc/gpt2)** (from OpenAI) released with the paper [Language Models are Unsupervised Multitask Learners](https://blog.openai.com/better-language-models/) by Alec Radford*, Jeffrey Wu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever**. 1. **[GPT-J](https://huggingface.co/docs/transformers/model_doc/gptj)** (from EleutherAI) released in the repository [kingoflolz/mesh-transformer-jax](https://github.com/kingoflolz/mesh-transformer-jax/) by Ben Wang and Aran Komatsuzaki. 1. **[Hubert](https://huggingface.co/docs/transformers/model_doc/hubert)** (from Facebook) released with the paper [HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units](https://arxiv.org/abs/2106.07447) by Wei-Ning Hsu, Benjamin Bolte, Yao-Hung Hubert Tsai, Kushal Lakhotia, Ruslan Salakhutdinov, Abdelrahman Mohamed. diff --git a/README_zh-hans.md b/README_zh-hans.md index 7031fd3570bf1..8325df7cbd86c 100644 --- a/README_zh-hans.md +++ b/README_zh-hans.md @@ -275,6 +275,7 @@ conda install -c huggingface transformers 1. **[GLPN](https://huggingface.co/docs/transformers/main/model_doc/glpn)** (来自 KAIST) 伴随论文 [Global-Local Path Networks for Monocular Depth Estimation with Vertical CutDepth](https://arxiv.org/abs/2201.07436) 由 Doyeon Kim, Woonghyun Ga, Pyungwhan Ahn, Donggyu Joo, Sehwan Chun, Junmo Kim 发布。 1. **[GPT](https://huggingface.co/docs/transformers/model_doc/openai-gpt)** (来自 OpenAI) 伴随论文 [Improving Language Understanding by Generative Pre-Training](https://blog.openai.com/language-unsupervised/) 由 Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever 发布。 1. **[GPT Neo](https://huggingface.co/docs/transformers/model_doc/gpt_neo)** (来自 EleutherAI) 随仓库 [EleutherAI/gpt-neo](https://github.com/EleutherAI/gpt-neo) 发布。作者为 Sid Black, Stella Biderman, Leo Gao, Phil Wang and Connor Leahy 发布。 +1. **[GPT NeoX](https://huggingface.co/docs/transformers/main/model_doc/gpt_neox)** (from EleutherAI) released with the paper [GPT-NeoX-20B: An Open-Source Autoregressive Language Model](https://arxiv.org/abs/2204.06745) by Sid Black, Stella Biderman, Eric Hallahan, Quentin Anthony, Leo Gao, Laurence Golding, Horace He, Connor Leahy, Kyle McDonell, Jason Phang, Michael Pieler, USVSN Sai Prashanth, Shivanshu Purohit, Laria Reynolds, Jonathan Tow, Ben Wang, Samuel Weinbach 1. **[GPT-2](https://huggingface.co/docs/transformers/model_doc/gpt2)** (来自 OpenAI) 伴随论文 [Language Models are Unsupervised Multitask Learners](https://blog.openai.com/better-language-models/) 由 Alec Radford*, Jeffrey Wu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever** 发布。 1. **[GPT-J](https://huggingface.co/docs/transformers/model_doc/gptj)** (来自 EleutherAI) 伴随论文 [kingoflolz/mesh-transformer-jax](https://github.com/kingoflolz/mesh-transformer-jax/) 由 Ben Wang and Aran Komatsuzaki 发布。 1. **[Hubert](https://huggingface.co/docs/transformers/model_doc/hubert)** (来自 Facebook) 伴随论文 [HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units](https://arxiv.org/abs/2106.07447) 由 Wei-Ning Hsu, Benjamin Bolte, Yao-Hung Hubert Tsai, Kushal Lakhotia, Ruslan Salakhutdinov, Abdelrahman Mohamed 发布。 diff --git a/README_zh-hant.md b/README_zh-hant.md index 5971ab404917f..a19c9ba23cbf4 100644 --- a/README_zh-hant.md +++ b/README_zh-hant.md @@ -287,6 +287,7 @@ conda install -c huggingface transformers 1. **[GLPN](https://huggingface.co/docs/transformers/main/model_doc/glpn)** (from KAIST) released with the paper [Global-Local Path Networks for Monocular Depth Estimation with Vertical CutDepth](https://arxiv.org/abs/2201.07436) by Doyeon Kim, Woonghyun Ga, Pyungwhan Ahn, Donggyu Joo, Sehwan Chun, Junmo Kim. 1. **[GPT](https://huggingface.co/docs/transformers/model_doc/openai-gpt)** (from OpenAI) released with the paper [Improving Language Understanding by Generative Pre-Training](https://blog.openai.com/language-unsupervised/) by Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever. 1. **[GPT Neo](https://huggingface.co/docs/transformers/model_doc/gpt_neo)** (from EleutherAI) released in the repository [EleutherAI/gpt-neo](https://github.com/EleutherAI/gpt-neo) by Sid Black, Stella Biderman, Leo Gao, Phil Wang and Connor Leahy. +1. **[GPT NeoX](https://huggingface.co/docs/transformers/main/model_doc/gpt_neox)** (from EleutherAI) released with the paper [GPT-NeoX-20B: An Open-Source Autoregressive Language Model](https://arxiv.org/abs/2204.06745) by Sid Black, Stella Biderman, Eric Hallahan, Quentin Anthony, Leo Gao, Laurence Golding, Horace He, Connor Leahy, Kyle McDonell, Jason Phang, Michael Pieler, USVSN Sai Prashanth, Shivanshu Purohit, Laria Reynolds, Jonathan Tow, Ben Wang, Samuel Weinbach 1. **[GPT-2](https://huggingface.co/docs/transformers/model_doc/gpt2)** (from OpenAI) released with the paper [Language Models are Unsupervised Multitask Learners](https://blog.openai.com/better-language-models/) by Alec Radford*, Jeffrey Wu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever**. 1. **[GPT-J](https://huggingface.co/docs/transformers/model_doc/gptj)** (from EleutherAI) released with the paper [kingoflolz/mesh-transformer-jax](https://github.com/kingoflolz/mesh-transformer-jax/) by Ben Wang and Aran Komatsuzaki. 1. **[Hubert](https://huggingface.co/docs/transformers/model_doc/hubert)** (from Facebook) released with the paper [HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units](https://arxiv.org/abs/2106.07447) by Wei-Ning Hsu, Benjamin Bolte, Yao-Hung Hubert Tsai, Kushal Lakhotia, Ruslan Salakhutdinov, Abdelrahman Mohamed. diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index cb67299cff4d4..c8384eaaf4dd9 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -284,6 +284,8 @@ title: GPT-J - local: model_doc/gpt_neo title: GPT Neo + - local: model_doc/gpt_neox + title: GPT NeoX - local: model_doc/hubert title: Hubert - local: model_doc/perceiver diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index 35d2a99440e79..190cdaa8cc24e 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -95,6 +95,7 @@ The library currently contains JAX, PyTorch and TensorFlow implementations, pret 1. **[GPT-2](model_doc/gpt2)** (from OpenAI) released with the paper [Language Models are Unsupervised Multitask Learners](https://blog.openai.com/better-language-models/) by Alec Radford*, Jeffrey Wu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever**. 1. **[GPT-J](model_doc/gptj)** (from EleutherAI) released in the repository [kingoflolz/mesh-transformer-jax](https://github.com/kingoflolz/mesh-transformer-jax/) by Ben Wang and Aran Komatsuzaki. 1. **[GPT Neo](model_doc/gpt_neo)** (from EleutherAI) released in the repository [EleutherAI/gpt-neo](https://github.com/EleutherAI/gpt-neo) by Sid Black, Stella Biderman, Leo Gao, Phil Wang and Connor Leahy. +1. **[GPT NeoX](model_doc/gpt_neox)** (from EleutherAI) released with the paper [GPT-NeoX-20B: An Open-Source Autoregressive Language Model](https://arxiv.org/abs/2204.06745) by Sid Black, Stella Biderman, Eric Hallahan, Quentin Anthony, Leo Gao, Laurence Golding, Horace He, Connor Leahy, Kyle McDonell, Jason Phang, Michael Pieler, USVSN Sai Prashanth, Shivanshu Purohit, Laria Reynolds, Jonathan Tow, Ben Wang, Samuel Weinbach 1. **[Hubert](model_doc/hubert)** (from Facebook) released with the paper [HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units](https://arxiv.org/abs/2106.07447) by Wei-Ning Hsu, Benjamin Bolte, Yao-Hung Hubert Tsai, Kushal Lakhotia, Ruslan Salakhutdinov, Abdelrahman Mohamed. 1. **[I-BERT](model_doc/ibert)** (from Berkeley) released with the paper [I-BERT: Integer-only BERT Quantization](https://arxiv.org/abs/2101.01321) by Sehoon Kim, Amir Gholami, Zhewei Yao, Michael W. Mahoney, Kurt Keutzer. 1. **[ImageGPT](model_doc/imagegpt)** (from OpenAI) released with the paper [Generative Pretraining from Pixels](https://openai.com/blog/image-gpt/) by Mark Chen, Alec Radford, Rewon Child, Jeffrey Wu, Heewoo Jun, David Luan, Ilya Sutskever. @@ -214,6 +215,7 @@ Flax), PyTorch, and/or TensorFlow. | Funnel Transformer | ✅ | ✅ | ✅ | ✅ | ❌ | | GLPN | ❌ | ❌ | ✅ | ❌ | ❌ | | GPT Neo | ❌ | ❌ | ✅ | ❌ | ✅ | +| GPT NeoX | ❌ | ✅ | ✅ | ❌ | ❌ | | GPT-J | ❌ | ❌ | ✅ | ✅ | ✅ | | Hubert | ❌ | ❌ | ✅ | ✅ | ❌ | | I-BERT | ❌ | ❌ | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/gpt_neox.mdx b/docs/source/en/model_doc/gpt_neox.mdx new file mode 100644 index 0000000000000..1be8be7a6a5e6 --- /dev/null +++ b/docs/source/en/model_doc/gpt_neox.mdx @@ -0,0 +1,76 @@ + + +# GPT-NeoX + +## Overview + +We introduce GPT-NeoX-20B, a 20 billion parameter autoregressive language model trained on the Pile, whose weights will +be made freely and openly available to the public through a permissive license. It is, to the best of our knowledge, +the largest dense autoregressive model that has publicly available weights at the time of submission. In this work, +we describe GPT-NeoX-20B's architecture and training and evaluate its performance on a range of language-understanding, +mathematics, and knowledge-based tasks. We find that GPT-NeoX-20B is a particularly powerful few-shot reasoner and +gains far more in performance when evaluated five-shot than similarly sized GPT-3 and FairSeq models. We open-source +the training and evaluation code, as well as the model weights, at [https://github.com/EleutherAI/gpt-neox](https://github.com/EleutherAI/gpt-neox). + +Development of the model was led by Sid Black, Stella Biderman and Eric Hallahan, and the model was trained with +generous the support of [CoreWeave](https://www.coreweave.com/). + +GPT-NeoX-20B was trained with fp16, thus it is recommended to initialize the model as follows: + +```python +model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b").half().cuda() +``` + +GPT-NeoX-20B also has a different tokenizer from the one used in GPT-J-6B and GPT-Neo. The new tokenizer allocates +additional tokens to whitespace characters, making the model more suitable for certain tasks like code generation. + +### Generation + +The `generate()` method can be used to generate text using GPT Neo model. + +```python +>>> from transformers import GPTNeoXForCausalLM, GPTNeoXTokenizerFast + +>>> model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b") +>>> tokenizer = GPTNeoXTokenizerFast.from_pretrained("EleutherAI/gpt-neox-20b") + +>>> prompt = "GPTNeoX20B is a 20B-parameter autoregressive Transformer model developed by EleutherAI." + +>>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids + +>>> gen_tokens = model.generate( +... input_ids, +... do_sample=True, +... temperature=0.9, +... max_length=100, +... ) +>>> gen_text = tokenizer.batch_decode(gen_tokens)[0] +``` + +## GPTNeoXConfig + +[[autodoc]] GPTNeoXConfig + +## GPTNeoXTokenizerFast + +[[autodoc]] GPTNeoXTokenizerFast + +## GPTNeoXModel + +[[autodoc]] GPTNeoXModel + - forward + +## GPTNeoXForCausalLM + +[[autodoc]] GPTNeoXForCausalLM + - forward diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index aff37abbec6aa..34a57736bd8bb 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -213,6 +213,7 @@ "models.glpn": ["GLPN_PRETRAINED_CONFIG_ARCHIVE_MAP", "GLPNConfig"], "models.gpt2": ["GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPT2Config", "GPT2Tokenizer"], "models.gpt_neo": ["GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoConfig"], + "models.gpt_neox": ["GPT_NEOX_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoXConfig"], "models.gptj": ["GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTJConfig"], "models.herbert": ["HerbertTokenizer"], "models.hubert": ["HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "HubertConfig"], @@ -501,6 +502,7 @@ _import_structure["models.fnet"].append("FNetTokenizerFast") _import_structure["models.funnel"].append("FunnelTokenizerFast") _import_structure["models.gpt2"].append("GPT2TokenizerFast") + _import_structure["models.gpt_neox"].append("GPTNeoXTokenizerFast") _import_structure["models.herbert"].append("HerbertTokenizerFast") _import_structure["models.layoutlm"].append("LayoutLMTokenizerFast") _import_structure["models.layoutlmv2"].append("LayoutLMv2TokenizerFast") @@ -1138,6 +1140,15 @@ "load_tf_weights_in_gpt_neo", ] ) + _import_structure["models.gpt_neox"].extend( + [ + "GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST", + "GPTNeoXForCausalLM", + "GPTNeoXLayer", + "GPTNeoXModel", + "GPTNeoXPreTrainedModel", + ] + ) _import_structure["models.gptj"].extend( [ "GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -2746,6 +2757,7 @@ from .models.glpn import GLPN_PRETRAINED_CONFIG_ARCHIVE_MAP, GLPNConfig from .models.gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, GPT2Tokenizer from .models.gpt_neo import GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoConfig + from .models.gpt_neox import GPT_NEOX_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoXConfig from .models.gptj import GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTJConfig from .models.herbert import HerbertTokenizer from .models.hubert import HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, HubertConfig @@ -3001,6 +3013,7 @@ from .models.fnet import FNetTokenizerFast from .models.funnel import FunnelTokenizerFast from .models.gpt2 import GPT2TokenizerFast + from .models.gpt_neox import GPTNeoXTokenizerFast from .models.herbert import HerbertTokenizerFast from .models.layoutlm import LayoutLMTokenizerFast from .models.layoutlmv2 import LayoutLMv2TokenizerFast @@ -3532,6 +3545,13 @@ GPTNeoPreTrainedModel, load_tf_weights_in_gpt_neo, ) + from .models.gpt_neox import ( + GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST, + GPTNeoXForCausalLM, + GPTNeoXLayer, + GPTNeoXModel, + GPTNeoXPreTrainedModel, + ) from .models.gptj import ( GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST, GPTJForCausalLM, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 66910e3e0a53b..560c984923097 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -62,6 +62,7 @@ glpn, gpt2, gpt_neo, + gpt_neox, gptj, herbert, hubert, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index aa4b64fa7015c..1c8779cc64ad1 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -67,6 +67,7 @@ ("glpn", "GLPNConfig"), ("gpt2", "GPT2Config"), ("gpt_neo", "GPTNeoConfig"), + ("gpt_neox", "GPTNeoXConfig"), ("gptj", "GPTJConfig"), ("hubert", "HubertConfig"), ("ibert", "IBertConfig"), @@ -177,6 +178,7 @@ ("glpn", "GLPN_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("gpt2", "GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("gpt_neo", "GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("gpt_neox", "GPT_NEOX_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("gptj", "GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("hubert", "HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("ibert", "IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), @@ -287,6 +289,7 @@ ("glpn", "GLPN"), ("gpt2", "OpenAI GPT-2"), ("gpt_neo", "GPT Neo"), + ("gpt_neox", "GPT NeoX"), ("gptj", "GPT-J"), ("herbert", "HerBERT"), ("hubert", "Hubert"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 1e62a4ab8ed20..454a71463d0d9 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -66,6 +66,7 @@ ("glpn", "GLPNModel"), ("gpt2", "GPT2Model"), ("gpt_neo", "GPTNeoModel"), + ("gpt_neox", "GPTNeoXModel"), ("gptj", "GPTJModel"), ("hubert", "HubertModel"), ("ibert", "IBertModel"), @@ -205,6 +206,7 @@ ("funnel", "FunnelForMaskedLM"), ("gpt2", "GPT2LMHeadModel"), ("gpt_neo", "GPTNeoForCausalLM"), + ("gpt_neox", "GPTNeoXForCausalLM"), ("gptj", "GPTJForCausalLM"), ("ibert", "IBertForMaskedLM"), ("layoutlm", "LayoutLMForMaskedLM"), @@ -253,6 +255,7 @@ ("electra", "ElectraForCausalLM"), ("gpt2", "GPT2LMHeadModel"), ("gpt_neo", "GPTNeoForCausalLM"), + ("gpt_neox", "GPTNeoXForCausalLM"), ("gptj", "GPTJForCausalLM"), ("marian", "MarianForCausalLM"), ("mbart", "MBartForCausalLM"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 8c1ceeee2cdd5..13c8177ba30af 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -125,6 +125,7 @@ ("funnel", ("FunnelTokenizer", "FunnelTokenizerFast" if is_tokenizers_available() else None)), ("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), ("gpt_neo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), + ("gpt_neox", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)), ("gptj", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), ("herbert", ("HerbertTokenizer", "HerbertTokenizerFast" if is_tokenizers_available() else None)), ("hubert", ("Wav2Vec2CTCTokenizer", None)), diff --git a/src/transformers/models/gpt_neox/__init__.py b/src/transformers/models/gpt_neox/__init__.py new file mode 100644 index 0000000000000..ca1d2c1f42c36 --- /dev/null +++ b/src/transformers/models/gpt_neox/__init__.py @@ -0,0 +1,80 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...file_utils import _LazyModule, is_tokenizers_available, is_torch_available +from ...utils import OptionalDependencyNotAvailable + + +_import_structure = { + "configuration_gpt_neox": ["GPT_NEOX_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoXConfig"], +} + +try: + if not is_tokenizers_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["tokenization_gpt_neox_fast"] = ["GPTNeoXTokenizerFast"] + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_gpt_neox"] = [ + "GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST", + "GPTNeoXForCausalLM", + "GPTNeoXLayer", + "GPTNeoXModel", + "GPTNeoXPreTrainedModel", + ] + + +if TYPE_CHECKING: + from .configuration_gpt_neox import GPT_NEOX_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoXConfig + + try: + if not is_tokenizers_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .tokenization_gpt_neox_fast import GPTNeoXTokenizerFast + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_gpt_neox import ( + GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST, + GPTNeoXForCausalLM, + GPTNeoXLayer, + GPTNeoXModel, + GPTNeoXPreTrainedModel, + ) + + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/gpt_neox/configuration_gpt_neox.py b/src/transformers/models/gpt_neox/configuration_gpt_neox.py new file mode 100644 index 0000000000000..2ffb351f3b592 --- /dev/null +++ b/src/transformers/models/gpt_neox/configuration_gpt_neox.py @@ -0,0 +1,125 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" GPTNeoX model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + +GPT_NEOX_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "eleutherai/gpt-neox-20b": "https://huggingface.co/eleutherai/gpt-neox-20b/resolve/main/config.json", + # See all GPTNeoX models at https://huggingface.co/models?filter=gpt_neox +} + + +class GPTNeoXConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`GPTNeoXModel`]. It is used to instantiate an + GPTNeoX model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the GPTNeoX + [eleutherai/gpt-neox-20b](https://huggingface.co/eleutherai/gpt-neox-20b) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 30522): + Vocabulary size of the GPTNeoX model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`GPTNeoXModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimension of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + rotary_pct (`float`, *optional*, defaults to 0.25): + percentage of hidden dimensions to allocate to rotary embeddings + rotary_emb_base (`int`, *optional*, defaults to 10000) + base for computing rotary embeddings frequency + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + Example: + + ```python + >>> from transformers import GPTNeoXModel, GPTNeoXConfig + + >>> # Initializing a GPTNeoX gpt-neox-20b style configuration + >>> configuration = GPTNeoXConfig() + + >>> # Initializing a model from the gpt-neox-20b style configuration + >>> model = GPTNeoXModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + model_type = "gpt_neox" + + def __init__( + self, + vocab_size=50432, + hidden_size=6144, + num_hidden_layers=44, + num_attention_heads=64, + intermediate_size=24576, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + rotary_pct=0.25, + rotary_emb_base=10000, + max_position_embeddings=2048, + initializer_range=0.02, + layer_norm_eps=1e-5, + use_cache=True, + bos_token_id=0, + eos_token_id=2, + tie_word_embeddings=False, + **kwargs + ): + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.rotary_pct = rotary_pct + self.rotary_emb_base = rotary_emb_base + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.use_cache = use_cache + self.tie_word_embeddings = tie_word_embeddings diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py new file mode 100755 index 0000000000000..6bf5762d6d97d --- /dev/null +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -0,0 +1,650 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch GPTNeoX model.""" + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...file_utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_utils import PreTrainedModel +from ...utils import logging +from .configuration_gpt_neox import GPTNeoXConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "gpt-neox-20b" +_CONFIG_FOR_DOC = "GPTNeoXConfig" +_TOKENIZER_FOR_DOC = "GPTNeoXTokenizerFast" + +GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "EleutherAI/gpt-neox-20b", + # See all GPTNeoX models at https://huggingface.co/models?filter=gpt_neox +] + + +class GPTNeoXPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = GPTNeoXConfig + base_model_prefix = "gpt_neox" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, GPTNeoXModel): + module.gradient_checkpointing = value + + +class GPTNeoXAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.num_attention_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.num_attention_heads + self.rotary_ndims = int(self.head_size * config.rotary_pct) + max_positions = config.max_position_embeddings + self.register_buffer( + "bias", + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( + 1, 1, max_positions, max_positions + ), + ) + self.register_buffer("masked_bias", torch.tensor(-1e9)) + self.rotary_emb = RotaryEmbedding(self.rotary_ndims, base=config.rotary_emb_base) + self.norm_factor = torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype()) + self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size) + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + + def forward( + self, + hidden_states, + attention_mask, + head_mask=None, + layer_past=None, + use_cache=False, + output_attentions=False, + ): + has_layer_past = layer_past is not None + + # Compute QKV + # Attention heads [batch, seq_len, hidden_size] + # --> [batch, seq_len, (np * 3 * head_size)] + qkv = self.query_key_value(hidden_states) + + # [batch, seq_len, (num_heads * 3 * head_size)] + # --> [batch, seq_len, num_heads, 3 * head_size] + new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size) + qkv = qkv.view(*new_qkv_shape) + + # [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size] + query = qkv[..., : self.head_size].permute(0, 2, 1, 3) + key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3) + value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3) + + # Compute rotary embeddings on rotary_ndims + query_rot = query[..., : self.rotary_ndims] + query_pass = query[..., self.rotary_ndims :] + key_rot = key[..., : self.rotary_ndims] + key_pass = key[..., self.rotary_ndims :] + + # Compute token offset for rotary embeddings (when decoding) + seq_len = key.shape[-2] + offset = 0 + if has_layer_past: + offset = layer_past[0].shape[-2] + seq_len += offset + cos, sin = self.rotary_emb(value, seq_len=seq_len) + query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, offset=offset) + query = torch.cat((query, query_pass), dim=-1) + key = torch.cat((key, key_pass), dim=-1) + + # Cache QKV values + if has_layer_past: + past_key = layer_past[0] + past_value = layer_past[1] + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + present = None if use_cache else (key, value) + + # Compute attention + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + + # Reshape outputs + attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size) + attn_output = self.dense(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights,) + + return outputs + + @classmethod + def _split_heads(cls, tensor, num_attention_heads, attn_head_size): + """ + Splits hidden dim into attn_head_size and num_attention_heads + """ + # tensor: [bs, seq_len, hidden_size] + new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size) + # -> [bs, seq_len, num_attention_heads, attn_head_size] + tensor = tensor.view(new_shape) + # -> [bs, num_attention_heads, seq_len, attn_head_size] + tensor = tensor.permute(0, 2, 1, 3) + return tensor + + @classmethod + def _merge_heads(cls, tensor, num_attention_heads, attn_head_size): + """ + Merges attn_head_size dim and num_attn_heads dim into hidden dim + """ + # tensor [bs, num_attention_heads, seq_len, attn_head_size] + tensor = tensor.permute(0, 2, 1, 3).contiguous() + # -> [bs, seq_len, num_attention_heads, attn_head_size] + tensor = tensor.view(tensor.size(0), tensor.size(1), num_attention_heads * attn_head_size) + # -> [bs, seq_len, hidden_size] + return tensor + + def _attn(self, query, key, value, attention_mask=None, head_mask=None): + # q, k, v: [bs, num_attention_heads, seq_len, attn_head_size] + # compute causal mask from causal mask buffer + batch_size, num_attention_heads, query_length, attn_head_size = query.size() + key_length = key.size(-2) + + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() + + query = query.view(batch_size * num_attention_heads, query_length, attn_head_size) + key = key.view(batch_size * num_attention_heads, key_length, attn_head_size) + attn_scores = torch.einsum("bik,bjk->bij", query, key) / self.norm_factor + if torch.isnan(attn_scores).any(): + raise RuntimeError() + attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length) + + attn_scores = torch.where(causal_mask, attn_scores, self.masked_bias.to(attn_scores.dtype)) + + if attention_mask is not None: + # Apply the attention mask + attn_scores = attn_scores + attention_mask + + attn_weights = nn.functional.softmax(attn_scores, dim=-1) + if torch.isnan(attn_weights).any(): + raise RuntimeError() + attn_weights = attn_weights.to(value.dtype) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + if torch.isnan(attn_output).any(): + raise RuntimeError() + return attn_output, attn_weights + + +def attention_mask_func(attention_scores, ltor_mask): + attention_scores.masked_fill_(~ltor_mask, -10000.0) + return attention_scores + + +class RotaryEmbedding(torch.nn.Module): + def __init__(self, dim, base=10000, device=None): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + self.register_buffer("inv_freq", inv_freq) + self.max_seq_len_cached = None + self.cos_cached = None + self.sin_cached = None + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + self.cos_cached = emb.cos()[None, None, :, :] + self.sin_cached = emb.sin()[None, None, :, :] + return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...] + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0): + cos = cos[..., offset : q.shape[-2] + offset, :] + sin = sin[..., offset : q.shape[-2] + offset, :] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class GPTNeoXMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.dense_h_to_4h = nn.Linear(config.hidden_size, config.intermediate_size) + self.dense_4h_to_h = nn.Linear(config.intermediate_size, config.hidden_size) + self.act = ACT2FN[config.hidden_act] + + def forward(self, hidden_states): + hidden_states = self.dense_h_to_4h(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dense_4h_to_h(hidden_states) + return hidden_states + + +class GPTNeoXLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.attention = GPTNeoXAttention(config) + self.mlp = GPTNeoXMLP(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + use_cache=False, + layer_past=None, + output_attentions=False, + ): + residual = hidden_states + ln_out = self.input_layernorm(hidden_states) + attention_layer_outputs = self.attention( + ln_out, + attention_mask=attention_mask, + layer_past=layer_past, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + attn_output = attention_layer_outputs[0] # output_attn: a, present, (attentions) + outputs = attention_layer_outputs[1:] + + mlp_output = self.mlp(self.post_attention_layernorm(hidden_states)) + hidden_states = mlp_output + attn_output + residual + + if use_cache: + outputs = (hidden_states,) + outputs + else: + outputs = (hidden_states,) + outputs[1:] + + return outputs # hidden_states, present, (attentions) + + +GPT_NEOX_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`~GPTNeoXConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +GPT_NEOX_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`GPTNeoXTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert *input_ids* indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare GPTNeoX Model transformer outputting raw hidden-states without any specific head on top.", + GPT_NEOX_START_DOCSTRING, +) +class GPTNeoXModel(GPTNeoXPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.config = config + + self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList([GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)]) + self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_in + + def set_input_embeddings(self, value): + self.embed_in = value + + @add_start_docstrings_to_model_forward(GPT_NEOX_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPast, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + attention_mask=None, + head_mask=None, + inputs_embeds=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + + if past_key_values is None: + past_key_values = tuple([None] * self.config.num_hidden_layers) + + # Attention mask. + if attention_mask is not None: + assert batch_size > 0, "batch_size has to be defined and > 0" + attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * -10000.0 + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + if inputs_embeds is None: + inputs_embeds = self.embed_in(input_ids) + + hidden_states = inputs_embeds + + presents = () if use_cache else None + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + for i, (layer, layer_past) in enumerate(zip(self.layers, past_key_values)): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + outputs = layer( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask[i], + layer_past=layer_past, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + if output_attentions: + all_attentions = all_attentions + (outputs[2 if use_cache else 1],) + + hidden_states = self.final_layer_norm(hidden_states) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_attentions, + ) + + +@add_start_docstrings( + """GPTNeoX Model with a `language modeling` head on top for CLM fine-tuning.""", GPT_NEOX_START_DOCSTRING +) +class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel): + + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.gpt_neox = GPTNeoXModel(config) + self.embed_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.embed_out + + def set_output_embeddings(self, new_embeddings): + self.embed_out = new_embeddings + + @add_start_docstrings_to_model_forward(GPT_NEOX_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + inputs_embeds=None, + head_mask=None, + past_key_values=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional tensors are + only required when the model is used as a decoder in a Sequence to Sequence model. + + Contains pre-computed hidden-states (key and values in the self-attention blocks that can be used (see + `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + + Returns: + + Example: + + ```python + >>> from transformers import GPTNeoXTokenizer, GPTNeoXForCausalLM, GPTNeoXConfig + >>> import torch + + >>> tokenizer = GPTNeoXTokenizer.from_pretrained("gpt-neox-20b") + >>> config = GPTNeoXConfig.from_pretrained("gpt-neox-20b") + >>> config.is_decoder = True + >>> model = GPTNeoXForCausalLM.from_pretrained("gpt-neox-20b", config=config) + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> prediction_logits = outputs.logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.gpt_neox( + input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + lm_logits = self.embed_out(hidden_states) + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shift_logits = lm_logits[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithPast( + loss=lm_loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past} + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + ) + return reordered_past diff --git a/src/transformers/models/gpt_neox/tokenization_gpt_neox_fast.py b/src/transformers/models/gpt_neox/tokenization_gpt_neox_fast.py new file mode 100644 index 0000000000000..ce304b0a19719 --- /dev/null +++ b/src/transformers/models/gpt_neox/tokenization_gpt_neox_fast.py @@ -0,0 +1,142 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for GPTNeoX.""" +import json +from typing import TYPE_CHECKING, List, Optional, Tuple + +from tokenizers import pre_tokenizers + +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import logging + + +if TYPE_CHECKING: + from transformers.pipelines.conversational import Conversation + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"} + +PRETRAINED_VOCAB_FILES_MAP = { + "tokenizer_file": { + "eleutherai/gpt-neox-20b": "https://huggingface.co/EleutherAI/gpt-neox-20b/resolve/main/tokenizer.json", + }, +} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "gpt-neox-20b": 2048, +} + + +class GPTNeoXTokenizerFast(PreTrainedTokenizerFast): + """ + Construct a "fast" GPT-NeoX-20B tokenizer (backed by HuggingFace's *tokenizers* library). Based on byte-level + Byte-Pair-Encoding. + + This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will + be encoded differently whether it is at the beginning of the sentence (without space) or not: + + ``` + >>> from transformers import GPTNeoXTokenizerFast + >>> tokenizer = GPTNeoXTokenizerFast.from_pretrained("gpt2") + >>> tokenizer("Hello world")['input_ids'] + [15496, 995] + >>> tokenizer(" Hello world")['input_ids'] + [18435, 995] + ``` + + You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer, but since + the model was not pretrained this way, it might yield a decrease in performance. + + + + When used with `is_split_into_words=True`, this tokenizer needs to be instantiated with `add_prefix_space=True`. + + + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + merges_file (`str`): + Path to the merges file. + errors (`str`, *optional*, defaults to `"replace"`): + Paradigm to follow when decoding bytes to UTF-8. See + [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information. + unk_token (`str`, *optional*, defaults to `<|endoftext|>`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + bos_token (`str`, *optional*, defaults to `<|endoftext|>`): + The beginning of sequence token. + eos_token (`str`, *optional*, defaults to `<|endoftext|>`): + The end of sequence token. + add_prefix_space (`bool`, *optional*, defaults to `False`): + Whether or not to add an initial space to the input. This allows to treat the leading word just as any + other word. (GPTNeoX tokenizer detect beginning of words by the preceding space). + trim_offsets (`bool`, *optional*, defaults to `True`): + Whether or not the post-processing step should trim offsets to avoid including whitespaces. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file=None, + merges_file=None, + tokenizer_file=None, + unk_token="<|endoftext|>", + bos_token="<|endoftext|>", + eos_token="<|endoftext|>", + add_prefix_space=False, + **kwargs + ): + super().__init__( + vocab_file, + merges_file, + tokenizer_file=tokenizer_file, + unk_token=unk_token, + bos_token=bos_token, + eos_token=eos_token, + add_prefix_space=add_prefix_space, + **kwargs, + ) + + pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__()) + if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space: + pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type")) + pre_tok_state["add_prefix_space"] = add_prefix_space + self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state) + + self.add_prefix_space = add_prefix_space + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + files = self._tokenizer.model.save(save_directory, name=filename_prefix) + return tuple(files) + + def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]: + """This corresponds to DialoGPT variants of models.""" + input_ids = [] + for is_user, text in conversation.iter_texts(): + input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id]) + + if len(input_ids) > self.model_max_length: + input_ids = input_ids[-self.model_max_length :] + return input_ids diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 8f7e291beac6a..3fd3016c13fec 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -2135,6 +2135,37 @@ def load_tf_weights_in_gpt_neo(*args, **kwargs): requires_backends(load_tf_weights_in_gpt_neo, ["torch"]) +GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class GPTNeoXForCausalLM(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class GPTNeoXLayer(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class GPTNeoXModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class GPTNeoXPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/src/transformers/utils/dummy_tokenizers_objects.py b/src/transformers/utils/dummy_tokenizers_objects.py index 12cec6a4a2605..a7db0e523b985 100644 --- a/src/transformers/utils/dummy_tokenizers_objects.py +++ b/src/transformers/utils/dummy_tokenizers_objects.py @@ -150,6 +150,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["tokenizers"]) +class GPTNeoXTokenizerFast(metaclass=DummyObject): + _backends = ["tokenizers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tokenizers"]) + + class HerbertTokenizerFast(metaclass=DummyObject): _backends = ["tokenizers"] diff --git a/tests/models/gpt_neox/__init__.py b/tests/models/gpt_neox/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/models/gpt_neox/test_modeling_gpt_neox.py b/tests/models/gpt_neox/test_modeling_gpt_neox.py new file mode 100644 index 0000000000000..a4fb95384e83f --- /dev/null +++ b/tests/models/gpt_neox/test_modeling_gpt_neox.py @@ -0,0 +1,245 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch GPTNeoX model. """ + + +import unittest + +from transformers import GPTNeoXConfig, is_torch_available +from transformers.testing_utils import require_torch, slow, torch_device + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask + + +if is_torch_available(): + import torch + + from transformers import GPTNeoXForCausalLM, GPTNeoXModel + from transformers.models.gpt_neox.modeling_gpt_neox import GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST + + +class GPTNeoXModelTester: + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_input_mask=True, + use_token_type_ids=True, + use_labels=True, + vocab_size=99, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + type_sequence_label_size=2, + initializer_range=0.02, + num_labels=3, + num_choices=4, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_token_type_ids = use_token_type_ids + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.num_labels = num_labels + self.num_choices = num_choices + self.scope = scope + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = random_attention_mask([self.batch_size, self.seq_length]) + + token_labels = None + if self.use_labels: + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + + config = self.get_config() + + return config, input_ids, input_mask, token_labels + + def get_config(self): + return GPTNeoXConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + is_decoder=False, + initializer_range=self.initializer_range, + ) + + def prepare_config_and_inputs_for_decoder(self): + config, input_ids, input_mask, token_labels = self.prepare_config_and_inputs() + + config.is_decoder = True + + return config, input_ids, input_mask, token_labels + + def create_and_check_model(self, config, input_ids, input_mask): + model = GPTNeoXModel(config=config) + model.to(torch_device) + model.eval() + _ = model(input_ids, attention_mask=input_mask) + result = model(input_ids) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_model_as_decoder(self, config, input_ids, input_mask): + config.add_cross_attention = True + model = GPTNeoXModel(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_for_causal_lm(self, config, input_ids, input_mask, token_labels): + model = GPTNeoXForCausalLM(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, labels=token_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + + def create_and_check_decoder_model_past_large_inputs(self, config, input_ids, input_mask): + config.is_decoder = True + model = GPTNeoXForCausalLM(config=config) + model.to(torch_device) + model.eval() + + # first forward pass + outputs = model(input_ids, attention_mask=input_mask, use_cache=True) + past_key_values = outputs.past_key_values + + # create hypothetical multiple next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + next_mask = ids_tensor((self.batch_size, 3), vocab_size=2) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + next_attention_mask = torch.cat([input_mask, next_mask], dim=-1) + + output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask, output_hidden_states=True) + output_from_no_past = output_from_no_past["hidden_states"][0] + output_from_past = model( + next_tokens, + attention_mask=next_attention_mask, + past_key_values=past_key_values, + output_hidden_states=True, + )["hidden_states"][0] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() + + self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, input_ids, input_mask, token_labels = config_and_inputs + inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} + return config, inputs_dict + + +@require_torch +class GPTNeoXModelTest(ModelTesterMixin, unittest.TestCase): + + all_model_classes = (GPTNeoXModel, GPTNeoXForCausalLM) if is_torch_available() else () + all_generative_model_classes = (GPTNeoXForCausalLM,) if is_torch_available() else () + test_pruning = False + test_missing_keys = False + test_model_parallel = False + test_head_masking = False + + def setUp(self): + self.model_tester = GPTNeoXModelTester(self) + self.config_tester = ConfigTester(self, config_class=GPTNeoXConfig, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config, input_ids, input_mask, token_labels = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(config, input_ids, input_mask) + + def test_model_as_decoder(self): + config, input_ids, input_mask, token_labels = self.model_tester.prepare_config_and_inputs_for_decoder() + self.model_tester.create_and_check_model_as_decoder(config, input_ids, input_mask) + + def test_model_as_decoder_with_default_input_mask(self): + # This regression test was failing with PyTorch < 1.3 + config, input_ids, input_mask, token_labels = self.model_tester.prepare_config_and_inputs_for_decoder() + + input_mask = None + + self.model_tester.create_and_check_model_as_decoder(config, input_ids, input_mask) + + @slow + def test_model_from_pretrained(self): + for model_name in GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + model = GPTNeoXModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + +@require_torch +class GPTNeoXModelIntegrationTest(unittest.TestCase): + @slow + def test_inference_masked_lm(self): + model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b") + input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]]) + output = model(input_ids)[0] + + vocab_size = model.config.vocab_size + + expected_shape = torch.Size((1, 6, vocab_size)) + self.assertEqual(output.shape, expected_shape) + + expected_slice = torch.tensor( + [[[33.8045, 2.3958, 34.2816], [63.7805, 4.8332, 63.5882], [66.9116, 5.2198, 63.1185]]] + ) + + self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))