diff --git a/README.md b/README.md
index b831ef600da0a..acbbf838c5558 100644
--- a/README.md
+++ b/README.md
@@ -240,6 +240,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h
1. **[BigBird-Pegasus](https://huggingface.co/docs/transformers/model_doc/bigbird_pegasus)** (from Google Research) released with the paper [Big Bird: Transformers for Longer Sequences](https://arxiv.org/abs/2007.14062) by Manzil Zaheer, Guru Guruganesh, Avinava Dubey, Joshua Ainslie, Chris Alberti, Santiago Ontanon, Philip Pham, Anirudh Ravula, Qifan Wang, Li Yang, Amr Ahmed.
1. **[Blenderbot](https://huggingface.co/docs/transformers/model_doc/blenderbot)** (from Facebook) released with the paper [Recipes for building an open-domain chatbot](https://arxiv.org/abs/2004.13637) by Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary Williamson, Yinhan Liu, Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston.
1. **[BlenderbotSmall](https://huggingface.co/docs/transformers/model_doc/blenderbot-small)** (from Facebook) released with the paper [Recipes for building an open-domain chatbot](https://arxiv.org/abs/2004.13637) by Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary Williamson, Yinhan Liu, Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston.
+1. **[BLOOM](https://huggingface.co/docs/transformers/main/model_doc/bloom)** (from BigScience workshop) released by the [BigSicence Workshop](https://bigscience.huggingface.co/).
1. **[BORT](https://huggingface.co/docs/transformers/model_doc/bort)** (from Alexa) released with the paper [Optimal Subarchitecture Extraction For BERT](https://arxiv.org/abs/2010.10499) by Adrian de Wynter and Daniel J. Perry.
1. **[ByT5](https://huggingface.co/docs/transformers/model_doc/byt5)** (from Google Research) released with the paper [ByT5: Towards a token-free future with pre-trained byte-to-byte models](https://arxiv.org/abs/2105.13626) by Linting Xue, Aditya Barua, Noah Constant, Rami Al-Rfou, Sharan Narang, Mihir Kale, Adam Roberts, Colin Raffel.
1. **[CamemBERT](https://huggingface.co/docs/transformers/model_doc/camembert)** (from Inria/Facebook/Sorbonne) released with the paper [CamemBERT: a Tasty French Language Model](https://arxiv.org/abs/1911.03894) by Louis Martin*, Benjamin Muller*, Pedro Javier Ortiz Suárez*, Yoann Dupont, Laurent Romary, Éric Villemonte de la Clergerie, Djamé Seddah and Benoît Sagot.
diff --git a/README_ko.md b/README_ko.md
index 9a5acdc5b086a..0119328eff725 100644
--- a/README_ko.md
+++ b/README_ko.md
@@ -221,6 +221,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는
1. **[BigBird-RoBERTa](https://huggingface.co/docs/transformers/model_doc/big_bird)** (from Google Research) released with the paper [Big Bird: Transformers for Longer Sequences](https://arxiv.org/abs/2007.14062) by Manzil Zaheer, Guru Guruganesh, Avinava Dubey, Joshua Ainslie, Chris Alberti, Santiago Ontanon, Philip Pham, Anirudh Ravula, Qifan Wang, Li Yang, Amr Ahmed.
1. **[Blenderbot](https://huggingface.co/docs/transformers/model_doc/blenderbot)** (from Facebook) released with the paper [Recipes for building an open-domain chatbot](https://arxiv.org/abs/2004.13637) by Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary Williamson, Yinhan Liu, Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston.
1. **[BlenderbotSmall](https://huggingface.co/docs/transformers/model_doc/blenderbot-small)** (from Facebook) released with the paper [Recipes for building an open-domain chatbot](https://arxiv.org/abs/2004.13637) by Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary Williamson, Yinhan Liu, Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston.
+1. **[BLOOM](https://huggingface.co/docs/transformers/main/model_doc/bloom)** (from BigScience workshop) released by the [BigSicence Workshop](https://bigscience.huggingface.co/).
1. **[BORT](https://huggingface.co/docs/transformers/model_doc/bort)** (from Alexa) released with the paper [Optimal Subarchitecture Extraction For BERT](https://arxiv.org/abs/2010.10499) by Adrian de Wynter and Daniel J. Perry.
1. **[ByT5](https://huggingface.co/docs/transformers/model_doc/byt5)** (from Google Research) released with the paper [ByT5: Towards a token-free future with pre-trained byte-to-byte models](https://arxiv.org/abs/2105.13626) by Linting Xue, Aditya Barua, Noah Constant, Rami Al-Rfou, Sharan Narang, Mihir Kale, Adam Roberts, Colin Raffel.
1. **[CamemBERT](https://huggingface.co/docs/transformers/model_doc/camembert)** (from Inria/Facebook/Sorbonne) released with the paper [CamemBERT: a Tasty French Language Model](https://arxiv.org/abs/1911.03894) by Louis Martin*, Benjamin Muller*, Pedro Javier Ortiz Suárez*, Yoann Dupont, Laurent Romary, Éric Villemonte de la Clergerie, Djamé Seddah and Benoît Sagot.
diff --git a/README_zh-hans.md b/README_zh-hans.md
index 1c513d10433bd..0ca14ba89ff33 100644
--- a/README_zh-hans.md
+++ b/README_zh-hans.md
@@ -245,6 +245,7 @@ conda install -c huggingface transformers
1. **[BigBird-RoBERTa](https://huggingface.co/docs/transformers/model_doc/big_bird)** (来自 Google Research) 伴随论文 [Big Bird: Transformers for Longer Sequences](https://arxiv.org/abs/2007.14062) 由 Manzil Zaheer, Guru Guruganesh, Avinava Dubey, Joshua Ainslie, Chris Alberti, Santiago Ontanon, Philip Pham, Anirudh Ravula, Qifan Wang, Li Yang, Amr Ahmed 发布。
1. **[Blenderbot](https://huggingface.co/docs/transformers/model_doc/blenderbot)** (来自 Facebook) 伴随论文 [Recipes for building an open-domain chatbot](https://arxiv.org/abs/2004.13637) 由 Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary Williamson, Yinhan Liu, Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston 发布。
1. **[BlenderbotSmall](https://huggingface.co/docs/transformers/model_doc/blenderbot-small)** (来自 Facebook) 伴随论文 [Recipes for building an open-domain chatbot](https://arxiv.org/abs/2004.13637) 由 Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary Williamson, Yinhan Liu, Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston 发布。
+1. **[BLOOM](https://huggingface.co/docs/transformers/main/model_doc/bloom)** (from BigScience workshop) released by the [BigSicence Workshop](https://bigscience.huggingface.co/).
1. **[BORT](https://huggingface.co/docs/transformers/model_doc/bort)** (来自 Alexa) 伴随论文 [Optimal Subarchitecture Extraction For BERT](https://arxiv.org/abs/2010.10499) 由 Adrian de Wynter and Daniel J. Perry 发布。
1. **[ByT5](https://huggingface.co/docs/transformers/model_doc/byt5)** (来自 Google Research) 伴随论文 [ByT5: Towards a token-free future with pre-trained byte-to-byte models](https://arxiv.org/abs/2105.13626) 由 Linting Xue, Aditya Barua, Noah Constant, Rami Al-Rfou, Sharan Narang, Mihir Kale, Adam Roberts, Colin Raffel 发布。
1. **[CamemBERT](https://huggingface.co/docs/transformers/model_doc/camembert)** (来自 Inria/Facebook/Sorbonne) 伴随论文 [CamemBERT: a Tasty French Language Model](https://arxiv.org/abs/1911.03894) 由 Louis Martin*, Benjamin Muller*, Pedro Javier Ortiz Suárez*, Yoann Dupont, Laurent Romary, Éric Villemonte de la Clergerie, Djamé Seddah and Benoît Sagot 发布。
diff --git a/README_zh-hant.md b/README_zh-hant.md
index 8276e2f9129b9..c3a8391e74c56 100644
--- a/README_zh-hant.md
+++ b/README_zh-hant.md
@@ -257,6 +257,7 @@ conda install -c huggingface transformers
1. **[BigBird-RoBERTa](https://huggingface.co/docs/transformers/model_doc/big_bird)** (from Google Research) released with the paper [Big Bird: Transformers for Longer Sequences](https://arxiv.org/abs/2007.14062) by Manzil Zaheer, Guru Guruganesh, Avinava Dubey, Joshua Ainslie, Chris Alberti, Santiago Ontanon, Philip Pham, Anirudh Ravula, Qifan Wang, Li Yang, Amr Ahmed.
1. **[Blenderbot](https://huggingface.co/docs/transformers/model_doc/blenderbot)** (from Facebook) released with the paper [Recipes for building an open-domain chatbot](https://arxiv.org/abs/2004.13637) by Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary Williamson, Yinhan Liu, Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston.
1. **[BlenderbotSmall](https://huggingface.co/docs/transformers/model_doc/blenderbot-small)** (from Facebook) released with the paper [Recipes for building an open-domain chatbot](https://arxiv.org/abs/2004.13637) by Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary Williamson, Yinhan Liu, Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston.
+1. **[BLOOM](https://huggingface.co/docs/transformers/main/model_doc/bloom)** (from BigScience workshop) released by the [BigSicence Workshop](https://bigscience.huggingface.co/).
1. **[BORT](https://huggingface.co/docs/transformers/model_doc/bort)** (from Alexa) released with the paper [Optimal Subarchitecture Extraction For BERT](https://arxiv.org/abs/2010.10499) by Adrian de Wynter and Daniel J. Perry.
1. **[ByT5](https://huggingface.co/docs/transformers/model_doc/byt5)** (from Google Research) released with the paper [ByT5: Towards a token-free future with pre-trained byte-to-byte models](https://arxiv.org/abs/2105.13626) by Linting Xue, Aditya Barua, Noah Constant, Rami Al-Rfou, Sharan Narang, Mihir Kale, Adam Roberts, Colin Raffel.
1. **[CamemBERT](https://huggingface.co/docs/transformers/model_doc/camembert)** (from Inria/Facebook/Sorbonne) released with the paper [CamemBERT: a Tasty French Language Model](https://arxiv.org/abs/1911.03894) by Louis Martin*, Benjamin Muller*, Pedro Javier Ortiz Suárez*, Yoann Dupont, Laurent Romary, Éric Villemonte de la Clergerie, Djamé Seddah and Benoît Sagot.
diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index cf835c08174fc..26c25872f365d 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -174,6 +174,8 @@
title: Blenderbot
- local: model_doc/blenderbot-small
title: Blenderbot Small
+ - local: model_doc/bloom
+ title: BLOOM
- local: model_doc/bort
title: BORT
- local: model_doc/byt5
diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx
index b4e8d5154a128..8f9235906816f 100644
--- a/docs/source/en/index.mdx
+++ b/docs/source/en/index.mdx
@@ -63,6 +63,7 @@ The library currently contains JAX, PyTorch and TensorFlow implementations, pret
1. **[BigBird-Pegasus](model_doc/bigbird_pegasus)** (from Google Research) released with the paper [Big Bird: Transformers for Longer Sequences](https://arxiv.org/abs/2007.14062) by Manzil Zaheer, Guru Guruganesh, Avinava Dubey, Joshua Ainslie, Chris Alberti, Santiago Ontanon, Philip Pham, Anirudh Ravula, Qifan Wang, Li Yang, Amr Ahmed.
1. **[Blenderbot](model_doc/blenderbot)** (from Facebook) released with the paper [Recipes for building an open-domain chatbot](https://arxiv.org/abs/2004.13637) by Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary Williamson, Yinhan Liu, Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston.
1. **[BlenderbotSmall](model_doc/blenderbot-small)** (from Facebook) released with the paper [Recipes for building an open-domain chatbot](https://arxiv.org/abs/2004.13637) by Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary Williamson, Yinhan Liu, Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston.
+1. **[BLOOM](model_doc/bloom)** (from BigScience workshop) released by the [BigSicence Workshop](https://bigscience.huggingface.co/).
1. **[BORT](model_doc/bort)** (from Alexa) released with the paper [Optimal Subarchitecture Extraction For BERT](https://arxiv.org/abs/2010.10499) by Adrian de Wynter and Daniel J. Perry.
1. **[ByT5](model_doc/byt5)** (from Google Research) released with the paper [ByT5: Towards a token-free future with pre-trained byte-to-byte models](https://arxiv.org/abs/2105.13626) by Linting Xue, Aditya Barua, Noah Constant, Rami Al-Rfou, Sharan Narang, Mihir Kale, Adam Roberts, Colin Raffel.
1. **[CamemBERT](model_doc/camembert)** (from Inria/Facebook/Sorbonne) released with the paper [CamemBERT: a Tasty French Language Model](https://arxiv.org/abs/1911.03894) by Louis Martin*, Benjamin Muller*, Pedro Javier Ortiz Suárez*, Yoann Dupont, Laurent Romary, Éric Villemonte de la Clergerie, Djamé Seddah and Benoît Sagot.
@@ -189,6 +190,7 @@ Flax), PyTorch, and/or TensorFlow.
| BigBirdPegasus | ❌ | ❌ | ✅ | ❌ | ❌ |
| Blenderbot | ✅ | ✅ | ✅ | ✅ | ✅ |
| BlenderbotSmall | ✅ | ✅ | ✅ | ✅ | ✅ |
+| BLOOM | ❌ | ✅ | ✅ | ❌ | ❌ |
| CamemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| Canine | ✅ | ❌ | ✅ | ❌ | ❌ |
| CLIP | ✅ | ✅ | ✅ | ✅ | ✅ |
diff --git a/docs/source/en/model_doc/bloom.mdx b/docs/source/en/model_doc/bloom.mdx
new file mode 100644
index 0000000000000..649a270525a20
--- /dev/null
+++ b/docs/source/en/model_doc/bloom.mdx
@@ -0,0 +1,47 @@
+
+
+# BLOOM
+
+## Overview
+
+The BLOOM model has been proposed with its various versions through the [BigScience Workshop](https://bigscience.huggingface.co/). BigScience is inspired by other open science initiatives where researchers have pooled their time and resources to collectively achieve a higher impact.
+The architecture of BLOOM is essentially similar to GPT3 (auto-regressive model for next token prediction), but has been trained on different 46 languages including code.
+Several smaller versions of the models have been trained on the same dataset. BLOOM is available in the following versions:
+
+- [bloom-350m](https://huggingface.co/bigscience/bloom-350m)
+- [bloom-760m](https://huggingface.co/bigscience/bloom-760m)
+- [bloom-1b3](https://huggingface.co/bigscience/bloom-1b3)
+- [bloom-2b5](https://huggingface.co/bigscience/bloom-2b5)
+- [bloom-6b3](https://huggingface.co/bigscience/bloom-6b3)
+- [bloom](https://huggingface.co/bigscience/bloom) (175B parameters)
+
+
+## BloomConfig
+
+[[autodoc]] BloomConfig
+ - all
+
+## BloomModel
+
+[[autodoc]] BloomModel
+ - forward
+
+## BloomTokenizerFast
+
+[[autodoc]] BloomTokenizerFast
+ - all
+
+## BloomForCausalLM
+
+[[autodoc]] BloomForCausalLM
+ - forward
\ No newline at end of file
diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py
index 0afe8588d6586..f0290751ade40 100755
--- a/src/transformers/__init__.py
+++ b/src/transformers/__init__.py
@@ -156,6 +156,7 @@
"BlenderbotSmallConfig",
"BlenderbotSmallTokenizer",
],
+ "models.bloom": ["BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP", "BloomConfig"],
"models.bort": [],
"models.byt5": ["ByT5Tokenizer"],
"models.camembert": ["CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "CamembertConfig"],
@@ -495,6 +496,7 @@
_import_structure["models.big_bird"].append("BigBirdTokenizerFast")
_import_structure["models.blenderbot"].append("BlenderbotTokenizerFast")
_import_structure["models.blenderbot_small"].append("BlenderbotSmallTokenizerFast")
+ _import_structure["models.bloom"].append("BloomTokenizerFast")
_import_structure["models.camembert"].append("CamembertTokenizerFast")
_import_structure["models.clip"].append("CLIPTokenizerFast")
_import_structure["models.convbert"].append("ConvBertTokenizerFast")
@@ -853,6 +855,14 @@
"BigBirdPegasusPreTrainedModel",
]
)
+ _import_structure["models.bloom"].extend(
+ [
+ "BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "BloomForCausalLM",
+ "BloomModel",
+ "BloomPreTrainedModel",
+ ]
+ )
_import_structure["models.blenderbot"].extend(
[
"BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -2718,6 +2728,7 @@
BlenderbotSmallConfig,
BlenderbotSmallTokenizer,
)
+ from .models.bloom import BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP, BloomConfig
from .models.byt5 import ByT5Tokenizer
from .models.camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
from .models.canine import CANINE_PRETRAINED_CONFIG_ARCHIVE_MAP, CanineConfig, CanineTokenizer
@@ -3025,6 +3036,7 @@
from .models.big_bird import BigBirdTokenizerFast
from .models.blenderbot import BlenderbotTokenizerFast
from .models.blenderbot_small import BlenderbotSmallTokenizerFast
+ from .models.bloom import BloomTokenizerFast
from .models.camembert import CamembertTokenizerFast
from .models.clip import CLIPTokenizerFast
from .models.convbert import ConvBertTokenizerFast
@@ -3340,6 +3352,12 @@
BlenderbotSmallModel,
BlenderbotSmallPreTrainedModel,
)
+ from .models.bloom import (
+ BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST,
+ BloomForCausalLM,
+ BloomModel,
+ BloomPreTrainedModel,
+ )
from .models.camembert import (
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
CamembertForCausalLM,
diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py
index e435265a8378d..92dd515c9c77e 100644
--- a/src/transformers/models/__init__.py
+++ b/src/transformers/models/__init__.py
@@ -31,6 +31,7 @@
bigbird_pegasus,
blenderbot,
blenderbot_small,
+ bloom,
bort,
byt5,
camembert,
diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py
index dbb19c55aa97d..71b0eed3f969d 100644
--- a/src/transformers/models/auto/configuration_auto.py
+++ b/src/transformers/models/auto/configuration_auto.py
@@ -38,6 +38,7 @@
("bigbird_pegasus", "BigBirdPegasusConfig"),
("blenderbot", "BlenderbotConfig"),
("blenderbot-small", "BlenderbotSmallConfig"),
+ ("bloom", "BloomConfig"),
("camembert", "CamembertConfig"),
("canine", "CanineConfig"),
("clip", "CLIPConfig"),
@@ -51,7 +52,6 @@
("deberta", "DebertaConfig"),
("deberta-v2", "DebertaV2Config"),
("decision_transformer", "DecisionTransformerConfig"),
- ("decision_transformer", "DecisionTransformerConfig"),
("deit", "DeiTConfig"),
("detr", "DetrConfig"),
("distilbert", "DistilBertConfig"),
@@ -153,6 +153,7 @@
("bigbird_pegasus", "BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("blenderbot", "BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("blenderbot-small", "BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
+ ("bloom", "BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("camembert", "CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("canine", "CANINE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("clip", "CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP"),
@@ -258,6 +259,7 @@
("bigbird_pegasus", "BigBirdPegasus"),
("blenderbot", "Blenderbot"),
("blenderbot-small", "BlenderbotSmall"),
+ ("bloom", "BLOOM"),
("bort", "BORT"),
("byt5", "ByT5"),
("camembert", "CamemBERT"),
@@ -356,7 +358,6 @@
("van", "VAN"),
("vilt", "ViLT"),
("vision-encoder-decoder", "Vision Encoder decoder"),
- ("vision-encoder-decoder", "Vision Encoder decoder"),
("vision-text-dual-encoder", "VisionTextDualEncoder"),
("visual_bert", "VisualBert"),
("vit", "ViT"),
diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py
index be7dc5bc9e88d..d7c8b746fed78 100644
--- a/src/transformers/models/auto/modeling_auto.py
+++ b/src/transformers/models/auto/modeling_auto.py
@@ -37,6 +37,7 @@
("bigbird_pegasus", "BigBirdPegasusModel"),
("blenderbot", "BlenderbotModel"),
("blenderbot-small", "BlenderbotSmallModel"),
+ ("bloom", "BloomModel"),
("camembert", "CamembertModel"),
("canine", "CanineModel"),
("clip", "CLIPModel"),
@@ -50,7 +51,6 @@
("deberta", "DebertaModel"),
("deberta-v2", "DebertaV2Model"),
("decision_transformer", "DecisionTransformerModel"),
- ("decision_transformer", "DecisionTransformerModel"),
("decision_transformer_gpt2", "DecisionTransformerGPT2Model"),
("deit", "DeiTModel"),
("detr", "DetrModel"),
@@ -142,6 +142,7 @@
("bart", "BartForConditionalGeneration"),
("bert", "BertForPreTraining"),
("big_bird", "BigBirdForPreTraining"),
+ ("bloom", "BloomForCausalLM"),
("camembert", "CamembertForMaskedLM"),
("ctrl", "CTRLLMHeadModel"),
("data2vec-text", "Data2VecTextForMaskedLM"),
@@ -192,6 +193,7 @@
("big_bird", "BigBirdForMaskedLM"),
("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"),
("blenderbot-small", "BlenderbotSmallForConditionalGeneration"),
+ ("bloom", "BloomForCausalLM"),
("camembert", "CamembertForMaskedLM"),
("convbert", "ConvBertForMaskedLM"),
("ctrl", "CTRLLMHeadModel"),
@@ -250,6 +252,7 @@
("bigbird_pegasus", "BigBirdPegasusForCausalLM"),
("blenderbot", "BlenderbotForCausalLM"),
("blenderbot-small", "BlenderbotSmallForCausalLM"),
+ ("bloom", "BloomForCausalLM"),
("camembert", "CamembertForCausalLM"),
("ctrl", "CTRLLMHeadModel"),
("data2vec-text", "Data2VecTextForCausalLM"),
diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py
index 840273aa867e9..5980eed726232 100644
--- a/src/transformers/models/auto/tokenization_auto.py
+++ b/src/transformers/models/auto/tokenization_auto.py
@@ -76,6 +76,7 @@
("bigbird_pegasus", ("PegasusTokenizer", "PegasusTokenizerFast" if is_tokenizers_available() else None)),
("blenderbot", ("BlenderbotTokenizer", "BlenderbotTokenizerFast")),
("blenderbot-small", ("BlenderbotSmallTokenizer", None)),
+ ("bloom", (None, "BloomTokenizerFast" if is_tokenizers_available() else None)),
("byt5", ("ByT5Tokenizer", None)),
(
"camembert",
diff --git a/src/transformers/models/bloom/__init__.py b/src/transformers/models/bloom/__init__.py
new file mode 100644
index 0000000000000..de67c7e387cdb
--- /dev/null
+++ b/src/transformers/models/bloom/__init__.py
@@ -0,0 +1,78 @@
+# flake8: noqa
+# There's no way to ignore "F401 '...' imported but unused" warnings in this
+# module, but to preserve other warnings. So, don't check this module at all.
+
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING
+
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tokenizers_available, is_torch_available
+
+
+_import_structure = {
+ "configuration_bloom": [
+ "BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP",
+ "BloomConfig",
+ ],
+}
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["tokenization_bloom_fast"] = ["BloomTokenizerFast"]
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_bloom"] = [
+ "BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "BloomForCausalLM",
+ "BloomModel",
+ "BloomPreTrainedModel",
+ ]
+
+if TYPE_CHECKING:
+ from .configuration_bloom import BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP, BloomConfig
+
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .tokenization_bloom_fast import BloomTokenizerFast
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_bloom import (
+ BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST,
+ BloomForCausalLM,
+ BloomModel,
+ BloomPreTrainedModel,
+ )
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/bloom/configuration_bloom.py b/src/transformers/models/bloom/configuration_bloom.py
new file mode 100644
index 0000000000000..f841d66699657
--- /dev/null
+++ b/src/transformers/models/bloom/configuration_bloom.py
@@ -0,0 +1,155 @@
+# coding=utf-8
+# Copyright 2022 the Big Science Workshop and HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Bloom configuration"""
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+ "bigscience/bloom": "https://huggingface.co/bigscience/bloom/resolve/main/config.json",
+ "bigscience/bloom-350m": "https://huggingface.co/bigscience/bloom-350m/blob/main/config.json",
+ "bigscience/bloom-760m": "https://huggingface.co/bigscience/bloom-760m/blob/main/config.json",
+ "bigscience/bloom-1b3": "https://huggingface.co/bigscience/bloom-1b3/blob/main/config.json",
+ "bigscience/bloom-2b5": "https://huggingface.co/bigscience/bloom-2b5/blob/main/config.json",
+ "bigscience/bloom-6b3": "https://huggingface.co/bigscience/bloom-6b3/blob/main/config.json",
+}
+
+
+class BloomConfig(PretrainedConfig):
+ """
+ This is the configuration class to store the configuration of a [`BloomModel`]. It is used to instantiate a Bloom
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to the Bloom architecture
+ [bigscience/bloom](https://huggingface.co/bigscience/bloom).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 50257):
+ Vocabulary size of the Bloom model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`BloomModel`].
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the embeddings and hidden states.
+ n_layer (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ n_head (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ attn_pdrop (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention.
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
+ The epsilon to use in the layer normalization layers.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ apply_residual_connection_post_layernorm (`bool`, *optional*, defaults to `False`):
+ If enabled, use the layer norm of the hidden states as the residual in the transformer blocks
+ skip_bias_add (`bool`, *optional*, defaults to `True`):
+ If set to `True`, it will skip bias add for each linear layer in the transformer blocks
+ skip_bias_add_qkv (`bool`, *optional*, defaults to `False`):
+ If set to `True`, it will skip bias add for the first linear layer in the transformer blocks
+ attention_softmax_in_fp32 (`bool`, *optional*, defaults to `True`):
+ If set to `True` and the `dtype` is set to `float16` it will scale the input of the Softmax function to
+ `fp32`
+ hidden_dropout (`float`, *optional*, defaults to 0.1):
+ Dropout rate of the dropout function on the bias dropout.
+ attention_dropout (`float`, *optional*, defaults to 0.1):
+ Dropout rate applied to the attention probs
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models).
+ dtype (`str`, *optional*, defaults to `"bfloat16"`):
+ Precision that has been used for the model's training in Megatron. Please load the model in the correct
+ precision by doing `model = BloomModel.from_pretrained(model_name, torch_dtype="auto")`.`
+ pretraining_tp (`int`, *optional*, defaults to `1`):
+ Experimental feature. Tensor parallelism rank used during pretraining with Megatron. Please refer to [this
+ document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
+ necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
+ issue](https://github.com/pytorch/pytorch/issues/76232). Note also that this is enabled only when
+ `slow_but_exact=True`.
+ slow_but_exact (`bool`, *optional*, defaults to `False`):
+ Experimental feature. Whether to use slow but exact implementation of the attention mechanism. While
+ merging the TP rank tensors, due to slicing operations the results may be slightly different between the
+ model trained on Megatron and our model. Please refer to [this
+ issue](https://github.com/pytorch/pytorch/issues/76232). A solution to obtain more accurate results is to
+ enable this feature. Enabling this will hurt the computational time of the inference. Will be probably
+ resolved in the future once the main model has been fine-tuned with TP_rank=1.
+
+ Example:
+
+ ```python
+ >>> from transformers import BloomModel, BloomConfig
+
+ >>> # Initializing a Bloom configuration
+ >>> configuration = BloomConfig()
+
+ >>> # Initializing a model from the configuration
+ >>> model = BloomModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "bloom"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ attribute_map = {
+ "num_hidden_layers": "n_layer",
+ "n_head": "num_attention_heads",
+ "hidden_size": "n_embed",
+ "dtype": "torch_dtype",
+ }
+
+ def __init__(
+ self,
+ vocab_size=250880,
+ hidden_size=64,
+ n_layer=2,
+ n_head=8,
+ masked_softmax_fusion=True,
+ layer_norm_epsilon=1e-5,
+ initializer_range=0.02,
+ use_cache=False,
+ bos_token_id=1,
+ eos_token_id=2,
+ apply_residual_connection_post_layernorm=False,
+ hidden_dropout=0.0,
+ attention_dropout=0.0,
+ attention_softmax_in_fp32=True,
+ pretraining_tp=1, # TP rank used when training with megatron
+ dtype="bfloat16",
+ slow_but_exact=False,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.n_layer = n_layer
+ self.n_head = n_head
+ self.masked_softmax_fusion = masked_softmax_fusion
+ self.layer_norm_epsilon = layer_norm_epsilon
+ self.initializer_range = initializer_range
+ self.use_cache = use_cache
+ self.pretraining_tp = pretraining_tp
+ self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
+ self.hidden_dropout = hidden_dropout
+ self.attention_dropout = attention_dropout
+ self.attention_softmax_in_fp32 = attention_softmax_in_fp32
+
+ self.bos_token_id = bos_token_id
+ self.eos_token_id = eos_token_id
+ self.dtype = dtype
+ self.slow_but_exact = slow_but_exact
+
+ super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
diff --git a/src/transformers/models/bloom/convert_bloom_original_checkpoint_to_pytorch.py b/src/transformers/models/bloom/convert_bloom_original_checkpoint_to_pytorch.py
new file mode 100644
index 0000000000000..c9cfeed6dc427
--- /dev/null
+++ b/src/transformers/models/bloom/convert_bloom_original_checkpoint_to_pytorch.py
@@ -0,0 +1,253 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convert BigScience BLOOM checkpoint."""
+
+
+import argparse
+import json
+import os
+import re
+
+import torch
+
+from transformers import BloomConfig, BloomModel
+from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME
+from transformers.utils import logging
+
+
+logging.set_verbosity_info()
+
+WEIGHTS_TO_AVERAGE_ENDSWITH = [
+ "word_embeddings_layernorm.weight",
+ "word_embeddings_layernorm.bias",
+ "input_layernorm.weight",
+ "input_layernorm.bias",
+ "post_attention_layernorm.weight",
+ "post_attention_layernorm.bias",
+ "self_attention.dense.bias",
+ "mlp.dense_4h_to_h.bias",
+ "ln_f.weight",
+ "ln_f.bias",
+]
+
+WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN = [
+ "mlp.dense_4h_to_h.weight",
+ "self_attention.dense.weight",
+]
+
+
+def layer_name_mapping(key, file):
+ """Convert Megatron-DeepSpeed TP/PP weights mapping in transformers PP only"""
+ # Handle first and last layers
+ layer_rename_map = {
+ "word_embeddings.weight": "word_embeddings.weight",
+ "word_embeddings.norm.weight": "word_embeddings_layernorm.weight",
+ "word_embeddings.norm.bias": "word_embeddings_layernorm.bias",
+ "weight": "ln_f.weight",
+ "bias": "ln_f.bias",
+ }
+
+ if key in layer_rename_map:
+ return layer_rename_map[key]
+
+ # Handle transformer blocks
+ layer_number = int(re.match(r".*layer_(\d*).*", file)[1])
+ layer_number -= 3
+ return f"h.{layer_number}." + key
+
+
+def get_dtype_size(dtype):
+ if dtype == torch.bool:
+ return 1 / 8
+ bit_search = re.search("[^\d](\d+)$", str(dtype))
+ if bit_search is None:
+ raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
+ bit_size = int(bit_search.groups()[0])
+ return bit_size // 8
+
+
+def convert_bloom_checkpoint_to_pytorch(
+ bloom_checkpoint_path, bloom_config_file, pytorch_dump_folder_path, shard_model, pretraining_tp
+):
+ # Construct model
+ if bloom_config_file == "":
+ config = BloomConfig()
+ else:
+ config = BloomConfig.from_json_file(bloom_config_file)
+
+ if shard_model:
+ file_names = os.listdir(bloom_checkpoint_path)
+ file_names = list(sorted(filter(lambda s: s.startswith("layer") and "model_00" in s, file_names)))
+
+ index_dict = {"weight_map": {}, "metadata": {}}
+ total_size = 0
+
+ missing_keys = None
+
+ config = BloomConfig()
+
+ for j, file in enumerate(file_names):
+ print("Processing file: {}".format(file))
+ tensors = None
+
+ for i in range(pretraining_tp):
+ # load all TP files
+ f_name = file.replace("model_00", f"model_0{i}")
+ temp = torch.load(os.path.join(bloom_checkpoint_path, f_name), map_location="cpu")
+
+ # Rename keys in the transformers names
+ keys = list(temp.keys())
+ for key in keys:
+ temp[layer_name_mapping(key, file)] = temp.pop(key)
+
+ if tensors is None:
+ tensors = temp
+ else:
+ for key in tensors.keys():
+ if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
+ # We average (sum and then divide) some weights accross TP ranks (see https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/olruwase/sync_layer_norms/megatron/training.py#L425)
+ tensors[key] += temp[key]
+ else:
+ # Some weights are RowParallelLinear in Megatron-Deepspeed, others are ColumnParallel
+ cat_dim = 1 if any(text in key for text in WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN) else 0
+ # We concatenate these weights accross TP ranks
+ tensors[key] = torch.cat([tensors[key], temp[key]], dim=cat_dim)
+
+ # Divide by the number of TP the weights we want to average
+ for key in tensors.keys():
+ if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
+ tensors[key] = tensors[key] / pretraining_tp
+ torch.save(
+ tensors,
+ os.path.join(
+ pytorch_dump_folder_path,
+ "pytorch_model_{}-of-{}.bin".format(str(j + 1).zfill(5), str(len(file_names)).zfill(5)),
+ ),
+ )
+
+ for key in tensors.keys():
+ value = tensors[key]
+ total_size += value.numel() * get_dtype_size(value.dtype)
+ if key not in index_dict["weight_map"]:
+ index_dict["weight_map"][key] = "pytorch_model_{}-of-{}.bin".format(
+ str(j + 1).zfill(5), str(len(file_names)).zfill(5)
+ )
+
+ config = BloomConfig()
+ pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
+ index_dict["metadata"]["total_size"] = total_size
+ with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
+ f.write(config.to_json_string())
+ with open(os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME + ".index.json"), "w", encoding="utf-8") as f:
+ json_config = json.dumps(index_dict, indent=2, sort_keys=True) + "\n"
+ f.write(json_config)
+ else:
+ model = BloomModel(config)
+
+ file_names = os.listdir(bloom_checkpoint_path)
+ file_names = list(sorted(filter(lambda s: s.startswith("layer") and "model_00" in s, file_names)))
+
+ missing_keys = None
+ for i, file in enumerate(file_names):
+ tensors = None
+ for i in range(pretraining_tp):
+ # load all TP files
+ f_name = file.replace("model_00", f"model_0{i}")
+ temp = torch.load(os.path.join(bloom_checkpoint_path, f_name), map_location="cpu")
+
+ # Rename keys in the transformers names
+ keys = list(temp.keys())
+ for key in keys:
+ temp[layer_name_mapping(key, file)] = temp.pop(key)
+
+ if tensors is None:
+ tensors = temp
+ else:
+ for key in tensors.keys():
+ # We average (sum and then divide) some weights accross TP ranks (see https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/olruwase/sync_layer_norms/megatron/training.py#L425)
+ if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
+ tensors[key] += temp[key]
+ else:
+ # Some weights are RowParallelLinear in Megatron-Deepspeed, others are ColumnParallel
+ cat_dim = 1 if any(text in key for text in WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN) else 0
+ # We concatenate these weights accross TP ranks
+ tensors[key] = torch.cat([tensors[key], temp[key]], dim=cat_dim)
+
+ # Divide by the number of TP the weights we want to average
+ for key in tensors.keys():
+ if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
+ tensors[key] = tensors[key] / pretraining_tp
+
+ other_keys = model.load_state_dict(tensors, strict=False)
+ assert not other_keys.unexpected_keys
+ if missing_keys is None:
+ missing_keys = set(other_keys.missing_keys)
+ else:
+ missing_keys = missing_keys.intersection(set(other_keys.missing_keys))
+
+ assert not missing_keys
+
+ # Save pytorch-model
+ os.makedirs(pytorch_dump_folder_path, exist_ok=True)
+ pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME
+ pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
+ print(f"Save PyTorch model to {pytorch_weights_dump_path}")
+ torch.save(model.state_dict(), pytorch_weights_dump_path)
+ print(f"Save configuration file to {pytorch_config_dump_path}")
+ with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
+ f.write(config.to_json_string())
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ # Required parameters
+ parser.add_argument(
+ "--bloom_checkpoint_path",
+ default=None,
+ type=str,
+ required=True,
+ help="Path to the Megatron-LM checkpoint path.",
+ )
+ parser.add_argument(
+ "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
+ )
+ parser.add_argument(
+ "--bloom_config_file",
+ default="",
+ type=str,
+ help=(
+ "An optional config json file corresponding to the pre-trained model. \n"
+ "This specifies the model architecture."
+ ),
+ )
+ parser.add_argument(
+ "--shard_model",
+ action="store_true",
+ help="An optional setting to shard the output model \nThis enables sharding the converted checkpoint",
+ )
+ parser.add_argument(
+ "--pretraining_tp",
+ default=4,
+ type=int,
+ help="Pretraining TP rank that has been used when training the model in Megatron-LM \n",
+ )
+ args = parser.parse_args()
+ convert_bloom_checkpoint_to_pytorch(
+ args.bloom_checkpoint_path,
+ args.bloom_config_file,
+ args.pytorch_dump_folder_path,
+ args.shard_model,
+ args.pretraining_tp,
+ )
diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py
new file mode 100644
index 0000000000000..232bbc5f22e31
--- /dev/null
+++ b/src/transformers/models/bloom/modeling_bloom.py
@@ -0,0 +1,961 @@
+# coding=utf-8
+# Copyright 2022 HuggingFace Inc. team and BigScience workshop.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch BLOOM model."""
+
+import math
+from typing import Tuple
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import CrossEntropyLoss, LayerNorm
+
+from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
+from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
+from ...modeling_utils import PreTrainedModel
+from ...utils import logging
+from .configuration_bloom import BloomConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "bigscience/Bloom"
+_CONFIG_FOR_DOC = "BloomConfig"
+_TOKENIZER_FOR_DOC = "BloomTokenizer"
+
+BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "bigscience/bigscience-small-testing",
+ "bigscience/bloom-350m",
+ "bigscience/bloom-760m",
+ "bigscience/bloom-1b3",
+ "bigscience/bloom-2b5",
+ "bigscience/bloom-6b3",
+ "bigscience/bloom-176b",
+]
+
+
+def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False):
+ """Split a tensor along its last dimension.
+
+ Args:
+ tensor: ([`torch.tensor`], *required*):
+ input tensor to split
+ num_partitions ([`int`], *required*):
+ number of partitions to split the tensor
+ contiguous_split_chunks ([`bool`], *optional*, default=`False`)::
+ If True, make each chunk contiguous in memory.
+ """
+ # Get the size and dimension.
+ last_dim = tensor.dim() - 1
+ numerator, denominator = tensor.size()[last_dim], num_partitions
+ if not (numerator % denominator == 0):
+ raise ValueError(f"{numerator} is not divisible by {denominator}")
+ last_dim_size = numerator // denominator
+ # Split.
+ tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
+ # Note: torch.split does not create contiguous tensors by default.
+ if contiguous_split_chunks:
+ return tuple(chunk.contiguous() for chunk in tensor_list)
+
+ return tensor_list
+
+
+def attention_mask_func(attention_scores, attention_mask, causal_mask):
+ if attention_mask.dtype == torch.bool:
+ attention_mask_bool = ~attention_mask
+ else:
+ attention_mask_bool = (1 - attention_mask).bool()
+
+ query_length, key_length, n_heads = attention_scores.size(2), attention_scores.size(3), attention_scores.size(1)
+ padded_causal_mask = (
+ attention_mask_bool[:, None, key_length - query_length : key_length, None]
+ + ~causal_mask[:, :, key_length - query_length : key_length, :key_length]
+ ).bool()
+ padded_causal_mask = padded_causal_mask + attention_mask_bool[:, None, None, :key_length].bool()
+ # Make use of floats
+ return (
+ attention_scores.masked_fill_(padded_causal_mask.expand(-1, n_heads, -1, -1), -10000.0),
+ padded_causal_mask,
+ )
+
+
+def build_alibi_tensor(max_seq_len, n_head, dtype=torch.bfloat16):
+ """
+ Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
+ relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
+ `softmax(l+a) = softmax(l)`. Based on
+ https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
+
+ Args:
+ Returns tensor shaped (n_head, 1, max_seq_len)
+ max_seq_len: (`int`, *required*):
+ max sequence length
+ n_head: (`int`, *required*):
+ number of heads
+ dtype: (`torch.dtype`, *optional*, default=`torch.bfloat16`):
+ dtype of the output tensor
+ """
+
+ def get_slopes(n):
+ def get_slopes_power_of_2(n):
+ start = 2 ** (-(2 ** -(math.log2(n) - 3)))
+ ratio = start
+ return [start * ratio**i for i in range(n)]
+
+ if math.log2(n).is_integer():
+ return get_slopes_power_of_2(n)
+ else:
+ closest_power_of_2 = 2 ** math.floor(math.log2(n))
+ return (
+ get_slopes_power_of_2(closest_power_of_2)
+ + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
+ )
+
+ slopes = torch.Tensor(get_slopes(n_head)).unsqueeze(1).unsqueeze(1)
+ arange_tensor = torch.arange(max_seq_len).unsqueeze(0).unsqueeze(0)
+ alibi = slopes * arange_tensor.expand(n_head, -1, -1)
+
+ alibi = alibi.to(dtype)
+
+ return alibi
+
+
+def pre_process_alibi_for_pad(alibi, attention_mask, num_heads):
+ """
+ Args:
+ Pre-process the alibi tensor for padding.
+ alibi: ([`torch.tensor`], *required*):
+ alibi tensor to pre-process
+ attention_mask: ([`torch.tensor`], *required*):
+ attention mask to pre-process"""
+
+ # Sanity check if we are not inferring less tokens than the total sequence length
+ # This usually happens when the inference is done with past_key_values
+ # In this case we re-create the alibi tensor with the correct sequence length
+ if attention_mask.shape[-1] != alibi.shape[-1]:
+ alibi = build_alibi_tensor(attention_mask.shape[-1], num_heads, alibi.dtype).repeat(
+ attention_mask.shape[0], 1, 1
+ )
+ # Get the indexes of the padding tokens
+ index_x0, index_y0 = torch.where(attention_mask == 0.0)
+ index_x1, index_y1 = torch.where(attention_mask == 1.0)
+
+ # Clone the embeddings - we can detach because the embeddings are not learned
+ # Get a refence tensor
+ slice_reference_alibi = build_alibi_tensor(alibi.shape[-1], num_heads, alibi.dtype)
+
+ # Loop over the batch where the padding is and replace the alibi tensor by the reference tensor
+ # Only where you do not have padding. Replace padding tokens by zeros
+ # This operation can be seen as a shifting operation.
+ for i, index in enumerate(torch.unique(index_x0)):
+ slice_to_modify = torch.zeros_like(slice_reference_alibi)
+ index_shift = index_y1[index_x1 == index]
+ shift_value = len(index_shift)
+ slice_to_modify[:, :, index_shift] = slice_reference_alibi[:, :, :shift_value]
+ alibi[index * num_heads : (index + 1) * num_heads] = slice_to_modify
+ return alibi
+
+
+def dropout_add(x, residual, prob, training):
+ """
+ Dropout add function
+
+ Args:
+ x (`torch.tensor`, *required*):
+ input tensor
+ residual (`torch.tensor`, *rquired*):
+ esidual tensor
+ prob (`float`, *required*):
+ dropout probability
+ training (`bool`, *required*):
+ training mode
+ """
+ out = nn.functional.dropout(x, p=prob, training=training)
+ out = residual + out
+ return out
+
+
+def bloom_gelu_forward(x):
+ """
+ Custom bias GELU function. Adapted from Megatron-DeepSpeed code. Here we use a simple implementation (inference) to
+ make the model jitable.
+
+ Args:
+ x (`torch.tensor`, *required*):
+ input hidden states
+ """
+ return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
+
+
+def bloom_gelu_back(g, x):
+ """
+ gradient of tanh approximation of gelu gradient of actual gelu is: 0.5 * (1. + torch.erf(x * 0.70710678)) +
+ 0.3989423 * x * torch.exp(-0.5 * x * x)
+
+ Args:
+ g (`torch.tensor`, *required*):
+ gradient output tensor
+ x (`torch.tensor`, *required*):
+ input tensor
+ """
+ x = x[0] # x is a tuple of 1 element, needs to unpack it first
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
+ # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
+ ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
+ return ff * g
+
+
+class GeLUFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input):
+ ctx.save_for_backward(input)
+ return bloom_gelu_forward(input)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input = ctx.saved_tensors
+ tmp = bloom_gelu_back(grad_output, input)
+ return tmp
+
+
+class BloomGelu(nn.Module):
+ """
+ BloomBiasGelu wrapper function that make use of the simple function on inference mode to make the model
+ torchscriptable and use the autograd function in training mode to get the accurate results of the gradients Partly
+ copied from Megatron-DeepSpeed code and adapted for our needs
+
+ See here why autograd functions are not torchscriptable: https://github.com/pytorch/pytorch/issues/22329
+
+ """
+
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ if self.training:
+ return GeLUFunction.apply(x)
+ else:
+ return bloom_gelu_forward(x)
+
+
+class BloomScaledSoftmax(nn.Module):
+ """
+ fused operation: scaling + mask + softmax
+
+ Args:
+ input_in_fp16 (`bool`, *required*):
+ flag to indicate if input in fp16 data format.
+ input_in_bf16 (`bool`, *required*):
+ flag to indicate if input in bf16 data format.
+ scaled_masked_softmax_fusion (`bool`, *required*):
+ flag to indicate user want to use softmax fusion
+ mask_func (`function`, *required*):
+ mask function to be applied.
+ softmax_in_fp32 (`bool`, *required*):
+ if true, softmax in performed at fp32 precision.
+ scale (`float`, *required*):
+ scaling factor used in input tensor scaling.
+ """
+
+ def __init__(self, scaled_masked_softmax_fusion, mask_func, softmax_in_fp32, scale):
+ super().__init__()
+ self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
+ self.mask_func = mask_func
+ self.softmax_in_fp32 = softmax_in_fp32
+ self.scale = scale
+
+ if not (self.scale is None or softmax_in_fp32):
+ raise ValueError("softmax should be in fp32 when scaled")
+
+ def forward(self, input, mask, max_positions):
+ input_dtype = input.dtype
+ input_in_16bit = input_dtype in [torch.float16, torch.bfloat16]
+ softmax_dtype = torch.float32 if self.softmax_in_fp32 else input_dtype
+
+ if self.scale is not None:
+ input = input * self.scale
+
+ if mask is not None:
+ mask = mask.to(input.device)
+ causal_mask = (
+ torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool))
+ .view(1, 1, max_positions, max_positions)
+ .to(input.device)
+ )
+ mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask)
+ probs = nn.functional.softmax(mask_output, dim=-1, dtype=softmax_dtype) * (~padded_causal_mask)
+ else:
+ probs = nn.functional.softmax(input, dim=-1, dtype=softmax_dtype)
+
+ if input_in_16bit and self.softmax_in_fp32:
+ probs = probs.to(dtype=input_dtype)
+
+ return probs
+
+
+class BloomAttention(nn.Module):
+ def __init__(self, config, layer_number=None):
+ super().__init__()
+
+ self.pretraining_tp = config.pretraining_tp
+ self.slow_but_exact = config.slow_but_exact
+
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.n_head
+ self.head_dim = self.hidden_size // self.num_heads
+ self.split_size = self.hidden_size
+ self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
+ self.masked_softmax_fusion = config.masked_softmax_fusion
+ self.hidden_dropout = config.hidden_dropout
+
+ if self.head_dim * self.num_heads != self.hidden_size:
+ raise ValueError(
+ f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+
+ # Layer-wise attention scaling
+ self.layer_number = max(1, layer_number)
+ self.norm_factor = math.sqrt(self.head_dim) * self.layer_number
+
+ # Scaled Softmax
+ self.scale_mask_softmax = BloomScaledSoftmax(
+ self.masked_softmax_fusion,
+ attention_mask_func,
+ self.attention_softmax_in_fp32,
+ self.layer_number,
+ )
+
+ self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
+ self.dense = nn.Linear(self.hidden_size, self.hidden_size)
+ self.attention_dropout = nn.Dropout(config.attention_dropout)
+
+ def forward(
+ self,
+ hidden_states,
+ residual,
+ layer_past=None,
+ attention_mask=None,
+ alibi=None,
+ head_mask=None,
+ use_cache=False,
+ output_attentions=False,
+ ):
+ # hidden_states: [batch_size, seq_length, hidden_size]
+ # repeat alibi tensor with the batch size
+ alibi = alibi.repeat(hidden_states.shape[0], 1, 1).to(hidden_states.device)
+
+ # apply preprocessing if the input is padded
+ if attention_mask is not None and 0 in attention_mask:
+ alibi = pre_process_alibi_for_pad(alibi, attention_mask, self.num_heads)
+
+ mixed_x_layer = self.query_key_value(hidden_states)
+
+ # [batch_size, seq_length, 3 x hidden_size] --> [batch_size, seq_length, num_heads, 3 x head_dim]
+ new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_heads, 3 * self.head_dim)
+ mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
+
+ # [batch_size, seq_length, num_heads, 3 x head_dim] --> 3 [batch_size, seq_length, num_heads, head_dim]
+ (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
+
+ if layer_past is not None:
+ past_key, past_value = layer_past
+ key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=1)
+ value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=1)
+
+ if use_cache is True:
+ present = (key_layer, value_layer)
+ else:
+ present = None
+
+ # [batch_size, head_dim, q_length, k_length]
+ output_size = (query_layer.size(0), query_layer.size(2), query_layer.size(1), key_layer.size(1))
+
+ # [batch_size, q_length, num_heads, head_dim] -> [q_length, batch_size * num_heads, head_dim]
+ query_layer = query_layer.transpose(1, 0).reshape(output_size[2], output_size[0] * output_size[1], -1)
+
+ # [batch_size, k_length, num_heads, head_dim] -> [k_length, batch_size * num_heads, head_dim]
+ key_layer = key_layer.transpose(1, 0).reshape(output_size[3], output_size[0] * output_size[1], -1)
+
+ # slice alibi tensor until the query length
+ sliced_alibi = alibi[: output_size[0] * output_size[1], :, : output_size[3]]
+
+ # Raw attention scores. [batch_size * num_heads, q_length, k_length]
+ beta = 1.0 / self.layer_number
+
+ matmul_result = torch.baddbmm(
+ sliced_alibi,
+ query_layer.transpose(1, 0),
+ key_layer.transpose(1, 0).transpose(1, 2),
+ beta=beta,
+ alpha=(1.0 / self.norm_factor),
+ )
+
+ # change view to [batch_size, num_heads, q_length, k_length]
+ attention_scores = matmul_result.view(*output_size)
+
+ # attention scores and attention mask [b, np, sq, sk]
+ max_positions = max(attention_scores.shape[-1], attention_scores.shape[-2])
+ attention_probs = self.scale_mask_softmax(attention_scores, attention_mask, max_positions).to(
+ value_layer.dtype
+ )
+ attention_probs = self.attention_dropout(attention_probs)
+
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ # context layer shape: [batch_size, num_heads, q_length, head_dim]
+ output_size = (value_layer.size(0), value_layer.size(2), query_layer.size(0), value_layer.size(3))
+
+ # change view [k_length, batch_size x num_heads, head_dim]
+ value_layer = value_layer.transpose(1, 0).reshape(value_layer.size(1), output_size[0] * output_size[1], -1)
+
+ # change view [batch_size x num_heads, q_length, k_length]
+ attention_probs_reshaped = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
+
+ # matmul: [batch_size * num_heads, q_length, head_dim]
+ context_layer = torch.bmm(attention_probs_reshaped, value_layer.transpose(0, 1))
+
+ # change view [batch_size, num_heads, q_length, head_dim]
+ context_layer = context_layer.view(*output_size)
+
+ # [batchs_size, num_heads, q_length, head_dim] --> [q_length, batch_size, num_heads, head_dim]
+ context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
+
+ # [q_length, batch_size, num_heads, head_dim] --> [q_length, batch_size, hidden_size]
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,)
+
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ # Output. [q_length, batch_size, hidden_size]
+
+ # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
+ if self.pretraining_tp > 1 and self.slow_but_exact:
+ slices = context_layer.shape[-1] / self.pretraining_tp
+ output_tensor = torch.zeros_like(context_layer)
+ for i in range(self.pretraining_tp):
+ output_tensor = output_tensor + nn.functional.linear(
+ context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
+ self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
+ )
+ else:
+ output_tensor = self.dense(context_layer)
+
+ output = output_tensor.transpose(1, 0)
+
+ output = dropout_add(output, residual, self.hidden_dropout, self.training)
+
+ outputs = (output, present)
+ if output_attentions:
+ outputs += (attention_probs,)
+
+ return outputs
+
+
+class BloomMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ hidden_size = config.hidden_size
+
+ self.pretraining_tp = config.pretraining_tp
+ self.slow_but_exact = config.slow_but_exact
+ self.dense_h_to_4h = nn.Linear(hidden_size, 4 * hidden_size)
+ self.dense_4h_to_h = nn.Linear(4 * hidden_size, hidden_size)
+ self.hidden_dropout = config.hidden_dropout
+ self.gelu_impl = BloomGelu()
+
+ def forward(self, hidden_states, residual):
+ hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states))
+
+ if self.pretraining_tp > 1 and self.slow_but_exact:
+ intermediate_output = torch.zeros_like(residual)
+ slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp
+ for i in range(self.pretraining_tp):
+ intermediate_output = intermediate_output + nn.functional.linear(
+ hidden_states[:, :, int(i * slices) : int((i + 1) * slices)],
+ self.dense_4h_to_h.weight[:, int(i * slices) : int((i + 1) * slices)],
+ )
+ else:
+ intermediate_output = self.dense_4h_to_h(hidden_states)
+
+ output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
+
+ return output
+
+
+class BloomBlock(nn.Module):
+ def __init__(self, config, layer_number=None):
+ super().__init__()
+ hidden_size = config.hidden_size
+
+ self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+ self.n_head = config.n_head
+ self.self_attention = BloomAttention(config, layer_number=layer_number)
+ self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+
+ self.mlp = BloomMLP(config)
+
+ self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
+ self.hidden_dropout = config.hidden_dropout
+
+ def forward(
+ self,
+ hidden_states,
+ layer_past=None,
+ attention_mask=None,
+ head_mask=None,
+ use_cache=False,
+ output_attentions=False,
+ alibi=None,
+ ):
+ # hidden_states: [batch_size, seq_length, hidden_size]
+
+ # Layer norm at the beginning of the transformer layer.
+ layernorm_output = self.input_layernorm(hidden_states)
+
+ # Layer norm post the self attention.
+ if self.apply_residual_connection_post_layernorm:
+ residual = layernorm_output
+ else:
+ residual = hidden_states
+
+ # Self attention.
+ attn_outputs = self.self_attention(
+ layernorm_output,
+ residual,
+ layer_past=layer_past,
+ attention_mask=attention_mask,
+ alibi=alibi,
+ head_mask=head_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+
+ attention_output = attn_outputs[0]
+
+ outputs = attn_outputs[1:]
+
+ layernorm_output = self.post_attention_layernorm(attention_output)
+
+ # Get residual
+ if self.apply_residual_connection_post_layernorm:
+ residual = layernorm_output
+ else:
+ residual = attention_output
+
+ # MLP.
+ output = self.mlp(layernorm_output, residual)
+
+ if use_cache:
+ outputs = (output,) + outputs
+ else:
+ outputs = (output,) + outputs[1:]
+
+ return outputs # hidden_states, present, attentions
+
+
+class BloomPreTrainedModel(PreTrainedModel):
+ _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = BloomConfig
+ base_model_prefix = "transformer"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["BloomBlock"]
+
+ def __init__(self, *inputs, **kwargs):
+ super().__init__(*inputs, **kwargs)
+
+ def _init_weights(self, module):
+ """Initialize the weights."""
+ if isinstance(module, (nn.Linear)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, BloomModel):
+ module.gradient_checkpointing = value
+
+
+BLOOM_START_DOCSTRING = r"""
+
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`BloomConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+BLOOM_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
+ `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
+ sequence tokens in the vocabulary.
+
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
+ `input_ids`.
+
+ Indices can be obtained using [`BloomTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`):
+ Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
+ `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
+ their past given to this model should not be passed as `input_ids` as they have already been computed.
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+
+ If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
+ `past_key_values`).
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.",
+ BLOOM_START_DOCSTRING,
+)
+class BloomModel(BloomPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.embed_dim = config.hidden_size
+ self.n_head = config.n_head
+
+ # Embedding + LN Embedding
+ self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
+
+ self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
+
+ # Transformer blocks
+ self.h = nn.ModuleList([BloomBlock(config, layer_number=i) for i in range(config.num_hidden_layers)])
+
+ # Final Layer Norm
+ self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
+
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.word_embeddings
+
+ def set_input_embeddings(self, new_embeddings):
+ self.word_embeddings = new_embeddings
+
+ @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutputWithPastAndCrossAttentions,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids=None,
+ past_key_values=None,
+ attention_mask=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if past_key_values is None:
+ past_key_values = tuple([None] * len(self.h))
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_head x N x N
+ # head_mask has shape n_layer x batch x n_head x N x N
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+
+ hidden_states = self.word_embeddings_layernorm(inputs_embeds)
+
+ output_shape = input_shape + (hidden_states.size(-1),)
+
+ presents = () if use_cache else None
+ all_self_attentions = () if output_attentions else None
+ all_hidden_states = () if output_hidden_states else None
+
+ # Compute alibi tensor: check build_alibi_tensor documentation
+ current_sequence_length = hidden_states.shape[1]
+ if past_key_values[0] is not None:
+ current_sequence_length += past_key_values[0][0].shape[1]
+ alibi = build_alibi_tensor(current_sequence_length, self.n_head, hidden_states.dtype)
+
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+
+ if use_cache:
+ logger.warning(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs, use_cache, output_attentions, alibi)
+
+ return custom_forward
+
+ outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ None,
+ attention_mask,
+ head_mask[i],
+ )
+ else:
+ outputs = block(
+ hidden_states,
+ layer_past=layer_past,
+ attention_mask=attention_mask,
+ head_mask=head_mask[i],
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ alibi=alibi,
+ )
+
+ hidden_states = outputs[0]
+ if use_cache is True:
+ presents = presents + (outputs[1],)
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
+
+ # Add last hidden state
+ hidden_states = self.ln_f(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ hidden_states = hidden_states.view(output_shape)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
+
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=presents,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input
+ embeddings).
+ """,
+ BLOOM_START_DOCSTRING,
+)
+class BloomForCausalLM(BloomPreTrainedModel):
+ _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.transformer = BloomModel(config)
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
+ # only last token for inputs_ids if past is defined in kwargs
+ if past:
+ input_ids = input_ids[:, -1].unsqueeze(-1)
+
+ attention_mask = kwargs.get("attention_mask", None)
+ position_ids = kwargs.get("position_ids", None)
+
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past:
+ position_ids = position_ids[:, -1].unsqueeze(-1)
+ else:
+ position_ids = None
+ return {
+ "input_ids": input_ids,
+ "past_key_values": past,
+ "use_cache": kwargs.get("use_cache"),
+ "position_ids": position_ids,
+ "attention_mask": attention_mask,
+ }
+
+ @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=CausalLMOutputWithCrossAttentions,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids=None,
+ past_key_values=None,
+ attention_mask=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ labels=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.transformer(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = transformer_outputs[0]
+
+ lm_logits = self.lm_head(hidden_states)
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = lm_logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
+
+ if not return_dict:
+ output = (lm_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=loss,
+ logits=lm_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+ @staticmethod
+ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
+ """
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
+ beam_idx at every generation step.
+ """
+ return tuple(
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
+ for layer_past in past
+ )
diff --git a/src/transformers/models/bloom/tokenization_bloom_fast.py b/src/transformers/models/bloom/tokenization_bloom_fast.py
new file mode 100644
index 0000000000000..c9785d641bbbc
--- /dev/null
+++ b/src/transformers/models/bloom/tokenization_bloom_fast.py
@@ -0,0 +1,181 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for Bloom."""
+
+
+import json
+from typing import TYPE_CHECKING, List, Optional, Tuple
+
+from tokenizers import pre_tokenizers
+
+from ...tokenization_utils_base import BatchEncoding
+from ...tokenization_utils_fast import PreTrainedTokenizerFast
+from ...utils import logging
+
+
+if TYPE_CHECKING:
+ from transformers.pipelines.conversational import Conversation
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"tokenizer_file": "tokenizer.json"}
+
+PRETRAINED_VOCAB_FILES_MAP = {
+ "tokenizer_file": {
+ "bigscience/tokenizer": "https://huggingface.co/bigscience/tokenizer/blob/main/tokenizer.json",
+ "bigscience/bloom-350m": "https://huggingface.co/bigscience/bloom-350m/blob/main/tokenizer.json",
+ "bigscience/bloom-760m": "https://huggingface.co/bigscience/bloom-760m/blob/main/tokenizer.json",
+ "bigscience/bloom-1b3": "https://huggingface.co/bigscience/bloom-1b3/blob/main/tokenizer.json",
+ "bigscience/bloom-2b5": "https://huggingface.co/bigscience/bloom-2b5/blob/main/tokenizer.json",
+ "bigscience/bloom-6b3": "https://huggingface.co/bigscience/bloom-2b5/blob/main/tokenizer.json",
+ "bigscience/bloom": "https://huggingface.co/bigscience/bloom/blob/main/tokenizer.json",
+ },
+}
+
+PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
+ "bigscience/tokenizer": 1024,
+ "bigscience/bloom-350m": 1024,
+ "bigscience/bloom-760m": 1024,
+ "bigscience/bloom-1b3": 1024,
+ "bigscience/bloom-2b5": 1024,
+ "bigscience/bloom-6b3": 1024,
+ "bigscience/bloom": 1024,
+}
+
+
+class BloomTokenizerFast(PreTrainedTokenizerFast):
+ """
+ Construct a "fast" Bloom tokenizer (backed by HuggingFace's *tokenizers* library). Based on byte-level
+ Byte-Pair-Encoding.
+
+ This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
+ be encoded differently whether it is at the beginning of the sentence (without space) or not:
+
+ ```
+ >>> from transformers import BloomTokenizerFast
+ >>> tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom")
+ >>> tokenizer("Hello world")['input_ids']
+ [15496, 995]
+ >>> tokenizer(" Hello world")['input_ids']
+ [18435, 995]
+ ```
+
+ You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer, but since
+ the model was not pretrained this way, it might yield a decrease in performance.
+
+
+
+ When used with `is_split_into_words=True`, this tokenizer needs to be instantiated with `add_prefix_space=True`.
+
+
+
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
+ refer to this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`):
+ Path to the vocabulary file.
+ merges_file (`str`):
+ Path to the merges file.
+ errors (`str`, *optional*, defaults to `"replace"`):
+ Paradigm to follow when decoding bytes to UTF-8. See
+ [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
+ unk_token (`str`, *optional*, defaults to `<|endoftext|>`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ bos_token (`str`, *optional*, defaults to `<|endoftext|>`):
+ The beginning of sequence token.
+ eos_token (`str`, *optional*, defaults to `<|endoftext|>`):
+ The end of sequence token.
+ add_prefix_space (`bool`, *optional*, defaults to `False`):
+ Whether or not to add an initial space to the input. This allows to treat the leading word just as any
+ other word. (Bloom tokenizer detect beginning of words by the preceding space).
+ trim_offsets (`bool`, *optional*, defaults to `True`):
+ Whether or not the post-processing step should trim offsets to avoid including whitespaces.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
+ model_input_names = ["input_ids", "attention_mask"]
+ slow_tokenizer_class = None
+
+ def __init__(
+ self,
+ vocab_file=None,
+ merges_file=None,
+ tokenizer_file=None,
+ unk_token="",
+ bos_token="",
+ eos_token="",
+ pad_token="",
+ add_prefix_space=False,
+ **kwargs
+ ):
+ super().__init__(
+ vocab_file,
+ merges_file,
+ tokenizer_file=tokenizer_file,
+ unk_token=unk_token,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ pad_token=pad_token,
+ add_prefix_space=add_prefix_space,
+ **kwargs,
+ )
+
+ pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
+ if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
+ pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type"))
+ pre_tok_state["add_prefix_space"] = add_prefix_space
+ self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)
+
+ self.add_prefix_space = add_prefix_space
+
+ def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding:
+ is_split_into_words = kwargs.get("is_split_into_words", False)
+ if not (self.add_prefix_space or not is_split_into_words):
+ raise Exception(
+ f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True to use it with"
+ " pretokenized inputs."
+ )
+
+ return super()._batch_encode_plus(*args, **kwargs)
+
+ def _encode_plus(self, *args, **kwargs) -> BatchEncoding:
+ is_split_into_words = kwargs.get("is_split_into_words", False)
+
+ if not (self.add_prefix_space or not is_split_into_words):
+ raise Exception(
+ f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True to use it with"
+ " pretokenized inputs."
+ )
+
+ return super()._encode_plus(*args, **kwargs)
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
+ files = self._tokenizer.model.save(save_directory, name=filename_prefix)
+ return tuple(files)
+
+ def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
+ """This corresponds to DialoGPT variants of models."""
+ input_ids = []
+ for is_user, text in conversation.iter_texts():
+ input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id])
+
+ if len(input_ids) > self.model_max_length:
+ input_ids = input_ids[-self.model_max_length :]
+ return input_ids
diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py
index e25185b1a11c9..e130d0ef91645 100644
--- a/src/transformers/utils/dummy_pt_objects.py
+++ b/src/transformers/utils/dummy_pt_objects.py
@@ -959,6 +959,30 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
+BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST = None
+
+
+class BloomForCausalLM(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class BloomModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class BloomPreTrainedModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
diff --git a/src/transformers/utils/dummy_tokenizers_objects.py b/src/transformers/utils/dummy_tokenizers_objects.py
index c4daff7e375ab..631df9f25890c 100644
--- a/src/transformers/utils/dummy_tokenizers_objects.py
+++ b/src/transformers/utils/dummy_tokenizers_objects.py
@@ -52,6 +52,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["tokenizers"])
+class BloomTokenizerFast(metaclass=DummyObject):
+ _backends = ["tokenizers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tokenizers"])
+
+
class CamembertTokenizerFast(metaclass=DummyObject):
_backends = ["tokenizers"]
diff --git a/tests/deepspeed/test_model_zoo.py b/tests/deepspeed/test_model_zoo.py
index 189452c609d36..f5b43f4c1bcca 100644
--- a/tests/deepspeed/test_model_zoo.py
+++ b/tests/deepspeed/test_model_zoo.py
@@ -58,6 +58,7 @@
BIGBIRD_PEGASUS_TINY = "hf-internal-testing/tiny-random-bigbird_pegasus"
BIG_BIRD_TINY = "hf-internal-testing/tiny-random-big_bird"
BLENDERBOT_TINY = "hf-internal-testing/tiny-random-blenderbot"
+BLOOM_TINY = "bigscience/bigscience-small-testing"
DEBERTA_TINY = "hf-internal-testing/tiny-random-deberta"
DEBERTA_V2_TINY = "hf-internal-testing/tiny-random-deberta-v2"
DISTILBERT_TINY = "sshleifer/tiny-distilbert-base-cased"
@@ -183,6 +184,7 @@ def make_task_cmds():
"big_bird",
"bigbird_pegasus",
"blenderbot",
+ "bloom",
"gpt2",
"gpt_neo",
"gptj",
diff --git a/tests/models/bloom/__init__.py b/tests/models/bloom/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/tests/models/bloom/test_modeling_bloom.py b/tests/models/bloom/test_modeling_bloom.py
new file mode 100644
index 0000000000000..1f5c10d2ee598
--- /dev/null
+++ b/tests/models/bloom/test_modeling_bloom.py
@@ -0,0 +1,710 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+import unittest
+
+from transformers import BloomConfig, is_torch_available
+from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
+
+from ...generation.test_generation_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
+
+
+if is_torch_available():
+ import torch
+
+ from transformers import BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST, BloomForCausalLM, BloomModel, BloomTokenizerFast
+
+
+@require_torch
+class BloomModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=14,
+ seq_length=7,
+ is_training=True,
+ use_token_type_ids=False,
+ use_input_mask=True,
+ use_labels=True,
+ use_mc_token_ids=True,
+ vocab_size=99,
+ hidden_size=32,
+ num_hidden_layers=5,
+ num_attention_heads=4,
+ intermediate_size=37,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=512,
+ type_vocab_size=16,
+ type_sequence_label_size=2,
+ initializer_range=0.02,
+ num_labels=3,
+ num_choices=4,
+ scope=None,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.seq_length = seq_length
+ self.is_training = is_training
+ self.use_token_type_ids = use_token_type_ids
+ self.use_input_mask = use_input_mask
+ self.use_labels = use_labels
+ self.use_mc_token_ids = use_mc_token_ids
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.type_vocab_size = type_vocab_size
+ self.type_sequence_label_size = type_sequence_label_size
+ self.initializer_range = initializer_range
+ self.num_labels = num_labels
+ self.num_choices = num_choices
+ self.scope = None
+ self.bos_token_id = vocab_size - 1
+ self.eos_token_id = vocab_size - 1
+ self.pad_token_id = vocab_size - 1
+
+ def get_large_model_config(self):
+ return BloomConfig.from_pretrained("bigscience/bloom")
+
+ def prepare_config_and_inputs(self, gradient_checkpointing=False):
+ input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
+
+ input_mask = None
+ if self.use_input_mask:
+ input_mask = random_attention_mask([self.batch_size, self.seq_length])
+
+ config = self.get_config(gradient_checkpointing=gradient_checkpointing)
+
+ return (config, input_ids, input_mask)
+
+ def get_config(self, gradient_checkpointing=False, slow_but_exact=True):
+ return BloomConfig(
+ vocab_size=self.vocab_size,
+ seq_length=self.seq_length,
+ hidden_size=self.hidden_size,
+ n_layer=self.num_hidden_layers,
+ n_head=self.num_attention_heads,
+ resid_pdrop=self.hidden_dropout_prob,
+ attn_pdrop=self.attention_probs_dropout_prob,
+ n_positions=self.max_position_embeddings,
+ type_vocab_size=self.type_vocab_size,
+ initializer_range=self.initializer_range,
+ use_cache=True,
+ bos_token_id=self.bos_token_id,
+ eos_token_id=self.eos_token_id,
+ pad_token_id=self.pad_token_id,
+ gradient_checkpointing=gradient_checkpointing,
+ slow_but_exact=slow_but_exact,
+ dtype="float32",
+ )
+
+ def create_and_check_bloom_model(self, config, input_ids, input_mask, *args):
+ model = BloomModel(config=config)
+ model.to(torch_device)
+ model.eval()
+
+ result = model(input_ids)
+
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
+ self.parent.assertEqual(len(result.past_key_values), config.n_layer)
+
+ def create_and_check_bloom_model_past(self, config, input_ids, input_mask, *args):
+ model = BloomModel(config=config)
+
+ model.to(torch_device)
+ model.eval()
+
+ # first forward pass
+ outputs = model(input_ids, attention_mask=torch.ones_like(input_ids), use_cache=True)
+ outputs_use_cache_conf = model(input_ids, attention_mask=torch.ones_like(input_ids))
+ outputs_no_past = model(input_ids, use_cache=False, attention_mask=torch.ones_like(input_ids))
+
+ self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
+ self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
+
+ past = outputs["past_key_values"]
+
+ # create hypothetical next token and extent to next_input_ids
+ next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
+
+ # append to next input_ids and token_type_ids
+ next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
+
+ output_from_no_past = model(next_input_ids)["last_hidden_state"]
+ output_from_past = model(next_tokens, past_key_values=past)["last_hidden_state"]
+
+ # select random slice
+ random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
+ output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach()
+ output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
+
+ # test that outputs are equal for slice
+ self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
+
+ def create_and_check_bloom_model_attention_mask_past(self, config, input_ids, input_mask, *args):
+ model = BloomModel(config=config)
+ model.to(torch_device)
+ model.eval()
+
+ # create attention mask
+ attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
+ half_seq_length = self.seq_length // 2
+ attn_mask[:, half_seq_length:] = 0
+
+ # first forward pass
+ output, past = model(input_ids, attention_mask=attn_mask).to_tuple()
+
+ # create hypothetical next token and extent to next_input_ids
+ next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
+
+ # change a random masked slice from input_ids
+ random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1
+ random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1)
+ input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens
+
+ # append to next input_ids and attn_mask
+ next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
+ attn_mask = torch.cat(
+ [attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)],
+ dim=1,
+ )
+
+ # get two different outputs
+ output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
+ output_from_past = model(next_tokens, past_key_values=past, attention_mask=attn_mask)["last_hidden_state"]
+
+ # select random slice
+ random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
+ output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach()
+ output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
+
+ # test that outputs are equal for slice
+ self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
+
+ def create_and_check_bloom_model_past_large_inputs(self, config, input_ids, input_mask, *args):
+ model = BloomModel(config=config)
+ model.to(torch_device)
+ model.eval()
+
+ # first forward pass
+ outputs = model(input_ids, attention_mask=input_mask, use_cache=True)
+
+ output, past = outputs.to_tuple()
+
+ # create hypothetical next token and extent to next_input_ids
+ next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
+ next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
+
+ # append to next input_ids and token_type_ids
+ next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
+ next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
+
+ output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"]
+ output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past)[
+ "last_hidden_state"
+ ]
+ self.parent.assertTrue(output_from_past.shape[1] == next_tokens.shape[1])
+
+ # select random slice
+ random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
+ output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
+ output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
+
+ # test that outputs are equal for slice
+ self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
+
+ def create_and_check_lm_head_model(self, config, input_ids, input_mask, *args):
+ model = BloomForCausalLM(config)
+ model.to(torch_device)
+ model.eval()
+
+ result = model(input_ids, labels=input_ids)
+ self.parent.assertEqual(result.loss.shape, ())
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
+
+ def create_and_check_forward_and_backwards(
+ self, config, input_ids, input_mask, *args, gradient_checkpointing=False
+ ):
+ model = BloomForCausalLM(config)
+ model.to(torch_device)
+ if gradient_checkpointing:
+ model.gradient_checkpointing_enable()
+
+ result = model(input_ids, labels=input_ids)
+ self.parent.assertEqual(result.loss.shape, ())
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
+ result.loss.backward()
+
+ def create_and_check_bloom_weight_initialization(self, config, *args):
+ model = BloomModel(config)
+ model_std = model.config.initializer_range / math.sqrt(2 * model.config.n_layer)
+ for key in model.state_dict().keys():
+ if "c_proj" in key and "weight" in key:
+ self.parent.assertLessEqual(abs(torch.std(model.state_dict()[key]) - model_std), 0.001)
+ self.parent.assertLessEqual(abs(torch.mean(model.state_dict()[key]) - 0.0), 0.01)
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+
+ config, input_ids, input_mask = config_and_inputs
+
+ inputs_dict = {"input_ids": input_ids}
+
+ return config, inputs_dict
+
+
+@require_torch
+class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
+
+ all_model_classes = (BloomModel, BloomForCausalLM) if is_torch_available() else ()
+ all_generative_model_classes = (BloomForCausalLM,) if is_torch_available() else ()
+ fx_compatible = False
+ test_missing_keys = False
+ test_pruning = False
+ test_torchscript = True # torch.autograd functions seems to be not supported
+
+ def setUp(self):
+ self.model_tester = BloomModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=BloomConfig, n_embd=37)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_bloom_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_bloom_model(*config_and_inputs)
+
+ def test_bloom_model_past(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_bloom_model_past(*config_and_inputs)
+
+ def test_bloom_model_att_mask_past(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_bloom_model_attention_mask_past(*config_and_inputs)
+
+ def test_bloom_model_past_large_inputs(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_bloom_model_past_large_inputs(*config_and_inputs)
+
+ def test_bloom_lm_head_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_lm_head_model(*config_and_inputs)
+
+ def test_bloom_gradient_checkpointing(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs, gradient_checkpointing=True)
+
+ def test_bloom_weight_initialization(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_bloom_weight_initialization(*config_and_inputs)
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = BloomModel.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+ @slow
+ @require_torch_gpu
+ def test_simple_generation(self):
+ path_350m = "bigscience/bloom-350m"
+ model = BloomForCausalLM.from_pretrained(path_350m, torch_dtype="auto", use_cache=True).cuda()
+ model = model.eval()
+ tokenizer = BloomTokenizerFast.from_pretrained(path_350m)
+
+ input_sentence = "I enjoy walking with my cute dog"
+ EXPECTED_OUTPUT = (
+ "I enjoy walking with my cute dog, and I love to watch the kids play. I am a very active person, and I am"
+ " a very good listener. I am a very good person, and I am a very good person. I am a"
+ )
+
+ input_ids = tokenizer.encode(input_sentence, return_tensors="pt")
+ greedy_output = model.generate(input_ids.cuda(), max_length=50)
+
+ self.assertEqual(tokenizer.decode(greedy_output[0], skip_special_tokens=True), EXPECTED_OUTPUT)
+
+ @slow
+ @require_torch_gpu
+ def test_batch_generation(self):
+ path_350m = "bigscience/bloom-350m"
+ model = BloomForCausalLM.from_pretrained(path_350m, torch_dtype="auto", use_cache=True).cuda()
+ model = model.eval()
+ tokenizer = BloomTokenizerFast.from_pretrained(path_350m, padding_side="left")
+
+ input_sentence = ["I enjoy walking with my cute dog", "I enjoy walking with my cute dog"]
+
+ input_ids = tokenizer.batch_encode_plus(input_sentence, return_tensors="pt", padding=True)
+ greedy_output = model.generate(
+ input_ids["input_ids"].cuda(), attention_mask=input_ids["attention_mask"], max_length=50, do_sample=False
+ )
+
+ self.assertEqual(
+ tokenizer.decode(greedy_output[0], skip_special_tokens=True),
+ tokenizer.decode(greedy_output[1], skip_special_tokens=True),
+ )
+
+ @slow
+ @require_torch_gpu
+ def test_batch_generation_padd(self):
+ path_350m = "bigscience/bloom-350m"
+ model = BloomForCausalLM.from_pretrained(path_350m, torch_dtype="auto", use_cache=True).cuda()
+ model = model.eval()
+ tokenizer = BloomTokenizerFast.from_pretrained(path_350m, padding_side="left")
+
+ input_sentence = ["I enjoy walking with my cute dog", "Hello my name is"]
+ input_sentence_without_pad = "Hello my name is"
+
+ input_ids = tokenizer.batch_encode_plus(input_sentence, return_tensors="pt", padding=True)
+ input_ids_without_pad = tokenizer.encode(input_sentence_without_pad, return_tensors="pt")
+
+ greedy_output = model.generate(
+ input_ids["input_ids"].cuda(), attention_mask=input_ids["attention_mask"], max_length=50, do_sample=False
+ )
+ greedy_output_without_pad = model.generate(input_ids_without_pad.cuda(), max_length=50, do_sample=False)
+
+ # test token values
+ self.assertEqual(greedy_output[-1, 3:].tolist(), greedy_output_without_pad[0, :-3].tolist())
+
+ # test reconstructions
+ self.assertEqual(
+ tokenizer.decode(greedy_output[-1, 3:], skip_special_tokens=True),
+ tokenizer.decode(greedy_output_without_pad[0, :-3], skip_special_tokens=True),
+ )
+
+
+@require_torch
+class BloomEmbeddingTest(unittest.TestCase):
+ """
+ The goal here is to compare the embeddings generated by the model trained
+ using Megatron-LM with the one from the transformers library, with a small GPT2-like model
+ to ensure that the conversion from Megatron-LM to transformers has been done successfully.
+ The script compares the logits of the embedding layer and the transformer layers.
+
+ WARNING: It is expected that these logits will not have exactly the same statistics when running
+ the code on CPU or GPU. For more info, please visit:
+ - https://github.com/pytorch/pytorch/issues/76052#issuecomment-1103193548
+ - https://discuss.pytorch.org/t/reproducibility-issue-between-intel-and-amd-cpus/144779/9
+
+
+ You need to install tokenizers following this readme:
+ - https://huggingface.co/bigscience-catalogue-data-dev/byte-level-bpe-tokenizer-no-norm-250k-whitespace-and-eos-regex-alpha-v3-dedup-lines-articles
+
+ Tokenizer used during training:
+ - https://huggingface.co/bigscience-catalogue-data-dev/byte-level-bpe-tokenizer-no-norm-250k-whitespace-and-eos-regex-alpha-v3-dedup-lines-articles
+
+ # TODO change the script (or just add skip) when building the env with tokenizers 0.12.0
+ """
+
+ def setUp(self):
+ super().setUp()
+ self.path_bigscience_model = "bigscience/bigscience-small-testing"
+
+ @require_torch
+ def test_embeddings(self):
+ model = BloomForCausalLM.from_pretrained(self.path_bigscience_model, torch_dtype="auto") # load in fp32
+ model.eval()
+
+ EMBEDDINGS_DS_BEFORE_LN_BF_16_MEAN = {
+ 3478: 0.0002307891845703125,
+ 368: -0.000568389892578125,
+ 109586: -0.0003910064697265625,
+ 35433: -0.000194549560546875,
+ 2: 0.0004138946533203125,
+ 77: 0.000659942626953125,
+ 132619: -0.00031280517578125,
+ 2175: 0.000457763671875,
+ 23714: 0.000263214111328125,
+ 73173: -0.000286102294921875,
+ 144252: 0.00052642822265625,
+ }
+ EMBEDDINGS_DS_BEFORE_LN_BF_16_MIN = {
+ 3478: -0.00921630859375,
+ 368: -0.010009765625,
+ 109586: -0.01031494140625,
+ 35433: -0.01177978515625,
+ 2: -0.0074462890625,
+ 77: -0.00848388671875,
+ 132619: -0.009521484375,
+ 2175: -0.0074462890625,
+ 23714: -0.0145263671875,
+ 73173: -0.007415771484375,
+ 144252: -0.01007080078125,
+ }
+ EMBEDDINGS_DS_BEFORE_LN_BF_16_MAX = {
+ 3478: 0.0128173828125,
+ 368: 0.01214599609375,
+ 109586: 0.0111083984375,
+ 35433: 0.01019287109375,
+ 2: 0.0157470703125,
+ 77: 0.0174560546875,
+ 132619: 0.0078125,
+ 2175: 0.0113525390625,
+ 23714: 0.0146484375,
+ 73173: 0.01116943359375,
+ 144252: 0.01141357421875,
+ }
+ EMBEDDINGS_DS_BEFORE_LN_BF_16_SUM = {"value": 0.08203125}
+
+ EMBEDDINGS_DS_BEFORE_LN_F_16_MEAN = {
+ 132619: -0.00031256675720214844,
+ 3478: 0.00023090839385986328,
+ 368: -0.0005702972412109375,
+ 109586: -0.00039124488830566406,
+ 35433: -0.000194549560546875,
+ 2: 0.0004146099090576172,
+ 2175: 0.0004572868347167969,
+ 23714: 0.00026416778564453125,
+ 73173: -0.0002865791320800781,
+ 144252: 0.0005254745483398438,
+ 77: 0.0006618499755859375,
+ }
+ EMBEDDINGS_DS_BEFORE_LN_F_16_MIN = {
+ 3478: -0.00921630859375,
+ 368: -0.010009765625,
+ 109586: -0.01031494140625,
+ 35433: -0.01177978515625,
+ 2: -0.0074462890625,
+ 77: -0.00848388671875,
+ 132619: -0.009521484375,
+ 2175: -0.0074462890625,
+ 23714: -0.0145263671875,
+ 73173: -0.007415771484375,
+ 144252: -0.01007080078125,
+ }
+ EMBEDDINGS_DS_BEFORE_LN_F_16_MAX = {
+ 3478: 0.0128173828125,
+ 368: 0.01214599609375,
+ 109586: 0.0111083984375,
+ 35433: 0.01019287109375,
+ 2: 0.0157470703125,
+ 77: 0.0174560546875,
+ 132619: 0.0078125,
+ 2175: 0.0113525390625,
+ 23714: 0.0146484375,
+ 73173: 0.01116943359375,
+ 144252: 0.01141357421875,
+ }
+ EMBEDDINGS_DS_BEFORE_LN_F_16_SUM = {"value": 0.0821533203125}
+
+ EMBEDDINGS_DS_BEFORE_LN_F_32_MEAN = {
+ 132619: -0.00031267106533050537,
+ 3478: 0.00023087859153747559,
+ 368: -0.0005701072514057159,
+ 109586: -0.0003911703824996948,
+ 35433: -0.0001944899559020996,
+ 2: 0.0004146844148635864,
+ 2175: 0.00045740045607089996,
+ 23714: 0.0002641640603542328,
+ 73173: -0.0002864748239517212,
+ 144252: 0.0005256589502096176,
+ 77: 0.0006617321632802486,
+ }
+ EMBEDDINGS_DS_BEFORE_LN_F_32_MIN = {
+ 3478: -0.00921630859375,
+ 368: -0.010009765625,
+ 109586: -0.01031494140625,
+ 35433: -0.01177978515625,
+ 2: -0.0074462890625,
+ 77: -0.00848388671875,
+ 132619: -0.009521484375,
+ 2175: -0.0074462890625,
+ 23714: -0.0145263671875,
+ 73173: -0.007415771484375,
+ 144252: -0.01007080078125,
+ }
+ EMBEDDINGS_DS_BEFORE_LN_F_32_MAX = {
+ 3478: 0.0128173828125,
+ 368: 0.01214599609375,
+ 109586: 0.0111083984375,
+ 35433: 0.01019287109375,
+ 2: 0.0157470703125,
+ 77: 0.0174560546875,
+ 132619: 0.0078125,
+ 2175: 0.0113525390625,
+ 23714: 0.0146484375,
+ 73173: 0.01116943359375,
+ 144252: 0.01141357421875,
+ }
+ EMBEDDINGS_DS_BEFORE_LN_F_32_SUM = {"value": 0.08217757940292358}
+
+ TEST_EMBEDDINGS = {
+ "torch.bfloat16": {
+ "mean": EMBEDDINGS_DS_BEFORE_LN_BF_16_MEAN,
+ "max": EMBEDDINGS_DS_BEFORE_LN_BF_16_MAX,
+ "min": EMBEDDINGS_DS_BEFORE_LN_BF_16_MIN,
+ "sum": EMBEDDINGS_DS_BEFORE_LN_BF_16_SUM,
+ },
+ "torch.float32": {
+ "mean": EMBEDDINGS_DS_BEFORE_LN_F_32_MEAN,
+ "max": EMBEDDINGS_DS_BEFORE_LN_F_32_MAX,
+ "min": EMBEDDINGS_DS_BEFORE_LN_F_32_MIN,
+ "sum": EMBEDDINGS_DS_BEFORE_LN_F_32_SUM,
+ },
+ "torch.float": {
+ "mean": EMBEDDINGS_DS_BEFORE_LN_F_32_MEAN,
+ "max": EMBEDDINGS_DS_BEFORE_LN_F_32_MAX,
+ "min": EMBEDDINGS_DS_BEFORE_LN_F_32_MIN,
+ "sum": EMBEDDINGS_DS_BEFORE_LN_F_32_SUM,
+ },
+ "torch.float16": {
+ "mean": EMBEDDINGS_DS_BEFORE_LN_F_16_MEAN,
+ "max": EMBEDDINGS_DS_BEFORE_LN_F_16_MAX,
+ "min": EMBEDDINGS_DS_BEFORE_LN_F_16_MIN,
+ "sum": EMBEDDINGS_DS_BEFORE_LN_F_16_SUM,
+ },
+ }
+
+ # fmt: off
+ EXAMPLE_IDS = [3478, 368, 109586, 35433, 2, 77, 132619, 3478, 368, 109586, 35433, 2, 2175, 23714, 73173, 144252, 2, 77, 132619, 3478]
+ # fmt: on
+
+ EMBEDDINGS_DS_AFTER_LN_MEAN = {
+ 3478: -6.580352783203125e-05,
+ 368: 0.0001316070556640625,
+ 109586: -0.00030517578125,
+ 35433: 4.00543212890625e-05,
+ 2: -7.2479248046875e-05,
+ 77: -8.96453857421875e-05,
+ 132619: 0.0001583099365234375,
+ 2175: 2.1219253540039062e-05,
+ 23714: -0.000247955322265625,
+ 73173: -0.00021839141845703125,
+ 144252: -0.0001430511474609375,
+ }
+ EMBEDDINGS_DS_AFTER_LN_MIN = {
+ 3478: -1.6953125,
+ 368: -1.6875,
+ 109586: -1.6875,
+ 35433: -2.125,
+ 2: -1.390625,
+ 77: -1.5390625,
+ 132619: -1.875,
+ 2175: -1.4609375,
+ 23714: -2.296875,
+ 73173: -1.3515625,
+ 144252: -1.78125,
+ }
+ EMBEDDINGS_DS_AFTER_LN_MAX = {
+ 3478: 2.265625,
+ 368: 2.28125,
+ 109586: 1.953125,
+ 35433: 1.90625,
+ 2: 2.703125,
+ 77: 2.828125,
+ 132619: 1.65625,
+ 2175: 2.015625,
+ 23714: 2.234375,
+ 73173: 2.171875,
+ 144252: 1.828125,
+ }
+
+ EMBEDDINGS_DS_AFTER_LN = {
+ "mean": EMBEDDINGS_DS_AFTER_LN_MEAN,
+ "min": EMBEDDINGS_DS_AFTER_LN_MIN,
+ "max": EMBEDDINGS_DS_AFTER_LN_MAX,
+ }
+
+ tensor_ids = torch.LongTensor([EXAMPLE_IDS])
+ with torch.no_grad():
+ embeddings = model.transformer.word_embeddings(tensor_ids)
+ embeddings_ln = model.transformer.word_embeddings_layernorm(embeddings) #
+ # first check the embeddings before LN
+ output_dict = {"min": {}, "max": {}, "mean": {}, "sum": {"value": embeddings.sum().item()}}
+ for i, idx in enumerate(EXAMPLE_IDS):
+ output_dict["min"][idx] = embeddings.min(dim=-1).values[0][i].item()
+ output_dict["max"][idx] = embeddings.max(dim=-1).values[0][i].item()
+ output_dict["mean"][idx] = embeddings.mean(dim=-1)[0][i].item()
+
+ for key in TEST_EMBEDDINGS[str(model.dtype)].keys():
+ self.assertDictEqual(TEST_EMBEDDINGS[str(model.dtype)][key], output_dict[key])
+
+ output_dict_norm = {"min": {}, "max": {}, "mean": {}}
+ for i, idx in enumerate(EXAMPLE_IDS):
+ output_dict_norm["min"][idx] = embeddings_ln.min(dim=-1).values[0][i].item()
+ output_dict_norm["max"][idx] = embeddings_ln.max(dim=-1).values[0][i].item()
+ output_dict_norm["mean"][idx] = embeddings_ln.mean(dim=-1)[0][i].item()
+
+ # This test does not pass when places = 2
+ for i, key in enumerate(output_dict_norm.keys()):
+ for j, idx in enumerate(output_dict[key].keys()):
+ self.assertAlmostEqual(EMBEDDINGS_DS_AFTER_LN[key][idx], output_dict_norm[key][idx], places=1)
+
+ @require_torch
+ def test_hidden_states_transformers(self):
+ cuda_available = torch.cuda.is_available()
+ model = BloomModel.from_pretrained(self.path_bigscience_model, use_cache=False, torch_dtype="auto").to(
+ torch_device
+ )
+ model.eval()
+
+ # fmt: off
+ EXAMPLE_IDS = [3478, 368, 109586, 35433, 2, 77, 132619, 3478, 368, 109586, 35433, 2, 2175, 23714, 73173, 144252, 2, 77, 132619, 3478]
+ # fmt: on
+
+ MEAN_VALUE_LAST_LM = -4.3392181396484375e-05
+ MIN_MAX_DICT = {"min": -2.0625, "max": 2.75}
+ tensor_ids = torch.LongTensor([EXAMPLE_IDS])
+
+ with torch.no_grad():
+ logits = model(tensor_ids.to(torch_device))
+ output_dict = {
+ "min": logits.last_hidden_state.min(dim=-1).values[0][0].item(),
+ "max": logits.last_hidden_state.max(dim=-1).values[0][0].item(),
+ }
+
+ if cuda_available:
+ self.assertEqual(MEAN_VALUE_LAST_LM, logits.last_hidden_state.mean().item())
+ else:
+ self.assertAlmostEqual(MEAN_VALUE_LAST_LM, logits.last_hidden_state.mean().item(), places=3)
+
+ self.assertDictEqual(MIN_MAX_DICT, output_dict)
+
+ @require_torch
+ def test_logits(self):
+ cuda_available = torch.cuda.is_available()
+ model = BloomForCausalLM.from_pretrained(self.path_bigscience_model, use_cache=False, torch_dtype="auto").to(
+ torch_device
+ ) # load in bf16
+ model.eval()
+
+ # fmt: off
+ EXAMPLE_IDS = [3478, 368, 109586, 35433, 2, 77, 132619, 3478, 368, 109586, 35433, 2, 2175, 23714, 73173, 144252, 2, 77, 132619, 3478]
+ # fmt: on
+
+ MEAN_LOGITS_GPU_1 = -1.823902130126953e-05
+ MEAN_LOGITS_GPU_2 = 1.9431114196777344e-05
+
+ tensor_ids = torch.LongTensor([EXAMPLE_IDS]).to(torch_device)
+ with torch.no_grad():
+ output = model(tensor_ids).logits
+
+ output_gpu_1, output_gpu_2 = output.split(125440, dim=-1)
+ if cuda_available:
+ self.assertEqual(output_gpu_1.mean().item(), MEAN_LOGITS_GPU_1)
+ self.assertEqual(output_gpu_2.mean().item(), MEAN_LOGITS_GPU_2)
+ else:
+ self.assertAlmostEqual(output_gpu_1.mean().item(), MEAN_LOGITS_GPU_1, places=6) # 1e-06 precision!!
+ self.assertAlmostEqual(output_gpu_2.mean().item(), MEAN_LOGITS_GPU_2, places=6)
diff --git a/tests/models/bloom/test_tokenization_bloom.py b/tests/models/bloom/test_tokenization_bloom.py
new file mode 100644
index 0000000000000..c213437a37dd0
--- /dev/null
+++ b/tests/models/bloom/test_tokenization_bloom.py
@@ -0,0 +1,129 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+from datasets import load_dataset
+
+from transformers import BloomTokenizerFast
+from transformers.testing_utils import require_tokenizers
+
+from ...test_tokenization_common import TokenizerTesterMixin
+
+
+@require_tokenizers
+class BloomTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
+
+ slow_tokenizer_class = None
+ rust_tokenizer_class = BloomTokenizerFast
+ tokenizer_class = BloomTokenizerFast
+ test_rust_tokenizer = True
+ test_slow_tokenizer = False
+ from_pretrained_vocab_key = "tokenizer_file"
+ special_tokens_map = {"bos_token": "", "eos_token": "", "unk_token": "", "pad_token": ""}
+
+ def setUp(self):
+ super().setUp()
+ tokenizer = BloomTokenizerFast.from_pretrained("bigscience/tokenizer")
+ tokenizer.save_pretrained(self.tmpdirname)
+
+ def get_rust_tokenizer(self, **kwargs):
+ kwargs.update(self.special_tokens_map)
+ return BloomTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
+
+ def test_encodings_from_sample_data(self):
+ """
+ Assert that the created tokens are the same than the hard-coded ones
+ """
+ tokenizer = self.get_rust_tokenizer()
+
+ INPUT_SENTENCES = ["The quick brown fox", "jumps over the lazy dog"]
+ TARGET_TOKENS = [[2175, 23714, 73173, 144252, 2], [77, 132619, 3478, 368, 109586, 35433, 2]]
+
+ computed_tokens = tokenizer.batch_encode_plus(INPUT_SENTENCES)["input_ids"]
+ self.assertListEqual(TARGET_TOKENS, computed_tokens)
+
+ decoded_tokens = tokenizer.batch_decode(computed_tokens)
+ self.assertListEqual(decoded_tokens, INPUT_SENTENCES)
+
+ def test_padding(self, max_length=6):
+ for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
+ with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
+ tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
+ # tokenizer_r.pad_token = None # Hotfixing padding = None
+ # Simple input
+ s = "This is a simple input"
+ s2 = ["This is a simple input 1", "This is a simple input 2"]
+ p = ("This is a simple input", "This is a pair")
+ p2 = [
+ ("This is a simple input 1", "This is a simple input 2"),
+ ("This is a simple pair 1", "This is a simple pair 2"),
+ ]
+
+ # Simple input tests
+ try:
+ tokenizer_r.encode(s, max_length=max_length)
+ tokenizer_r.encode_plus(s, max_length=max_length)
+
+ tokenizer_r.batch_encode_plus(s2, max_length=max_length)
+ tokenizer_r.encode(p, max_length=max_length)
+ tokenizer_r.batch_encode_plus(p2, max_length=max_length)
+ except ValueError:
+ self.fail("Bloom Tokenizer should be able to deal with padding")
+
+ tokenizer_r.pad_token = None # Hotfixing padding = None
+ self.assertRaises(ValueError, tokenizer_r.encode, s, max_length=max_length, padding="max_length")
+
+ # Simple input
+ self.assertRaises(ValueError, tokenizer_r.encode_plus, s, max_length=max_length, padding="max_length")
+
+ # Simple input
+ self.assertRaises(
+ ValueError,
+ tokenizer_r.batch_encode_plus,
+ s2,
+ max_length=max_length,
+ padding="max_length",
+ )
+
+ # Pair input
+ self.assertRaises(ValueError, tokenizer_r.encode, p, max_length=max_length, padding="max_length")
+
+ # Pair input
+ self.assertRaises(ValueError, tokenizer_r.encode_plus, p, max_length=max_length, padding="max_length")
+
+ # Pair input
+ self.assertRaises(
+ ValueError,
+ tokenizer_r.batch_encode_plus,
+ p2,
+ max_length=max_length,
+ padding="max_length",
+ )
+
+ def test_encodings_from_xnli_dataset(self):
+ """
+ Tests the tokenizer downloaded from here:
+ - https://huggingface.co/bigscience/tokenizer/
+ """
+ tokenizer = self.get_rust_tokenizer()
+ ds = load_dataset("xnli", "all_languages", split="test", streaming=True)
+
+ sample_data = next(iter(ds))["premise"] # pick up one data
+ input_text = list(sample_data.values())
+
+ output_tokens = list(map(tokenizer.encode, input_text))
+ predicted_text = list(map(lambda x: tokenizer.decode(x, clean_up_tokenization_spaces=False), output_tokens))
+ self.assertListEqual(predicted_text, input_text)