From d45930be385bb981a8e3dbac542551ee51ae12ce Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 4 Oct 2022 12:24:12 +0100 Subject: [PATCH 001/102] first commit --- README.md | 1 + README_ko.md | 1 + README_zh-hans.md | 1 + README_zh-hant.md | 1 + docs/source/en/index.mdx | 2 + .../en/model_doc/switchtransformers.mdx | 101 + docs/source/en/serialization.mdx | 1 + src/transformers/__init__.py | 54 + src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 3 + src/transformers/models/auto/modeling_auto.py | 4 + .../models/auto/modeling_flax_auto.py | 3 + .../models/auto/modeling_tf_auto.py | 4 + .../models/auto/tokenization_auto.py | 7 + .../models/switchtransformers/__init__.py | 158 ++ .../configuration_switchtransformers.py | 168 ++ ...rmers_original_tf_checkpoint_to_pytorch.py | 60 + ..._switchtransformersx_checkpoint_to_flax.py | 234 ++ .../modeling_flax_switchtransformers.py | 1826 ++++++++++++++++ .../modeling_switchtransformers.py | 1878 +++++++++++++++++ .../modeling_tf_switchtransformers.py | 1671 +++++++++++++++ .../tokenization_switchtransformers.py | 339 +++ .../tokenization_switchtransformers_fast.py | 237 +++ src/transformers/utils/dummy_flax_objects.py | 28 + src/transformers/utils/dummy_pt_objects.py | 35 + .../utils/dummy_sentencepiece_objects.py | 7 + src/transformers/utils/dummy_tf_objects.py | 31 + .../utils/dummy_tokenizers_objects.py | 7 + tests/models/switchtransformers/__init__.py | 0 .../test_modeling_flax_switchtransformers.py | 1099 ++++++++++ .../test_modeling_switchtransformers.py | 1254 +++++++++++ .../test_modeling_tf_switchtransformers.py | 1054 +++++++++ .../test_tokenization_switchtransformers.py | 381 ++++ 33 files changed, 10651 insertions(+) create mode 100644 docs/source/en/model_doc/switchtransformers.mdx create mode 100644 src/transformers/models/switchtransformers/__init__.py create mode 100644 src/transformers/models/switchtransformers/configuration_switchtransformers.py create mode 100644 src/transformers/models/switchtransformers/convert_switchtransformers_original_tf_checkpoint_to_pytorch.py create mode 100644 src/transformers/models/switchtransformers/convert_switchtransformersx_checkpoint_to_flax.py create mode 100644 src/transformers/models/switchtransformers/modeling_flax_switchtransformers.py create mode 100644 src/transformers/models/switchtransformers/modeling_switchtransformers.py create mode 100644 src/transformers/models/switchtransformers/modeling_tf_switchtransformers.py create mode 100644 src/transformers/models/switchtransformers/tokenization_switchtransformers.py create mode 100644 src/transformers/models/switchtransformers/tokenization_switchtransformers_fast.py create mode 100644 tests/models/switchtransformers/__init__.py create mode 100644 tests/models/switchtransformers/test_modeling_flax_switchtransformers.py create mode 100644 tests/models/switchtransformers/test_modeling_switchtransformers.py create mode 100644 tests/models/switchtransformers/test_modeling_tf_switchtransformers.py create mode 100644 tests/models/switchtransformers/test_tokenization_switchtransformers.py diff --git a/README.md b/README.md index 4eb429652ab01..d4e95468ba4f3 100644 --- a/README.md +++ b/README.md @@ -371,6 +371,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h 1. **[SqueezeBERT](https://huggingface.co/docs/transformers/model_doc/squeezebert)** (from Berkeley) released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer. 1. **[Swin Transformer](https://huggingface.co/docs/transformers/model_doc/swin)** (from Microsoft) released with the paper [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030) by Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, Baining Guo. 1. **[Swin Transformer V2](https://huggingface.co/docs/transformers/model_doc/swinv2)** (from Microsoft) released with the paper [Swin Transformer V2: Scaling Up Capacity and Resolution](https://arxiv.org/abs/2111.09883) by Ze Liu, Han Hu, Yutong Lin, Zhuliang Yao, Zhenda Xie, Yixuan Wei, Jia Ning, Yue Cao, Zheng Zhang, Li Dong, Furu Wei, Baining Guo. +1. **[SwitchTransformers](https://huggingface.co/docs/transformers/main/model_doc/switchtransformers)** (from ) released with the paper []() by . 1. **[T5](https://huggingface.co/docs/transformers/model_doc/t5)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[T5v1.1](https://huggingface.co/docs/transformers/model_doc/t5v1.1)** (from Google AI) released in the repository [google-research/text-to-text-transfer-transformer](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[TAPAS](https://huggingface.co/docs/transformers/model_doc/tapas)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos. diff --git a/README_ko.md b/README_ko.md index c591b50417ad9..ef0d938a9405b 100644 --- a/README_ko.md +++ b/README_ko.md @@ -321,6 +321,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는 1. **[SqueezeBERT](https://huggingface.co/docs/transformers/model_doc/squeezebert)** (from Berkeley) released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer. 1. **[Swin Transformer](https://huggingface.co/docs/transformers/model_doc/swin)** (from Microsoft) released with the paper [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030) by Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, Baining Guo. 1. **[Swin Transformer V2](https://huggingface.co/docs/transformers/model_doc/swinv2)** (from Microsoft) released with the paper [Swin Transformer V2: Scaling Up Capacity and Resolution](https://arxiv.org/abs/2111.09883) by Ze Liu, Han Hu, Yutong Lin, Zhuliang Yao, Zhenda Xie, Yixuan Wei, Jia Ning, Yue Cao, Zheng Zhang, Li Dong, Furu Wei, Baining Guo. +1. **[SwitchTransformers](https://huggingface.co/docs/transformers/main/model_doc/switchtransformers)** (from ) released with the paper []() by . 1. **[T5](https://huggingface.co/docs/transformers/model_doc/t5)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[T5v1.1](https://huggingface.co/docs/transformers/model_doc/t5v1.1)** (from Google AI) released in the repository [google-research/text-to-text-transfer-transformer](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[TAPAS](https://huggingface.co/docs/transformers/model_doc/tapas)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos. diff --git a/README_zh-hans.md b/README_zh-hans.md index 36b33982d0a13..82b042184afdb 100644 --- a/README_zh-hans.md +++ b/README_zh-hans.md @@ -345,6 +345,7 @@ conda install -c huggingface transformers 1. **[SqueezeBERT](https://huggingface.co/docs/transformers/model_doc/squeezebert)** (来自 Berkeley) 伴随论文 [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) 由 Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer 发布。 1. **[Swin Transformer](https://huggingface.co/docs/transformers/model_doc/swin)** (来自 Microsoft) 伴随论文 [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030) 由 Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, Baining Guo 发布。 1. **[Swin Transformer V2](https://huggingface.co/docs/transformers/model_doc/swinv2)** (来自 Microsoft) 伴随论文 [Swin Transformer V2: Scaling Up Capacity and Resolution](https://arxiv.org/abs/2111.09883) 由 Ze Liu, Han Hu, Yutong Lin, Zhuliang Yao, Zhenda Xie, Yixuan Wei, Jia Ning, Yue Cao, Zheng Zhang, Li Dong, Furu Wei, Baining Guo 发布。 +1. **[SwitchTransformers](https://huggingface.co/docs/transformers/main/model_doc/switchtransformers)** (from ) released with the paper []() by . 1. **[T5](https://huggingface.co/docs/transformers/model_doc/t5)** (来自 Google AI) 伴随论文 [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) 由 Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu 发布。 1. **[T5v1.1](https://huggingface.co/docs/transformers/model_doc/t5v1.1)** (来自 Google AI) 伴随论文 [google-research/text-to-text-transfer-transformer](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) 由 Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu 发布。 1. **[TAPAS](https://huggingface.co/docs/transformers/model_doc/tapas)** (来自 Google AI) 伴随论文 [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) 由 Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos 发布。 diff --git a/README_zh-hant.md b/README_zh-hant.md index eef6a3589f4ef..16f5109a77795 100644 --- a/README_zh-hant.md +++ b/README_zh-hant.md @@ -357,6 +357,7 @@ conda install -c huggingface transformers 1. **[SqueezeBERT](https://huggingface.co/docs/transformers/model_doc/squeezebert)** (from Berkeley) released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer. 1. **[Swin Transformer](https://huggingface.co/docs/transformers/model_doc/swin)** (from Microsoft) released with the paper [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030) by Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, Baining Guo. 1. **[Swin Transformer V2](https://huggingface.co/docs/transformers/model_doc/swinv2)** (from Microsoft) released with the paper [Swin Transformer V2: Scaling Up Capacity and Resolution](https://arxiv.org/abs/2111.09883) by Ze Liu, Han Hu, Yutong Lin, Zhuliang Yao, Zhenda Xie, Yixuan Wei, Jia Ning, Yue Cao, Zheng Zhang, Li Dong, Furu Wei, Baining Guo. +1. **[SwitchTransformers](https://huggingface.co/docs/transformers/main/model_doc/switchtransformers)** (from ) released with the paper []() by . 1. **[T5](https://huggingface.co/docs/transformers/model_doc/t5)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[T5v1.1](https://huggingface.co/docs/transformers/model_doc/t5v1.1)** (from Google AI) released with the paper [google-research/text-to-text-transfer-transformer](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[TAPAS](https://huggingface.co/docs/transformers/model_doc/tapas)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos. diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index 1b862df0b0e4c..82172cccb1c04 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -161,6 +161,7 @@ The documentation is organized into five sections: 1. **[SqueezeBERT](model_doc/squeezebert)** (from Berkeley) released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer. 1. **[Swin Transformer](model_doc/swin)** (from Microsoft) released with the paper [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030) by Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, Baining Guo. 1. **[Swin Transformer V2](model_doc/swinv2)** (from Microsoft) released with the paper [Swin Transformer V2: Scaling Up Capacity and Resolution](https://arxiv.org/abs/2111.09883) by Ze Liu, Han Hu, Yutong Lin, Zhuliang Yao, Zhenda Xie, Yixuan Wei, Jia Ning, Yue Cao, Zheng Zhang, Li Dong, Furu Wei, Baining Guo. +1. **[SwitchTransformers](model_doc/switchtransformers)** (from ) released with the paper []() by . 1. **[T5](model_doc/t5)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[T5v1.1](model_doc/t5v1.1)** (from Google AI) released in the repository [google-research/text-to-text-transfer-transformer](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[TAPAS](model_doc/tapas)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos. @@ -309,6 +310,7 @@ Flax), PyTorch, and/or TensorFlow. | SqueezeBERT | ✅ | ✅ | ✅ | ❌ | ❌ | | Swin Transformer | ❌ | ❌ | ✅ | ✅ | ❌ | | Swin Transformer V2 | ❌ | ❌ | ✅ | ❌ | ❌ | +| SwitchTransformers | ✅ | ✅ | ✅ | ✅ | ✅ | | T5 | ✅ | ✅ | ✅ | ✅ | ✅ | | TAPAS | ✅ | ❌ | ✅ | ✅ | ❌ | | Time Series Transformer | ❌ | ❌ | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/switchtransformers.mdx b/docs/source/en/model_doc/switchtransformers.mdx new file mode 100644 index 0000000000000..0ba701599c7e9 --- /dev/null +++ b/docs/source/en/model_doc/switchtransformers.mdx @@ -0,0 +1,101 @@ + + +# SwitchTransformers + +## Overview + +The SwitchTransformers model was proposed in []() by . + + +The abstract from the paper is the following: + +** + +Tips: + + + +This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/). +The original code can be found [here](). + + +## SwitchTransformersConfig + +[[autodoc]] SwitchTransformersConfig + +## SwitchTransformersTokenizer + +[[autodoc]] SwitchTransformersTokenizer + - build_inputs_with_special_tokens + - get_special_tokens_mask + - create_token_type_ids_from_sequences + - save_vocabulary + +## SwitchTransformersTokenizerFast + +[[autodoc]] SwitchTransformersTokenizerFast + +## SwitchTransformersModel + +[[autodoc]] SwitchTransformersModel + - forward + - parallelize + - deparallelize + +## SwitchTransformersForConditionalGeneration + +[[autodoc]] SwitchTransformersForConditionalGeneration + - forward + - parallelize + - deparallelize + +## SwitchTransformersEncoderModel + +[[autodoc]] SwitchTransformersEncoderModel + - forward + - parallelize + - deparallelize + +## TFSwitchTransformersModel + +[[autodoc]] TFSwitchTransformersModel + - call + +## TFSwitchTransformersForConditionalGeneration + +[[autodoc]] TFSwitchTransformersForConditionalGeneration + - call + +## TFSwitchTransformersEncoderModel + +[[autodoc]] TFSwitchTransformersEncoderModel + - call + +## FlaxSwitchTransformersModel + +[[autodoc]] FlaxSwitchTransformersModel + - __call__ + - encode + - decode + +## FlaxSwitchTransformersForConditionalGeneration + +[[autodoc]] FlaxSwitchTransformersForConditionalGeneration + - __call__ + - encode + - decode + +## FlaxSwitchTransformersEncoderModel + +[[autodoc]] FlaxSwitchTransformersEncoderModel + - __call__ diff --git a/docs/source/en/serialization.mdx b/docs/source/en/serialization.mdx index 903d35da4c4cd..2f62306914fe4 100644 --- a/docs/source/en/serialization.mdx +++ b/docs/source/en/serialization.mdx @@ -94,6 +94,7 @@ Ready-made configurations include the following architectures: - RoFormer - SegFormer - SqueezeBERT +- SwitchTransformers - T5 - ViT - XLM diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index e5e6e6c171c0d..23a34d2011927 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -333,6 +333,7 @@ "models.squeezebert": ["SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "SqueezeBertConfig", "SqueezeBertTokenizer"], "models.swin": ["SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP", "SwinConfig"], "models.swinv2": ["SWINV2_PRETRAINED_CONFIG_ARCHIVE_MAP", "Swinv2Config"], + "models.switchtransformers": ["SWITCHTRANSFORMERS_PRETRAINED_CONFIG_ARCHIVE_MAP", "SwitchTransformersConfig"], "models.t5": ["T5_PRETRAINED_CONFIG_ARCHIVE_MAP", "T5Config"], "models.tapas": ["TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP", "TapasConfig", "TapasTokenizer"], "models.tapex": ["TapexTokenizer"], @@ -533,6 +534,7 @@ _import_structure["models.rembert"].append("RemBertTokenizer") _import_structure["models.speech_to_text"].append("Speech2TextTokenizer") _import_structure["models.t5"].append("T5Tokenizer") + _import_structure["models.switchtransformers"].append("SwitchTransformersTokenizer") _import_structure["models.xglm"].append("XGLMTokenizer") _import_structure["models.xlm_prophetnet"].append("XLMProphetNetTokenizer") _import_structure["models.xlm_roberta"].append("XLMRobertaTokenizer") @@ -601,6 +603,7 @@ _import_structure["models.roformer"].append("RoFormerTokenizerFast") _import_structure["models.splinter"].append("SplinterTokenizerFast") _import_structure["models.squeezebert"].append("SqueezeBertTokenizerFast") + _import_structure["models.switchtransformers"].append("SwitchTransformersTokenizerFast") _import_structure["models.t5"].append("T5TokenizerFast") _import_structure["models.xglm"].append("XGLMTokenizerFast") _import_structure["models.xlm_roberta"].append("XLMRobertaTokenizerFast") @@ -1913,6 +1916,16 @@ "load_tf_weights_in_t5", ] ) + _import_structure["models.switchtransformers"].extend( + [ + "SWITCHTRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST", + "SwitchTransformersEncoderModel", + "SwitchTransformersForConditionalGeneration", + "SwitchTransformersModel", + "SwitchTransformersPreTrainedModel", + "load_tf_weights_in_switchtransformers", + ] + ) _import_structure["models.trajectory_transformer"].extend( [ "TRAJECTORY_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -2688,6 +2701,15 @@ "TFT5PreTrainedModel", ] ) + _import_structure["models.switchtransformers"].extend( + [ + "TF_SWITCHTRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST", + "TFSwitchTransformersEncoderModel", + "TFSwitchTransformersForConditionalGeneration", + "TFSwitchTransformersModel", + "TFSwitchTransformersPreTrainedModel", + ] + ) _import_structure["models.tapas"].extend( [ "TF_TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -3008,6 +3030,14 @@ _import_structure["models.t5"].extend( ["FlaxT5EncoderModel", "FlaxT5ForConditionalGeneration", "FlaxT5Model", "FlaxT5PreTrainedModel"] ) + _import_structure["models.switchtransformers"].extend( + [ + "FlaxSwitchTransformersEncoderModel", + "FlaxSwitchTransformersForConditionalGeneration", + "FlaxSwitchTransformersModel", + "FlaxSwitchTransformersPreTrainedModel", + ] + ) _import_structure["models.vision_encoder_decoder"].append("FlaxVisionEncoderDecoderModel") _import_structure["models.vision_text_dual_encoder"].extend(["FlaxVisionTextDualEncoderModel"]) _import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel", "FlaxViTPreTrainedModel"]) @@ -3296,6 +3326,7 @@ from .models.squeezebert import SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, SqueezeBertConfig, SqueezeBertTokenizer from .models.swin import SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP, SwinConfig from .models.swinv2 import SWINV2_PRETRAINED_CONFIG_ARCHIVE_MAP, Swinv2Config + from .models.switchtransformers import SWITCHTRANSFORMERS_PRETRAINED_CONFIG_ARCHIVE_MAP, SwitchTransformersConfig from .models.t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config from .models.tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig, TapasTokenizer from .models.tapex import TapexTokenizer @@ -3477,6 +3508,7 @@ from .models.reformer import ReformerTokenizer from .models.rembert import RemBertTokenizer from .models.speech_to_text import Speech2TextTokenizer + from .models.switchtransformers import SwitchTransformersTokenizer from .models.t5 import T5Tokenizer from .models.xglm import XGLMTokenizer from .models.xlm_prophetnet import XLMProphetNetTokenizer @@ -3539,6 +3571,7 @@ from .models.roformer import RoFormerTokenizerFast from .models.splinter import SplinterTokenizerFast from .models.squeezebert import SqueezeBertTokenizerFast + from .models.switchtransformers import SwitchTransformersTokenizerFast from .models.t5 import T5TokenizerFast from .models.xglm import XGLMTokenizerFast from .models.xlm_roberta import XLMRobertaTokenizerFast @@ -4586,6 +4619,14 @@ Swinv2Model, Swinv2PreTrainedModel, ) + from .models.switchtransformers import ( + SWITCHTRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST, + SwitchTransformersEncoderModel, + SwitchTransformersForConditionalGeneration, + SwitchTransformersModel, + SwitchTransformersPreTrainedModel, + load_tf_weights_in_switchtransformers, + ) from .models.t5 import ( T5_PRETRAINED_MODEL_ARCHIVE_LIST, T5EncoderModel, @@ -5228,6 +5269,13 @@ TFSwinModel, TFSwinPreTrainedModel, ) + from .models.switchtransformers import ( + TF_SWITCHTRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST, + TFSwitchTransformersEncoderModel, + TFSwitchTransformersForConditionalGeneration, + TFSwitchTransformersModel, + TFSwitchTransformersPreTrainedModel, + ) from .models.t5 import ( TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST, TFT5EncoderModel, @@ -5475,6 +5523,12 @@ FlaxRoFormerPreTrainedModel, ) from .models.speech_encoder_decoder import FlaxSpeechEncoderDecoderModel + from .models.switchtransformers import ( + FlaxSwitchTransformersEncoderModel, + FlaxSwitchTransformersForConditionalGeneration, + FlaxSwitchTransformersModel, + FlaxSwitchTransformersPreTrainedModel, + ) from .models.t5 import FlaxT5EncoderModel, FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel from .models.vision_encoder_decoder import FlaxVisionEncoderDecoderModel from .models.vision_text_dual_encoder import FlaxVisionTextDualEncoderModel diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index f7979e0a77b12..34c0c50745273 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -136,6 +136,7 @@ squeezebert, swin, swinv2, + switchtransformers, t5, tapas, tapex, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 14fa334b57978..6aac35af15d6b 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -132,6 +132,7 @@ ("squeezebert", "SqueezeBertConfig"), ("swin", "SwinConfig"), ("swinv2", "Swinv2Config"), + ("switchtransformers", "SwitchTransformersConfig"), ("t5", "T5Config"), ("tapas", "TapasConfig"), ("time_series_transformer", "TimeSeriesTransformerConfig"), @@ -261,6 +262,7 @@ ("squeezebert", "SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("swin", "SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("swinv2", "SWINV2_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("switchtransformers", "SWITCHTRANSFORMERS_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("t5", "T5_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("tapas", "TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("time_series_transformer", "TIME_SERIES_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), @@ -410,6 +412,7 @@ ("squeezebert", "SqueezeBERT"), ("swin", "Swin Transformer"), ("swinv2", "Swin Transformer V2"), + ("switchtransformers", "SwitchTransformers"), ("t5", "T5"), ("t5v1.1", "T5v1.1"), ("tapas", "TAPAS"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 8821cfb6c93e9..d28353e0bae37 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -128,6 +128,7 @@ ("squeezebert", "SqueezeBertModel"), ("swin", "SwinModel"), ("swinv2", "Swinv2Model"), + ("switchtransformers", "SwitchTransformersModel"), ("t5", "T5Model"), ("tapas", "TapasModel"), ("time_series_transformer", "TimeSeriesTransformerModel"), @@ -195,6 +196,7 @@ ("roberta", "RobertaForMaskedLM"), ("splinter", "SplinterForPreTraining"), ("squeezebert", "SqueezeBertForMaskedLM"), + ("switchtransformers", "SwitchTransformersForConditionalGeneration"), ("t5", "T5ForConditionalGeneration"), ("tapas", "TapasForMaskedLM"), ("transfo-xl", "TransfoXLLMHeadModel"), @@ -269,6 +271,7 @@ ("roformer", "RoFormerForMaskedLM"), ("speech_to_text", "Speech2TextForConditionalGeneration"), ("squeezebert", "SqueezeBertForMaskedLM"), + ("switchtransformers", "SwitchTransformersForConditionalGeneration"), ("t5", "T5ForConditionalGeneration"), ("tapas", "TapasForMaskedLM"), ("transfo-xl", "TransfoXLLMHeadModel"), @@ -491,6 +494,7 @@ ("pegasus_x", "PegasusXForConditionalGeneration"), ("plbart", "PLBartForConditionalGeneration"), ("prophetnet", "ProphetNetForConditionalGeneration"), + ("switchtransformers", "SwitchTransformersForConditionalGeneration"), ("t5", "T5ForConditionalGeneration"), ("xlm-prophetnet", "XLMProphetNetForConditionalGeneration"), ] diff --git a/src/transformers/models/auto/modeling_flax_auto.py b/src/transformers/models/auto/modeling_flax_auto.py index 98c5d6fb5a104..5850797ea769f 100644 --- a/src/transformers/models/auto/modeling_flax_auto.py +++ b/src/transformers/models/auto/modeling_flax_auto.py @@ -49,6 +49,7 @@ ("pegasus", "FlaxPegasusModel"), ("roberta", "FlaxRobertaModel"), ("roformer", "FlaxRoFormerModel"), + ("switchtransformers", "FlaxSwitchTransformersModel"), ("t5", "FlaxT5Model"), ("vision-text-dual-encoder", "FlaxVisionTextDualEncoderModel"), ("vit", "FlaxViTModel"), @@ -71,6 +72,7 @@ ("mt5", "FlaxMT5ForConditionalGeneration"), ("roberta", "FlaxRobertaForMaskedLM"), ("roformer", "FlaxRoFormerForMaskedLM"), + ("switchtransformers", "FlaxSwitchTransformersForConditionalGeneration"), ("t5", "FlaxT5ForConditionalGeneration"), ("wav2vec2", "FlaxWav2Vec2ForPreTraining"), ("xlm-roberta", "FlaxXLMRobertaForMaskedLM"), @@ -105,6 +107,7 @@ ("mbart", "FlaxMBartForConditionalGeneration"), ("mt5", "FlaxMT5ForConditionalGeneration"), ("pegasus", "FlaxPegasusForConditionalGeneration"), + ("switchtransformers", "FlaxSwitchTransformersForConditionalGeneration"), ("t5", "FlaxT5ForConditionalGeneration"), ] ) diff --git a/src/transformers/models/auto/modeling_tf_auto.py b/src/transformers/models/auto/modeling_tf_auto.py index e13a0754b6926..0d8c9280cdeb6 100644 --- a/src/transformers/models/auto/modeling_tf_auto.py +++ b/src/transformers/models/auto/modeling_tf_auto.py @@ -74,6 +74,7 @@ ("segformer", "TFSegformerModel"), ("speech_to_text", "TFSpeech2TextModel"), ("swin", "TFSwinModel"), + ("switchtransformers", "TFSwitchTransformersModel"), ("t5", "TFT5Model"), ("tapas", "TFTapasModel"), ("transfo-xl", "TFTransfoXLModel"), @@ -106,6 +107,7 @@ ("mpnet", "TFMPNetForMaskedLM"), ("openai-gpt", "TFOpenAIGPTLMHeadModel"), ("roberta", "TFRobertaForMaskedLM"), + ("switchtransformers", "TFSwitchTransformersForConditionalGeneration"), ("t5", "TFT5ForConditionalGeneration"), ("tapas", "TFTapasForMaskedLM"), ("transfo-xl", "TFTransfoXLLMHeadModel"), @@ -142,6 +144,7 @@ ("roberta", "TFRobertaForMaskedLM"), ("roformer", "TFRoFormerForMaskedLM"), ("speech_to_text", "TFSpeech2TextForConditionalGeneration"), + ("switchtransformers", "TFSwitchTransformersForConditionalGeneration"), ("t5", "TFT5ForConditionalGeneration"), ("tapas", "TFTapasForMaskedLM"), ("transfo-xl", "TFTransfoXLLMHeadModel"), @@ -246,6 +249,7 @@ ("mbart", "TFMBartForConditionalGeneration"), ("mt5", "TFMT5ForConditionalGeneration"), ("pegasus", "TFPegasusForConditionalGeneration"), + ("switchtransformers", "TFSwitchTransformersForConditionalGeneration"), ("t5", "TFT5ForConditionalGeneration"), ] ) diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 97e048885e180..b205d45a72803 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -240,6 +240,13 @@ "squeezebert", ("SqueezeBertTokenizer", "SqueezeBertTokenizerFast" if is_tokenizers_available() else None), ), + ( + "switchtransformers", + ( + "SwitchTransformersTokenizer" if is_sentencepiece_available() else None, + "SwitchTransformersTokenizerFast" if is_tokenizers_available() else None, + ), + ), ( "t5", ( diff --git a/src/transformers/models/switchtransformers/__init__.py b/src/transformers/models/switchtransformers/__init__.py new file mode 100644 index 0000000000000..0f656af9dfdfb --- /dev/null +++ b/src/transformers/models/switchtransformers/__init__.py @@ -0,0 +1,158 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ...utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_flax_available, + is_sentencepiece_available, + is_tf_available, + is_tokenizers_available, + is_torch_available, +) + + +_import_structure = {"configuration_switchtransformers": ["SWITCHTRANSFORMERS_PRETRAINED_CONFIG_ARCHIVE_MAP", "SwitchTransformersConfig", "SwitchTransformersOnnxConfig"]} + +try: + if not is_sentencepiece_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["tokenization_switchtransformers"] = ["SwitchTransformersTokenizer"] + +try: + if not is_tokenizers_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["tokenization_switchtransformers_fast"] = ["SwitchTransformersTokenizerFast"] + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_switchtransformers"] = [ + "SWITCHTRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST", + "SwitchTransformersEncoderModel", + "SwitchTransformersForConditionalGeneration", + "SwitchTransformersModel", + "SwitchTransformersPreTrainedModel", + "load_tf_weights_in_switchtransformers", + ] + +try: + if not is_tf_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_tf_switchtransformers"] = [ + "TF_SWITCHTRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST", + "TFSwitchTransformersEncoderModel", + "TFSwitchTransformersForConditionalGeneration", + "TFSwitchTransformersModel", + "TFSwitchTransformersPreTrainedModel", + ] + +try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_flax_switchtransformers"] = [ + "FlaxSwitchTransformersEncoderModel", + "FlaxSwitchTransformersForConditionalGeneration", + "FlaxSwitchTransformersModel", + "FlaxSwitchTransformersPreTrainedModel", + ] + + +if TYPE_CHECKING: + from .configuration_switchtransformers import SWITCHTRANSFORMERS_PRETRAINED_CONFIG_ARCHIVE_MAP, SwitchTransformersConfig, SwitchTransformersOnnxConfig + + try: + if not is_sentencepiece_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .tokenization_switchtransformers import SwitchTransformersTokenizer + + try: + if not is_tokenizers_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .tokenization_switchtransformers_fast import SwitchTransformersTokenizerFast + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_switchtransformers import ( + SWITCHTRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST, + SwitchTransformersEncoderModel, + SwitchTransformersForConditionalGeneration, + SwitchTransformersModel, + SwitchTransformersPreTrainedModel, + load_tf_weights_in_switchtransformers, + ) + + try: + if not is_tf_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_tf_switchtransformers import ( + TF_SWITCHTRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST, + TFSwitchTransformersEncoderModel, + TFSwitchTransformersForConditionalGeneration, + TFSwitchTransformersModel, + TFSwitchTransformersPreTrainedModel, + ) + + try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_flax_switchtransformers import ( + FlaxSwitchTransformersEncoderModel, + FlaxSwitchTransformersForConditionalGeneration, + FlaxSwitchTransformersModel, + FlaxSwitchTransformersPreTrainedModel, + ) + + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/switchtransformers/configuration_switchtransformers.py b/src/transformers/models/switchtransformers/configuration_switchtransformers.py new file mode 100644 index 0000000000000..0056bc7758a96 --- /dev/null +++ b/src/transformers/models/switchtransformers/configuration_switchtransformers.py @@ -0,0 +1,168 @@ +# coding=utf-8 +# Copyright 2020, The SwitchTransformers Authors and HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" SwitchTransformers model configuration""" +from typing import Mapping + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxSeq2SeqConfigWithPast +from ...utils import logging + + +logger = logging.get_logger(__name__) + +SWITCHTRANSFORMERS_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "ybelkada/switchtransformers-base": "https://huggingface.co/ybelkada/switchtransformers-base/resolve/main/config.json", +} + + + +class SwitchTransformersConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SwitchTransformersModel`] or a + [`TFSwitchTransformersModel`]. It is used to instantiate a SwitchTransformers model according to the specified + arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar + configuration to that of the SwitchTransformers + [ybelkada/switchtransformers-base](https://huggingface.co/ybelkada/switchtransformers-base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Arguments: + vocab_size (`int`, *optional*, defaults to 32128): + Vocabulary size of the SwitchTransformers model. Defines the number of different tokens that can be + represented by the `inputs_ids` passed when calling [`SwitchTransformersModel`] or + [`TFSwitchTransformersModel`]. + d_model (`int`, *optional*, defaults to 512): + Size of the encoder layers and the pooler layer. + d_kv (`int`, *optional*, defaults to 64): + Size of the key, query, value projections per attention head. `d_kv` has to be equal to `d_model // + num_heads`. + d_ff (`int`, *optional*, defaults to 2048): + Size of the intermediate feed forward layer in each `SwitchTransformersBlock`. + num_layers (`int`, *optional*, defaults to 6): + Number of hidden layers in the Transformer encoder. + num_decoder_layers (`int`, *optional*): + Number of hidden layers in the Transformer decoder. Will use the same value as `num_layers` if not set. + num_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + num_experts (`int`, *optional*, defaults to 8): + Number of experts for each SwitchTransformer layer. + relative_attention_num_buckets (`int`, *optional*, defaults to 32): + The number of buckets to use for each attention layer. + relative_attention_max_distance (`int`, *optional*, defaults to 128): + The maximum distance of the longer sequences for the bucket separation. + dropout_rate (`float`, *optional*, defaults to 0.1): + The ratio for all dropout layers. + layer_norm_eps (`float`, *optional*, defaults to 1e-6): + The epsilon used by the layer normalization layers. + initializer_factor (`float`, *optional*, defaults to 1): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + feed_forward_proj (`string`, *optional*, defaults to `"relu"`): + Type of feed forward layer to be used. Should be one of `"relu"` or `"gated-gelu"`. SwitchTransformersv1.1 + uses the `"gated-gelu"` feed forward projection. Original SwitchTransformers uses `"relu"`. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + """ + model_type = "switchtransformers" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"} + + def __init__( + self, + vocab_size=32128, + d_model=512, + d_kv=64, + d_ff=2048, + num_layers=6, + num_decoder_layers=None, + num_heads=8, + num_experts=8, + relative_attention_num_buckets=32, + relative_attention_max_distance=128, + dropout_rate=0.1, + layer_norm_epsilon=1e-6, + initializer_factor=1.0, + feed_forward_proj="relu", + is_encoder_decoder=True, + use_cache=True, + pad_token_id=0, + eos_token_id=1, + **kwargs + ): + self.vocab_size = vocab_size + self.d_model = d_model + self.d_kv = d_kv + self.d_ff = d_ff + self.num_layers = num_layers + self.num_decoder_layers = ( + num_decoder_layers if num_decoder_layers is not None else self.num_layers + ) # default = symmetry + self.num_heads = num_heads + self.num_experts = num_experts + self.relative_attention_num_buckets = relative_attention_num_buckets + self.relative_attention_max_distance = relative_attention_max_distance + self.dropout_rate = dropout_rate + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_factor = initializer_factor + self.feed_forward_proj = feed_forward_proj + self.use_cache = use_cache + + act_info = self.feed_forward_proj.split("-") + self.dense_act_fn = act_info[-1] + self.is_gated_act = act_info[0] == "gated" + + if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2: + raise ValueError( + f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer." + "Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. " + "'gated-gelu' or 'relu'" + ) + + # for backwards compatibility + if feed_forward_proj == "gated-gelu": + self.dense_act_fn = "gelu_new" + + super().__init__( + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + is_encoder_decoder=is_encoder_decoder, + **kwargs, + ) + + +class SwitchTransformersOnnxConfig(OnnxSeq2SeqConfigWithPast): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + common_inputs = { + "input_ids": {0: "batch", 1: "encoder_sequence"}, + "attention_mask": {0: "batch", 1: "encoder_sequence"}, + } + if self.use_past: + common_inputs["attention_mask"][1] = "past_encoder_sequence + sequence" + common_inputs["decoder_input_ids"] = {0: "batch"} + common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"} + else: + common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"} + common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"} + + if self.use_past: + self.fill_with_past_key_values_(common_inputs, direction="inputs") + + return common_inputs + + @property + def default_onnx_opset(self) -> int: + return 13 diff --git a/src/transformers/models/switchtransformers/convert_switchtransformers_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/switchtransformers/convert_switchtransformers_original_tf_checkpoint_to_pytorch.py new file mode 100644 index 0000000000000..fbb0edd2ce29e --- /dev/null +++ b/src/transformers/models/switchtransformers/convert_switchtransformers_original_tf_checkpoint_to_pytorch.py @@ -0,0 +1,60 @@ +# coding=utf-8 +# Copyright 2022 The SwitchTransformers authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert SwitchTransformers checkpoint.""" + + +import argparse + +from transformers import SwitchTransformersConfig, SwitchTransformersForConditionalGeneration, load_tf_weights_in_switchtransformers +from transformers.utils import logging + + +logging.set_verbosity_info() + + +def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path): + # Initialise PyTorch model + config = SwitchTransformersConfig.from_json_file(config_file) + print(f"Building PyTorch model from configuration: {config}") + model = SwitchTransformersForConditionalGeneration(config) + + # Load weights from tf checkpoint + load_tf_weights_in_switchtransformers(model, config, tf_checkpoint_path) + + # Save pytorch-model + print(f"Save PyTorch model to {pytorch_dump_path}") + model.save_pretrained(pytorch_dump_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." + ) + parser.add_argument( + "--config_file", + default=None, + type=str, + required=True, + help=( + "The config json file corresponding to the pre-trained SwitchTransformers model. \nThis specifies the model architecture." + ), + ) + parser.add_argument( + "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) + args = parser.parse_args() + convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path) diff --git a/src/transformers/models/switchtransformers/convert_switchtransformersx_checkpoint_to_flax.py b/src/transformers/models/switchtransformers/convert_switchtransformersx_checkpoint_to_flax.py new file mode 100644 index 0000000000000..90d923d623025 --- /dev/null +++ b/src/transformers/models/switchtransformers/convert_switchtransformersx_checkpoint_to_flax.py @@ -0,0 +1,234 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Convert SwitchTransformersX checkpoints from the original repository to JAX/FLAX model.""" + +import argparse + +from switchtransformersx import checkpoints +from transformers import FlaxSwitchTransformersForConditionalGeneration, SwitchTransformersConfig + + +def convert_switchtransformersx_checkpoint_to_flax(switchtransformersx_checkpoint_path, config_name, flax_dump_folder_path): + config = SwitchTransformersConfig.from_pretrained(config_name) + flax_model = FlaxSwitchTransformersForConditionalGeneration(config=config) + switchtransformersx_model = checkpoints.load_switchtransformersx_checkpoint(switchtransformersx_checkpoint_path) + + split_mlp_wi = "wi_0" in switchtransformersx_model["target"]["encoder"]["layers_0"]["mlp"] + + # Encoder + for layer_index in range(config.num_layers): + layer_name = f"layers_{str(layer_index)}" + + # Self-Attention + switchtransformersx_attention_key = switchtransformersx_model["target"]["encoder"][layer_name]["attention"]["key"]["kernel"] + switchtransformersx_attention_out = switchtransformersx_model["target"]["encoder"][layer_name]["attention"]["out"]["kernel"] + switchtransformersx_attention_query = switchtransformersx_model["target"]["encoder"][layer_name]["attention"]["query"]["kernel"] + switchtransformersx_attention_value = switchtransformersx_model["target"]["encoder"][layer_name]["attention"]["value"]["kernel"] + + # Layer Normalization + switchtransformersx_attention_layer_norm = switchtransformersx_model["target"]["encoder"][layer_name]["pre_attention_layer_norm"]["scale"] + + if split_mlp_wi: + switchtransformersx_mlp_wi_0 = switchtransformersx_model["target"]["encoder"][layer_name]["mlp"]["wi_0"]["kernel"] + switchtransformersx_mlp_wi_1 = switchtransformersx_model["target"]["encoder"][layer_name]["mlp"]["wi_1"]["kernel"] + else: + switchtransformersx_mlp_wi = switchtransformersx_model["target"]["encoder"][layer_name]["mlp"]["wi"]["kernel"] + + switchtransformersx_mlp_wo = switchtransformersx_model["target"]["encoder"][layer_name]["mlp"]["wo"]["kernel"] + + # Layer Normalization + switchtransformersx_mlp_layer_norm = switchtransformersx_model["target"]["encoder"][layer_name]["pre_mlp_layer_norm"]["scale"] + + # Assigning + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["k"][ + "kernel" + ] = switchtransformersx_attention_key + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["o"][ + "kernel" + ] = switchtransformersx_attention_out + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["q"][ + "kernel" + ] = switchtransformersx_attention_query + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["v"][ + "kernel" + ] = switchtransformersx_attention_value + + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["layer_norm"][ + "weight" + ] = switchtransformersx_attention_layer_norm + + if split_mlp_wi: + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi_0"][ + "kernel" + ] = switchtransformersx_mlp_wi_0 + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi_1"][ + "kernel" + ] = switchtransformersx_mlp_wi_1 + else: + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi"][ + "kernel" + ] = switchtransformersx_mlp_wi + + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wo"][ + "kernel" + ] = switchtransformersx_mlp_wo + flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["layer_norm"][ + "weight" + ] = switchtransformersx_mlp_layer_norm + + # Only for layer 0: + switchtransformersx_encoder_rel_embedding = switchtransformersx_model["target"]["encoder"]["relpos_bias"]["rel_embedding"].T + flax_model.params["encoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"][ + "embedding" + ] = switchtransformersx_encoder_rel_embedding + + # Assigning + switchtransformersx_encoder_norm = switchtransformersx_model["target"]["encoder"]["encoder_norm"]["scale"] + flax_model.params["encoder"]["final_layer_norm"]["weight"] = switchtransformersx_encoder_norm + + # Decoder + for layer_index in range(config.num_decoder_layers): + layer_name = f"layers_{str(layer_index)}" + + # Self-Attention + switchtransformersx_attention_key = switchtransformersx_model["target"]["decoder"][layer_name]["self_attention"]["key"]["kernel"] + switchtransformersx_attention_out = switchtransformersx_model["target"]["decoder"][layer_name]["self_attention"]["out"]["kernel"] + switchtransformersx_attention_query = switchtransformersx_model["target"]["decoder"][layer_name]["self_attention"]["query"]["kernel"] + switchtransformersx_attention_value = switchtransformersx_model["target"]["decoder"][layer_name]["self_attention"]["value"]["kernel"] + + # Layer Normalization + switchtransformersx_pre_attention_layer_norm = switchtransformersx_model["target"]["decoder"][layer_name]["pre_self_attention_layer_norm"][ + "scale" + ] + + # Encoder-Decoder-Attention + switchtransformersx_enc_dec_attention_key = switchtransformersx_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["key"][ + "kernel" + ] + switchtransformersx_enc_dec_attention_out = switchtransformersx_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["out"][ + "kernel" + ] + switchtransformersx_enc_dec_attention_query = switchtransformersx_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["query"][ + "kernel" + ] + switchtransformersx_enc_dec_attention_value = switchtransformersx_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["value"][ + "kernel" + ] + + # Layer Normalization + switchtransformersx_cross_layer_norm = switchtransformersx_model["target"]["decoder"][layer_name]["pre_cross_attention_layer_norm"]["scale"] + + # MLP + if split_mlp_wi: + switchtransformersx_mlp_wi_0 = switchtransformersx_model["target"]["decoder"][layer_name]["mlp"]["wi_0"]["kernel"] + switchtransformersx_mlp_wi_1 = switchtransformersx_model["target"]["decoder"][layer_name]["mlp"]["wi_1"]["kernel"] + else: + switchtransformersx_mlp_wi = switchtransformersx_model["target"]["decoder"][layer_name]["mlp"]["wi"]["kernel"] + + switchtransformersx_mlp_wo = switchtransformersx_model["target"]["decoder"][layer_name]["mlp"]["wo"]["kernel"] + + # Layer Normalization + tx5_mlp_layer_norm = switchtransformersx_model["target"]["decoder"][layer_name]["pre_mlp_layer_norm"]["scale"] + + # Assigning + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["k"][ + "kernel" + ] = switchtransformersx_attention_key + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["o"][ + "kernel" + ] = switchtransformersx_attention_out + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["q"][ + "kernel" + ] = switchtransformersx_attention_query + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["v"][ + "kernel" + ] = switchtransformersx_attention_value + + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["layer_norm"][ + "weight" + ] = switchtransformersx_pre_attention_layer_norm + + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["k"][ + "kernel" + ] = switchtransformersx_enc_dec_attention_key + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["o"][ + "kernel" + ] = switchtransformersx_enc_dec_attention_out + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["q"][ + "kernel" + ] = switchtransformersx_enc_dec_attention_query + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["v"][ + "kernel" + ] = switchtransformersx_enc_dec_attention_value + + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["layer_norm"][ + "weight" + ] = switchtransformersx_cross_layer_norm + + if split_mlp_wi: + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi_0"][ + "kernel" + ] = switchtransformersx_mlp_wi_0 + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi_1"][ + "kernel" + ] = switchtransformersx_mlp_wi_1 + else: + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi"][ + "kernel" + ] = switchtransformersx_mlp_wi + + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wo"][ + "kernel" + ] = switchtransformersx_mlp_wo + + flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["layer_norm"][ + "weight" + ] = tx5_mlp_layer_norm + + # Decoder Normalization + tx5_decoder_norm = switchtransformersx_model["target"]["decoder"]["decoder_norm"]["scale"] + flax_model.params["decoder"]["final_layer_norm"]["weight"] = tx5_decoder_norm + + # Only for layer 0: + switchtransformersx_decoder_rel_embedding = switchtransformersx_model["target"]["decoder"]["relpos_bias"]["rel_embedding"].T + flax_model.params["decoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"][ + "embedding" + ] = switchtransformersx_decoder_rel_embedding + + # Token Embeddings + tx5_token_embeddings = switchtransformersx_model["target"]["token_embedder"]["embedding"] + flax_model.params["shared"]["embedding"] = tx5_token_embeddings + + # LM Head (only in v1.1 checkpoints) + if "logits_dense" in switchtransformersx_model["target"]["decoder"]: + flax_model.params["lm_head"]["kernel"] = switchtransformersx_model["target"]["decoder"]["logits_dense"]["kernel"] + + flax_model.save_pretrained(flax_dump_folder_path) + print("SwitchTransformersX Model was sucessfully converted!") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--switchtransformersx_checkpoint_path", default=None, type=str, required=True, help="Path the TX5 checkpoint." + ) + parser.add_argument("--config_name", default=None, type=str, required=True, help="Config name of SwitchTransformers model.") + parser.add_argument( + "--flax_dump_folder_path", default=None, type=str, required=True, help="Path to the output FLAX model." + ) + args = parser.parse_args() + convert_switchtransformersx_checkpoint_to_flax(args.switchtransformersx_checkpoint_path, args.config_name, args.flax_dump_folder_path) diff --git a/src/transformers/models/switchtransformers/modeling_flax_switchtransformers.py b/src/transformers/models/switchtransformers/modeling_flax_switchtransformers.py new file mode 100644 index 0000000000000..de48b929ad5ac --- /dev/null +++ b/src/transformers/models/switchtransformers/modeling_flax_switchtransformers.py @@ -0,0 +1,1826 @@ +# coding=utf-8 +# Copyright 2022 SwitchTransformers Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Flax SwitchTransformers model.""" + + +import copy +from typing import Callable, Optional, Tuple + +import numpy as np + +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.linen import combine_masks, make_causal_mask +from flax.linen import partitioning as nn_partitioning +from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict +from jax.random import PRNGKey + +from ...modeling_flax_outputs import ( + FlaxBaseModelOutput, + FlaxBaseModelOutputWithPastAndCrossAttentions, + FlaxCausalLMOutputWithCrossAttentions, + FlaxSeq2SeqLMOutput, + FlaxSeq2SeqModelOutput, +) +from ...modeling_flax_utils import ( + ACT2FN, + FlaxPreTrainedModel, + append_call_sample_docstring, + append_replace_return_docstrings, + overwrite_call_docstring, +) +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from .configuration_switchtransformers import SwitchTransformersConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "ybelkada/switchtransformers-base" +_CONFIG_FOR_DOC = "SwitchTransformersConfig" +_TOKENIZER_FOR_DOC = "SwitchTransformersTokenizer" + +remat = nn_partitioning.remat + + +# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right +def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray: + """ + Shift input ids one token to the right. + """ + shifted_input_ids = np.zeros_like(input_ids) + shifted_input_ids[:, 1:] = input_ids[:, :-1] + shifted_input_ids[:, 0] = decoder_start_token_id + + shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) + return shifted_input_ids + + +# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5LayerNorm with T5->SwitchTransformers +class FlaxSwitchTransformersLayerNorm(nn.Module): + hidden_size: int + dtype: jnp.dtype = jnp.float32 + eps: float = 1e-6 + weight_init: Callable[..., np.ndarray] = jax.nn.initializers.ones + + def setup(self): + self.weight = self.param("weight", self.weight_init, (self.hidden_size,)) + + def __call__(self, hidden_states): + """ + Construct a layernorm module in the SwitchTransformers style; No bias and no subtraction of mean. + """ + # layer norm should always be calculated in float32 + variance = jnp.power(hidden_states.astype("f4"), 2).mean(axis=-1, keepdims=True) + hidden_states = hidden_states / jnp.sqrt(variance + self.eps) + + return self.weight * hidden_states + + +# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5DenseActDense with T5->SwitchTransformers +class FlaxSwitchTransformersDenseActDense(nn.Module): + config: SwitchTransformersConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5) + wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5) + + self.wi = nn.Dense( + self.config.d_ff, + use_bias=False, + kernel_init=jax.nn.initializers.normal(wi_init_std), + dtype=self.dtype, + ) + self.wo = nn.Dense( + self.config.d_model, + use_bias=False, + kernel_init=jax.nn.initializers.normal(wo_init_std), + dtype=self.dtype, + ) + self.dropout = nn.Dropout(self.config.dropout_rate) + self.act = ACT2FN[self.config.dense_act_fn] + + def __call__(self, hidden_states, deterministic=True): + hidden_states = self.wi(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.wo(hidden_states) + return hidden_states + + +# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5DenseGatedActDense with T5->SwitchTransformers +class FlaxSwitchTransformersDenseGatedActDense(nn.Module): + config: SwitchTransformersConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5) + wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5) + + self.wi_0 = nn.Dense( + self.config.d_ff, + use_bias=False, + kernel_init=jax.nn.initializers.normal(wi_init_std), + dtype=self.dtype, + ) + self.wi_1 = nn.Dense( + self.config.d_ff, + use_bias=False, + kernel_init=jax.nn.initializers.normal(wi_init_std), + dtype=self.dtype, + ) + self.wo = nn.Dense( + self.config.d_model, + use_bias=False, + kernel_init=jax.nn.initializers.normal(wo_init_std), + dtype=self.dtype, + ) + self.dropout = nn.Dropout(self.config.dropout_rate) + self.act = ACT2FN[self.config.dense_act_fn] + + def __call__(self, hidden_states, deterministic): + hidden_gelu = self.act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.wo(hidden_states) + return hidden_states + + +# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5LayerFF with T5->SwitchTransformers +class FlaxSwitchTransformersLayerFF(nn.Module): + config: SwitchTransformersConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + if self.config.is_gated_act: + self.DenseReluDense = FlaxSwitchTransformersDenseGatedActDense(self.config, dtype=self.dtype) + else: + self.DenseReluDense = FlaxSwitchTransformersDenseActDense(self.config, dtype=self.dtype) + + self.layer_norm = FlaxSwitchTransformersLayerNorm( + self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype + ) + self.dropout = nn.Dropout(self.config.dropout_rate) + + def __call__(self, hidden_states, deterministic=True): + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states, deterministic=deterministic) + hidden_states = hidden_states + self.dropout(forwarded_states, deterministic=deterministic) + return hidden_states + + +# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5Attention with T5->SwitchTransformers +class FlaxSwitchTransformersAttention(nn.Module): + config: SwitchTransformersConfig + has_relative_attention_bias: bool = False + causal: bool = False + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.relative_attention_num_buckets = self.config.relative_attention_num_buckets + self.relative_attention_max_distance = self.config.relative_attention_max_distance + self.d_model = self.config.d_model + self.key_value_proj_dim = self.config.d_kv + self.n_heads = self.config.num_heads + self.dropout = self.config.dropout_rate + self.inner_dim = self.n_heads * self.key_value_proj_dim + + q_init_std = self.config.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5) + kv_init_std = self.config.initializer_factor * (self.inner_dim**-0.5) + o_init_std = self.config.initializer_factor * (self.inner_dim**-0.5) + + self.q = nn.Dense( + self.inner_dim, + use_bias=False, + kernel_init=jax.nn.initializers.normal(q_init_std), + dtype=self.dtype, + ) + self.k = nn.Dense( + self.inner_dim, + use_bias=False, + kernel_init=jax.nn.initializers.normal(kv_init_std), + dtype=self.dtype, + ) + self.v = nn.Dense( + self.inner_dim, + use_bias=False, + kernel_init=jax.nn.initializers.normal(kv_init_std), + dtype=self.dtype, + ) + self.o = nn.Dense( + self.d_model, + use_bias=False, + kernel_init=jax.nn.initializers.normal(o_init_std), + dtype=self.dtype, + ) + + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embed( + self.relative_attention_num_buckets, + self.n_heads, + embedding_init=jax.nn.initializers.normal(kv_init_std), + ) + + @staticmethod + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0) * num_buckets + relative_position = jnp.abs(relative_position) + else: + relative_position = -jnp.clip(relative_position, a_max=0) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact) + ) + relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1) + + relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large) + + return relative_buckets.astype("i4") + + def compute_bias(self, query_length, key_length): + """Compute binned relative position bias""" + context_position = jnp.arange(query_length, dtype="i4")[:, None] + memory_position = jnp.arange(key_length, dtype="i4")[None, :] + + relative_position = memory_position - context_position + relative_position_bucket = self._relative_position_bucket( + relative_position, + bidirectional=(not self.causal), + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + + values = self.relative_attention_bias(relative_position_bucket) + values = values.transpose((2, 0, 1))[None, :, :, :] + return values + + def _split_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.n_heads, self.key_value_proj_dim)) + + def _merge_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.inner_dim,)) + + @nn.compact + def _concatenate_to_cache(self, key, value, query, attention_mask): + """ + This function takes projected key, value states from a single input token and concatenates the states to cached + states from previous steps. This function is slighly adapted from the official Flax repository: + https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 + """ + # detect if we're initializing by absence of existing cache data. + is_initialized = self.has_variable("cache", "cached_key") + cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) + cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) + cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) + + if is_initialized: + *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape + # update key, value caches with our new 1d spatial slices + cur_index = cache_index.value + indices = (0,) * len(batch_dims) + (cur_index, 0, 0) + key = jax.lax.dynamic_update_slice(cached_key.value, key, indices) + value = jax.lax.dynamic_update_slice(cached_value.value, value, indices) + cached_key.value = key + cached_value.value = value + num_updated_cache_vectors = query.shape[1] + cache_index.value = cache_index.value + num_updated_cache_vectors + # causal mask for cached decoder self-attention: our single query position should only attend to those key positions + # that have already been generated and cached, not the remaining zero elements. + pad_mask = jnp.broadcast_to( + jnp.arange(max_length) < cur_index + num_updated_cache_vectors, + tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), + ) + attention_mask = combine_masks(pad_mask, attention_mask) + return key, value, attention_mask + + def _create_position_bias( + self, key_states, query_states, attention_mask, init_cache, seq_length, causal_attention_mask_shift + ): + cache_is_filled = self.causal and self.has_variable("cache", "cached_key") and (not init_cache) + key_length = key_states.shape[1] + query_length = key_length if cache_is_filled else query_states.shape[1] + + if self.has_relative_attention_bias: + position_bias = self.compute_bias(query_length, key_length) + elif attention_mask is not None: + position_bias = jnp.zeros_like(attention_mask) + else: + position_bias = jnp.zeros((1, self.n_heads, query_length, key_length), dtype=self.dtype) + + # if key and values are already calculated, only the last query position bias should be taken + if cache_is_filled: + max_decoder_length = self.variables["cache"]["cached_key"].shape[1] + position_bias = jax.lax.dynamic_slice( + position_bias, + (0, 0, causal_attention_mask_shift, 0), + (1, self.n_heads, seq_length, max_decoder_length), + ) + return position_bias + + def __call__( + self, + hidden_states, + attention_mask=None, + key_value_states=None, + position_bias=None, + use_cache=False, + output_attentions=False, + deterministic=True, + init_cache=False, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + batch_size, seq_length = hidden_states.shape[:2] + + # q, k, v projections + query_states = self.q(hidden_states) # (batch_size, n_heads, seq_length, dim_per_head) + key_states = self.k(hidden_states) if key_value_states is None else self.k(key_value_states) + value_states = self.v(hidden_states) if key_value_states is None else self.v(key_value_states) + + # reshape to (batch_size, seq_length, n_heads, head_dim) + query_states = self._split_heads(query_states) + key_states = self._split_heads(key_states) + value_states = self._split_heads(value_states) + + # counter-act scaling in dot_product_attention_weights function + query_states *= jnp.sqrt(query_states.shape[-1]) + + # for fast decoding causal attention mask should be shifted + causal_attention_mask_shift = ( + self.variables["cache"]["cache_index"] if (self.has_variable("cache", "cached_key") and self.causal) else 0 + ) + # create causal attention_mask; attention_mask has to be defined when model is causal + if self.causal: + causal_attention_mask = make_causal_mask(attention_mask, dtype="bool") + + # fast decoding for generate requires special attention_mask + if self.has_variable("cache", "cached_key"): + max_decoder_length = self.variables["cache"]["cached_key"].shape[1] + causal_attention_mask = jax.lax.dynamic_slice( + causal_attention_mask, + (0, 0, causal_attention_mask_shift, 0), + (1, 1, seq_length, max_decoder_length), + ) + + # broadcast causal attention mask & attention mask to fit for merge + causal_attention_mask = jnp.broadcast_to( + causal_attention_mask, (batch_size,) + causal_attention_mask.shape[1:] + ) + attention_mask = jnp.broadcast_to( + jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_attention_mask.shape + ) + attention_mask = combine_masks(attention_mask, causal_attention_mask) + elif attention_mask is not None: + attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) + + # During fast autoregressive decoding, we feed one position at a time, + # and cache the keys and values step by step. + if self.causal and (self.has_variable("cache", "cached_key") or init_cache): + key_states, value_states, attention_attention_mask = self._concatenate_to_cache( + key_states, value_states, query_states, attention_mask + ) + + # replace masked positions with -10_000 + if attention_mask is not None: + mask_value = jnp.finfo(self.dtype).min + attention_mask = jax.lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, mask_value).astype(self.dtype), + ) + + if position_bias is None: + # compute position bias (only for first layer) + position_bias = self._create_position_bias( + key_states, query_states, attention_mask, init_cache, seq_length, causal_attention_mask_shift + ) + + if attention_mask is not None: + position_bias = position_bias + attention_mask + + # create dropout rng + dropout_rng = None + if not deterministic and self.dropout > 0.0: + dropout_rng = self.make_rng("dropout") + + # Softmax(QK^T) + attn_weights = dot_product_attention_weights( + query_states, + key_states, + bias=position_bias, + dropout_rng=dropout_rng, + dropout_rate=self.dropout, + broadcast_dropout=True, + deterministic=deterministic, + dtype=self.dtype, + ) + + # multiply with value states + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) + + # bring back to (batch_size, seq_length, d_model) + attn_output = self._merge_heads(attn_output) + + # apply output matrix + attn_output = self.o(attn_output) + + outputs = (attn_output, position_bias) + + if output_attentions: + outputs = outputs + (attn_weights,) + + return outputs + + +# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5LayerSelfAttention with T5->SwitchTransformers +class FlaxSwitchTransformersLayerSelfAttention(nn.Module): + config: SwitchTransformersConfig + has_relative_attention_bias: bool = False + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.SelfAttention = FlaxSwitchTransformersAttention( + self.config, + has_relative_attention_bias=self.has_relative_attention_bias, + causal=self.config.causal, + dtype=self.dtype, + ) + self.layer_norm = FlaxSwitchTransformersLayerNorm( + self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype + ) + self.dropout = nn.Dropout(self.config.dropout_rate) + + def __call__( + self, + hidden_states, + attention_mask=None, + position_bias=None, + output_attentions=False, + deterministic=True, + init_cache=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + normed_hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + output_attentions=output_attentions, + deterministic=deterministic, + init_cache=init_cache, + ) + hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5LayerCrossAttention with T5->SwitchTransformers +class FlaxSwitchTransformersLayerCrossAttention(nn.Module): + config: SwitchTransformersConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.EncDecAttention = FlaxSwitchTransformersAttention( + self.config, has_relative_attention_bias=False, causal=False, dtype=self.dtype + ) + self.layer_norm = FlaxSwitchTransformersLayerNorm( + self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype + ) + self.dropout = nn.Dropout(self.config.dropout_rate) + + def __call__( + self, + hidden_states, + key_value_states, + attention_mask=None, + position_bias=None, + output_attentions=False, + deterministic=True, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + normed_hidden_states, + attention_mask=attention_mask, + key_value_states=key_value_states, + position_bias=position_bias, + output_attentions=output_attentions, + ) + hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5Block with T5->SwitchTransformers +class FlaxSwitchTransformersBlock(nn.Module): + config: SwitchTransformersConfig + has_relative_attention_bias: bool = False + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.causal = self.config.causal + self.layer = ( + FlaxSwitchTransformersLayerSelfAttention( + self.config, + has_relative_attention_bias=self.has_relative_attention_bias, + name=str(0), + dtype=self.dtype, + ), + ) + feed_forward_index = 1 + if self.causal: + self.layer += (FlaxSwitchTransformersLayerCrossAttention(self.config, name=str(1), dtype=self.dtype),) + feed_forward_index += 1 + + self.layer += (FlaxSwitchTransformersLayerFF(self.config, name=str(feed_forward_index), dtype=self.dtype),) + + def __call__( + self, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + output_attentions=False, + return_dict=True, + deterministic=True, + init_cache=False, + ): + self_attention_outputs = self.layer[0]( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + output_attentions=output_attentions, + deterministic=deterministic, + init_cache=init_cache, + ) + hidden_states = self_attention_outputs[0] + attention_outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights + + do_cross_attention = self.causal and encoder_hidden_states is not None + if do_cross_attention: + cross_attention_outputs = self.layer[1]( + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_bias=encoder_decoder_position_bias, + output_attentions=output_attentions, + deterministic=deterministic, + ) + hidden_states = cross_attention_outputs[0] + + # Keep cross-attention outputs and relative position weights + attention_outputs = attention_outputs + cross_attention_outputs[1:] + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states, deterministic=deterministic) + + outputs = (hidden_states,) + + outputs = outputs + attention_outputs + + # returns hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + return outputs + + +# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5LayerCollection with T5->SwitchTransformers +class FlaxSwitchTransformersLayerCollection(nn.Module): + config: SwitchTransformersConfig + has_relative_attention_bias: bool + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.layer = FlaxSwitchTransformersBlock( + self.config, has_relative_attention_bias=self.has_relative_attention_bias, dtype=self.dtype + ) + + def __call__( + self, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + output_attentions=False, + deterministic=True, + init_cache=False, + ): + return self.layer( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + output_attentions=output_attentions, + deterministic=deterministic, + init_cache=init_cache, + ) + + +# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5BlockCollection with T5->SwitchTransformers +class FlaxSwitchTransformersBlockCollection(nn.Module): + config: SwitchTransformersConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False + + def setup(self): + self.causal = self.config.causal + if self.gradient_checkpointing: + FlaxSwitchTransformersCheckpointLayer = remat( + FlaxSwitchTransformersLayerCollection, static_argnums=(6, 7, 8) + ) + self.blocks = [ + FlaxSwitchTransformersCheckpointLayer( + self.config, + has_relative_attention_bias=(i == 0), + dtype=self.dtype, + name=str(i), + ) + for i in range(self.config.num_layers) + ] + else: + self.blocks = [ + FlaxSwitchTransformersLayerCollection( + self.config, + has_relative_attention_bias=(i == 0), + dtype=self.dtype, + name=str(i), + ) + for i in range(self.config.num_layers) + ] + + def __call__( + self, + hidden_states=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + output_attentions: bool = False, + output_hidden_states: bool = False, + deterministic: bool = True, + init_cache: bool = False, + ): + # Prepare head mask if needed + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.causal) else None + position_bias = None + encoder_decoder_position_bias = None + + for i, layer_module in enumerate(self.blocks): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = layer_module( + hidden_states, + attention_mask, + position_bias, + encoder_hidden_states, + encoder_attention_mask, + encoder_decoder_position_bias, + output_attentions, + deterministic, + init_cache, + ) + + hidden_states = layer_outputs[0] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + position_bias = layer_outputs[1] + + if self.causal and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[3 if output_attentions else 2] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[2],) + if self.causal: + all_cross_attentions = all_cross_attentions + (layer_outputs[4],) + + return FlaxBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + +# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5Stack with T5->SwitchTransformers +class FlaxSwitchTransformersStack(nn.Module): + config: SwitchTransformersConfig + embed_tokens: nn.Embed + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False + + def setup(self): + self.causal = self.config.causal + + self.block = FlaxSwitchTransformersBlockCollection( + self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) + self.final_layer_norm = FlaxSwitchTransformersLayerNorm( + self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype + ) + self.dropout = nn.Dropout(self.config.dropout_rate) + + def __call__( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + init_cache: bool = False, + ): + hidden_states = self.embed_tokens(input_ids) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + + outputs = self.block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + deterministic=deterministic, + init_cache=init_cache, + ) + + hidden_states = outputs[0] + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + + # Add last layer + all_hidden_states = None + + if output_hidden_states: + all_hidden_states = outputs.hidden_states + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + if output_hidden_states: + return ( + hidden_states, + all_hidden_states, + ) + outputs[2:] + return (hidden_states,) + outputs[1:] + + return FlaxBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +SWITCHTRANSFORMERS_ENCODE_INPUTS_DOCSTRING = r""" + Args: + input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. SWITCHTRANSFORMERS is a model with relative position + embeddings so you should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`SwitchTransformersTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for detail. + + To know more on how to prepare `input_ids` for pretraining take a look a [SWITCHTRANSFORMERS + Training](./switchtransformers#training). + attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +SWITCHTRANSFORMERS_DECODE_INPUTS_DOCSTRING = r""" + Args: + decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`SwitchTransformersTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + For training, `decoder_input_ids` should be provided. + encoder_outputs (`tuple(tuple(jnp.ndarray)`): + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the + paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. + past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): + Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast + auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +SWITCHTRANSFORMERS_INPUTS_DOCSTRING = r""" + Args: + input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. SWITCHTRANSFORMERS is a model with relative position + embeddings so you should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`SwitchTransformersTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for detail. + + [What are input IDs?](../glossary#input-ids) + + To know more on how to prepare `input_ids` for pretraining take a look a [SWITCHTRANSFORMERS + Training](./switchtransformers#training). + attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`SwitchTransformersTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + SWITCHTRANSFORMERS uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If + `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + To know more on how to prepare `decoder_input_ids` for pretraining take a look at [SWITCHTRANSFORMERS + Training](./switchtransformers#training). + decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + encoder_outputs (`tuple(tuple(jnp.ndarray)`, *optional*): + Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at + the output of the last layer of the encoder. Used in the cross-attention of the decoder. + past_key_values (`tuple(tuple(jnp.ndarray))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class FlaxSwitchTransformersPreTrainedModel(FlaxPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SwitchTransformersConfig + base_model_prefix = "transformer" + module_class: nn.Module = None + + def __init__( + self, + config: SwitchTransformersConfig, + input_shape: Tuple[int] = (1, 1), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + gradient_checkpointing: bool = False, + **kwargs + ): + module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def enable_gradient_checkpointing(self): + self._module = self.module_class( + config=self.config, + dtype=self.dtype, + gradient_checkpointing=True, + ) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensors + input_ids = jnp.zeros(input_shape, dtype="i4") + + attention_mask = jnp.ones_like(input_ids) + args = [input_ids, attention_mask] + if self.module_class not in [FlaxSwitchTransformersEncoderModule]: + decoder_input_ids = jnp.ones_like(input_ids) + decoder_attention_mask = jnp.ones_like(input_ids) + args.extend([decoder_input_ids, decoder_attention_mask]) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + random_params = self.module.init( + rngs, + *args, + )["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + @add_start_docstrings_to_model_forward(SWITCHTRANSFORMERS_INPUTS_DOCSTRING) + def __call__( + self, + input_ids: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + decoder_input_ids: jnp.ndarray = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + if decoder_input_ids is None: + raise ValueError( + "Make sure to provide both `input_ids` and `decoder_input_ids`. `decoder_input_ids` is not passed" + " here." + ) + + # prepare encoder inputs + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + + # prepare decoder inputs + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones_like(decoder_input_ids) + + # Handle any PRNG if needed + rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} + + return self.module.apply( + {"params": params or self.params}, + input_ids=jnp.array(input_ids, dtype="i4"), + attention_mask=jnp.array(attention_mask, dtype="i4"), + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + ) + + def init_cache(self, batch_size, max_length, encoder_outputs): + r""" + Args: + batch_size (`int`): + batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. + max_length (`int`): + maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized + cache. + encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`): + `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: + `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) + is a sequence of hidden-states at the output of the last layer of the encoder. Used in the + cross-attention of the decoder. + """ + # init input variables to retrieve cache + decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") + decoder_attention_mask = jnp.ones_like(decoder_input_ids) + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs): + decoder_module = module._get_decoder_module() + return decoder_module( + decoder_input_ids, + decoder_attention_mask, + **kwargs, + ) + + init_variables = self.module.init( + jax.random.PRNGKey(0), + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + init_cache=True, + method=_decoder_forward, # we only need to call the decoder to init the cache + ) + return unfreeze(init_variables["cache"]) + + @add_start_docstrings(SWITCHTRANSFORMERS_ENCODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=SwitchTransformersConfig) + def encode( + self, + input_ids: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example: + + ```python + >>> from transformers import SwitchTransformersTokenizer, FlaxSwitchTransformersForConditionalGeneration + + >>> tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switchtransformers-base") + >>> model = FlaxSwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switchtransformers-base") + + >>> text = "My friends are cool but they eat too many carbs." + >>> inputs = tokenizer(text, return_tensors="np") + >>> encoder_outputs = model.encode(**inputs) + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + def _encoder_forward(module, input_ids, attention_mask, **kwargs): + encode_module = module._get_encoder_module() + return encode_module(input_ids, attention_mask, **kwargs) + + return self.module.apply( + {"params": params or self.params}, + input_ids=jnp.array(input_ids, dtype="i4"), + attention_mask=jnp.array(attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + method=_encoder_forward, + ) + + @add_start_docstrings(SWITCHTRANSFORMERS_DECODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=SwitchTransformersConfig) + def decode( + self, + decoder_input_ids, + encoder_outputs, + encoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + past_key_values: dict = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example: + + ```python + >>> from transformers import SwitchTransformersTokenizer, FlaxSwitchTransformersForConditionalGeneration + >>> import jax.numpy as jnp + + >>> tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switchtransformers-base") + >>> model = FlaxSwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switchtransformers-base") + + >>> text = "My friends are cool but they eat too many carbs." + >>> inputs = tokenizer(text, return_tensors="np") + >>> encoder_outputs = model.encode(**inputs) + + >>> decoder_start_token_id = model.config.decoder_start_token_id + >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id + + >>> outputs = model.decode(decoder_input_ids, encoder_outputs) + >>> logits = outputs.logits + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + encoder_hidden_states = encoder_outputs[0] + if encoder_attention_mask is None: + batch_size, sequence_length = encoder_hidden_states.shape[:2] + encoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + batch_size, sequence_length = decoder_input_ids.shape + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be + # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that + # it can be changed by FlaxSwitchTransformersAttention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs): + decoder_module = module._get_decoder_module() + return decoder_module( + decoder_input_ids, + decoder_attention_mask, + **kwargs, + ) + + outputs = self.module.apply( + inputs, + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + mutable=mutable, + method=_decoder_forward, + ) + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs, past = outputs + outputs["past_key_values"] = unfreeze(past["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs, past = outputs + outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] + + return outputs + + +SWITCHTRANSFORMERS_START_DOCSTRING = r""" + The SWITCHTRANSFORMERS model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text + Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan + Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a + text-to-text denoising generative setting. + + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a Flax Linen + [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a + regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. + + Finally, this model supports inherent JAX features such as: + + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + config ([`SwitchTransformersConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. + dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): + The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and + `jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given `dtype`. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and + [`~FlaxPreTrainedModel.to_bf16`]. +""" + + +@add_start_docstrings( + "The bare SWITCHTRANSFORMERS Model transformer outputting raw hidden-stateswithout any specific head on top.", + SWITCHTRANSFORMERS_START_DOCSTRING, +) +# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5Module with T5->SwitchTransformers +class FlaxSwitchTransformersModule(nn.Module): + config: SwitchTransformersConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False + + def _get_encoder_module(self): + return self.encoder + + def _get_decoder_module(self): + return self.decoder + + def setup(self): + self.shared = nn.Embed( + self.config.vocab_size, + self.config.d_model, + embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0), + ) + + encoder_config = copy.deepcopy(self.config) + encoder_config.causal = False + self.encoder = FlaxSwitchTransformersStack( + encoder_config, + embed_tokens=self.shared, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + + decoder_config = copy.deepcopy(self.config) + decoder_config.causal = True + decoder_config.num_layers = self.config.num_decoder_layers + self.decoder = FlaxSwitchTransformersStack( + decoder_config, + embed_tokens=self.shared, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + + def __call__( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + encoder_outputs=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + deterministic: bool = True, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Encode if needed (training, first prediction pass) + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return FlaxSeq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5Model with T5->SwitchTransformers +class FlaxSwitchTransformersModel(FlaxSwitchTransformersPreTrainedModel): + module_class = FlaxSwitchTransformersModule + + +append_call_sample_docstring( + FlaxSwitchTransformersModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC +) + +FLAX_SWITCHTRANSFORMERS_MODEL_DOCSTRING = """ + Returns: + + Example: + + ```python + >>> from transformers import SwitchTransformersTokenizer, FlaxSwitchTransformersModel + + >>> tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switchtransformers-base") + >>> model = FlaxSwitchTransformersModel.from_pretrained("ybelkada/switchtransformers-base") + + >>> input_ids = tokenizer( + ... "Studies have been shown that owning a dog is good for you", return_tensors="np" + ... ).input_ids + >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="np").input_ids + + >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for SwitchTransformersModel. + >>> # This is not needed for torch's SwitchTransformersForConditionalGeneration as it does this internally using labels arg. + >>> decoder_input_ids = model._shift_right(decoder_input_ids) + + >>> # forward pass + >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + >>> last_hidden_states = outputs.last_hidden_state + ``` +""" + + +overwrite_call_docstring(FlaxSwitchTransformersModel, SWITCHTRANSFORMERS_INPUTS_DOCSTRING + FLAX_SWITCHTRANSFORMERS_MODEL_DOCSTRING) +append_replace_return_docstrings(FlaxSwitchTransformersModel, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + + +@add_start_docstrings( + "The bare SWITCHTRANSFORMERS Model transformer outputting encoder's raw hidden-states without any specific head on top.", + SWITCHTRANSFORMERS_START_DOCSTRING, +) +# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5EncoderModule with T5->SwitchTransformers +class FlaxSwitchTransformersEncoderModule(nn.Module): + config: SwitchTransformersConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False + + def setup(self): + self.shared = nn.Embed( + self.config.vocab_size, + self.config.d_model, + embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0), + ) + + encoder_config = copy.deepcopy(self.config) + encoder_config.is_decoder = False + encoder_config.is_encoder_decoder = False + encoder_config.causal = False + self.encoder = FlaxSwitchTransformersStack( + encoder_config, + embed_tokens=self.shared, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + + def __call__( + self, + input_ids=None, + attention_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict: bool = True, + deterministic: bool = True, + ): + + # Encode if needed (training, first prediction pass) + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + return encoder_outputs + + +class FlaxSwitchTransformersEncoderModel(FlaxSwitchTransformersPreTrainedModel): + module_class = FlaxSwitchTransformersEncoderModule + + @add_start_docstrings_to_model_forward(SWITCHTRANSFORMERS_ENCODE_INPUTS_DOCSTRING) + def __call__( + self, + input_ids: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # prepare encoder inputs + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + + # Handle any PRNG if needed + rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} + + return self.module.apply( + {"params": params or self.params}, + input_ids=jnp.array(input_ids, dtype="i4"), + attention_mask=jnp.array(attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + ) + + +@add_start_docstrings("""SWITCHTRANSFORMERS Model with a `language modeling` head on top.""", SWITCHTRANSFORMERS_START_DOCSTRING) +# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5ForConditionalGenerationModule with T5->SwitchTransformers +class FlaxSwitchTransformersForConditionalGenerationModule(nn.Module): + config: SwitchTransformersConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False + + def _get_encoder_module(self): + return self.encoder + + def _get_decoder_module(self): + return self.decoder + + def setup(self): + self.model_dim = self.config.d_model + + self.shared = nn.Embed( + self.config.vocab_size, + self.config.d_model, + embedding_init=jax.nn.initializers.normal(self.config.initializer_factor), + ) + + encoder_config = copy.deepcopy(self.config) + encoder_config.causal = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = FlaxSwitchTransformersStack( + encoder_config, self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) + + decoder_config = copy.deepcopy(self.config) + decoder_config.causal = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = self.config.num_decoder_layers + self.decoder = FlaxSwitchTransformersStack( + decoder_config, self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) + + self.lm_head = nn.Dense( + self.config.vocab_size, + use_bias=False, + kernel_init=jax.nn.initializers.normal(self.config.initializer_factor), + dtype=self.dtype, + ) + + def __call__( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + encoder_outputs=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + deterministic: bool = True, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Encode + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + hidden_states = encoder_outputs[0] + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + sequence_output = decoder_outputs[0] + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.model_dim**-0.5) + + if self.config.tie_word_embeddings: + shared_embedding = self.shared.variables["params"]["embedding"] + lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output) + else: + lm_logits = self.lm_head(sequence_output) + + if not return_dict: + return (lm_logits,) + decoder_outputs[1:] + encoder_outputs + + return FlaxSeq2SeqLMOutput( + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +class FlaxSwitchTransformersForConditionalGeneration(FlaxSwitchTransformersPreTrainedModel): + module_class = FlaxSwitchTransformersForConditionalGenerationModule + + @add_start_docstrings(SWITCHTRANSFORMERS_DECODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=SwitchTransformersConfig) + def decode( + self, + decoder_input_ids, + encoder_outputs, + encoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + past_key_values: dict = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example: + + ```python + >>> from transformers import SwitchTransformersTokenizer, FlaxSwitchTransformersForConditionalGeneration + >>> import jax.numpy as jnp + + >>> tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switchtransformers-base") + >>> model = FlaxSwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switchtransformers-base") + + >>> text = "summarize: My friends are cool but they eat too many carbs." + >>> inputs = tokenizer(text, return_tensors="np") + >>> encoder_outputs = model.encode(**inputs) + + >>> decoder_start_token_id = model.config.decoder_start_token_id + >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id + + >>> outputs = model.decode(decoder_input_ids, encoder_outputs) + >>> logits = outputs.logits + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + encoder_hidden_states = encoder_outputs[0] + if encoder_attention_mask is None: + batch_size, sequence_length = encoder_hidden_states.shape[:2] + encoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + batch_size, sequence_length = decoder_input_ids.shape + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be + # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that + # it can be changed by FlaxSwitchTransformersAttention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs): + decoder_module = module._get_decoder_module() + decoder_outputs = decoder_module( + decoder_input_ids, + decoder_attention_mask, + **kwargs, + ) + + sequence_output = decoder_outputs[0] + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.config.d_model**-0.5) + + if self.config.tie_word_embeddings: + shared_embedding = module.shared.variables["params"]["embedding"] + lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output) + else: + lm_logits = module.lm_head(sequence_output) + + return lm_logits, decoder_outputs + + outputs = self.module.apply( + inputs, + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + mutable=mutable, + method=_decoder_forward, + ) + + if past_key_values is None: + lm_logits, decoder_outputs = outputs + else: + (lm_logits, decoder_outputs), past = outputs + + if return_dict: + outputs = FlaxCausalLMOutputWithCrossAttentions( + logits=lm_logits, + hidden_states=decoder_outputs.hidden_states, + attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + ) + else: + outputs = (lm_logits,) + decoder_outputs[1:] + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs["past_key_values"] = unfreeze(past["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] + + return outputs + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + max_length, + attention_mask: Optional[jnp.DeviceArray] = None, + decoder_attention_mask: Optional[jnp.DeviceArray] = None, + encoder_outputs=None, + **kwargs + ): + # initializing the cache + batch_size, seq_length = decoder_input_ids.shape + + past_key_values = self.init_cache(batch_size, max_length, encoder_outputs) + # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. + # But since the decoder uses a causal mask, those positions are masked anyways. + # Thus we can create a single static attention_mask here, which is more efficient for compilation + extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") + if decoder_attention_mask is not None: + extended_attention_mask = jax.lax.dynamic_update_slice( + extended_attention_mask, decoder_attention_mask, (0, 0) + ) + + return { + "past_key_values": past_key_values, + "encoder_outputs": encoder_outputs, + "encoder_attention_mask": attention_mask, + "decoder_attention_mask": extended_attention_mask, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + return model_kwargs + + +FLAX_SWITCHTRANSFORMERS_CONDITIONAL_GENERATION_DOCSTRING = """ + Returns: + + Example: + + ```python + >>> from transformers import SwitchTransformersTokenizer, FlaxSwitchTransformersForConditionalGeneration + + >>> tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switchtransformers-base") + >>> model = FlaxSwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switchtransformers-base") + + >>> ARTICLE_TO_SUMMARIZE = "summarize: My friends are cool but they eat too many carbs." + >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], return_tensors="np") + + >>> # Generate Summary + >>> summary_ids = model.generate(inputs["input_ids"]).sequences + >>> print(tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)) + ``` +""" + + +overwrite_call_docstring( + FlaxSwitchTransformersForConditionalGeneration, SWITCHTRANSFORMERS_INPUTS_DOCSTRING + FLAX_SWITCHTRANSFORMERS_CONDITIONAL_GENERATION_DOCSTRING +) +append_replace_return_docstrings( + FlaxSwitchTransformersForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC +) diff --git a/src/transformers/models/switchtransformers/modeling_switchtransformers.py b/src/transformers/models/switchtransformers/modeling_switchtransformers.py new file mode 100644 index 0000000000000..f64b1ea60d37d --- /dev/null +++ b/src/transformers/models/switchtransformers/modeling_switchtransformers.py @@ -0,0 +1,1878 @@ +# coding=utf-8 +# Copyright 2022 Mesh TensorFlow authors, SwitchTransformers Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch SwitchTransformers model.""" + + +import copy +import math +import os +import warnings +from typing import Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss +from torch.utils.checkpoint import checkpoint + +from ...activations import ACT2FN +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ( + DUMMY_INPUTS, + DUMMY_MASK, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_torch_fx_proxy, + logging, + replace_return_docstrings, +) +from ...utils.model_parallel_utils import assert_device_map, get_device_map +from .configuration_switchtransformers import SwitchTransformersConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "SwitchTransformersConfig" +_TOKENIZER_FOR_DOC = "SwitchTransformersTokenizer" +_CHECKPOINT_FOR_DOC = "ybelkada/switchtransformers-base" + +#################################################### +# This dict contains ids and associated url +# for the pretrained weights provided with the models +#################################################### +SWITCHTRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "ybelkada/switchtransformers-base", + # See all SwitchTransformers models at https://huggingface.co/models?filter=switchtransformers +] + + + +#################################################### +# This is a conversion method from TF 1.0 to PyTorch +# More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28 +#################################################### +# Copied from transformers.models.t5.modeling_t5.load_tf_weights_in_t5 with t5->switchtransformers +def load_tf_weights_in_switchtransformers(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + tf_weights = {} + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + tf_weights[name] = array + + for txt_name in names: + name = txt_name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + tf_weights.pop(txt_name, None) + continue + if "_slot_" in name[-1]: + logger.info(f"Skipping {'/'.join(name)}") + tf_weights.pop(txt_name, None) + continue + pointer = model + array = tf_weights[txt_name] + + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] in ["kernel", "scale", "embedding"]: + pointer = getattr(pointer, "weight") + elif scope_names[0] == "self_attention": + pointer = getattr(pointer, "layer") + pointer = pointer[0] + elif scope_names[0] == "enc_dec_attention": + pointer = getattr(pointer, "layer") + pointer = pointer[1] + elif scope_names[0] == "dense_relu_dense": + pointer = getattr(pointer, "layer") + pointer = pointer[2] + elif scope_names[0] == "rms_norm": + if hasattr(pointer, "layer_norm"): + pointer = getattr(pointer, "layer_norm") + elif hasattr(pointer, "final_layer_norm"): + pointer = getattr(pointer, "final_layer_norm") + elif scope_names[0] == "scale": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + elif scope_names[0] == "decoder" and name[1] == "logits": + continue + elif scope_names[0] == "logits": + pointer = getattr(pointer, "lm_head") + elif scope_names[0] == "wi" and len(scope_names) > 1 and scope_names[1].isdigit(): + pointer = getattr(pointer, f"wi_{scope_names[1]}") + continue + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if scope_names[0] not in ["kernel", "scale", "embedding"]: + pointer = getattr(pointer, "weight") + if scope_names[0] != "embedding": + logger.info(f"Transposing numpy weight of shape {array.shape} for {name}") + array = np.transpose(array) + try: + assert ( + pointer.shape == array.shape + ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array.astype(np.float32)) + tf_weights.pop(txt_name, None) + + logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.") + return model + + +#################################################### +# PyTorch Models are constructed by sub-classing +# - torch.nn.Module for the layers and +# - PreTrainedModel for the models (it-self a sub-class of nn.Module) +#################################################### +PARALLELIZE_DOCSTRING = r""" + This is an experimental feature and is a subject to change at a moment's notice. + + Uses a device map to distribute attention modules of the model across several devices. If no device map is given, + it will evenly distribute blocks across all devices. + + Args: + device_map (`Dict[int, list]`, optional, defaults to None): + A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always + automatically mapped to the first device (for esoteric reasons). That means that the first device should + have fewer attention modules mapped to it than other devices. For reference, the switchtransformers models + have the following number of attention modules: + + - ybelkada/switchtransformers-base: 6 + - switchtransformers-base: 12 + - switchtransformers-large: 24 + - switchtransformers-3b: 24 + - switchtransformers-11b: 24 + + Example: + + ```python + # Here is an example of a device map on a machine with 4 GPUs using switchtransformers-3b, which has a total of 24 attention modules: + model = SwitchTransformersForConditionalGeneration.from_pretrained("switchtransformers-3b") + device_map = { + 0: [0, 1, 2], + 1: [3, 4, 5, 6, 7, 8, 9], + 2: [10, 11, 12, 13, 14, 15, 16], + 3: [17, 18, 19, 20, 21, 22, 23], + } + model.parallelize(device_map) + ``` +""" +DEPARALLELIZE_DOCSTRING = r""" + Moves the model to cpu from a model parallel state. + + Example: + + ```python + # On a 4 GPU machine with switchtransformers-3b: + model = SwitchTransformersForConditionalGeneration.from_pretrained("switchtransformers-3b") + device_map = { + 0: [0, 1, 2], + 1: [3, 4, 5, 6, 7, 8, 9], + 2: [10, 11, 12, 13, 14, 15, 16], + 3: [17, 18, 19, 20, 21, 22, 23], + } + model.parallelize(device_map) # Splits the model across several devices + model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() + ``` +""" + + +# Copied from transformers.models.t5.modeling_t5.T5LayerNorm with T5->SwitchTransformers +class SwitchTransformersLayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Construct a layernorm module in the SwitchTransformers style. No bias and no subtraction of mean. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + + # SwitchTransformers uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +try: + from apex.normalization import FusedRMSNorm + + SwitchTransformersLayerNorm = FusedRMSNorm # noqa + + logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of SwitchTransformersLayerNorm") +except ImportError: + # using the normal SwitchTransformersLayerNorm + pass +except Exception: + logger.warning("discovered apex but it failed to load, falling back to SwitchTransformersLayerNorm") + pass + +ALL_LAYERNORM_LAYERS.append(SwitchTransformersLayerNorm) + + +# Copied from transformers.models.t5.modeling_t5.T5DenseActDense with T5->SwitchTransformers +class SwitchTransformersDenseActDense(nn.Module): + def __init__(self, config: SwitchTransformersConfig): + super().__init__() + self.wi = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + self.act = ACT2FN[config.dense_act_fn] + + def forward(self, hidden_states): + hidden_states = self.wi(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.wo(hidden_states) + return hidden_states + + +# Copied from transformers.models.t5.modeling_t5.T5DenseGatedActDense with T5->SwitchTransformers +class SwitchTransformersDenseGatedActDense(nn.Module): + def __init__(self, config: SwitchTransformersConfig): + super().__init__() + self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + self.act = ACT2FN[config.dense_act_fn] + + def forward(self, hidden_states): + hidden_gelu = self.act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states) + hidden_states = self.wo(hidden_states) + return hidden_states + + +# TODO: Change it here to adapt it from the paper, the FF layer contains experts +# an expert is a FF layer with multiple sub-FF layers inside. +# This class should also contain a router class +# check flaxformer/architecture/moe/router.py : https://github.com/google/flaxformer/blob/main/flaxformer/architectures/moe/routing.py +class SwitchTransformersLayerFF(nn.Module): + def __init__(self, config: SwitchTransformersConfig): + super().__init__() + # TODO: check the comments above + self.layer_norm = SwitchTransformersLayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states): + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + self.dropout(forwarded_states) + return hidden_states + + +# Copied from transformers.models.t5.modeling_t5.T5Attention with T5->SwitchTransformers +class SwitchTransformersAttention(nn.Module): + def __init__(self, config: SwitchTransformersConfig, has_relative_attention_bias=False): + super().__init__() + self.is_decoder = config.is_decoder + self.has_relative_attention_bias = has_relative_attention_bias + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.relative_attention_max_distance = config.relative_attention_max_distance + self.d_model = config.d_model + self.key_value_proj_dim = config.d_kv + self.n_heads = config.num_heads + self.dropout = config.dropout_rate + self.inner_dim = self.n_heads * self.key_value_proj_dim + + # Mesh TensorFlow initialization to avoid scaling before softmax + self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) + + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) + self.pruned_heads = set() + self.gradient_checkpointing = False + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads + ) + # Prune linear layers + self.q = prune_linear_layer(self.q, index) + self.k = prune_linear_layer(self.k, index) + self.v = prune_linear_layer(self.v, index) + self.o = prune_linear_layer(self.o, index, dim=1) + # Update hyper params + self.n_heads = self.n_heads - len(heads) + self.inner_dim = self.key_value_proj_dim * self.n_heads + self.pruned_heads = self.pruned_heads.union(heads) + + @staticmethod + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length, device=None): + """Compute binned relative position bias""" + if device is None: + device = self.relative_attention_bias.weight.device + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) + return values + + def forward( + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + past_key_value=None, + layer_head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) + # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + batch_size, seq_length = hidden_states.shape[:2] + + real_seq_length = seq_length + + if past_key_value is not None: + assert ( + len(past_key_value) == 2 + ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" + real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length + + key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] + + def shape(states): + """projection""" + return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + def unshape(states): + """reshape""" + return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + + def project(hidden_states, proj_layer, key_value_states, past_key_value): + """projects hidden states correctly to key/query states""" + if key_value_states is None: + # self-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(hidden_states)) + elif past_key_value is None: + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(key_value_states)) + + if past_key_value is not None: + if key_value_states is None: + # self-attn + # (batch_size, n_heads, key_length, dim_per_head) + hidden_states = torch.cat([past_key_value, hidden_states], dim=2) + else: + # cross-attn + hidden_states = past_key_value + return hidden_states + + # get query states + query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) + + # get key/value states + key_states = project( + hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None + ) + value_states = project( + hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None + ) + + # compute scores + scores = torch.matmul( + query_states, key_states.transpose(3, 2) + ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + + if position_bias is None: + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype + ) + if self.gradient_checkpointing and self.training: + position_bias.requires_grad = True + else: + position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) + + # if key and values are already calculated + # we want only the last query position bias + if past_key_value is not None: + position_bias = position_bias[:, :, -hidden_states.size(1) :, :] + + if mask is not None: + position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + + if self.pruned_heads: + mask = torch.ones(position_bias.shape[1]) + mask[list(self.pruned_heads)] = 0 + position_bias_masked = position_bias[:, mask.bool()] + else: + position_bias_masked = position_bias + + scores += position_bias_masked + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( + scores + ) # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) # (batch_size, n_heads, seq_length, key_length) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask + + attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) + attn_output = self.o(attn_output) + + present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None + outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + + if output_attentions: + outputs = outputs + (attn_weights,) + return outputs + + +# Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->SwitchTransformers +class SwitchTransformersLayerSelfAttention(nn.Module): + def __init__(self, config, has_relative_attention_bias=False): + super().__init__() + self.SelfAttention = SwitchTransformersAttention( + config, has_relative_attention_bias=has_relative_attention_bias + ) + self.layer_norm = SwitchTransformersLayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + normed_hidden_states, + mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = hidden_states + self.dropout(attention_output[0]) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->SwitchTransformers +class SwitchTransformersLayerCrossAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.EncDecAttention = SwitchTransformersAttention(config, has_relative_attention_bias=False) + self.layer_norm = SwitchTransformersLayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + key_value_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + query_length=None, + output_attentions=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + normed_hidden_states, + mask=attention_mask, + key_value_states=key_value_states, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + query_length=query_length, + output_attentions=output_attentions, + ) + layer_output = hidden_states + self.dropout(attention_output[0]) + outputs = (layer_output,) + attention_output[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.t5.modeling_t5.T5Block with T5->SwitchTransformers +class SwitchTransformersBlock(nn.Module): + def __init__(self, config, has_relative_attention_bias=False): + super().__init__() + self.is_decoder = config.is_decoder + self.layer = nn.ModuleList() + self.layer.append( + SwitchTransformersLayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias) + ) + if self.is_decoder: + self.layer.append(SwitchTransformersLayerCrossAttention(config)) + + self.layer.append(SwitchTransformersLayerFF(config)) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + layer_head_mask=None, + cross_attn_layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + return_dict=True, + ): + + if past_key_value is not None: + if not self.is_decoder: + logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.") + expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 + + if len(past_key_value) != expected_num_past_key_values: + raise ValueError( + f"There should be {expected_num_past_key_values} past states. " + f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}" + f"Got {len(past_key_value)} past key / value states" + ) + + self_attn_past_key_value = past_key_value[:2] + cross_attn_past_key_value = past_key_value[2:] + else: + self_attn_past_key_value, cross_attn_past_key_value = None, None + + self_attention_outputs = self.layer[0]( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=self_attn_past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states, present_key_value_state = self_attention_outputs[:2] + attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + do_cross_attention = self.is_decoder and encoder_hidden_states is not None + if do_cross_attention: + # the actual query length is unknown for cross attention + # if using past key value states. Need to inject it here + if present_key_value_state is not None: + query_length = present_key_value_state[0].shape[2] + else: + query_length = None + + cross_attention_outputs = self.layer[1]( + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_bias=encoder_decoder_position_bias, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + query_length=query_length, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = cross_attention_outputs[0] + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + # Combine self attn and cross attn key value states + if present_key_value_state is not None: + present_key_value_state = present_key_value_state + cross_attention_outputs[1] + + # Keep cross-attention outputs and relative position weights + attention_outputs = attention_outputs + cross_attention_outputs[2:] + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states) + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if use_cache: + outputs = outputs + (present_key_value_state,) + attention_outputs + else: + outputs = outputs + attention_outputs + + return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + + +# Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel with T5->SwitchTransformers,t5->switchtransformers +class SwitchTransformersPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SwitchTransformersConfig + load_tf_weights = load_tf_weights_in_switchtransformers + base_model_prefix = "transformer" + is_parallelizable = True + supports_gradient_checkpointing = True + _no_split_modules = ["SwitchTransformersBlock"] + + @property + def dummy_inputs(self): + input_ids = torch.tensor(DUMMY_INPUTS) + input_mask = torch.tensor(DUMMY_MASK) + dummy_inputs = { + "decoder_input_ids": input_ids, + "input_ids": input_ids, + "decoder_attention_mask": input_mask, + } + return dummy_inputs + + def _init_weights(self, module): + """Initialize the weights""" + factor = self.config.initializer_factor # Used for testing weights initialization + if isinstance(module, SwitchTransformersLayerNorm): + module.weight.data.fill_(factor * 1.0) + elif isinstance( + module, + (SwitchTransformersModel, SwitchTransformersForConditionalGeneration, SwitchTransformersEncoderModel), + ): + # Mesh TensorFlow embeddings initialization + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 + module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) + if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: + module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0) + elif isinstance(module, SwitchTransformersDenseActDense): + # Mesh TensorFlow FF initialization + # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 + # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 + module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi, "bias") and module.wi.bias is not None: + module.wi.bias.data.zero_() + module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, SwitchTransformersDenseGatedActDense): + module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None: + module.wi_0.bias.data.zero_() + module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None: + module.wi_1.bias.data.zero_() + module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, SwitchTransformersAttention): + # Mesh TensorFlow attention initialization to avoid scaling before softmax + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 + d_model = self.config.d_model + key_value_proj_dim = self.config.d_kv + n_heads = self.config.num_heads + module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) + module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) + if module.has_relative_attention_bias: + module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (SwitchTransformersAttention, SwitchTransformersStack)): + module.gradient_checkpointing = value + + def _shift_right(self, input_ids): + decoder_start_token_id = self.config.decoder_start_token_id + pad_token_id = self.config.pad_token_id + + assert decoder_start_token_id is not None, ( + "self.model.config.decoder_start_token_id has to be defined. In SwitchTransformers it is usually set to" + " the pad_token_id. See SwitchTransformers docs for more information" + ) + + # shift inputs to the right + if is_torch_fx_proxy(input_ids): + # Item assignment is not supported natively for proxies. + shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id) + shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1) + else: + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id + + assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +# Copied from transformers.models.t5.modeling_t5.T5Stack with T5->SwitchTransformers +class SwitchTransformersStack(nn.Module): + def __init__(self, config, embed_tokens=None): + super().__init__(config) + + self.embed_tokens = embed_tokens + self.is_decoder = config.is_decoder + + self.block = nn.ModuleList( + [ + SwitchTransformersBlock(config, has_relative_attention_bias=bool(i == 0)) + for i in range(config.num_layers) + ] + ) + self.final_layer_norm = SwitchTransformersLayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + # Initialize weights and apply final processing + self.post_init() + # Model parallel + self.model_parallel = False + self.device_map = None + self.gradient_checkpointing = False + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + # Check validity of device_map + self.device_map = ( + get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map + ) + assert_device_map(self.device_map, len(self.block)) + self.model_parallel = True + self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) + self.last_device = "cuda:" + str(max(self.device_map.keys())) + # Load onto devices + for k, v in self.device_map.items(): + for layer in v: + cuda_device = "cuda:" + str(k) + self.block[layer] = self.block[layer].to(cuda_device) + + # Set embed_tokens to first layer + self.embed_tokens = self.embed_tokens.to(self.first_device) + # Set final layer norm to last device + self.final_layer_norm = self.final_layer_norm.to(self.last_device) + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def deparallelize(self): + self.model_parallel = False + self.device_map = None + self.first_device = "cpu" + self.last_device = "cpu" + for i in range(len(self.block)): + self.block[i] = self.block[i].to("cpu") + self.embed_tokens = self.embed_tokens.to("cpu") + self.final_layer_norm = self.final_layer_norm.to("cpu") + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, new_embeddings): + self.embed_tokens = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + # Model parallel + if self.model_parallel: + torch.cuda.set_device(self.first_device) + self.embed_tokens = self.embed_tokens.to(self.first_device) + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError( + f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") + + if inputs_embeds is None: + assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings" + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = input_shape + + # required mask seq length can be calculated via length of past + mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length + + if use_cache is True: + assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder" + + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: + encoder_seq_length = encoder_hidden_states.shape[1] + encoder_attention_mask = torch.ones( + batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long + ) + + # initialize past_key_values with `None` if past does not exist + if past_key_values is None: + past_key_values = [None] * len(self.block) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) + present_key_value_states = () if use_cache else None + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.is_decoder) else None + position_bias = None + encoder_decoder_position_bias = None + + hidden_states = self.dropout(inputs_embeds) + + for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + layer_head_mask = head_mask[i] + cross_attn_layer_head_mask = cross_attn_head_mask[i] + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if position_bias is not None: + position_bias = position_bias.to(hidden_states.device) + if encoder_hidden_states is not None: + encoder_hidden_states = encoder_hidden_states.to(hidden_states.device) + if encoder_extended_attention_mask is not None: + encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) + if encoder_decoder_position_bias is not None: + encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) + if layer_head_mask is not None: + layer_head_mask = layer_head_mask.to(hidden_states.device) + if cross_attn_layer_head_mask is not None: + cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return tuple(module(*inputs, use_cache, output_attentions)) + + return custom_forward + + layer_outputs = checkpoint( + create_custom_forward(layer_module), + hidden_states, + extended_attention_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, + layer_head_mask, + cross_attn_layer_head_mask, + None, # past_key_value is always None with gradient checkpointing + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask=extended_attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + if use_cache is False: + layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] + + hidden_states, present_key_value_state = layer_outputs[:2] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + position_bias = layer_outputs[2] + if self.is_decoder and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] + # append next layer key value states + if use_cache: + present_key_value_states = present_key_value_states + (present_key_value_state,) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[3],) + if self.is_decoder: + all_cross_attentions = all_cross_attentions + (layer_outputs[5],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + present_key_value_states, + all_hidden_states, + all_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=present_key_value_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + +SWITCHTRANSFORMERS_START_DOCSTRING = r""" + + The SWITCHTRANSFORMERS model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text + Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan + Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a + text-to-text denoising generative setting. + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`SwitchTransformersConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +SWITCHTRANSFORMERS_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. SWITCHTRANSFORMERS is a model with relative position + embeddings so you should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`SwitchTransformersTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for detail. + + [What are input IDs?](../glossary#input-ids) + + To know more on how to prepare `input_ids` for pretraining take a look a [SWITCHTRANSFORMERS + Training](./switchtransformers#training). + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`SwitchTransformersTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + SWITCHTRANSFORMERS uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If + `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + To know more on how to prepare `decoder_input_ids` for pretraining take a look at [SWITCHTRANSFORMERS + Training](./switchtransformers#training). + decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in + `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at + the output of the last layer of the encoder. Used in the cross-attention of the decoder. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded + representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be + input (see `past_key_values`). This is useful if you want more control over how to convert + `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + + If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value + of `inputs_embeds`. + + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +SWITCHTRANSFORMERS_ENCODER_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. SWITCHTRANSFORMERS is a model with relative position + embeddings so you should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`SwitchTransformersTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for detail. + + To know more on how to prepare `input_ids` for pretraining take a look a [SWITCHTRANSFORMERS + Training](./switchtransformers#training). + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask +__HEAD_MASK_WARNING_MSG = """ +The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently, +`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions. +If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers, +num_heads)`. +""" + + +@add_start_docstrings( + "The bare SWITCHTRANSFORMERS Model transformer outputting raw hidden-states without any specific head on top.", + SWITCHTRANSFORMERS_START_DOCSTRING, +) +class SwitchTransformersModel(SwitchTransformersPreTrainedModel): + _keys_to_ignore_on_load_missing = [ + r"encoder.embed_tokens.weight", + r"decoder.embed_tokens.weight", + ] + _keys_to_ignore_on_load_unexpected = [ + r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", + ] + + def __init__(self, config: SwitchTransformersConfig): + super().__init__(config) + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = SwitchTransformersStack(encoder_config, self.shared) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = SwitchTransformersStack(decoder_config, self.shared) + + # Initialize weights and apply final processing + self.post_init() + + # Model parallel + self.model_parallel = False + self.device_map = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.decoder.parallelize(self.device_map) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + self.encoder.deparallelize() + self.decoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.decoder = self.decoder.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(SWITCHTRANSFORMERS_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + decoder_inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: + r""" + Returns: + + Example: + + ```python + >>> from transformers import SwitchTransformersTokenizer, SwitchTransformersModel + + >>> tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switchtransformers-base") + >>> model = SwitchTransformersModel.from_pretrained("ybelkada/switchtransformers-base") + + >>> input_ids = tokenizer( + ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" + ... ).input_ids # Batch size 1 + >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1 + + >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for SwitchTransformersModel. + >>> # This is not needed for torch's SwitchTransformersForConditionalGeneration as it does this internally using labels arg. + >>> decoder_input_ids = model._shift_right(decoder_input_ids) + + >>> # forward pass + >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + >>> last_hidden_states = outputs.last_hidden_state + ```""" + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + hidden_states = hidden_states.to(self.decoder.first_device) + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) + if attention_mask is not None: + attention_mask = attention_mask.to(self.decoder.first_device) + if decoder_attention_mask is not None: + decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings("""SWITCHTRANSFORMERS Model with a `language modeling` head on top.""", SWITCHTRANSFORMERS_START_DOCSTRING) +class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedModel): + _keys_to_ignore_on_load_missing = [ + r"encoder.embed_tokens.weight", + r"decoder.embed_tokens.weight", + r"lm_head.weight", + ] + _keys_to_ignore_on_load_unexpected = [ + r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", + ] + + def __init__(self, config: SwitchTransformersConfig): + super().__init__(config) + self.model_dim = config.d_model + + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = SwitchTransformersStack(encoder_config, self.shared) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = SwitchTransformersStack(decoder_config, self.shared) + + self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + # Model parallel + self.model_parallel = False + self.device_map = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.decoder.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.decoder.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + self.encoder.deparallelize() + self.decoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.decoder = self.decoder.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_output_embeddings(self): + return self.lm_head + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(SWITCHTRANSFORMERS_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ..., + config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for + labels in `[0, ..., config.vocab_size]` + + Returns: + + Examples: + + ```python + >>> from transformers import SwitchTransformersTokenizer, SwitchTransformersForConditionalGeneration + + >>> tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switchtransformers-base") + >>> model = SwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switchtransformers-base") + + >>> # training + >>> input_ids = tokenizer("The walks in park", return_tensors="pt").input_ids + >>> labels = tokenizer(" cute dog the ", return_tensors="pt").input_ids + >>> outputs = model(input_ids=input_ids, labels=labels) + >>> loss = outputs.loss + >>> logits = outputs.logits + + >>> # inference + >>> input_ids = tokenizer( + ... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt" + ... ).input_ids # Batch size 1 + >>> outputs = model.generate(input_ids) + >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) + >>> # studies have shown that owning a dog is good for you. + ```""" + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + # Convert encoder inputs in embeddings if needed + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + + if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + hidden_states = hidden_states.to(self.decoder.first_device) + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) + if attention_mask is not None: + attention_mask = attention_mask.to(self.decoder.first_device) + if decoder_attention_mask is not None: + decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = decoder_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.encoder.first_device) + self.lm_head = self.lm_head.to(self.encoder.first_device) + sequence_output = sequence_output.to(self.lm_head.weight.device) + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.model_dim**-0.5) + + lm_logits = self.lm_head(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss(ignore_index=-100) + loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) + # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 + + if not return_dict: + output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs + return ((loss,) + output) if loss is not None else output + + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs + ): + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + "decoder_input_ids": input_ids, + "past_key_values": past, + "encoder_outputs": encoder_outputs, + "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, + } + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return self._shift_right(labels) + + def _reorder_cache(self, past, beam_idx): + # if decoder past is not included in output + # speedy decoding is disabled and no need to reorder + if past is None: + logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") + return past + + reordered_decoder_past = () + for layer_past_states in past: + # get the correct batch idx from layer past batch dim + # batch dim of `past` is at 2nd position + reordered_layer_past_states = () + for layer_past_state in layer_past_states: + # need to set correct `past` for each of the four key / value states + reordered_layer_past_states = reordered_layer_past_states + ( + layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)), + ) + + assert reordered_layer_past_states[0].shape == layer_past_states[0].shape + assert len(reordered_layer_past_states) == len(layer_past_states) + + reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) + return reordered_decoder_past + + +@add_start_docstrings( + "The bare SWITCHTRANSFORMERS Model transformer outputting encoder's raw hidden-states without any specific head on top.", + SWITCHTRANSFORMERS_START_DOCSTRING, +) +class SwitchTransformersEncoderModel(SwitchTransformersPreTrainedModel): + authorized_missing_keys = [ + r"encoder.embed_tokens.weight", + ] + + def __init__(self, config: SwitchTransformersConfig): + super().__init__(config) + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = SwitchTransformersStack(encoder_config, self.shared) + + # Initialize weights and apply final processing + self.post_init() + + # Model parallel + self.model_parallel = False + self.device_map = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + self.encoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + + def get_encoder(self): + return self.encoder + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(SWITCHTRANSFORMERS_ENCODER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]: + r""" + Returns: + + Example: + + ```python + >>> from transformers import SwitchTransformersTokenizer, SwitchTransformersEncoderModel + + >>> tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switchtransformers-base") + >>> model = SwitchTransformersEncoderModel.from_pretrained("ybelkada/switchtransformers-base") + >>> input_ids = tokenizer( + ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" + ... ).input_ids # Batch size 1 + >>> outputs = model(input_ids=input_ids) + >>> last_hidden_states = outputs.last_hidden_state + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + return encoder_outputs diff --git a/src/transformers/models/switchtransformers/modeling_tf_switchtransformers.py b/src/transformers/models/switchtransformers/modeling_tf_switchtransformers.py new file mode 100644 index 0000000000000..4e31c6319726d --- /dev/null +++ b/src/transformers/models/switchtransformers/modeling_tf_switchtransformers.py @@ -0,0 +1,1671 @@ +# coding=utf-8 +# Copyright 2022 SwitchTransformers Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" TF 2.0 SwitchTransformers model.""" + +import copy +import itertools +import math +import warnings +from typing import Optional, Tuple, Union + +import numpy as np +import tensorflow as tf +from tensorflow.compiler.tf2xla.python.xla import dynamic_slice + +from ...activations_tf import get_tf_activation +from ...modeling_tf_outputs import ( + TFBaseModelOutput, + TFBaseModelOutputWithPastAndCrossAttentions, + TFSeq2SeqLMOutput, + TFSeq2SeqModelOutput, +) +from ...modeling_tf_utils import ( + TFCausalLanguageModelingLoss, + TFModelInputType, + TFPreTrainedModel, + TFSharedEmbeddings, + TFWrappedEmbeddings, + keras_serializable, + unpack_inputs, +) +from ...tf_utils import shape_list, stable_softmax +from ...utils import ( + DUMMY_INPUTS, + DUMMY_MASK, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_switchtransformers import SwitchTransformersConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "SwitchTransformersConfig" +_TOKENIZER_FOR_DOC = "SwitchTransformersTokenizer" + +TF_SWITCHTRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "ybelkada/switchtransformers-base", + # See all SwitchTransformers models at https://huggingface.co/models?filter=switchtransformers +] + + +#################################################### +# TF 2.0 Models are constructed using Keras imperative API by sub-classing +# - tf.keras.layers.Layer for the layers and +# - TFPreTrainedModel for the models (it-self a sub-class of tf.keras.Model) +#################################################### + + +# Copied from transformers.models.t5.modeling_tf_t5.TFT5LayerNorm with T5->SwitchTransformers +class TFSwitchTransformersLayerNorm(tf.keras.layers.Layer): + def __init__(self, epsilon=1e-6, **kwargs): + """ + Construct a layernorm module in the SwitchTransformers style No bias and no subtraction of mean. + """ + super().__init__(**kwargs) + self.variance_epsilon = epsilon + + def build(self, input_shape): + """Build shared word embedding layer""" + self.weight = self.add_weight("weight", shape=(input_shape[-1],), initializer="ones") + super().build(input_shape) + + def call(self, hidden_states): + variance = tf.math.reduce_mean(tf.math.square(hidden_states), axis=-1, keepdims=True) + hidden_states = hidden_states * tf.math.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states + + +# Copied from transformers.models.t5.modeling_tf_t5.TFT5DenseActDense with T5->SwitchTransformers +class TFSwitchTransformersDenseActDense(tf.keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + wi_initializer = tf.keras.initializers.RandomNormal( + mean=0, stddev=config.initializer_factor * (config.d_model**-0.5) + ) + wo_initializer = tf.keras.initializers.RandomNormal( + mean=0, stddev=config.initializer_factor * (config.d_ff**-0.5) + ) + self.wi = tf.keras.layers.Dense( + config.d_ff, use_bias=False, name="wi", kernel_initializer=wi_initializer + ) # Update init weights as in flax + self.wo = tf.keras.layers.Dense( + config.d_model, use_bias=False, name="wo", kernel_initializer=wo_initializer + ) # Update init weights as in flax + self.dropout = tf.keras.layers.Dropout(config.dropout_rate) + self.act = get_tf_activation(config.dense_act_fn) + + def call(self, hidden_states, training=False): + hidden_states = self.wi(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = self.wo(hidden_states) + return hidden_states + + +# Copied from transformers.models.t5.modeling_tf_t5.TFT5DenseGatedActDense with T5->SwitchTransformers +class TFSwitchTransformersDenseGatedActDense(tf.keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + wi_initializer = tf.keras.initializers.RandomNormal( + mean=0, stddev=config.initializer_factor * (config.d_model**-0.5) + ) + wo_initializer = tf.keras.initializers.RandomNormal( + mean=0, stddev=config.initializer_factor * (config.d_ff**-0.5) + ) + self.wi_0 = tf.keras.layers.Dense( + config.d_ff, use_bias=False, name="wi_0", kernel_initializer=wi_initializer + ) # Update init weights as in flax + self.wi_1 = tf.keras.layers.Dense( + config.d_ff, use_bias=False, name="wi_1", kernel_initializer=wi_initializer + ) # Update init weights as in flax + self.wo = tf.keras.layers.Dense( + config.d_model, use_bias=False, name="wo", kernel_initializer=wo_initializer + ) # Update init weights as in flax + self.dropout = tf.keras.layers.Dropout(config.dropout_rate) + self.act = get_tf_activation(config.dense_act_fn) + + def call(self, hidden_states, training=False): + hidden_gelu = self.act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = self.wo(hidden_states) + return hidden_states + + +# Copied from transformers.models.t5.modeling_tf_t5.TFT5LayerFF with T5->SwitchTransformers +class TFSwitchTransformersLayerFF(tf.keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + if config.is_gated_act: + self.DenseReluDense = TFSwitchTransformersDenseGatedActDense(config, name="DenseReluDense") + else: + self.DenseReluDense = TFSwitchTransformersDenseActDense(config, name="DenseReluDense") + + self.layer_norm = TFSwitchTransformersLayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm") + self.dropout = tf.keras.layers.Dropout(config.dropout_rate) + + def call(self, hidden_states, training=False): + normed_hidden_states = self.layer_norm(hidden_states) + dense_output = self.DenseReluDense(normed_hidden_states, training=training) + hidden_states = hidden_states + self.dropout(dense_output, training=training) + return hidden_states + + +# Copied from transformers.models.t5.modeling_tf_t5.TFT5Attention with T5->SwitchTransformers +class TFSwitchTransformersAttention(tf.keras.layers.Layer): + NEW_ID = itertools.count() + + def __init__(self, config, has_relative_attention_bias=False, **kwargs): + super().__init__(**kwargs) + self.layer_id = next(TFSwitchTransformersAttention.NEW_ID) + self.is_decoder = config.is_decoder + self.use_cache = config.use_cache + self.has_relative_attention_bias = has_relative_attention_bias + self.output_attentions = config.output_attentions + + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.relative_attention_max_distance = config.relative_attention_max_distance + self.d_model = config.d_model + self.key_value_proj_dim = config.d_kv + self.n_heads = config.num_heads + self.inner_dim = self.n_heads * self.key_value_proj_dim + + # Mesh TensorFlow initialization to avoid scaling before softmax + q_initializer = tf.keras.initializers.RandomNormal( + mean=0, stddev=config.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5) + ) + k_initializer = tf.keras.initializers.RandomNormal( + mean=0, stddev=config.initializer_factor * (self.inner_dim**-0.5) + ) + v_initializer = tf.keras.initializers.RandomNormal( + mean=0, stddev=config.initializer_factor * (self.inner_dim**-0.5) + ) + o_initializer = tf.keras.initializers.RandomNormal( + mean=0, stddev=config.initializer_factor * (self.inner_dim**-0.5) + ) + self.relative_attention_bias_initializer = tf.keras.initializers.RandomNormal( + mean=0, stddev=config.initializer_factor * (self.inner_dim**-0.5) + ) + + self.q = tf.keras.layers.Dense( + self.inner_dim, use_bias=False, name="q", kernel_initializer=q_initializer + ) # Update init weights as in flax + self.k = tf.keras.layers.Dense( + self.inner_dim, use_bias=False, name="k", kernel_initializer=k_initializer + ) # Update init weights as in flax + self.v = tf.keras.layers.Dense( + self.inner_dim, use_bias=False, name="v", kernel_initializer=v_initializer + ) # Update init weights as in flax + self.o = tf.keras.layers.Dense( + self.d_model, use_bias=False, name="o", kernel_initializer=o_initializer + ) # Update init weights as in flax + self.dropout = tf.keras.layers.Dropout(config.dropout_rate) + + self.pruned_heads = set() + + def build(self, input_shape): + if self.has_relative_attention_bias: + with tf.name_scope("relative_attention_bias"): + self.relative_attention_bias = self.add_weight( + name="embeddings", + shape=[self.relative_attention_num_buckets, self.n_heads], + initializer=self.relative_attention_bias_initializer, # Add initializer + ) + + return super().build(input_shape) + + def prune_heads(self, heads): + raise NotImplementedError + + @staticmethod + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + # n = -relative_position + if bidirectional: + num_buckets //= 2 + relative_buckets += ( + tf.cast(tf.math.greater(relative_position, 0), dtype=relative_position.dtype) * num_buckets + ) + relative_position = tf.math.abs(relative_position) + else: + relative_position = -tf.math.minimum(relative_position, 0) + # now n is in the range [0, inf) + max_exact = num_buckets // 2 + is_small = tf.math.less(relative_position, max_exact) + relative_position_if_large = max_exact + tf.cast( + tf.math.log(tf.cast(relative_position, tf.float32) / tf.cast(max_exact, tf.float32)) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact), + dtype=relative_position.dtype, + ) + relative_position_if_large = tf.math.minimum(relative_position_if_large, num_buckets - 1) + relative_buckets += tf.where(is_small, relative_position, relative_position_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length): + """Compute binned relative position bias""" + context_position = tf.range(query_length)[:, None] + memory_position = tf.range(key_length)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = tf.gather( + self.relative_attention_bias, relative_position_bucket + ) # shape (query_length, key_length, num_heads) + values = tf.expand_dims( + tf.transpose(values, [2, 0, 1]), axis=0 + ) # shape (1, num_heads, query_length, key_length) + return values + + def call( + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + past_key_value=None, + layer_head_mask=None, + query_length=None, + use_cache=False, + training=False, + output_attentions=False, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + # Input is (batch_size, query_length, dim) + # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) + # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + batch_size, seq_length = shape_list(hidden_states)[:2] + + real_seq_length = seq_length + + if past_key_value is not None: + assert ( + len(past_key_value) == 2 + ), f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states" + real_seq_length += shape_list(past_key_value[0])[2] if query_length is None else query_length + + key_length = real_seq_length if key_value_states is None else shape_list(key_value_states)[1] + + def shape(hidden_states): + """projection""" + return tf.transpose( + tf.reshape(hidden_states, (batch_size, -1, self.n_heads, self.key_value_proj_dim)), perm=(0, 2, 1, 3) + ) + + def unshape(hidden_states): + """compute context""" + return tf.reshape(tf.transpose(hidden_states, perm=(0, 2, 1, 3)), (batch_size, -1, self.inner_dim)) + + def project(hidden_states, proj_layer, key_value_states, past_key_value): + """projects hidden states correctly to key/query states""" + if key_value_states is None: + # self-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(hidden_states)) + elif past_key_value is None: + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(key_value_states)) + + if past_key_value is not None: + if key_value_states is None: + # self-attn + # (batch_size, n_heads, key_length, dim_per_head) + hidden_states = tf.concat([past_key_value, hidden_states], axis=2) + else: + # cross-attn + hidden_states = past_key_value + return hidden_states + + # get query + query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, query_length, dim_per_head) + + # get key/value + key_states = project( + hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None + ) + value_states = project( + hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None + ) + + # to cope with keras serialization + if self.is_decoder and use_cache: + present_key_value_state = (key_states, value_states) + else: + present_key_value_state = None + + scores = tf.einsum( + "bnqd,bnkd->bnqk", query_states, key_states + ) # (batch_size, n_heads, query_length, key_length) + + if position_bias is None: + if not self.has_relative_attention_bias: + position_bias = tf.zeros((1, self.n_heads, real_seq_length, key_length)) + else: + position_bias = self.compute_bias(real_seq_length, key_length) + + # if key and values are already calculated we want only the last query position bias + if past_key_value is not None: + if not self.has_relative_attention_bias: + position_bias = position_bias[:, :, -seq_length:, :] + else: + # we might have a padded past structure, in which case we want to fetch the position bias slice + # right after the most recently filled past index + most_recently_filled_past_index = tf.reduce_max(tf.where(past_key_value[0][0, 0, :, 0] != 0.0)) + position_bias = dynamic_slice( + position_bias, + (0, 0, most_recently_filled_past_index + 1, 0), + (1, self.n_heads, seq_length, real_seq_length), + ) + + if mask is not None: + position_bias = tf.cast(position_bias, dtype=mask.dtype) + position_bias = position_bias + mask # (batch_size, n_heads, query_length, key_length) + + scores += position_bias + weights = stable_softmax(scores, axis=-1) # (batch_size, n_heads, query_length, key_length) + weights = self.dropout(weights, training=training) # (batch_size, n_heads, query_length, key_length) + + # Mask heads if we want to + if layer_head_mask is not None: + tf.debugging.assert_equal( + shape_list(layer_head_mask), + [self.n_heads], + message=( + f"Head mask for a single layer should be of size {(self.n_heads)}, but is" + f" {shape_list(layer_head_mask)}" + ), + ) + weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * weights + + attn_output = tf.matmul(weights, value_states) # (batch_size, n_heads, query_length, dim_per_head) + + attn_output = self.o(unshape(attn_output)) + + outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + + if output_attentions: + outputs = outputs + (weights,) + + return outputs + + +# Copied from transformers.models.t5.modeling_tf_t5.TFT5LayerSelfAttention with T5->SwitchTransformers +class TFSwitchTransformersLayerSelfAttention(tf.keras.layers.Layer): + def __init__(self, config, has_relative_attention_bias=False, **kwargs): + super().__init__(**kwargs) + self.SelfAttention = TFSwitchTransformersAttention( + config, + has_relative_attention_bias=has_relative_attention_bias, + name="SelfAttention", + ) + self.layer_norm = TFSwitchTransformersLayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm") + self.dropout = tf.keras.layers.Dropout(config.dropout_rate) + + def call( + self, + hidden_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + training=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + normed_hidden_states, + mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + training=training, + ) + hidden_states = hidden_states + self.dropout(attention_output[0], training=training) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.t5.modeling_tf_t5.TFT5LayerCrossAttention with T5->SwitchTransformers +class TFSwitchTransformersLayerCrossAttention(tf.keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.EncDecAttention = TFSwitchTransformersAttention( + config, + has_relative_attention_bias=False, + name="EncDecAttention", + ) + self.layer_norm = TFSwitchTransformersLayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm") + self.dropout = tf.keras.layers.Dropout(config.dropout_rate) + + def call( + self, + hidden_states, + key_value_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + query_length=None, + use_cache=False, + output_attentions=False, + training=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + normed_hidden_states, + mask=attention_mask, + key_value_states=key_value_states, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + query_length=query_length, + use_cache=use_cache, + output_attentions=output_attentions, + training=training, + ) + hidden_states = hidden_states + self.dropout(attention_output[0], training=training) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.t5.modeling_tf_t5.TFT5Block with T5->SwitchTransformers +class TFSwitchTransformersBlock(tf.keras.layers.Layer): + def __init__(self, config, has_relative_attention_bias=False, **kwargs): + super().__init__(**kwargs) + self.is_decoder = config.is_decoder + self.layer = [] + self.layer.append( + TFSwitchTransformersLayerSelfAttention( + config, + has_relative_attention_bias=has_relative_attention_bias, + name="layer_._0", + ) + ) + if self.is_decoder: + self.layer.append( + TFSwitchTransformersLayerCrossAttention( + config, + name="layer_._1", + ) + ) + + self.layer.append(TFSwitchTransformersLayerFF(config, name=f"layer_._{len(self.layer)}")) + + def call( + self, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + layer_head_mask=None, + encoder_layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + training=False, + ): + + if past_key_value is not None: + assert self.is_decoder, "Only decoder can use `past_key_values`" + expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 + + if len(past_key_value) != expected_num_past_key_values: + raise ValueError( + f"There should be {expected_num_past_key_values} past states. " + f"{'2 (past / key) for cross attention' if expected_num_past_key_values == 4 else ''}." + f"Got {len(past_key_value)} past key / value states" + ) + + self_attn_past_key_value = past_key_value[:2] + cross_attn_past_key_value = past_key_value[2:] + else: + self_attn_past_key_value, cross_attn_past_key_value = None, None + + self_attention_outputs = self.layer[0]( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=self_attn_past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + training=training, + ) + hidden_states, present_key_value_state = self_attention_outputs[:2] + attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights + + if self.is_decoder and encoder_hidden_states is not None: + # the actual query length is unknown for cross attention + # if using past key value states. Need to inject it here + if present_key_value_state is not None: + query_length = shape_list(present_key_value_state[0])[2] + else: + query_length = None + + cross_attention_outputs = self.layer[1]( + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_bias=encoder_decoder_position_bias, + layer_head_mask=encoder_layer_head_mask, + past_key_value=cross_attn_past_key_value, + query_length=query_length, + use_cache=use_cache, + output_attentions=output_attentions, + training=training, + ) + hidden_states = cross_attention_outputs[0] + # Combine self attn and cross attn key value states + if present_key_value_state is not None: + present_key_value_state = present_key_value_state + cross_attention_outputs[1] + + # Keep cross-attention outputs and relative position weights + attention_outputs = attention_outputs + cross_attention_outputs[2:] + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states, training=training) + outputs = (hidden_states,) + + # Add attentions if we output them + outputs = outputs + (present_key_value_state,) + attention_outputs + return outputs # hidden-states, present_key_value_states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) + + +#################################################### +# The full model without a specific pretrained or finetuning head is +# provided as a tf.keras.layers.Layer usually called "TFSwitchTransformersMainLayer" +#################################################### +@keras_serializable +# Copied from transformers.models.t5.modeling_tf_t5.TFT5MainLayer with T5->SwitchTransformers +class TFSwitchTransformersMainLayer(tf.keras.layers.Layer): + config_class = SwitchTransformersConfig + + def __init__(self, config, embed_tokens=None, **kwargs): + super().__init__(**kwargs) + + self.config = config + self.output_hidden_states = config.output_hidden_states + self.output_attentions = config.output_attentions + self.use_cache = config.use_cache + + self.embed_tokens = embed_tokens + self.is_decoder = config.is_decoder + + self.config = config + self.num_hidden_layers = config.num_layers + + self.block = [ + TFSwitchTransformersBlock(config, has_relative_attention_bias=bool(i == 0), name=f"block_._{i}") + for i in range(config.num_layers) + ] + self.final_layer_norm = TFSwitchTransformersLayerNorm( + epsilon=config.layer_norm_epsilon, name="final_layer_norm" + ) + self.dropout = tf.keras.layers.Dropout(config.dropout_rate) + + def _prune_heads(self, heads_to_prune): + raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models + + @unpack_inputs + def call( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=None, + head_mask=None, + encoder_head_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + training=False, + ) -> Tuple: + + if input_ids is not None and inputs_embeds is not None: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError( + f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = shape_list(input_ids) + input_ids = tf.reshape(input_ids, (-1, input_shape[-1])) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] + else: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") + + if inputs_embeds is None: + assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings" + # Note: tf.gather, on which the embedding layer is based, won't check positive out of bound + # indices on GPU, returning zeros instead. This is a dangerous silent behavior. + tf.debugging.assert_less( + input_ids, + tf.cast(self.embed_tokens.vocab_size, dtype=input_ids.dtype), + message=( + "input_ids must be smaller than the embedding layer's input dimension (got" + f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.vocab_size})" + ), + ) + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = input_shape + + # required mask seq length can be calculated via length of past + mask_seq_length = ( + shape_list(past_key_values[0][0])[2] + seq_length if past_key_values is not None else seq_length + ) + + if attention_mask is None: + attention_mask = tf.fill((batch_size, mask_seq_length), 1) + if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: + encoder_seq_length = shape_list(encoder_hidden_states)[1] + encoder_attention_mask = tf.fill((batch_size, encoder_seq_length), 1) + + # initialize past_key_values with `None` if past does not exist + if past_key_values is None: + past_key_values = [None] * len(self.block) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + attention_mask = tf.cast(attention_mask, dtype=inputs_embeds.dtype) + num_dims_attention_mask = len(shape_list(attention_mask)) + if num_dims_attention_mask == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif num_dims_attention_mask == 2: + # Provided a padding mask of dimensions [batch_size, mask_seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] + if self.is_decoder: + seq_ids = tf.range(mask_seq_length) + causal_mask = tf.less_equal( + tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)), + seq_ids[None, :, None], + ) + causal_mask = tf.cast(causal_mask, dtype=attention_mask.dtype) + extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + if past_key_values[0] is not None: + extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :] + else: + extended_attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -1e9 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + + # SwitchTransformers has a mask that can compare sequence ids, we can simulate this here with this transposition + # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 + # extended_attention_mask = tf.math.equal(extended_attention_mask, + # tf.transpose(extended_attention_mask, perm=(-1, -2))) + + extended_attention_mask = (1.0 - extended_attention_mask) * -1e9 + + if self.is_decoder and encoder_attention_mask is not None: + # If a 2D ou 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=extended_attention_mask.dtype) + num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask)) + if num_dims_encoder_attention_mask == 3: + encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] + if num_dims_encoder_attention_mask == 2: + encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] + + # SwitchTransformers has a mask that can compare sequence ids, we can simulate this here with this transposition + # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 + # encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask, + # tf.transpose(encoder_extended_attention_mask, perm=(-1, -2))) + + encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9 + else: + encoder_extended_attention_mask = None + + present_key_value_states = () if use_cache and self.is_decoder else None + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.is_decoder) else None + position_bias = None + encoder_decoder_position_bias = None + + hidden_states = self.dropout(inputs_embeds, training=training) + + for idx, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + layer_outputs = layer_module( + hidden_states, + attention_mask=extended_attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=head_mask[idx] if head_mask is not None else None, + encoder_layer_head_mask=encoder_head_mask[idx] if encoder_head_mask is not None else None, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + training=training, + ) + + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) + hidden_states, present_key_value_state = layer_outputs[:2] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, past_key_values, (self-attention weights), + # (self-attention position bias), (cross-attention position bias), (cross-attention weights), + position_bias = layer_outputs[2] + + if self.is_decoder and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] + + # append next layer key value states + if present_key_value_state is not None and use_cache and self.is_decoder: + present_key_value_states = present_key_value_states + (present_key_value_state,) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[3],) + if self.is_decoder: + all_cross_attentions = all_cross_attentions + (layer_outputs[5],) + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + outputs = (hidden_states,) + # need to check if is decoder here as well for special cases when using keras compile + if use_cache and self.is_decoder: + outputs = outputs + (present_key_value_states,) + if output_hidden_states: + outputs = outputs + (all_hidden_states,) + if output_attentions: + outputs = outputs + (all_attentions,) + if self.is_decoder: + outputs + (all_cross_attentions,) + return outputs # last-layer hidden state, (past_key_values), (all hidden states), (all attentions), (all_cross_attentions) + + if self.is_decoder: + return TFBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=present_key_value_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + else: + return TFBaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + ) + + +#################################################### +# TFSwitchTransformersPreTrainedModel is a sub-class of tf.keras.Model +# which take care of loading and saving pretrained weights +# and various common utilities. +# Here you just need to specify a few (self-explanatory) +# pointers for your model. +#################################################### +# Copied from transformers.models.t5.modeling_tf_t5.TFT5PreTrainedModel with T5->SwitchTransformers +class TFSwitchTransformersPreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SwitchTransformersConfig + base_model_prefix = "transformer" + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [r"decoder\Wblock[\W_0]+layer[\W_1]+EncDecAttention\Wrelative_attention_bias"] + + @property + def dummy_inputs(self): + inputs = tf.constant(DUMMY_INPUTS) + input_mask = tf.constant(DUMMY_MASK) + dummy_inputs = { + "input_ids": inputs, + "decoder_input_ids": inputs, + "decoder_attention_mask": input_mask, + } + return dummy_inputs + + @tf.function( + input_signature=[ + { + "input_ids": tf.TensorSpec((None, None), tf.int64, name="input_ids"), + "attention_mask": tf.TensorSpec((None, None), tf.int64, name="attention_mask"), + "decoder_input_ids": tf.TensorSpec((None, None), tf.int64, name="decoder_input_ids"), + "decoder_attention_mask": tf.TensorSpec((None, None), tf.int64, name="decoder_attention_mask"), + } + ] + ) + def serving(self, inputs): + output = self.call(inputs) + + return self.serving_output(output) + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, value): + try: + self.shared.weight = value + except AttributeError: + self(self.dummy_inputs) + self.shared.weight = value + + self.shared.vocab_size = shape_list(value)[0] + # retrieve correct absolute scope for embed token wrapper + with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name: + pass + # Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope. + embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name) + self.encoder.embed_tokens = embed_tokens + if hasattr(self, "decoder"): + self.decoder.embed_tokens = embed_tokens + + def _shift_right(self, input_ids): + decoder_start_token_id = self.config.decoder_start_token_id + pad_token_id = self.config.pad_token_id + + assert decoder_start_token_id is not None, ( + "self.model.config.decoder_start_token_id has to be defined. In TF SwitchTransformers it is usually set to" + " the pad_token_id. See SwitchTransformers docs for more information" + ) + + start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id) + start_tokens = tf.cast(start_tokens, input_ids.dtype) # Ensure compatible dtypes for concatenation + shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1) + + assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids = tf.where( + shifted_input_ids == -100, + tf.cast(tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids.dtype), + shifted_input_ids, + ) + + # "Verify that `labels` has only positive values and -100" + assert_gte0 = tf.debugging.assert_greater_equal( + shifted_input_ids, tf.constant(0, dtype=shifted_input_ids.dtype) + ) + + # Make sure the assertion op is called by wrapping the result in an identity no-op + with tf.control_dependencies([assert_gte0]): + shifted_input_ids = tf.identity(shifted_input_ids) + + return shifted_input_ids + + +SWITCHTRANSFORMERS_START_DOCSTRING = r""" + + The SWITCHTRANSFORMERS model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text + Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan + Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a + text-to-text denoising generative setting. + + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a [tf.keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it + as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and + behavior. + + + + TensorFlow models and layers in `transformers` accept two formats as input: + + - having all inputs as keyword arguments (like PyTorch models), or + - having all inputs as a list, tuple or dict in the first positional argument. + + The reason the second format is supported is that Keras methods prefer this format when passing inputs to models + and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just + pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second + format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with + the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first + positional argument: + + - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` + - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: + `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` + - a dictionary with one or several input Tensors associated to the input names given in the docstring: + `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` + + Note that when creating models and layers with + [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry + about any of this, as you can just pass inputs like you would to any other Python function! + + + + Parameters: + config ([`SwitchTransformersConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +SWITCHTRANSFORMERS_INPUTS_DOCSTRING = r""" + Args: + input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. SWITCHTRANSFORMERS is a model with relative position + embeddings so you should be able to pad the inputs on the right or the left. + + Indices can be obtained using [`BertTokenizer`]. See [`PreTrainedTokenizer.__call__`] and + [`PreTrainedTokenizer.encode`] for details. + + [What are input IDs?](../glossary#input-ids) + + To know more on how to prepare `inputs` for pretraining take a look at [SWITCHTRANSFORMERS + Training](./switchtransformers#training). + decoder_input_ids (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Provide for sequence to sequence training. SWITCHTRANSFORMERS uses the `pad_token_id` as the starting token + for `decoder_input_ids` generation. If `past_key_values` is used, optionally only the last + `decoder_input_ids` have to be input (see `past_key_values`). + + To know more on how to prepare `decoder_input_ids` for pretraining take a look at [SWITCHTRANSFORMERS + Training](./switchtransformers#training). + attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_attention_mask (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + head_mask (`tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (`tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_outputs (`tuple(tuple(tf.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at + the output of the last layer of the encoder. Used in the cross-attention of the decoder. + past_key_values (`tuple(tuple(tf.Tensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + decoder_inputs_embeds (`tf.Tensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded + representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be + input (see `past_key_values`). This is useful if you want more control over how to convert + `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + + If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value + of `inputs_embeds`. + use_cache (`bool`, *optional*, defaults to `True`): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the + config will be used instead. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. This argument can be used only in eager mode, in graph mode the value in the config will be + used instead. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in + eager mode, in graph mode the value will always be set to True. + training (`bool`, *optional*, defaults to `False`): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). +""" + +SWITCHTRANSFORMERS_ENCODER_INPUTS_DOCSTRING = r""" + Args: + inputs (`tf.Tensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. SWITCHTRANSFORMERS is a model with relative position + embeddings so you should be able to pad the inputs on the right or the left. + + Indices can be obtained using [`SwitchTransformersTokenizer`]. See [`PreTrainedTokenizer.__call__`] and + [`PreTrainedTokenizer.encode`] for details. + + To know more on how to prepare `inputs` for pre-training take a look at [SWITCHTRANSFORMERS + Training](./switchtransformers#training). + attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + head_mask (`tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + training (`bool`, *optional*, defaults to `False`): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). +""" + +_HEAD_MASK_WARNING_MSG = """ +The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently, +`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions. +If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = tf.ones((num_layers, +num_heads))`. +""" + + +@add_start_docstrings( + "The bare SWITCHTRANSFORMERS Model transformer outputting raw hidden-stateswithout any specific head on top.", + SWITCHTRANSFORMERS_START_DOCSTRING, +) +class TFSwitchTransformersModel(TFSwitchTransformersPreTrainedModel): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.shared = TFSharedEmbeddings( + config.vocab_size, config.d_model, name="shared", initializer_range=self.config.initializer_factor + ) + + # retrieve correct absolute scope for embed token wrapper + with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name: + pass + # Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope. + embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name) + + encoder_config = copy.deepcopy(config) + encoder_config.use_cache = False + self.encoder = TFSwitchTransformersMainLayer(encoder_config, embed_tokens, name="encoder") + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.num_layers = config.num_decoder_layers + self.decoder = TFSwitchTransformersMainLayer(decoder_config, embed_tokens, name="decoder") + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + @unpack_inputs + @add_start_docstrings_to_model_forward(SWITCHTRANSFORMERS_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFSeq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) + def call( + self, + input_ids: Optional[TFModelInputType] = None, + attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, + decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + decoder_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + encoder_outputs: Optional[Union[np.ndarray, tf.Tensor]] = None, + past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, + inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, + decoder_inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: Optional[bool] = False, + ) -> Union[Tuple, TFSeq2SeqModelOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from transformers import SwitchTransformersTokenizer, TFSwitchTransformersModel + + >>> tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switchtransformers-base") + >>> model = TFSwitchTransformersModel.from_pretrained("ybelkada/switchtransformers-base") + + >>> input_ids = tokenizer( + ... "Studies have been shown that owning a dog is good for you", return_tensors="tf" + ... ).input_ids # Batch size 1 + >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="tf").input_ids # Batch size 1 + + >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for SwitchTransformersModel. + >>> # This is not needed for torch's SwitchTransformersForConditionalGeneration as it does this internally using labels arg. + >>> decoder_input_ids = model._shift_right(decoder_input_ids) + + >>> # forward pass + >>> outputs = model(input_ids, decoder_input_ids=decoder_input_ids) + >>> last_hidden_states = outputs.last_hidden_state + ```""" + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + warnings.warn(_HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids, + attention_mask=attention_mask, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + past_key_values=None, + use_cache=False, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + hidden_states = encoder_outputs[0] + + # Decode + decoder_outputs = self.decoder( + decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + inputs_embeds=decoder_inputs_embeds, + head_mask=decoder_head_mask, + encoder_head_mask=head_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + past = decoder_outputs[1] if use_cache else None + + if not return_dict: + if past is not None: + decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:] + return decoder_outputs + encoder_outputs + + return TFSeq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=past, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def serving_output(self, output): + pkv = tf.convert_to_tensor(output.past_key_values[1:]) if self.config.use_cache else None + dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None + dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None + enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None + enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None + + return TFSeq2SeqModelOutput( + last_hidden_state=output.last_hidden_state, + past_key_values=pkv, + decoder_hidden_states=dec_hs, + decoder_attentions=dec_attns, + encoder_last_hidden_state=output.encoder_last_hidden_state, + cross_attentions=cross_attns, + encoder_hidden_states=enc_hs, + encoder_attentions=enc_attns, + ) + + +@add_start_docstrings("""SWITCHTRANSFORMERS Model with a `language modeling` head on top.""", SWITCHTRANSFORMERS_START_DOCSTRING) +class TFSwitchTransformersForConditionalGeneration(TFSwitchTransformersPreTrainedModel, TFCausalLanguageModelingLoss): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.model_dim = config.d_model + self.shared = TFSharedEmbeddings( + config.vocab_size, config.d_model, name="shared", initializer_range=self.config.initializer_factor + ) + + # retrieve correct absolute scope for embed token wrapper + with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name: + pass + # Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope. + embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name) + + encoder_config = copy.deepcopy(config) + encoder_config.use_cache = False + self.encoder = TFSwitchTransformersMainLayer(encoder_config, embed_tokens, name="encoder") + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.num_layers = config.num_decoder_layers + self.decoder = TFSwitchTransformersMainLayer(decoder_config, embed_tokens, name="decoder") + + if not config.tie_word_embeddings: + lm_head_initializer = tf.keras.initializers.RandomNormal(mean=0, stddev=config.initializer_factor) + self.lm_head = tf.keras.layers.Dense( + config.vocab_size, use_bias=False, name="lm_head", kernel_initializer=lm_head_initializer + ) # Update init weights as in flax + + def get_output_embeddings(self): + if self.config.tie_word_embeddings: + return self.get_input_embeddings() + else: + # in a dense layer the kernel has a shape (last_dim, units), for us (dim, num_tokens) + # value has a shape (num_tokens, dim) then needs to be transposed + return tf.transpose(self.lm_head.kernel) + + def set_output_embeddings(self, value): + if self.config.tie_word_embeddings: + self.set_input_embeddings(value) + else: + lm_head_initializer = tf.keras.initializers.RandomNormal(mean=0, stddev=self.config.initializer_factor) + self.lm_head = tf.keras.layers.Dense( + shape_list(value)[0], use_bias=False, name="lm_head", kernel_initializer=lm_head_initializer + ) # Update init weights as in flax + # in a dense layer the kernel has a shape (last_dim, units), for us (dim, num_tokens) + # value has a shape (num_tokens, dim) then needs to be transposed + transposed_value = tf.transpose(value) + self.lm_head.kernel = transposed_value + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + @unpack_inputs + @add_start_docstrings_to_model_forward(SWITCHTRANSFORMERS_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + def call( + self, + input_ids: Optional[TFModelInputType] = None, + attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, + decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + decoder_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + encoder_outputs: Optional[Union[np.ndarray, tf.Tensor]] = None, + past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, + inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, + decoder_inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, + labels: Optional[Union[np.ndarray, tf.Tensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: Optional[bool] = False, + ) -> Union[Tuple, TFSeq2SeqLMOutput]: + r""" + labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., + config.vocab_size - 1]`. + + Returns: + + Examples: + + ```python + >>> from transformers import SwitchTransformersTokenizer, TFSwitchTransformersForConditionalGeneration + + >>> tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switchtransformers-base") + >>> model = TFSwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switchtransformers-base") + + >>> # training + >>> inputs = tokenizer("The walks in park", return_tensors="tf").input_ids + >>> labels = tokenizer(" cute dog the ", return_tensors="tf").input_ids + >>> outputs = model(inputs, labels=labels) + >>> loss = outputs.loss + >>> logits = outputs.logits + + >>> # inference + >>> inputs = tokenizer( + ... "summarize: studies have shown that owning a dog is good for you", return_tensors="tf" + ... ).input_ids # Batch size 1 + >>> outputs = model.generate(inputs) + >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) + >>> # studies have shown that owning a dog is good for you + ```""" + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + warnings.warn(_HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + hidden_states = encoder_outputs[0] + + if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + + # Decode + decoder_outputs = self.decoder( + decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + inputs_embeds=decoder_inputs_embeds, + head_mask=decoder_head_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = decoder_outputs[0] + + # SwitchTransformersv1.1 does not tie output word embeddings and thus does not require downscaling + if self.config.tie_word_embeddings: + sequence_output = sequence_output * (self.model_dim**-0.5) + logits = self.shared(sequence_output, mode="linear") + else: + logits = self.lm_head(sequence_output) + + logits = tf.cast(logits, tf.float32) + + loss = None if labels is None else self.hf_compute_loss(labels, logits) + + past = decoder_outputs[1] if use_cache else None + if not return_dict: + if past is not None: + decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:] + output = (logits,) + decoder_outputs[1:] + encoder_outputs + return ((loss,) + output) if loss is not None else output + + # If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True + elif isinstance(encoder_outputs, tuple): + last_hidden_state = encoder_outputs[0] + hidden_states = None + attentions = None + idx = 0 + if output_hidden_states: + idx += 1 + hidden_states = encoder_outputs[idx] + if output_attentions: + idx += 1 + attentions = encoder_outputs[idx] + + encoder_outputs = TFBaseModelOutput( + last_hidden_state=last_hidden_state, + hidden_states=hidden_states, + attentions=attentions, + ) + + return TFSeq2SeqLMOutput( + loss=loss, + logits=logits, + past_key_values=past, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def serving_output(self, output): + pkv = tf.convert_to_tensor(output.past_key_values[1:]) if self.config.use_cache else None + dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None + dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None + enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None + enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None + + return TFSeq2SeqLMOutput( + logits=output.logits, + past_key_values=pkv, + decoder_hidden_states=dec_hs, + decoder_attentions=dec_attns, + cross_attentions=cross_attns, + encoder_last_hidden_state=output.encoder_last_hidden_state, + encoder_hidden_states=enc_hs, + encoder_attentions=enc_attns, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past=None, + attention_mask=None, + decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs + ): + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + "input_ids": None, # needs to be passed to make Keras.layer.__call__ happy + "decoder_input_ids": input_ids, + "past_key_values": past, + "encoder_outputs": encoder_outputs, + "attention_mask": attention_mask, + "decoder_attention_mask": decoder_attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "use_cache": use_cache, + } + + def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor): + return self._shift_right(labels) + + def _reorder_cache(self, past, beam_idx): + # if decoder past is not included in output + # speedy decoding is disabled and no need to reorder + if past is None: + logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") + return past + + reordered_decoder_past = () + for layer_past_states in past: + # get the correct batch idx from layer past batch dim + # batch dim of `past` is at 2nd position + reordered_layer_past_states = () + for layer_past_state in layer_past_states: + # need to set correct `past` for each of the four key / value states + reordered_layer_past_states = reordered_layer_past_states + ( + tf.gather(layer_past_state, beam_idx, axis=0), + ) + + assert reordered_layer_past_states[0].shape == layer_past_states[0].shape + assert len(reordered_layer_past_states) == len(layer_past_states) + + reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) + return reordered_decoder_past + + +@add_start_docstrings( + "The bare SWITCHTRANSFORMERS Model transformer outputting encoder's raw hidden-stateswithout any specific head on top.", + SWITCHTRANSFORMERS_START_DOCSTRING, +) +class TFSwitchTransformersEncoderModel(TFSwitchTransformersPreTrainedModel): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.shared = TFSharedEmbeddings( + config.vocab_size, config.d_model, name="shared", initializer_range=self.config.initializer_factor + ) + + # retrieve correct absolute scope for embed token wrapper + with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name: + pass + # Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope. + embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name) + + encoder_config = copy.deepcopy(config) + encoder_config.use_cache = False + self.encoder = TFSwitchTransformersMainLayer(encoder_config, embed_tokens, name="encoder") + + @property + def dummy_inputs(self): + return {"input_ids": tf.constant(DUMMY_INPUTS)} + + def get_encoder(self): + return self.encoder + + @unpack_inputs + @add_start_docstrings_to_model_forward(SWITCHTRANSFORMERS_ENCODER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFBaseModelOutput, config_class=_CONFIG_FOR_DOC) + def call( + self, + input_ids: Optional[TFModelInputType] = None, + attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: Optional[bool] = False, + ) -> Union[Tuple, TFBaseModelOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from transformers import SwitchTransformersTokenizer, TFSwitchTransformersEncoderModel + + >>> tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switchtransformers-base") + >>> model = TFSwitchTransformersEncoderModel.from_pretrained("ybelkada/switchtransformers-base") + + >>> input_ids = tokenizer( + ... "Studies have been shown that owning a dog is good for you", return_tensors="tf" + ... ).input_ids # Batch size 1 + >>> outputs = model(input_ids) + ```""" + + encoder_outputs = self.encoder( + input_ids, + attention_mask=attention_mask, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + past_key_values=None, + use_cache=False, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + if not return_dict: + return encoder_outputs + + return TFBaseModelOutput( + last_hidden_state=encoder_outputs.last_hidden_state, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + @tf.function( + input_signature=[ + { + "input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"), + "attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"), + } + ] + ) + def serving(self, inputs): + output = self.call(inputs) + + return self.serving_output(output) + + def serving_output(self, output): + hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None + attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None + + return TFBaseModelOutput(last_hidden_state=output.last_hidden_state, hidden_states=hs, attentions=attns) diff --git a/src/transformers/models/switchtransformers/tokenization_switchtransformers.py b/src/transformers/models/switchtransformers/tokenization_switchtransformers.py new file mode 100644 index 0000000000000..d90721d76166a --- /dev/null +++ b/src/transformers/models/switchtransformers/tokenization_switchtransformers.py @@ -0,0 +1,339 @@ +# coding=utf-8 +# Copyright 2022 SwitchTransformers Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Tokenization class for model SwitchTransformers.""" + + +import os +import re +import warnings +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple + +import sentencepiece as spm + +from ...tokenization_utils import PreTrainedTokenizer +from ...utils import logging + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"} + +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + "ybelkada/switchtransformers-base": "https://huggingface.co/ybelkada/switchtransformers-base/resolve/main/spiece.model", + "switchtransformers-base": "https://huggingface.co/switchtransformers-base/resolve/main/spiece.model", + "switchtransformers-large": "https://huggingface.co/switchtransformers-large/resolve/main/spiece.model", + "switchtransformers-3b": "https://huggingface.co/switchtransformers-3b/resolve/main/spiece.model", + "switchtransformers-11b": "https://huggingface.co/switchtransformers-11b/resolve/main/spiece.model", + } +} + + +# TODO(PVP) - this should be removed in Transformers v5 +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "ybelkada/switchtransformers-base": 512, + "switchtransformers-base": 512, + "switchtransformers-large": 512, + "switchtransformers-3b": 512, + "switchtransformers-11b": 512, +} + + +class SwitchTransformersTokenizer(PreTrainedTokenizer): + """ + Construct a SwitchTransformers tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece). + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that + contains the vocabulary necessary to instantiate a tokenizer. + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + extra_ids (`int`, *optional*, defaults to 100): + Add a number of extra ids added to the end of the vocabulary for use as sentinels. These tokens are + accessible as "" where "{%d}" is a number between 0 and extra_ids-1. Extra tokens are + indexed from the end of the vocabulary up to beginning ("" is the last token in the vocabulary + like in SwitchTransformers preprocessing see + [here](https://github.com/google-research/text-to-text-transfer-transformer/blob/9fd7b14a769417be33bc6c850f9598764913c833/switchtransformers/data/preprocessors.py#L2117)). + additional_special_tokens (`List[str]`, *optional*): + Additional special tokens used by the tokenizer. + sp_model_kwargs (`dict`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + + Attributes: + sp_model (`SentencePieceProcessor`): + The *SentencePiece* processor that is used for every conversion (string, tokens and IDs). + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + eos_token="", + unk_token="", + pad_token="", + extra_ids=100, + additional_special_tokens=None, + sp_model_kwargs: Optional[Dict[str, Any]] = None, + **kwargs + ) -> None: + # Add extra_ids to the special token list + if extra_ids > 0 and additional_special_tokens is None: + additional_special_tokens = [f"" for i in range(extra_ids)] + elif extra_ids > 0 and additional_special_tokens is not None: + # Check that we have the right number of extra_id special tokens + extra_tokens = len(set(filter(lambda x: bool("extra_id" in str(x)), additional_special_tokens))) + if extra_tokens != extra_ids: + raise ValueError( + f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are" + " provided to SwitchTransformersTokenizer. In this case the additional_special_tokens must include the extra_ids" + " tokens" + ) + + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + + super().__init__( + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + extra_ids=extra_ids, + additional_special_tokens=additional_special_tokens, + sp_model_kwargs=self.sp_model_kwargs, + **kwargs, + ) + + self.vocab_file = vocab_file + self._extra_ids = extra_ids + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(vocab_file) + + @staticmethod + def _eventually_correct_switchtransformers_max_length(pretrained_model_name_or_path, max_model_length, init_max_model_length): + if pretrained_model_name_or_path in SwitchTransformersTokenizer.max_model_input_sizes: + deprecated_max_model_length = SwitchTransformersTokenizer.max_model_input_sizes[pretrained_model_name_or_path] + if init_max_model_length is not None and init_max_model_length != max_model_length: + return init_max_model_length + elif init_max_model_length is None: + warnings.warn( + "This tokenizer was incorrectly instantiated with a model max length of" + f" {deprecated_max_model_length} which will be corrected in Transformers v5.\nFor now, this" + " behavior is kept to avoid breaking backwards compatibility when padding/encoding with" + " `truncation is True`.\n- Be aware that you SHOULD NOT rely on" + f" {pretrained_model_name_or_path} automatically truncating your input to" + f" {deprecated_max_model_length} when padding/encoding.\n- If you want to encode/pad to sequences" + f" longer than {deprecated_max_model_length} you can either instantiate this tokenizer with" + " `model_max_length` or pass `max_length` when encoding/padding.\n- To avoid this warning, please" + " instantiate this tokenizer with `model_max_length` set to your preferred value.", + FutureWarning, + ) + + return max_model_length + + @property + def vocab_size(self): + return self.sp_model.get_piece_size() + self._extra_ids + + def get_vocab(self): + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + # normal case: some special tokens + if token_ids_1 is None: + return ([0] * len(token_ids_0)) + [1] + return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + + def _add_eos_if_not_present(self, token_ids: List[int]) -> List[int]: + """Do not add eos again if user already added it.""" + if len(token_ids) > 0 and token_ids[-1] == self.eos_token_id: + warnings.warn( + f"This sequence already has {self.eos_token}. In future versions this behavior may lead to duplicated" + " eos tokens being added." + ) + return token_ids + else: + return token_ids + [self.eos_token_id] + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. + SwitchTransformers does not make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of zeros. + """ + eos = [self.eos_token_id] + + if token_ids_1 is None: + return len(token_ids_0 + eos) * [0] + return len(token_ids_0 + eos + token_ids_1 + eos) * [0] + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A sequence has the following format: + + - single sequence: `X ` + - pair of sequences: `A B ` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + token_ids_0 = self._add_eos_if_not_present(token_ids_0) + if token_ids_1 is None: + return token_ids_0 + else: + token_ids_1 = self._add_eos_if_not_present(token_ids_1) + return token_ids_0 + token_ids_1 + + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + return state + + def __setstate__(self, d): + self.__dict__ = d + + # for backward compatibility + if not hasattr(self, "sp_model_kwargs"): + self.sp_model_kwargs = {} + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(self.vocab_file) + + def _tokenize(self, text: str) -> List[str]: + """Take as input a string and return a list of strings (tokens) for words/sub-words""" + return self.sp_model.encode(text, out_type=str) + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + if token.startswith("", token) + num = int(match.group(1)) + return self.vocab_size - num - 1 + return self.sp_model.piece_to_id(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + if index < self.sp_model.get_piece_size(): + token = self.sp_model.IdToPiece(index) + else: + token = f"" + return token + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + current_sub_tokens = [] + out_string = "" + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + out_string += self.sp_model.decode_pieces(current_sub_tokens) + token + " " + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + out_string += self.sp_model.decode_pieces(current_sub_tokens) + return out_string.strip() + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) diff --git a/src/transformers/models/switchtransformers/tokenization_switchtransformers_fast.py b/src/transformers/models/switchtransformers/tokenization_switchtransformers_fast.py new file mode 100644 index 0000000000000..e9f0302f48700 --- /dev/null +++ b/src/transformers/models/switchtransformers/tokenization_switchtransformers_fast.py @@ -0,0 +1,237 @@ +# coding=utf-8 +# Copyright 2022 SwitchTransformers Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Tokenization class for model SwitchTransformers.""" + + +import os +import warnings +from shutil import copyfile +from typing import List, Optional, Tuple + +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import is_sentencepiece_available, logging + + +if is_sentencepiece_available(): + from .tokenization_switchtransformers import SwitchTransformersTokenizer +else: + SwitchTransformersTokenizer = None + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "spiece.model", "tokenizer_file": "tokenizer.json"} + +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + "ybelkada/switchtransformers-base": "https://huggingface.co/ybelkada/switchtransformers-base/resolve/main/spiece.model", + "switchtransformers-base": "https://huggingface.co/switchtransformers-base/resolve/main/spiece.model", + "switchtransformers-large": "https://huggingface.co/switchtransformers-large/resolve/main/spiece.model", + "switchtransformers-3b": "https://huggingface.co/switchtransformers-3b/resolve/main/spiece.model", + "switchtransformers-11b": "https://huggingface.co/switchtransformers-11b/resolve/main/spiece.model", + }, + "tokenizer_file": { + "ybelkada/switchtransformers-base": "https://huggingface.co/ybelkada/switchtransformers-base/resolve/main/tokenizer.json", + "switchtransformers-base": "https://huggingface.co/switchtransformers-base/resolve/main/tokenizer.json", + "switchtransformers-large": "https://huggingface.co/switchtransformers-large/resolve/main/tokenizer.json", + "switchtransformers-3b": "https://huggingface.co/switchtransformers-3b/resolve/main/tokenizer.json", + "switchtransformers-11b": "https://huggingface.co/switchtransformers-11b/resolve/main/tokenizer.json", + }, +} + + +# TODO(PVP) - this should be removed in Transformers v5 +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "ybelkada/switchtransformers-base": 512, + "switchtransformers-base": 512, + "switchtransformers-large": 512, + "switchtransformers-3b": 512, + "switchtransformers-11b": 512, +} + + +class SwitchTransformersTokenizerFast(PreTrainedTokenizerFast): + """ + Construct a "fast" SwitchTransformers tokenizer (backed by HuggingFace's *tokenizers* library). Based on + [Unigram](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=unigram#models). + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that + contains the vocabulary necessary to instantiate a tokenizer. + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + extra_ids (`int`, *optional*, defaults to 100): + Add a number of extra ids added to the end of the vocabulary for use as sentinels. These tokens are + accessible as "" where "{%d}" is a number between 0 and extra_ids-1. Extra tokens are + indexed from the end of the vocabulary up to beginning ("" is the last token in the vocabulary + like in SwitchTransformers preprocessing see + [here](https://github.com/google-research/text-to-text-transfer-transformer/blob/9fd7b14a769417be33bc6c850f9598764913c833/switchtransformers/data/preprocessors.py#L2117)). + additional_special_tokens (`List[str]`, *optional*): + Additional special tokens used by the tokenizer. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + model_input_names = ["input_ids", "attention_mask"] + slow_tokenizer_class = SwitchTransformersTokenizer + + prefix_tokens: List[int] = [] + + def __init__( + self, + vocab_file=None, + tokenizer_file=None, + eos_token="", + unk_token="", + pad_token="", + extra_ids=100, + additional_special_tokens=None, + **kwargs + ): + # Add extra_ids to the special token list + if extra_ids > 0 and additional_special_tokens is None: + additional_special_tokens = [f"" for i in range(extra_ids)] + elif extra_ids > 0 and additional_special_tokens is not None: + # Check that we have the right number of extra special tokens + extra_tokens = len(set(filter(lambda x: bool("extra_id_" in str(x)), additional_special_tokens))) + if extra_tokens != extra_ids: + raise ValueError( + f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are" + " provided to SwitchTransformersTokenizer. In this case the additional_special_tokens must include the extra_ids" + " tokens" + ) + + super().__init__( + vocab_file, + tokenizer_file=tokenizer_file, + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + extra_ids=extra_ids, + additional_special_tokens=additional_special_tokens, + **kwargs, + ) + + self.vocab_file = vocab_file + self.can_save_slow_tokenizer = False if not self.vocab_file else True + self._extra_ids = extra_ids + + @staticmethod + def _eventually_correct_switchtransformers_max_length(pretrained_model_name_or_path, max_model_length, init_max_model_length): + if pretrained_model_name_or_path in SwitchTransformersTokenizerFast.max_model_input_sizes: + deprecated_max_model_length = SwitchTransformersTokenizerFast.max_model_input_sizes[pretrained_model_name_or_path] + if init_max_model_length is not None and init_max_model_length != max_model_length: + return init_max_model_length + elif init_max_model_length is None: + warnings.warn( + "This tokenizer was incorrectly instantiated with a model max length of" + f" {deprecated_max_model_length} which will be corrected in Transformers v5.\nFor now, this" + " behavior is kept to avoid breaking backwards compatibility when padding/encoding with" + " `truncation is True`.\n- Be aware that you SHOULD NOT rely on" + f" {pretrained_model_name_or_path} automatically truncating your input to" + f" {deprecated_max_model_length} when padding/encoding.\n- If you want to encode/pad to sequences" + f" longer than {deprecated_max_model_length} you can either instantiate this tokenizer with" + " `model_max_length` or pass `max_length` when encoding/padding.\n- To avoid this warning, please" + " instantiate this tokenizer with `model_max_length` set to your preferred value.", + FutureWarning, + ) + + return max_model_length + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not self.can_save_slow_tokenizer: + raise ValueError( + "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow " + "tokenizer." + ) + + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + copyfile(self.vocab_file, out_vocab_file) + logger.info(f"Copy vocab file to {out_vocab_file}") + + return (out_vocab_file,) + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A sequence has the following format: + + - single sequence: `X ` + - pair of sequences: `A B ` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + token_ids_0 = token_ids_0 + [self.eos_token_id] + if token_ids_1 is None: + return self.prefix_tokens + token_ids_0 + else: + token_ids_1 = token_ids_1 + [self.eos_token_id] + return self.prefix_tokens + token_ids_0 + token_ids_1 + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. + SwitchTransformers does not make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of zeros. + """ + eos = [self.eos_token_id] + + if token_ids_1 is None: + return len(token_ids_0 + eos) * [0] + return len(token_ids_0 + eos + token_ids_1 + eos) * [0] diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py index 953808dab8ad7..f26f9f2625138 100644 --- a/src/transformers/utils/dummy_flax_objects.py +++ b/src/transformers/utils/dummy_flax_objects.py @@ -977,6 +977,34 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) +class FlaxSwitchTransformersEncoderModel(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxSwitchTransformersForConditionalGeneration(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxSwitchTransformersModel(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxSwitchTransformersPreTrainedModel(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + class FlaxT5EncoderModel(metaclass=DummyObject): _backends = ["flax"] diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 9f540bd283863..117974de31ba2 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -4776,6 +4776,41 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +SWITCHTRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class SwitchTransformersEncoderModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class SwitchTransformersForConditionalGeneration(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class SwitchTransformersModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class SwitchTransformersPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +def load_tf_weights_in_switchtransformers(*args, **kwargs): + requires_backends(load_tf_weights_in_switchtransformers, ["torch"]) + + T5_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/src/transformers/utils/dummy_sentencepiece_objects.py b/src/transformers/utils/dummy_sentencepiece_objects.py index 69f0bdcb7b1aa..c73567d7fdc30 100644 --- a/src/transformers/utils/dummy_sentencepiece_objects.py +++ b/src/transformers/utils/dummy_sentencepiece_objects.py @@ -157,6 +157,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["sentencepiece"]) +class SwitchTransformersTokenizer(metaclass=DummyObject): + _backends = ["sentencepiece"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["sentencepiece"]) + + class T5Tokenizer(metaclass=DummyObject): _backends = ["sentencepiece"] diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index 3acc7804687df..06df4dbd872c4 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -2200,6 +2200,37 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["tf"]) +TF_SWITCHTRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class TFSwitchTransformersEncoderModel(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFSwitchTransformersForConditionalGeneration(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFSwitchTransformersModel(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFSwitchTransformersPreTrainedModel(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/src/transformers/utils/dummy_tokenizers_objects.py b/src/transformers/utils/dummy_tokenizers_objects.py index 8a24d9bea6b2c..48695eefde1d4 100644 --- a/src/transformers/utils/dummy_tokenizers_objects.py +++ b/src/transformers/utils/dummy_tokenizers_objects.py @@ -360,6 +360,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["tokenizers"]) +class SwitchTransformersTokenizerFast(metaclass=DummyObject): + _backends = ["tokenizers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tokenizers"]) + + class T5TokenizerFast(metaclass=DummyObject): _backends = ["tokenizers"] diff --git a/tests/models/switchtransformers/__init__.py b/tests/models/switchtransformers/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/models/switchtransformers/test_modeling_flax_switchtransformers.py b/tests/models/switchtransformers/test_modeling_flax_switchtransformers.py new file mode 100644 index 0000000000000..ba512e626ccc9 --- /dev/null +++ b/tests/models/switchtransformers/test_modeling_flax_switchtransformers.py @@ -0,0 +1,1099 @@ +# coding=utf-8 +# Copyright 2022 Google SwitchTransformers Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import tempfile +import unittest + +import numpy as np + +import transformers +from transformers import is_flax_available +from transformers.testing_utils import ( + is_pt_flax_cross_test, + require_flax, + require_sentencepiece, + require_tokenizers, + slow, +) + +from ...generation.test_generation_flax_utils import FlaxGenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor + + +if is_flax_available(): + import os + + # The slow tests are often failing with OOM error on GPU + # This makes JAX allocate exactly what is needed on demand, and deallocate memory that is no longer needed + # but will be slower as stated here https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html + os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" + + import jax + import jax.numpy as jnp + import optax + from flax.core.frozen_dict import unfreeze + from flax.training.common_utils import onehot + from flax.traverse_util import flatten_dict + from transformers import FLAX_MODEL_MAPPING, BySwitchTransformersTokenizer, SwitchTransformersConfig, SwitchTransformersTokenizer + from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model + from transformers.models.switchtransformers.modeling_flax_switchtransformers import ( + FlaxSwitchTransformersEncoderModel, + FlaxSwitchTransformersForConditionalGeneration, + FlaxSwitchTransformersModel, + shift_tokens_right, + ) + + +class FlaxSwitchTransformersModelTester: + def __init__( + self, + parent, + vocab_size=99, + batch_size=13, + encoder_seq_length=7, + decoder_seq_length=9, + # For common tests + is_training=True, + use_attention_mask=True, + use_labels=True, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + d_ff=37, + relative_attention_num_buckets=8, + dropout_rate=0.1, + initializer_factor=0.002, + eos_token_id=1, + pad_token_id=0, + decoder_start_token_id=0, + scope=None, + decoder_layers=None, + ): + + self.parent = parent + self.batch_size = batch_size + self.encoder_seq_length = encoder_seq_length + self.decoder_seq_length = decoder_seq_length + # For common tests + self.seq_length = self.decoder_seq_length + self.is_training = is_training + self.use_attention_mask = use_attention_mask + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.d_ff = d_ff + self.relative_attention_num_buckets = relative_attention_num_buckets + self.dropout_rate = dropout_rate + self.initializer_factor = initializer_factor + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.decoder_start_token_id = decoder_start_token_id + self.scope = None + self.decoder_layers = decoder_layers + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size) + decoder_input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) + + attention_mask = None + decoder_attention_mask = None + if self.use_attention_mask: + attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2) + decoder_attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2) + + config = SwitchTransformersConfig( + vocab_size=self.vocab_size, + d_model=self.hidden_size, + d_ff=self.d_ff, + d_kv=self.hidden_size // self.num_attention_heads, + num_layers=self.num_hidden_layers, + num_decoder_layers=self.decoder_layers, + num_heads=self.num_attention_heads, + relative_attention_num_buckets=self.relative_attention_num_buckets, + dropout_rate=self.dropout_rate, + initializer_factor=self.initializer_factor, + eos_token_id=self.eos_token_id, + bos_token_id=self.pad_token_id, + pad_token_id=self.pad_token_id, + decoder_start_token_id=self.decoder_start_token_id, + ) + + return ( + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + ) + + def create_and_check_model( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + ): + model = FlaxSwitchTransformersModel(config=config) + result = model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + ) + result = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + decoder_output = result.last_hidden_state + encoder_output = result.encoder_last_hidden_state + + self.parent.assertEqual(encoder_output.shape, (self.batch_size, self.encoder_seq_length, self.hidden_size)) + self.parent.assertEqual(decoder_output.shape, (self.batch_size, self.decoder_seq_length, self.hidden_size)) + + def check_use_cache_forward_with_attn_mask( + self, + model_class_name, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + ): + max_decoder_length = 20 + model = model_class_name(config) + + encoder_outputs = model.encode(input_ids) + + # prevent fully zero'd out attention mask + decoder_attention_mask = jnp.ones_like(decoder_attention_mask) + + decoder_attention_mask_cache = jnp.concatenate( + [ + decoder_attention_mask, + jnp.zeros((decoder_attention_mask.shape[0], max_decoder_length - decoder_attention_mask.shape[1])), + ], + axis=-1, + ) + + past_key_values = model.init_cache(decoder_input_ids.shape[0], max_decoder_length, encoder_outputs) + + outputs_cache = model.decode( + decoder_input_ids[:, :-1], + encoder_outputs, + decoder_attention_mask=decoder_attention_mask_cache, + past_key_values=past_key_values, + ) + outputs_cache_next = model.decode( + decoder_input_ids[:, -1:], + encoder_outputs, + past_key_values=outputs_cache.past_key_values, + decoder_attention_mask=decoder_attention_mask_cache, + ) + + outputs = model.decode(decoder_input_ids, encoder_outputs, decoder_attention_mask=decoder_attention_mask) + + diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5]))) + self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}") + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + ) = config_and_inputs + + inputs_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": decoder_attention_mask, + } + return config, inputs_dict + + +@require_flax +class FlaxSwitchTransformersModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase): + + all_model_classes = (FlaxSwitchTransformersModel, FlaxSwitchTransformersForConditionalGeneration) if is_flax_available() else () + all_generative_model_classes = (FlaxSwitchTransformersForConditionalGeneration,) if is_flax_available() else () + is_encoder_decoder = True + + def setUp(self): + self.model_tester = FlaxSwitchTransformersModelTester(self) + self.config_tester = ConfigTester(self, config_class=SwitchTransformersConfig, d_model=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_model_v1_1(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + # check that gated gelu feed forward and different word embeddings work + config = config_and_inputs[0] + config.tie_word_embeddings = False + config.feed_forward_proj = "gated-gelu" + self.model_tester.create_and_check_model(config, *config_and_inputs[1:]) + + def test_use_cache_forward_with_attn_mask(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + for model_class in self.all_model_classes: + self.model_tester.check_use_cache_forward_with_attn_mask(model_class, *config_and_inputs) + + def test_encode(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + with self.subTest(model_class.__name__): + prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) + model = model_class(config) + + @jax.jit + def encode_jitted(input_ids, attention_mask=None, **kwargs): + return model.encode(input_ids=input_ids, attention_mask=attention_mask) + + with self.subTest("JIT Enabled"): + jitted_outputs = encode_jitted(**prepared_inputs_dict).to_tuple() + + with self.subTest("JIT Disabled"): + with jax.disable_jit(): + outputs = encode_jitted(**prepared_inputs_dict).to_tuple() + + self.assertEqual(len(outputs), len(jitted_outputs)) + for jitted_output, output in zip(jitted_outputs, outputs): + self.assertEqual(jitted_output.shape, output.shape) + + def test_decode(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + with self.subTest(model_class.__name__): + model = model_class(config) + encoder_outputs = model.encode(inputs_dict["input_ids"], inputs_dict["attention_mask"]) + + prepared_inputs_dict = { + "decoder_input_ids": inputs_dict["decoder_input_ids"], + "decoder_attention_mask": inputs_dict["decoder_attention_mask"], + "encoder_outputs": encoder_outputs, + } + + @jax.jit + def decode_jitted(decoder_input_ids, decoder_attention_mask, encoder_outputs): + return model.decode( + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_outputs=encoder_outputs, + ) + + with self.subTest("JIT Enabled"): + jitted_outputs = decode_jitted(**prepared_inputs_dict).to_tuple() + + with self.subTest("JIT Disabled"): + with jax.disable_jit(): + outputs = decode_jitted(**prepared_inputs_dict).to_tuple() + + self.assertEqual(len(outputs), len(jitted_outputs)) + for jitted_output, output in zip(jitted_outputs, outputs): + self.assertEqual(jitted_output.shape, output.shape) + + def test_shift_right(self): + decoder_start_token_id = 0 + pad_token_id = 1 + labels = np.arange(2, 102).reshape(5, 20) + labels[:2, 15:] = -100 + + decoder_input_ids = shift_tokens_right(labels, pad_token_id, decoder_start_token_id) + np_decoder_input_ids = np.array(decoder_input_ids) + + padded_slice = np_decoder_input_ids[:2, (15 + 1) :] + self.assertTrue((padded_slice == 1).all()) + + not_padded_slice = np_decoder_input_ids[2:, 1:] + rolled_labels = np.roll(labels[2:], 1)[:, 1:] + self.assertTrue((not_padded_slice == rolled_labels).all()) + self.assertTrue((np_decoder_input_ids[:, 0] == 0).all()) + + # overwrite since special base model prefix is used + def test_save_load_from_base(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + base_class = FLAX_MODEL_MAPPING[config.__class__] + + for model_class in self.all_model_classes: + if model_class == base_class: + continue + + model = base_class(config) + base_params = flatten_dict(unfreeze(model.params)) + + # check that all base model weights are loaded correctly + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + head_model = model_class.from_pretrained(tmpdirname) + + base_param_from_head = flatten_dict(unfreeze(head_model.params)) + + for key in base_param_from_head.keys(): + max_diff = (base_params[key] - base_param_from_head[key]).sum().item() + self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + + # overwrite since special base model prefix is used + def test_save_load_to_base(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + base_class = FLAX_MODEL_MAPPING[config.__class__] + + for model_class in self.all_model_classes: + if model_class == base_class: + continue + + model = model_class(config) + base_params_from_head = flatten_dict(unfreeze(model.params)) + + # check that all base model weights are loaded correctly + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + base_model = base_class.from_pretrained(tmpdirname) + + base_params = flatten_dict(unfreeze(base_model.params)) + + for key in base_params_from_head.keys(): + max_diff = (base_params[key] - base_params_from_head[key]).sum().item() + self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + + # overwrite since special base model prefix is used + @is_pt_flax_cross_test + def test_save_load_from_base_pt(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + base_class = FLAX_MODEL_MAPPING[config.__class__] + + for model_class in self.all_model_classes: + if model_class == base_class: + continue + + model = base_class(config) + base_params = flatten_dict(unfreeze(model.params)) + + # convert Flax model to PyTorch model + pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning + pt_model = pt_model_class(config).eval() + pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params) + + # check that all base model weights are loaded correctly + with tempfile.TemporaryDirectory() as tmpdirname: + # save pt model + pt_model.save_pretrained(tmpdirname) + head_model = model_class.from_pretrained(tmpdirname, from_pt=True) + + base_param_from_head = flatten_dict(unfreeze(head_model.params)) + + for key in base_param_from_head.keys(): + max_diff = (base_params[key] - base_param_from_head[key]).sum().item() + self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + + # overwrite since special base model prefix is used + @is_pt_flax_cross_test + def test_save_load_to_base_pt(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + base_class = FLAX_MODEL_MAPPING[config.__class__] + + for model_class in self.all_model_classes: + if model_class == base_class: + continue + + model = model_class(config) + base_params_from_head = flatten_dict(unfreeze(model.params)) + + # convert Flax model to PyTorch model + pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning + pt_model = pt_model_class(config).eval() + pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params) + + # check that all base model weights are loaded correctly + with tempfile.TemporaryDirectory() as tmpdirname: + pt_model.save_pretrained(tmpdirname) + base_model = base_class.from_pretrained(tmpdirname, from_pt=True) + + base_params = flatten_dict(unfreeze(base_model.params)) + + for key in base_params_from_head.keys(): + max_diff = (base_params[key] - base_params_from_head[key]).sum().item() + self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + + # overwrite since special base model prefix is used + @is_pt_flax_cross_test + def test_save_load_bf16_to_base_pt(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + base_class = FLAX_MODEL_MAPPING[config.__class__] + + for model_class in self.all_model_classes: + if model_class == base_class: + continue + + model = model_class(config) + model.params = model.to_bf16(model.params) + base_params_from_head = flatten_dict(unfreeze(model.params)) + + # convert Flax model to PyTorch model + pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning + pt_model = pt_model_class(config).eval() + pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params) + + # check that all base model weights are loaded correctly + with tempfile.TemporaryDirectory() as tmpdirname: + pt_model.save_pretrained(tmpdirname) + base_model = base_class.from_pretrained(tmpdirname, from_pt=True) + + base_params = flatten_dict(unfreeze(base_model.params)) + + for key in base_params_from_head.keys(): + max_diff = (base_params[key] - base_params_from_head[key]).sum().item() + self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + + +class FlaxSwitchTransformersEncoderOnlyModelTester: + def __init__( + self, + parent, + vocab_size=99, + batch_size=13, + encoder_seq_length=7, + # For common tests + is_training=True, + use_attention_mask=True, + use_labels=True, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + d_ff=37, + relative_attention_num_buckets=8, + dropout_rate=0.1, + initializer_factor=0.002, + eos_token_id=1, + pad_token_id=0, + decoder_start_token_id=0, + scope=None, + ): + + self.parent = parent + self.batch_size = batch_size + self.encoder_seq_length = encoder_seq_length + # For common tests + self.seq_length = self.encoder_seq_length + self.is_training = is_training + self.use_attention_mask = use_attention_mask + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.d_ff = d_ff + self.relative_attention_num_buckets = relative_attention_num_buckets + self.dropout_rate = dropout_rate + self.initializer_factor = initializer_factor + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.decoder_start_token_id = decoder_start_token_id + self.scope = None + self.decoder_layers = 0 + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size) + + attention_mask = None + if self.use_attention_mask: + attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2) + + config = SwitchTransformersConfig( + vocab_size=self.vocab_size, + d_model=self.hidden_size, + d_ff=self.d_ff, + d_kv=self.hidden_size // self.num_attention_heads, + num_layers=self.num_hidden_layers, + num_decoder_layers=self.decoder_layers, + num_heads=self.num_attention_heads, + relative_attention_num_buckets=self.relative_attention_num_buckets, + dropout_rate=self.dropout_rate, + initializer_factor=self.initializer_factor, + eos_token_id=self.eos_token_id, + bos_token_id=self.pad_token_id, + pad_token_id=self.pad_token_id, + decoder_start_token_id=self.decoder_start_token_id, + is_encoder_decoder=False, + ) + + return ( + config, + input_ids, + attention_mask, + ) + + def create_and_check_model( + self, + config, + input_ids, + attention_mask, + ): + model = FlaxSwitchTransformersEncoderModel(config=config) + result = model( + input_ids=input_ids, + attention_mask=attention_mask, + ) + result = model(input_ids=input_ids) + encoder_output = result.last_hidden_state + + self.parent.assertEqual(encoder_output.shape, (self.batch_size, self.encoder_seq_length, self.hidden_size)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + attention_mask, + ) = config_and_inputs + + inputs_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + return config, inputs_dict + + +@require_flax +class FlaxSwitchTransformersEncoderOnlyModelTest(FlaxModelTesterMixin, unittest.TestCase): + + all_model_classes = (FlaxSwitchTransformersEncoderModel,) if is_flax_available() else () + is_encoder_decoder = False + + def setUp(self): + self.model_tester = FlaxSwitchTransformersEncoderOnlyModelTester(self) + self.config_tester = ConfigTester(self, config_class=SwitchTransformersConfig, d_model=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_model_v1_1(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + # check that gated gelu feed forward and different word embeddings work + config = config_and_inputs[0] + config.tie_word_embeddings = False + config.feed_forward_proj = "gated-gelu" + self.model_tester.create_and_check_model(config, *config_and_inputs[1:]) + + def test_encode(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + with self.subTest(model_class.__name__): + prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) + model = model_class(config) + + @jax.jit + def encode_jitted(input_ids, attention_mask=None, **kwargs): + return model(input_ids=input_ids, attention_mask=attention_mask) + + with self.subTest("JIT Enabled"): + jitted_outputs = encode_jitted(**prepared_inputs_dict).to_tuple() + + with self.subTest("JIT Disabled"): + with jax.disable_jit(): + outputs = encode_jitted(**prepared_inputs_dict).to_tuple() + + self.assertEqual(len(outputs), len(jitted_outputs)) + for jitted_output, output in zip(jitted_outputs, outputs): + self.assertEqual(jitted_output.shape, output.shape) + + # overwrite since special base model prefix is used + def test_save_load_from_base(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + base_class = FLAX_MODEL_MAPPING[config.__class__] + + for model_class in self.all_model_classes: + if model_class == base_class: + continue + + model = base_class(config) + base_params = flatten_dict(unfreeze(model.params)) + + # check that all base model weights are loaded correctly + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + head_model = model_class.from_pretrained(tmpdirname) + + base_param_from_head = flatten_dict(unfreeze(head_model.params)) + + for key in base_param_from_head.keys(): + max_diff = (base_params[key] - base_param_from_head[key]).sum().item() + self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + + # overwrite since special base model prefix is used + def test_save_load_to_base(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + base_class = FLAX_MODEL_MAPPING[config.__class__] + + for model_class in self.all_model_classes: + if model_class == base_class: + continue + + model = model_class(config) + base_params_from_head = flatten_dict(unfreeze(model.params)) + + # check that all base model weights are loaded correctly + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + base_model = base_class.from_pretrained(tmpdirname) + + base_params = flatten_dict(unfreeze(base_model.params)) + + for key in base_params_from_head.keys(): + max_diff = (base_params[key] - base_params_from_head[key]).sum().item() + self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + + # overwrite since special base model prefix is used + @is_pt_flax_cross_test + def test_save_load_from_base_pt(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + base_class = FLAX_MODEL_MAPPING[config.__class__] + + for model_class in self.all_model_classes: + if model_class == base_class: + continue + + model = base_class(config) + base_params = flatten_dict(unfreeze(model.params)) + + # convert Flax model to PyTorch model + pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning + pt_model = pt_model_class(config).eval() + pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params) + + # check that all base model weights are loaded correctly + with tempfile.TemporaryDirectory() as tmpdirname: + # save pt model + pt_model.save_pretrained(tmpdirname) + head_model = model_class.from_pretrained(tmpdirname, from_pt=True) + + base_param_from_head = flatten_dict(unfreeze(head_model.params)) + + for key in base_param_from_head.keys(): + max_diff = (base_params[key] - base_param_from_head[key]).sum().item() + self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + + # overwrite since special base model prefix is used + @is_pt_flax_cross_test + def test_save_load_to_base_pt(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + base_class = FLAX_MODEL_MAPPING[config.__class__] + + for model_class in self.all_model_classes: + if model_class == base_class: + continue + + model = model_class(config) + base_params_from_head = flatten_dict(unfreeze(model.params)) + + # convert Flax model to PyTorch model + pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning + pt_model = pt_model_class(config).eval() + pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params) + + # check that all base model weights are loaded correctly + with tempfile.TemporaryDirectory() as tmpdirname: + pt_model.save_pretrained(tmpdirname) + base_model = base_class.from_pretrained(tmpdirname, from_pt=True) + + base_params = flatten_dict(unfreeze(base_model.params)) + + for key in base_params_from_head.keys(): + max_diff = (base_params[key] - base_params_from_head[key]).sum().item() + self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + + # overwrite since special base model prefix is used + @is_pt_flax_cross_test + def test_save_load_bf16_to_base_pt(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + base_class = FLAX_MODEL_MAPPING[config.__class__] + + for model_class in self.all_model_classes: + if model_class == base_class: + continue + + model = model_class(config) + model.params = model.to_bf16(model.params) + base_params_from_head = flatten_dict(unfreeze(model.params)) + + # convert Flax model to PyTorch model + pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning + pt_model = pt_model_class(config).eval() + pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params) + + # check that all base model weights are loaded correctly + with tempfile.TemporaryDirectory() as tmpdirname: + pt_model.save_pretrained(tmpdirname) + base_model = base_class.from_pretrained(tmpdirname, from_pt=True) + + base_params = flatten_dict(unfreeze(base_model.params)) + + for key in base_params_from_head.keys(): + max_diff = (base_params[key] - base_params_from_head[key]).sum().item() + self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + + +@require_sentencepiece +@require_tokenizers +@require_flax +class FlaxSwitchTransformersModelIntegrationTests(unittest.TestCase): + @slow + def test_small_integration_test(self): + """ + For comparision run: + >>> import switchtransformers # pip install switchtransformers==0.7.1 + >>> from switchtransformers.data.sentencepiece_vocabulary import SentencePieceVocabulary + + >>> path_to_mtf_small_switchtransformers_checkpoint = '' + >>> path_to_mtf_small_spm_model_path = '' + >>> switchtransformers_model = switchtransformers.models.MtfModel(model_dir=path_to_mtf_small_switchtransformers_checkpoint, batch_size=1, tpu=None) + >>> vocab = SentencePieceVocabulary(path_to_mtf_small_spm_model_path, extra_ids=100) + >>> score = switchtransformers_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) + """ + + model = FlaxSwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switchtransformers-base") + tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switchtransformers-base") + + input_ids = tokenizer("Hello there", return_tensors="np").input_ids + labels = tokenizer("Hi I am", return_tensors="np").input_ids + + decoder_input_ids = shift_tokens_right(labels, model.config.pad_token_id, model.config.decoder_start_token_id) + + logits = model(input_ids, decoder_input_ids=decoder_input_ids).logits + + loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])).mean() + mtf_score = -(labels.shape[-1] * loss.item()) + + EXPECTED_SCORE = -19.0845 + self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4) + + @slow + def test_small_v1_1_integration_test(self): + """ + For comparision run: + >>> import switchtransformers # pip install switchtransformers==0.7.1 + >>> from switchtransformers.data.sentencepiece_vocabulary import SentencePieceVocabulary + + >>> path_to_mtf_small_switchtransformers_v1_1_checkpoint = '' + >>> path_to_mtf_small_spm_model_path = '' + >>> switchtransformers_model = switchtransformers.models.MtfModel(model_dir=path_to_mtf_small_switchtransformers_v1_1_checkpoint, batch_size=1, tpu=None) + >>> vocab = SentencePieceVocabulary(path_to_mtf_small_spm_model_path, extra_ids=100) + >>> score = switchtransformers_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) + """ + + model = FlaxSwitchTransformersForConditionalGeneration.from_pretrained("google/switchtransformers-v1_1-small") + tokenizer = SwitchTransformersTokenizer.from_pretrained("google/switchtransformers-v1_1-small") + + input_ids = tokenizer("Hello there", return_tensors="np").input_ids + labels = tokenizer("Hi I am", return_tensors="np").input_ids + + decoder_input_ids = shift_tokens_right(labels, model.config.pad_token_id, model.config.decoder_start_token_id) + + logits = model(input_ids, decoder_input_ids=decoder_input_ids).logits + loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])).mean() + + mtf_score = -(labels.shape[-1] * loss.item()) + + EXPECTED_SCORE = -59.0293 + self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4) + + @slow + def test_small_byswitchtransformers_integration_test(self): + """ + For comparision run: + >>> import switchtransformers # pip install switchtransformers==0.9.1 + + >>> path_to_byswitchtransformers_small_checkpoint = '' + >>> switchtransformers_model = switchtransformers.models.MtfModel(model_dir=path_to_tf_checkpoint, batch_size=1, tpu=None) + >>> vocab = switchtransformers.data.ByteVocabulary() + >>> score = switchtransformers_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) + """ + + model = FlaxSwitchTransformersForConditionalGeneration.from_pretrained("google/byybelkada/switchtransformers-base") + tokenizer = BySwitchTransformersTokenizer.from_pretrained("google/byybelkada/switchtransformers-base") + + input_ids = tokenizer("Hello there", return_tensors="np").input_ids + labels = tokenizer("Hi I am", return_tensors="np").input_ids + + decoder_input_ids = shift_tokens_right(labels, model.config.pad_token_id, model.config.decoder_start_token_id) + + logits = model(input_ids, decoder_input_ids=decoder_input_ids).logits + loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])).mean() + + mtf_score = -(labels.shape[-1] * loss.item()) + + EXPECTED_SCORE = -60.7397 + self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4) + + @slow + def test_small_generation(self): + model = FlaxSwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switchtransformers-base") + model.config.max_length = 8 + model.config.num_beams = 1 + model.config.do_sample = False + tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switchtransformers-base") + + input_ids = tokenizer("summarize: Hello there", return_tensors="np").input_ids + + sequences = model.generate(input_ids).sequences + + output_str = tokenizer.batch_decode(sequences, skip_special_tokens=True)[0] + self.assertTrue(output_str == "Hello there!") + + @slow + def test_summarization(self): + model = FlaxSwitchTransformersForConditionalGeneration.from_pretrained("switchtransformers-base") + tok = SwitchTransformersTokenizer.from_pretrained("switchtransformers-base") + + FRANCE_ARTICLE = ( # @noqa + "Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings" + " Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane." + ' Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation."' + ' He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s' + " comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video" + " showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French" + " Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a" + " phone at the wreckage site. The two publications described the supposed video, but did not post it on" + " their websites. The publications said that they watched the video, which was found by a source close to" + " the investigation. \"One can hear cries of 'My God' in several languages,\" Paris Match reported." + ' "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the' + " cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the" + ' screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt,' + " editor-in-chief of Bild online. An official with France's accident investigation agency, the BEA, said" + " the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie spokesman" + " in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the" + ' reports were "completely wrong" and "unwarranted." Cell phones have been collected at the site, he said,' + ' but that they "hadn\'t been exploited yet." Menichini said he believed the cell phones would need to be' + " sent to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by" + " specialized technicians working hand-in-hand with investigators. But none of the cell phones found so" + " far have been sent to the institute, Menichini said. Asked whether staff involved in the search could" + ' have leaked a memory card to the media, Menichini answered with a categorical "no." Reichelt told "Erin' + ' Burnett: Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match' + ' are "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered' + ' cell phones from the crash site after Bild and Paris Match published their reports. "That is something' + " we did not know before. ... Overall we can say many things of the investigation weren't revealed by the" + ' investigation at the beginning," he said. What was mental state of Germanwings co-pilot? German airline' + " Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled depression years before he took the" + " controls of Germanwings Flight 9525, which he's accused of deliberately crashing last week in the" + ' French Alps. Lubitz told his Lufthansa flight training school in 2009 that he had a "previous episode of' + ' severe depression," the airline said Tuesday. Email correspondence between Lubitz and the school' + " discovered in an internal investigation, Lufthansa said, included medical documents he submitted in" + " connection with resuming his flight training. The announcement indicates that Lufthansa, the parent" + " company of Germanwings, knew of Lubitz's battle with depression, allowed him to continue training and" + " ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously said Lubitz was 100%" + ' fit to fly, described its statement Tuesday as a "swift and seamless clarification" and said it was' + " sharing the information and documents -- including training and medical records -- with public" + " prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the" + " past week to recover human remains and plane debris scattered across a steep mountainside. He saw the" + " crisis center set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash" + " site, where grieving families have left flowers at a simple stone memorial. Menichini told CNN late" + " Tuesday that no visible human remains were left at the site but recovery teams would keep searching." + " French President Francois Hollande, speaking Tuesday, said that it should be possible to identify all" + " the victims using DNA analysis by the end of the week, sooner than authorities had previously suggested." + " In the meantime, the recovery of the victims' personal belongings will start Wednesday, Menichini said." + " Among those personal belongings could be more cell phones belonging to the 144 passengers and six crew" + " on board. Check out the latest from our correspondents . The details about Lubitz's correspondence with" + " the flight school during his training were among several developments as investigators continued to" + " delve into what caused the crash and Lubitz's possible motive for downing the jet. A Lufthansa" + " spokesperson told CNN on Tuesday that Lubitz had a valid medical certificate, had passed all his" + ' examinations and "held all the licenses required." Earlier, a spokesman for the prosecutor\'s office in' + " Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz suffered from suicidal tendencies at" + " some point before his aviation career and underwent psychotherapy before he got his pilot's license." + " Kumpa emphasized there's no evidence suggesting Lubitz was suicidal or acting aggressively before the" + " crash. Investigators are looking into whether Lubitz feared his medical condition would cause him to" + " lose his pilot's license, a European government official briefed on the investigation told CNN on" + ' Tuesday. While flying was "a big part of his life," the source said, it\'s only one theory being' + " considered. Another source, a law enforcement official briefed on the investigation, also told CNN that" + " authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would" + " not be allowed to fly because of his medical problems. Lubitz's girlfriend told investigators he had" + " seen an eye doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded" + " he had psychological issues, the European government official said. But no matter what details emerge" + " about his previous mental health struggles, there's more to the story, said Brian Russell, a forensic" + ' psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the fact' + " that maybe they weren't going to keep doing their job and they're upset about that and so they're" + ' suicidal," he said. "But there is no mental illness that explains why somebody then feels entitled to' + " also take that rage and turn it outward on 149 other people who had nothing to do with the person's" + ' problems." Germanwings crash compensation: What we know . Who was the captain of Germanwings Flight' + " 9525? CNN's Margot Haddad reported from Marseille and Pamela Brown from Dusseldorf, while Laura" + " Smith-Spark wrote from London. CNN's Frederik Pleitgen, Pamela Boykoff, Antonia Mortensen, Sandrine" + " Amiel and Anna-Maja Rappard contributed to this report." + ) + SHORTER_ARTICLE = ( + "(CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on" + " Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The" + " formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based." + " The Palestinians signed the ICC's founding Rome Statute in January, when they also accepted its" + ' jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East' + ' Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the' + " situation in Palestinian territories, paving the way for possible war crimes investigations against" + " Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and" + " the United States, neither of which is an ICC member, opposed the Palestinians' efforts to join the" + " body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday's ceremony, said it was a" + ' move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the' + ' world is also a step closer to ending a long era of impunity and injustice," he said, according to an' + ' ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge' + " Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the" + ' Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine' + " acquires all the rights as well as responsibilities that come with being a State Party to the Statute." + ' These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights' + ' Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should' + " immediately end their pressure, and countries that support universal acceptance of the court's treaty" + ' should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the' + " group. \"What's objectionable is the attempts to undermine international justice, not Palestine's" + ' decision to join a treaty to which over 100 countries around the world are members." In January, when' + " the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an" + ' outrage, saying the court was overstepping its boundaries. The United States also said it "strongly"' + " disagreed with the court's decision. \"As we have said repeatedly, we do not believe that Palestine is a" + ' state and therefore we do not believe that it is eligible to join the ICC," the State Department said in' + ' a statement. It urged the warring sides to resolve their differences through direct negotiations. "We' + ' will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace,"' + " it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the" + ' territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the' + " court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou" + ' Bensouda said her office would "conduct its analysis in full independence and impartiality." The war' + " between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry" + " will include alleged war crimes committed since June. The International Criminal Court was set up in" + " 2002 to prosecute genocide, crimes against humanity and war crimes. CNN's Vasco Cotovio, Kareem Khadder" + " and Faith Karimi contributed to this report." + ) + IRAN_ARTICLE = ( + "(CNN)The United States and its negotiating partners reached a very strong framework agreement with Iran" + " in Lausanne, Switzerland, on Thursday that limits Iran's nuclear program in such a way as to effectively" + " block it from building a nuclear weapon. Expect pushback anyway, if the recent past is any harbinger." + " Just last month, in an attempt to head off such an agreement, House Speaker John Boehner invited Israeli" + " Prime Minister Benjamin Netanyahu to preemptively blast it before Congress, and 47 senators sent a" + " letter to the Iranian leadership warning them away from a deal. The debate that has already begun since" + " the announcement of the new framework will likely result in more heat than light. It will not be helped" + " by the gathering swirl of dubious assumptions and doubtful assertions. Let us address some of these: ." + " The most misleading assertion, despite universal rejection by experts, is that the negotiations'" + " objective at the outset was the total elimination of any nuclear program in Iran. That is the position" + " of Netanyahu and his acolytes in the U.S. Congress. But that is not and never was the objective. If it" + " had been, there would have been no Iranian team at the negotiating table. Rather, the objective has" + " always been to structure an agreement or series of agreements so that Iran could not covertly develop a" + " nuclear arsenal before the United States and its allies could respond. The new framework has exceeded" + " expectations in achieving that goal. It would reduce Iran's low-enriched uranium stockpile, cut by" + " two-thirds its number of installed centrifuges and implement a rigorous inspection regime. Another" + " dubious assumption of opponents is that the Iranian nuclear program is a covert weapons program. Despite" + " sharp accusations by some in the United States and its allies, Iran denies having such a program, and" + " U.S. intelligence contends that Iran has not yet made the decision to build a nuclear weapon. Iran's" + " continued cooperation with International Atomic Energy Agency inspections is further evidence on this" + " point, and we'll know even more about Iran's program in the coming months and years because of the deal." + " In fact, the inspections provisions that are part of this agreement are designed to protect against any" + " covert action by the Iranians. What's more, the rhetoric of some members of Congress has implied that" + " the negotiations have been between only the United States and Iran (i.e., the 47 senators' letter" + " warning that a deal might be killed by Congress or a future president). This of course is not the case." + " The talks were between Iran and the five permanent members of the U.N. Security Council (United States," + " United Kingdom, France, China and Russia) plus Germany, dubbed the P5+1. While the United States has" + " played a leading role in the effort, it negotiated the terms alongside its partners. If the agreement" + " reached by the P5+1 is rejected by Congress, it could result in an unraveling of the sanctions on Iran" + " and threaten NATO cohesion in other areas. Another questionable assertion is that this agreement" + " contains a sunset clause, after which Iran will be free to do as it pleases. Again, this is not the" + " case. Some of the restrictions on Iran's nuclear activities, such as uranium enrichment, will be eased" + " or eliminated over time, as long as 15 years. But most importantly, the framework agreement includes" + " Iran's ratification of the Additional Protocol, which allows IAEA inspectors expanded access to nuclear" + " sites both declared and nondeclared. This provision will be permanent. It does not sunset. Thus, going" + " forward, if Iran decides to enrich uranium to weapons-grade levels, monitors will be able to detect such" + " a move in a matter of days and alert the U.N. Security Council. Many in Congress have said that the" + ' agreement should be a formal treaty requiring the Senate to "advise and consent." But the issue is not' + " suited for a treaty. Treaties impose equivalent obligations on all signatories. For example, the New" + " START treaty limits Russia and the United States to 1,550 deployed strategic warheads. But any agreement" + " with Iran will not be so balanced. The restrictions and obligations in the final framework agreement" + " will be imposed almost exclusively on Iran. The P5+1 are obligated only to ease and eventually remove" + " most but not all economic sanctions, which were imposed as leverage to gain this final deal. Finally" + " some insist that any agreement must address Iranian missile programs, human rights violations or support" + " for Hamas or Hezbollah. As important as these issues are, and they must indeed be addressed, they are" + " unrelated to the most important aim of a nuclear deal: preventing a nuclear Iran. To include them in" + " the negotiations would be a poison pill. This agreement should be judged on its merits and on how it" + " affects the security of our negotiating partners and allies, including Israel. Those judgments should be" + " fact-based, not based on questionable assertions or dubious assumptions." + ) + ARTICLE_SUBWAY = ( + "New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A" + " year later, she got married again in Westchester County, but to a different man and without divorcing" + " her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos" + ' declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married' + " once more, this time in the Bronx. In an application for a marriage license, she stated it was her" + ' "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false' + ' instrument for filing in the first degree," referring to her false statements on the 2010 marriage' + " license application, according to court documents. Prosecutors said the marriages were part of an" + " immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to" + " her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was" + " arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New" + " York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total," + " Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All" + " occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be" + " married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors" + " said the immigration scam involved some of her husbands, who filed for permanent residence status" + " shortly after the marriages. Any divorces happened only after such filings were approved. It was" + " unclear whether any of the men will be prosecuted. The case was referred to the Bronx District" + " Attorney's Office by Immigration and Customs Enforcement and the Department of Homeland Security's" + ' Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt,' + " Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his" + " native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces" + " up to four years in prison. Her next court appearance is scheduled for May 18." + ) + + expected_summaries = [ + 'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a' + " cell phone video of the final seconds . \"one can hear cries of 'My God' in several languages,\" one" + " magazine says . all 150 on board were killed when germanwings flight 9525 crashed .", + "the formal accession was marked by a ceremony at The Hague, in the Netherlands . the ICC opened a" + " preliminary examination into the situation in the occupied Palestinian territory . as members of the" + " court, Palestinians may be subject to counter-charges as well .", + "the u.s. and its negotiating partners reached a very strong framework agreement with Iran . aaron miller:" + " the debate that has already begun since the announcement of the new framework will likely result in more" + " heat than light . he says the new framework would reduce Iran's low-enriched uranium stockpile and cut" + " centrifuges . miller: if it had been, there would have been no Iranian team at the table .", + "prosecutors say the marriages were part of an immigration scam . if convicted, barrientos faces two" + ' criminal counts of "offering a false instrument for filing in the first degree" she has been married 10' + " times, with nine of her marriages occurring between 1999 and 2002 .", + ] + + dct = tok( + ["summarize: " + x for x in [FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY]], + padding="max_length", + truncation=True, + return_tensors="np", + ) + self.assertEqual(512, dct["input_ids"].shape[1]) + + hypotheses_batch = model.generate( + **dct, + num_beams=4, + length_penalty=2.0, + max_length=142, + min_length=56, + do_sample=False, + early_stopping=True, + ).sequences + + decoded = tok.batch_decode(hypotheses_batch, skip_special_tokens=True, clean_up_tokenization_spaces=False) + self.assertListEqual( + expected_summaries, + decoded, + ) diff --git a/tests/models/switchtransformers/test_modeling_switchtransformers.py b/tests/models/switchtransformers/test_modeling_switchtransformers.py new file mode 100644 index 0000000000000..31447a1d74fe6 --- /dev/null +++ b/tests/models/switchtransformers/test_modeling_switchtransformers.py @@ -0,0 +1,1254 @@ +# coding=utf-8 +# Copyright 2022 Google SwitchTransformers Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import copy +import tempfile +import unittest + +from transformers import SwitchTransformersConfig, is_torch_available +from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device +from transformers.utils import cached_property + +from ...generation.test_generation_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, ids_tensor + + +if is_torch_available(): + import torch + + from transformers import BySwitchTransformersTokenizer, SwitchTransformersEncoderModel, SwitchTransformersForConditionalGeneration, SwitchTransformersModel, SwitchTransformersTokenizer + from transformers.models.switchtransformers.modeling_switchtransformers import SWITCHTRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST + + +class SwitchTransformersModelTester: + def __init__( + self, + parent, + vocab_size=99, + batch_size=13, + encoder_seq_length=7, + decoder_seq_length=9, + # For common tests + is_training=True, + use_attention_mask=True, + use_labels=True, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + d_ff=37, + relative_attention_num_buckets=8, + dropout_rate=0.1, + initializer_factor=0.002, + eos_token_id=1, + pad_token_id=0, + decoder_start_token_id=0, + scope=None, + decoder_layers=None, + ): + + self.parent = parent + self.batch_size = batch_size + self.encoder_seq_length = encoder_seq_length + self.decoder_seq_length = decoder_seq_length + # For common tests + self.seq_length = self.decoder_seq_length + self.is_training = is_training + self.use_attention_mask = use_attention_mask + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.d_ff = d_ff + self.relative_attention_num_buckets = relative_attention_num_buckets + self.dropout_rate = dropout_rate + self.initializer_factor = initializer_factor + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.decoder_start_token_id = decoder_start_token_id + self.scope = None + self.decoder_layers = decoder_layers + + def get_large_model_config(self): + return SwitchTransformersConfig.from_pretrained("switchtransformers-base") + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size) + decoder_input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) + + attention_mask = None + decoder_attention_mask = None + if self.use_attention_mask: + attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2) + decoder_attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2) + + lm_labels = None + if self.use_labels: + lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) + + config = self.get_config() + + return ( + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ) + + def get_pipeline_config(self): + return SwitchTransformersConfig( + vocab_size=166, # switchtransformers forces 100 extra tokens + d_model=self.hidden_size, + d_ff=self.d_ff, + d_kv=self.hidden_size // self.num_attention_heads, + num_layers=self.num_hidden_layers, + num_decoder_layers=self.decoder_layers, + num_heads=self.num_attention_heads, + relative_attention_num_buckets=self.relative_attention_num_buckets, + dropout_rate=self.dropout_rate, + initializer_factor=self.initializer_factor, + eos_token_id=self.eos_token_id, + bos_token_id=self.pad_token_id, + pad_token_id=self.pad_token_id, + decoder_start_token_id=self.decoder_start_token_id, + ) + + def get_config(self): + return SwitchTransformersConfig( + vocab_size=self.vocab_size, + d_model=self.hidden_size, + d_ff=self.d_ff, + d_kv=self.hidden_size // self.num_attention_heads, + num_layers=self.num_hidden_layers, + num_decoder_layers=self.decoder_layers, + num_heads=self.num_attention_heads, + relative_attention_num_buckets=self.relative_attention_num_buckets, + dropout_rate=self.dropout_rate, + initializer_factor=self.initializer_factor, + eos_token_id=self.eos_token_id, + bos_token_id=self.pad_token_id, + pad_token_id=self.pad_token_id, + decoder_start_token_id=self.decoder_start_token_id, + ) + + def check_prepare_lm_labels_via_shift_left( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ): + model = SwitchTransformersModel(config=config) + model.to(torch_device) + model.eval() + + # make sure that lm_labels are correctly padded from the right + lm_labels.masked_fill_((lm_labels == self.decoder_start_token_id), self.eos_token_id) + + # add casaul pad token mask + triangular_mask = torch.tril(lm_labels.new_ones(lm_labels.shape)).logical_not() + lm_labels.masked_fill_(triangular_mask, self.pad_token_id) + decoder_input_ids = model._shift_right(lm_labels) + + for i, (decoder_input_ids_slice, lm_labels_slice) in enumerate(zip(decoder_input_ids, lm_labels)): + # first item + self.parent.assertEqual(decoder_input_ids_slice[0].item(), self.decoder_start_token_id) + if i < decoder_input_ids_slice.shape[-1]: + if i < decoder_input_ids.shape[-1] - 1: + # items before diagonal + self.parent.assertListEqual( + decoder_input_ids_slice[1 : i + 1].tolist(), lm_labels_slice[:i].tolist() + ) + # pad items after diagonal + if i < decoder_input_ids.shape[-1] - 2: + self.parent.assertListEqual( + decoder_input_ids_slice[i + 2 :].tolist(), lm_labels_slice[i + 1 : -1].tolist() + ) + else: + # all items after square + self.parent.assertListEqual(decoder_input_ids_slice[1:].tolist(), lm_labels_slice[:-1].tolist()) + + def create_and_check_model( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ): + model = SwitchTransformersModel(config=config) + model.to(torch_device) + model.eval() + result = model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + ) + result = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + decoder_output = result.last_hidden_state + decoder_past = result.past_key_values + encoder_output = result.encoder_last_hidden_state + + self.parent.assertEqual(encoder_output.size(), (self.batch_size, self.encoder_seq_length, self.hidden_size)) + self.parent.assertEqual(decoder_output.size(), (self.batch_size, self.decoder_seq_length, self.hidden_size)) + # There should be `num_layers` key value embeddings stored in decoder_past + self.parent.assertEqual(len(decoder_past), config.num_layers) + # There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past tuple + self.parent.assertEqual(len(decoder_past[0]), 4) + + def create_and_check_with_lm_head( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ): + model = SwitchTransformersForConditionalGeneration(config=config).to(torch_device).eval() + outputs = model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + labels=lm_labels, + ) + self.parent.assertEqual(len(outputs), 4) + self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, self.decoder_seq_length, self.vocab_size)) + self.parent.assertEqual(outputs["loss"].size(), ()) + + def create_and_check_decoder_model_past( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ): + model = SwitchTransformersModel(config=config).get_decoder().to(torch_device).eval() + # first forward pass + outputs = model(input_ids, use_cache=True) + outputs_use_cache_conf = model(input_ids) + outputs_no_past = model(input_ids, use_cache=False) + + self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf)) + self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1) + + output, past_key_values = outputs.to_tuple() + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + + output_from_no_past = model(next_input_ids)["last_hidden_state"] + output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def create_and_check_decoder_model_attention_mask_past( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ): + model = SwitchTransformersModel(config=config).get_decoder() + model.to(torch_device) + model.eval() + + # create attention mask + attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device) + + half_seq_length = input_ids.shape[-1] // 2 + attn_mask[:, half_seq_length:] = 0 + + # first forward pass + output, past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True).to_tuple() + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # change a random masked slice from input_ids + random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1 + random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1) + input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens + + # append to next input_ids and attn_mask + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + attn_mask = torch.cat( + [attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)], + dim=1, + ) + + # get two different outputs + output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"] + output_from_past = model(next_tokens, past_key_values=past_key_values, attention_mask=attn_mask)[ + "last_hidden_state" + ] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def create_and_check_decoder_model_past_large_inputs( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ): + model = SwitchTransformersModel(config=config).get_decoder().to(torch_device).eval() + # first forward pass + outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) + + output, past_key_values = outputs.to_tuple() + + # create hypothetical multiple next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + next_mask = ids_tensor((self.batch_size, 3), vocab_size=2) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + next_attention_mask = torch.cat([attention_mask, next_mask], dim=-1) + + output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"] + output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[ + "last_hidden_state" + ] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() + + self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def create_and_check_generate_with_past_key_values( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ): + model = SwitchTransformersForConditionalGeneration(config=config).to(torch_device).eval() + torch.manual_seed(0) + output_without_past_cache = model.generate( + input_ids[:1], num_beams=2, max_length=5, do_sample=True, use_cache=False + ) + torch.manual_seed(0) + output_with_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=5, do_sample=True) + self.parent.assertTrue(torch.all(output_with_past_cache == output_without_past_cache)) + + def create_and_check_model_fp16_forward( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ): + model = SwitchTransformersModel(config=config).to(torch_device).half().eval() + output = model(input_ids, decoder_input_ids=input_ids, attention_mask=attention_mask)["last_hidden_state"] + self.parent.assertFalse(torch.isnan(output).any().item()) + + def create_and_check_encoder_decoder_shared_weights( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ): + for model_class in [SwitchTransformersModel, SwitchTransformersForConditionalGeneration]: + torch.manual_seed(0) + model = model_class(config=config).to(torch_device).eval() + # load state dict copies weights but does not tie them + model.encoder.load_state_dict(model.decoder.state_dict(), strict=False) + + torch.manual_seed(0) + tied_config = copy.deepcopy(config) + tied_config.tie_encoder_decoder = True + tied_model = model_class(config=tied_config).to(torch_device).eval() + + model_result = model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + ) + + tied_model_result = tied_model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + ) + + # check that models has less parameters + self.parent.assertLess( + sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters()) + ) + random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item() + + # check that outputs are equal + self.parent.assertTrue( + torch.allclose( + model_result[0][0, :, random_slice_idx], tied_model_result[0][0, :, random_slice_idx], atol=1e-4 + ) + ) + + # check that outputs after saving and loading are equal + with tempfile.TemporaryDirectory() as tmpdirname: + tied_model.save_pretrained(tmpdirname) + tied_model = model_class.from_pretrained(tmpdirname) + tied_model.to(torch_device) + tied_model.eval() + + # check that models has less parameters + self.parent.assertLess( + sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters()) + ) + random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item() + + tied_model_result = tied_model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + ) + + # check that outputs are equal + self.parent.assertTrue( + torch.allclose( + model_result[0][0, :, random_slice_idx], + tied_model_result[0][0, :, random_slice_idx], + atol=1e-4, + ) + ) + + def check_resize_embeddings_switchtransformers_v1_1( + self, + config, + ): + prev_vocab_size = config.vocab_size + + config.tie_word_embeddings = False + model = SwitchTransformersForConditionalGeneration(config=config).to(torch_device).eval() + model.resize_token_embeddings(prev_vocab_size - 10) + + self.parent.assertEqual(model.get_input_embeddings().weight.shape[0], prev_vocab_size - 10) + self.parent.assertEqual(model.get_output_embeddings().weight.shape[0], prev_vocab_size - 10) + self.parent.assertEqual(model.config.vocab_size, prev_vocab_size - 10) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ) = config_and_inputs + + inputs_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": decoder_attention_mask, + "use_cache": False, + } + return config, inputs_dict + + +@require_torch +class SwitchTransformersModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + + all_model_classes = (SwitchTransformersModel, SwitchTransformersForConditionalGeneration) if is_torch_available() else () + all_generative_model_classes = (SwitchTransformersForConditionalGeneration,) if is_torch_available() else () + all_parallelizable_model_classes = (SwitchTransformersModel, SwitchTransformersForConditionalGeneration) if is_torch_available() else () + fx_compatible = False + test_pruning = False + test_resize_embeddings = True + test_model_parallel = True + is_encoder_decoder = True + # The small SWITCHTRANSFORMERS model needs higher percentages for CPU/MP tests + model_split_percents = [0.8, 0.9] + + def setUp(self): + self.model_tester = SwitchTransformersModelTester(self) + self.config_tester = ConfigTester(self, config_class=SwitchTransformersConfig, d_model=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_shift_right(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_prepare_lm_labels_via_shift_left(*config_and_inputs) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_model_v1_1(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + # check that gated gelu feed forward and different word embeddings work + config = config_and_inputs[0] + config.tie_word_embeddings = False + config.feed_forward_proj = "gated-gelu" + self.model_tester.create_and_check_model(config, *config_and_inputs[1:]) + + def test_config_and_model_silu_gated(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + config = config_and_inputs[0] + config.feed_forward_proj = "gated-silu" + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_with_lm_head(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_with_lm_head(*config_and_inputs) + + def test_decoder_model_past(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_past(*config_and_inputs) + + def test_decoder_model_past_with_attn_mask(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs) + + def test_decoder_model_past_with_3d_attn_mask(self): + ( + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ) = self.model_tester.prepare_config_and_inputs() + + attention_mask = ids_tensor( + [self.model_tester.batch_size, self.model_tester.encoder_seq_length, self.model_tester.encoder_seq_length], + vocab_size=2, + ) + decoder_attention_mask = ids_tensor( + [self.model_tester.batch_size, self.model_tester.decoder_seq_length, self.model_tester.decoder_seq_length], + vocab_size=2, + ) + + self.model_tester.create_and_check_decoder_model_attention_mask_past( + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ) + + def test_decoder_model_past_with_large_inputs(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) + + def test_generate_with_past_key_values(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_generate_with_past_key_values(*config_and_inputs) + + def test_encoder_decoder_shared_weights(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_encoder_decoder_shared_weights(*config_and_inputs) + + @unittest.skipIf(torch_device == "cpu", "Cant do half precision") + def test_model_fp16_forward(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs) + + def test_v1_1_resize_embeddings(self): + config = self.model_tester.prepare_config_and_inputs()[0] + self.model_tester.check_resize_embeddings_switchtransformers_v1_1(config) + + @slow + def test_model_from_pretrained(self): + for model_name in SWITCHTRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + model = SwitchTransformersModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + @unittest.skip("Test has a segmentation fault on torch 1.8.0") + def test_export_to_onnx(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + model = SwitchTransformersModel(config_and_inputs[0]).to(torch_device) + with tempfile.TemporaryDirectory() as tmpdirname: + torch.onnx.export( + model, + (config_and_inputs[1], config_and_inputs[3], config_and_inputs[2]), + f"{tmpdirname}/switchtransformers_test.onnx", + export_params=True, + opset_version=9, + input_names=["input_ids", "decoder_input_ids"], + ) + + def test_generate_with_head_masking(self): + attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"] + config_and_inputs = self.model_tester.prepare_config_and_inputs() + config = config_and_inputs[0] + max_length = config_and_inputs[1].shape[-1] + 3 + model = SwitchTransformersForConditionalGeneration(config).eval() + model.to(torch_device) + + head_masking = { + "head_mask": torch.zeros(config.num_layers, config.num_heads, device=torch_device), + "decoder_head_mask": torch.zeros(config.num_decoder_layers, config.num_heads, device=torch_device), + "cross_attn_head_mask": torch.zeros(config.num_decoder_layers, config.num_heads, device=torch_device), + } + + for attn_name, (name, mask) in zip(attention_names, head_masking.items()): + head_masks = {name: mask} + # Explicitly pass decoder_head_mask as it is required from SWITCHTRANSFORMERS model when head_mask specified + if name == "head_mask": + head_masks["decoder_head_mask"] = torch.ones( + config.num_decoder_layers, config.num_heads, device=torch_device + ) + + out = model.generate( + config_and_inputs[1], + num_beams=1, + max_length=max_length, + output_attentions=True, + return_dict_in_generate=True, + **head_masks, + ) + # We check the state of decoder_attentions and cross_attentions just from the last step + attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1] + self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0) + + @unittest.skip("Does not work on the tiny model as we keep hitting edge cases.") + def test_disk_offload(self): + pass + + +class SwitchTransformersEncoderOnlyModelTester: + def __init__( + self, + parent, + vocab_size=99, + batch_size=13, + encoder_seq_length=7, + # For common tests + use_attention_mask=True, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + d_ff=37, + relative_attention_num_buckets=8, + is_training=False, + dropout_rate=0.1, + initializer_factor=0.002, + is_encoder_decoder=False, + eos_token_id=1, + pad_token_id=0, + scope=None, + ): + + self.parent = parent + self.batch_size = batch_size + self.encoder_seq_length = encoder_seq_length + # For common tests + self.seq_length = self.encoder_seq_length + self.use_attention_mask = use_attention_mask + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.d_ff = d_ff + self.relative_attention_num_buckets = relative_attention_num_buckets + self.dropout_rate = dropout_rate + self.initializer_factor = initializer_factor + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.is_encoder_decoder = is_encoder_decoder + self.scope = None + self.is_training = is_training + + def get_large_model_config(self): + return SwitchTransformersConfig.from_pretrained("switchtransformers-base") + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size) + + attention_mask = None + if self.use_attention_mask: + attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2) + + config = SwitchTransformersConfig( + vocab_size=self.vocab_size, + d_model=self.hidden_size, + d_ff=self.d_ff, + d_kv=self.hidden_size // self.num_attention_heads, + num_layers=self.num_hidden_layers, + num_heads=self.num_attention_heads, + relative_attention_num_buckets=self.relative_attention_num_buckets, + dropout_rate=self.dropout_rate, + initializer_factor=self.initializer_factor, + eos_token_id=self.eos_token_id, + bos_token_id=self.pad_token_id, + pad_token_id=self.pad_token_id, + is_encoder_decoder=self.is_encoder_decoder, + ) + + return ( + config, + input_ids, + attention_mask, + ) + + def create_and_check_model( + self, + config, + input_ids, + attention_mask, + ): + model = SwitchTransformersEncoderModel(config=config) + model.to(torch_device) + model.eval() + result = model( + input_ids=input_ids, + attention_mask=attention_mask, + ) + result = model(input_ids=input_ids) + encoder_output = result.last_hidden_state + + self.parent.assertEqual(encoder_output.size(), (self.batch_size, self.encoder_seq_length, self.hidden_size)) + + def create_and_check_model_fp16_forward( + self, + config, + input_ids, + attention_mask, + ): + model = SwitchTransformersEncoderModel(config=config).to(torch_device).half().eval() + output = model(input_ids, attention_mask=attention_mask)["last_hidden_state"] + self.parent.assertFalse(torch.isnan(output).any().item()) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + attention_mask, + ) = config_and_inputs + + inputs_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + return config, inputs_dict + + +class SwitchTransformersEncoderOnlyModelTest(ModelTesterMixin, unittest.TestCase): + all_model_classes = (SwitchTransformersEncoderModel,) if is_torch_available() else () + test_pruning = False + test_resize_embeddings = False + test_model_parallel = True + all_parallelizable_model_classes = (SwitchTransformersEncoderModel,) if is_torch_available() else () + + def setUp(self): + self.model_tester = SwitchTransformersEncoderOnlyModelTester(self) + self.config_tester = ConfigTester(self, config_class=SwitchTransformersConfig, d_model=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + @unittest.skipIf(torch_device == "cpu", "Cant do half precision") + def test_model_fp16_forward(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs) + + +def use_task_specific_params(model, task): + model.config.update(model.config.task_specific_params[task]) + + +@require_torch +@require_sentencepiece +@require_tokenizers +class SwitchTransformersModelIntegrationTests(unittest.TestCase): + @cached_property + def model(self): + return SwitchTransformersForConditionalGeneration.from_pretrained("switchtransformers-base").to(torch_device) + + @cached_property + def tokenizer(self): + return SwitchTransformersTokenizer.from_pretrained("switchtransformers-base") + + @slow + def test_small_generation(self): + model = SwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switchtransformers-base").to(torch_device) + model.config.max_length = 8 + model.config.num_beams = 1 + model.config.do_sample = False + tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switchtransformers-base") + + input_ids = tokenizer("summarize: Hello there", return_tensors="pt").input_ids.to(torch_device) + + sequences = model.generate(input_ids) + + output_str = tokenizer.batch_decode(sequences, skip_special_tokens=True)[0] + self.assertTrue(output_str == "Hello there!") + + @slow + def test_small_integration_test(self): + """ + For comparision run: + >>> import switchtransformers # pip install switchtransformers==0.7.1 + >>> from switchtransformers.data.sentencepiece_vocabulary import SentencePieceVocabulary + + >>> path_to_mtf_small_switchtransformers_checkpoint = '' + >>> path_to_mtf_small_spm_model_path = '' + >>> switchtransformers_model = switchtransformers.models.MtfModel(model_dir=path_to_mtf_small_switchtransformers_checkpoint, batch_size=1, tpu=None) + >>> vocab = SentencePieceVocabulary(path_to_mtf_small_spm_model_path, extra_ids=100) + >>> score = switchtransformers_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) + """ + + model = SwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switchtransformers-base").to(torch_device) + tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switchtransformers-base") + + input_ids = tokenizer("Hello there", return_tensors="pt").input_ids + labels = tokenizer("Hi I am", return_tensors="pt").input_ids + + loss = model(input_ids.to(torch_device), labels=labels.to(torch_device)).loss + mtf_score = -(labels.shape[-1] * loss.item()) + + EXPECTED_SCORE = -19.0845 + self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4) + + @slow + def test_small_v1_1_integration_test(self): + """ + For comparision run: + >>> import switchtransformers # pip install switchtransformers==0.7.1 + >>> from switchtransformers.data.sentencepiece_vocabulary import SentencePieceVocabulary + + >>> path_to_mtf_small_switchtransformers_v1_1_checkpoint = '' + >>> path_to_mtf_small_spm_model_path = '' + >>> switchtransformers_model = switchtransformers.models.MtfModel(model_dir=path_to_mtf_small_switchtransformers_v1_1_checkpoint, batch_size=1, tpu=None) + >>> vocab = SentencePieceVocabulary(path_to_mtf_small_spm_model_path, extra_ids=100) + >>> score = switchtransformers_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) + """ + + model = SwitchTransformersForConditionalGeneration.from_pretrained("google/switchtransformers-v1_1-small").to(torch_device) + tokenizer = SwitchTransformersTokenizer.from_pretrained("google/switchtransformers-v1_1-small") + + input_ids = tokenizer("Hello there", return_tensors="pt").input_ids + labels = tokenizer("Hi I am", return_tensors="pt").input_ids + + loss = model(input_ids.to(torch_device), labels=labels.to(torch_device)).loss + mtf_score = -(labels.shape[-1] * loss.item()) + + EXPECTED_SCORE = -59.0293 + self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4) + + @slow + def test_small_byswitchtransformers_integration_test(self): + """ + For comparision run: + >>> import switchtransformers # pip install switchtransformers==0.9.1 + + >>> path_to_byswitchtransformers_small_checkpoint = '' + >>> switchtransformers_model = switchtransformers.models.MtfModel(model_dir=path_to_tf_checkpoint, batch_size=1, tpu=None) + >>> vocab = switchtransformers.data.ByteVocabulary() + >>> score = switchtransformers_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) + """ + + model = SwitchTransformersForConditionalGeneration.from_pretrained("google/byybelkada/switchtransformers-base").to(torch_device) + tokenizer = BySwitchTransformersTokenizer.from_pretrained("google/byybelkada/switchtransformers-base") + + input_ids = tokenizer("Hello there", return_tensors="pt").input_ids + labels = tokenizer("Hi I am", return_tensors="pt").input_ids + + loss = model(input_ids.to(torch_device), labels=labels.to(torch_device)).loss + mtf_score = -(labels.shape[-1] * loss.item()) + + EXPECTED_SCORE = -60.7397 + self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4) + + @slow + def test_summarization(self): + model = self.model + tok = self.tokenizer + + FRANCE_ARTICLE = ( # @noqa + "Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings" + " Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane." + ' Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation."' + ' He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s' + " comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video" + " showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French" + " Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a" + " phone at the wreckage site. The two publications described the supposed video, but did not post it on" + " their websites. The publications said that they watched the video, which was found by a source close to" + " the investigation. \"One can hear cries of 'My God' in several languages,\" Paris Match reported." + ' "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the' + " cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the" + ' screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt,' + " editor-in-chief of Bild online. An official with France's accident investigation agency, the BEA, said" + " the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie spokesman" + " in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the" + ' reports were "completely wrong" and "unwarranted." Cell phones have been collected at the site, he said,' + ' but that they "hadn\'t been exploited yet." Menichini said he believed the cell phones would need to be' + " sent to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by" + " specialized technicians working hand-in-hand with investigators. But none of the cell phones found so" + " far have been sent to the institute, Menichini said. Asked whether staff involved in the search could" + ' have leaked a memory card to the media, Menichini answered with a categorical "no." Reichelt told "Erin' + ' Burnett: Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match' + ' are "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered' + ' cell phones from the crash site after Bild and Paris Match published their reports. "That is something' + " we did not know before. ... Overall we can say many things of the investigation weren't revealed by the" + ' investigation at the beginning," he said. What was mental state of Germanwings co-pilot? German airline' + " Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled depression years before he took the" + " controls of Germanwings Flight 9525, which he's accused of deliberately crashing last week in the" + ' French Alps. Lubitz told his Lufthansa flight training school in 2009 that he had a "previous episode of' + ' severe depression," the airline said Tuesday. Email correspondence between Lubitz and the school' + " discovered in an internal investigation, Lufthansa said, included medical documents he submitted in" + " connection with resuming his flight training. The announcement indicates that Lufthansa, the parent" + " company of Germanwings, knew of Lubitz's battle with depression, allowed him to continue training and" + " ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously said Lubitz was 100%" + ' fit to fly, described its statement Tuesday as a "swift and seamless clarification" and said it was' + " sharing the information and documents -- including training and medical records -- with public" + " prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the" + " past week to recover human remains and plane debris scattered across a steep mountainside. He saw the" + " crisis center set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash" + " site, where grieving families have left flowers at a simple stone memorial. Menichini told CNN late" + " Tuesday that no visible human remains were left at the site but recovery teams would keep searching." + " French President Francois Hollande, speaking Tuesday, said that it should be possible to identify all" + " the victims using DNA analysis by the end of the week, sooner than authorities had previously suggested." + " In the meantime, the recovery of the victims' personal belongings will start Wednesday, Menichini said." + " Among those personal belongings could be more cell phones belonging to the 144 passengers and six crew" + " on board. Check out the latest from our correspondents . The details about Lubitz's correspondence with" + " the flight school during his training were among several developments as investigators continued to" + " delve into what caused the crash and Lubitz's possible motive for downing the jet. A Lufthansa" + " spokesperson told CNN on Tuesday that Lubitz had a valid medical certificate, had passed all his" + ' examinations and "held all the licenses required." Earlier, a spokesman for the prosecutor\'s office in' + " Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz suffered from suicidal tendencies at" + " some point before his aviation career and underwent psychotherapy before he got his pilot's license." + " Kumpa emphasized there's no evidence suggesting Lubitz was suicidal or acting aggressively before the" + " crash. Investigators are looking into whether Lubitz feared his medical condition would cause him to" + " lose his pilot's license, a European government official briefed on the investigation told CNN on" + ' Tuesday. While flying was "a big part of his life," the source said, it\'s only one theory being' + " considered. Another source, a law enforcement official briefed on the investigation, also told CNN that" + " authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would" + " not be allowed to fly because of his medical problems. Lubitz's girlfriend told investigators he had" + " seen an eye doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded" + " he had psychological issues, the European government official said. But no matter what details emerge" + " about his previous mental health struggles, there's more to the story, said Brian Russell, a forensic" + ' psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the fact' + " that maybe they weren't going to keep doing their job and they're upset about that and so they're" + ' suicidal," he said. "But there is no mental illness that explains why somebody then feels entitled to' + " also take that rage and turn it outward on 149 other people who had nothing to do with the person's" + ' problems." Germanwings crash compensation: What we know . Who was the captain of Germanwings Flight' + " 9525? CNN's Margot Haddad reported from Marseille and Pamela Brown from Dusseldorf, while Laura" + " Smith-Spark wrote from London. CNN's Frederik Pleitgen, Pamela Boykoff, Antonia Mortensen, Sandrine" + " Amiel and Anna-Maja Rappard contributed to this report." + ) + SHORTER_ARTICLE = ( + "(CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on" + " Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The" + " formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based." + " The Palestinians signed the ICC's founding Rome Statute in January, when they also accepted its" + ' jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East' + ' Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the' + " situation in Palestinian territories, paving the way for possible war crimes investigations against" + " Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and" + " the United States, neither of which is an ICC member, opposed the Palestinians' efforts to join the" + " body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday's ceremony, said it was a" + ' move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the' + ' world is also a step closer to ending a long era of impunity and injustice," he said, according to an' + ' ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge' + " Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the" + ' Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine' + " acquires all the rights as well as responsibilities that come with being a State Party to the Statute." + ' These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights' + ' Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should' + " immediately end their pressure, and countries that support universal acceptance of the court's treaty" + ' should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the' + " group. \"What's objectionable is the attempts to undermine international justice, not Palestine's" + ' decision to join a treaty to which over 100 countries around the world are members." In January, when' + " the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an" + ' outrage, saying the court was overstepping its boundaries. The United States also said it "strongly"' + " disagreed with the court's decision. \"As we have said repeatedly, we do not believe that Palestine is a" + ' state and therefore we do not believe that it is eligible to join the ICC," the State Department said in' + ' a statement. It urged the warring sides to resolve their differences through direct negotiations. "We' + ' will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace,"' + " it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the" + ' territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the' + " court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou" + ' Bensouda said her office would "conduct its analysis in full independence and impartiality." The war' + " between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry" + " will include alleged war crimes committed since June. The International Criminal Court was set up in" + " 2002 to prosecute genocide, crimes against humanity and war crimes. CNN's Vasco Cotovio, Kareem Khadder" + " and Faith Karimi contributed to this report." + ) + IRAN_ARTICLE = ( + "(CNN)The United States and its negotiating partners reached a very strong framework agreement with Iran" + " in Lausanne, Switzerland, on Thursday that limits Iran's nuclear program in such a way as to effectively" + " block it from building a nuclear weapon. Expect pushback anyway, if the recent past is any harbinger." + " Just last month, in an attempt to head off such an agreement, House Speaker John Boehner invited Israeli" + " Prime Minister Benjamin Netanyahu to preemptively blast it before Congress, and 47 senators sent a" + " letter to the Iranian leadership warning them away from a deal. The debate that has already begun since" + " the announcement of the new framework will likely result in more heat than light. It will not be helped" + " by the gathering swirl of dubious assumptions and doubtful assertions. Let us address some of these: ." + " The most misleading assertion, despite universal rejection by experts, is that the negotiations'" + " objective at the outset was the total elimination of any nuclear program in Iran. That is the position" + " of Netanyahu and his acolytes in the U.S. Congress. But that is not and never was the objective. If it" + " had been, there would have been no Iranian team at the negotiating table. Rather, the objective has" + " always been to structure an agreement or series of agreements so that Iran could not covertly develop a" + " nuclear arsenal before the United States and its allies could respond. The new framework has exceeded" + " expectations in achieving that goal. It would reduce Iran's low-enriched uranium stockpile, cut by" + " two-thirds its number of installed centrifuges and implement a rigorous inspection regime. Another" + " dubious assumption of opponents is that the Iranian nuclear program is a covert weapons program. Despite" + " sharp accusations by some in the United States and its allies, Iran denies having such a program, and" + " U.S. intelligence contends that Iran has not yet made the decision to build a nuclear weapon. Iran's" + " continued cooperation with International Atomic Energy Agency inspections is further evidence on this" + " point, and we'll know even more about Iran's program in the coming months and years because of the deal." + " In fact, the inspections provisions that are part of this agreement are designed to protect against any" + " covert action by the Iranians. What's more, the rhetoric of some members of Congress has implied that" + " the negotiations have been between only the United States and Iran (i.e., the 47 senators' letter" + " warning that a deal might be killed by Congress or a future president). This of course is not the case." + " The talks were between Iran and the five permanent members of the U.N. Security Council (United States," + " United Kingdom, France, China and Russia) plus Germany, dubbed the P5+1. While the United States has" + " played a leading role in the effort, it negotiated the terms alongside its partners. If the agreement" + " reached by the P5+1 is rejected by Congress, it could result in an unraveling of the sanctions on Iran" + " and threaten NATO cohesion in other areas. Another questionable assertion is that this agreement" + " contains a sunset clause, after which Iran will be free to do as it pleases. Again, this is not the" + " case. Some of the restrictions on Iran's nuclear activities, such as uranium enrichment, will be eased" + " or eliminated over time, as long as 15 years. But most importantly, the framework agreement includes" + " Iran's ratification of the Additional Protocol, which allows IAEA inspectors expanded access to nuclear" + " sites both declared and nondeclared. This provision will be permanent. It does not sunset. Thus, going" + " forward, if Iran decides to enrich uranium to weapons-grade levels, monitors will be able to detect such" + " a move in a matter of days and alert the U.N. Security Council. Many in Congress have said that the" + ' agreement should be a formal treaty requiring the Senate to "advise and consent." But the issue is not' + " suited for a treaty. Treaties impose equivalent obligations on all signatories. For example, the New" + " START treaty limits Russia and the United States to 1,550 deployed strategic warheads. But any agreement" + " with Iran will not be so balanced. The restrictions and obligations in the final framework agreement" + " will be imposed almost exclusively on Iran. The P5+1 are obligated only to ease and eventually remove" + " most but not all economic sanctions, which were imposed as leverage to gain this final deal. Finally" + " some insist that any agreement must address Iranian missile programs, human rights violations or support" + " for Hamas or Hezbollah. As important as these issues are, and they must indeed be addressed, they are" + " unrelated to the most important aim of a nuclear deal: preventing a nuclear Iran. To include them in" + " the negotiations would be a poison pill. This agreement should be judged on its merits and on how it" + " affects the security of our negotiating partners and allies, including Israel. Those judgments should be" + " fact-based, not based on questionable assertions or dubious assumptions." + ) + ARTICLE_SUBWAY = ( + "New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A" + " year later, she got married again in Westchester County, but to a different man and without divorcing" + " her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos" + ' declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married' + " once more, this time in the Bronx. In an application for a marriage license, she stated it was her" + ' "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false' + ' instrument for filing in the first degree," referring to her false statements on the 2010 marriage' + " license application, according to court documents. Prosecutors said the marriages were part of an" + " immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to" + " her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was" + " arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New" + " York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total," + " Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All" + " occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be" + " married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors" + " said the immigration scam involved some of her husbands, who filed for permanent residence status" + " shortly after the marriages. Any divorces happened only after such filings were approved. It was" + " unclear whether any of the men will be prosecuted. The case was referred to the Bronx District" + " Attorney's Office by Immigration and Customs Enforcement and the Department of Homeland Security's" + ' Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt,' + " Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his" + " native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces" + " up to four years in prison. Her next court appearance is scheduled for May 18." + ) + + expected_summaries = [ + 'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a' + " cell phone video of the final seconds . \"one can hear cries of 'My God' in several languages,\" one" + " magazine says .", + "the formal accession was marked by a ceremony at The Hague, in the Netherlands . the ICC opened a" + " preliminary examination into the situation in the occupied Palestinian territory . as members of the" + " court, Palestinians may be subject to counter-charges as well .", + "the u.s. and its negotiating partners reached a very strong framework agreement with Iran . aaron miller:" + " the debate that has already begun since the announcement of the new framework will likely result in more" + " heat than light . the deal would reduce Iran's low-enriched uranium stockpile, cut centrifuges and" + " implement a rigorous inspection regime .", + "prosecutors say the marriages were part of an immigration scam . if convicted, barrientos faces two" + ' criminal counts of "offering a false instrument for filing in the first degree" she has been married 10' + " times, with nine of her marriages occurring between 1999 and 2002 .", + ] + + use_task_specific_params(model, "summarization") + + dct = tok( + [model.config.prefix + x for x in [FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY]], + padding="max_length", + truncation=True, + return_tensors="pt", + ).to(torch_device) + self.assertEqual(512, dct["input_ids"].shape[1]) + + hypotheses_batch = model.generate( + **dct, + num_beams=4, + length_penalty=2.0, + max_length=142, + min_length=56, + no_repeat_ngram_size=3, + do_sample=False, + early_stopping=True, + ) + + decoded = tok.batch_decode(hypotheses_batch, skip_special_tokens=True, clean_up_tokenization_spaces=False) + self.assertListEqual( + expected_summaries, + decoded, + ) + + @slow + def test_translation_en_to_de(self): + model = self.model + tok = self.tokenizer + use_task_specific_params(model, "translation_en_to_de") + + en_text = '"Luigi often said to me that he never wanted the brothers to end up in court", she wrote.' + expected_translation = ( + '"Luigi sagte mir oft, dass er nie wollte, dass die Brüder am Gericht sitzen", schrieb sie.' + ) + + input_ids = tok.encode(model.config.prefix + en_text, return_tensors="pt") + input_ids = input_ids.to(torch_device) + output = model.generate(input_ids) + translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) + self.assertEqual(translation, expected_translation) + + @slow + def test_translation_en_to_fr(self): + model = self.model # switchtransformers-base + tok = self.tokenizer + use_task_specific_params(model, "translation_en_to_fr") + + en_text = ( + ' This image section from an infrared recording by the Spitzer telescope shows a "family portrait" of' + " countless generations of stars: the oldest stars are seen as blue dots. " + ) + + input_ids = tok.encode(model.config.prefix + en_text, return_tensors="pt") + input_ids = input_ids.to(torch_device) + + output = model.generate( + input_ids=input_ids, + num_beams=4, + length_penalty=2.0, + max_length=100, + no_repeat_ngram_size=3, + do_sample=False, + early_stopping=True, + ) + translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) + new_truncated_translation = ( + "Cette section d'images provenant de l'enregistrement infrarouge effectué par le télescope Spitzer montre " + "un " + "« portrait familial » de générations innombrables d’étoiles : les plus anciennes sont observées " + "sous forme " + "de points bleus." + ) + + self.assertEqual(translation, new_truncated_translation) + + @slow + def test_translation_en_to_ro(self): + model = self.model + tok = self.tokenizer + use_task_specific_params(model, "translation_en_to_ro") + en_text = "Taco Bell said it plans to add 2,000 locations in the US by 2022." + expected_translation = "Taco Bell a declarat că intenţionează să adauge 2 000 de locaţii în SUA până în 2022." + + inputs = tok(model.config.prefix + en_text, return_tensors="pt").to(torch_device) + output = model.generate(**inputs) + translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) + self.assertEqual(translation, expected_translation) + + +@require_torch +class TestAsymmetricSwitchTransformers(unittest.TestCase): + def build_model_and_check_forward_pass(self, **kwargs): + tester = SwitchTransformersModelTester(self, **kwargs) + config, *inputs = tester.prepare_config_and_inputs() + ( + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ) = inputs + model = SwitchTransformersForConditionalGeneration(config=config).to(torch_device).eval() + outputs = model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + labels=lm_labels, + ) + # outputs = model(*inputs) + assert len(outputs) == 4 + assert outputs["logits"].size() == (tester.batch_size, tester.decoder_seq_length, tester.vocab_size) + assert outputs["loss"].size() == () + return model + + def test_small_decoder(self): + # num_hidden_layers is passed to SwitchTransformersConfig as num_layers + model = self.build_model_and_check_forward_pass(decoder_layers=1, num_hidden_layers=2) + assert len(model.encoder.block) == 2 + assert len(model.decoder.block) == 1 + + def test_defaulting_to_symmetry(self): + # num_hidden_layers is passed to SwitchTransformersConfig as num_layers + model = self.build_model_and_check_forward_pass(num_hidden_layers=2) + assert len(model.decoder.block) == len(model.encoder.block) == 2 diff --git a/tests/models/switchtransformers/test_modeling_tf_switchtransformers.py b/tests/models/switchtransformers/test_modeling_tf_switchtransformers.py new file mode 100644 index 0000000000000..b24b0dae2ea0e --- /dev/null +++ b/tests/models/switchtransformers/test_modeling_tf_switchtransformers.py @@ -0,0 +1,1054 @@ +# coding=utf-8 +# Copyright 2022 Google SwitchTransformers Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from transformers import SwitchTransformersConfig, is_tf_available +from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow, tooslow +from transformers.utils import cached_property + +from ...test_configuration_common import ConfigTester +from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask + + +if is_tf_available(): + import tensorflow as tf + + from transformers import BySwitchTransformersTokenizer, SwitchTransformersTokenizer, TFSwitchTransformersEncoderModel, TFSwitchTransformersForConditionalGeneration, TFSwitchTransformersModel + + +class TFSwitchTransformersModelTester: + def __init__( + self, + parent, + ): + self.parent = parent + self.batch_size = 13 + self.seq_length = 7 + self.is_training = True + self.use_input_mask = True + self.use_labels = True + self.vocab_size = 99 + self.n_positions = 14 + self.hidden_size = 32 + self.num_hidden_layers = 5 + self.num_attention_heads = 4 + self.d_ff = 37 + self.relative_attention_num_buckets = 8 + self.dropout_rate = 0.1 + self.initializer_factor = 0.002 + self.eos_token_id = 1 + self.pad_token_id = 0 + self.scope = None + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = random_attention_mask([self.batch_size, self.seq_length]) + + token_labels = None + if self.use_labels: + token_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + config = SwitchTransformersConfig( + vocab_size=self.vocab_size, + n_positions=self.n_positions, + d_model=self.hidden_size, + d_ff=self.d_ff, + d_kv=self.hidden_size // self.num_attention_heads, + num_layers=self.num_hidden_layers, + num_heads=self.num_attention_heads, + relative_attention_num_buckets=self.relative_attention_num_buckets, + dropout_rate=self.dropout_rate, + initializer_factor=self.initializer_factor, + eos_token_id=self.eos_token_id, + bos_token_id=self.pad_token_id, + pad_token_id=self.pad_token_id, + decoder_start_token_id=self.pad_token_id, + ) + + return (config, input_ids, input_mask, token_labels) + + def create_and_check_switchtransformers_model(self, config, input_ids, input_mask, token_labels): + model = TFSwitchTransformersModel(config=config) + inputs = { + "input_ids": input_ids, + "decoder_input_ids": input_ids, + "decoder_attention_mask": input_mask, + } + result = model(inputs) + + result = model(input_ids, decoder_attention_mask=input_mask, decoder_input_ids=input_ids) + decoder_output = result.last_hidden_state + decoder_past = result.past_key_values + encoder_output = result.encoder_last_hidden_state + self.parent.assertListEqual(list(encoder_output.shape), [self.batch_size, self.seq_length, self.hidden_size]) + self.parent.assertListEqual(list(decoder_output.shape), [self.batch_size, self.seq_length, self.hidden_size]) + # There should be `num_layers` key value embeddings stored in decoder_past[1] + self.parent.assertEqual(len(decoder_past), config.num_layers) + # There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past[1] tuple + self.parent.assertEqual(len(decoder_past[0]), 4) + + def create_and_check_switchtransformers_with_lm_head(self, config, input_ids, input_mask, token_labels): + model = TFSwitchTransformersForConditionalGeneration(config=config) + inputs_dict = { + "input_ids": input_ids, + "decoder_input_ids": input_ids, + "decoder_attention_mask": input_mask, + } + + result = model(inputs_dict) + + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + + def create_and_check_switchtransformers_decoder_model_past(self, config, input_ids, decoder_input_ids, attention_mask): + model = TFSwitchTransformersModel(config=config).get_decoder() + + input_ids = input_ids[:1, :] + self.batch_size = 1 + + # first forward pass + outputs = model(input_ids, use_cache=True) + + outputs_use_cache_conf = model(input_ids) + outputs_no_past = model(input_ids, use_cache=False) + + self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf)) + self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1) + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # append to next input_ids and + next_input_ids = tf.concat([input_ids, next_tokens], axis=-1) + + output_from_no_past = model(next_input_ids)[0] + output_from_past = model(next_tokens, past_key_values=outputs.past_key_values)[0] + + # select random slice + random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1])) + output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx] + output_from_past_slice = output_from_past[:, 0, random_slice_idx] + + # test that outputs are equal for slice + tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3) + + def create_and_check_switchtransformers_decoder_model_attention_mask_past( + self, config, input_ids, decoder_input_ids, attention_mask + ): + model = TFSwitchTransformersModel(config=config).get_decoder() + + # create attention mask + half_seq_length = self.seq_length // 2 + attn_mask_begin = tf.ones((self.batch_size, half_seq_length), dtype=tf.int32) + attn_mask_end = tf.zeros((self.batch_size, self.seq_length - half_seq_length), dtype=tf.int32) + attn_mask = tf.concat([attn_mask_begin, attn_mask_end], axis=1) + + # first forward pass + outputs = model(input_ids, attention_mask=attn_mask, use_cache=True) + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # change a random masked slice from input_ids + random_seq_idx_to_change = ids_tensor((1,), half_seq_length).numpy() + 1 + random_other_next_tokens = ids_tensor((self.batch_size, self.seq_length), config.vocab_size) + vector_condition = tf.range(self.seq_length) == (self.seq_length - random_seq_idx_to_change) + condition = tf.transpose( + tf.broadcast_to(tf.expand_dims(vector_condition, -1), (self.seq_length, self.batch_size)) + ) + input_ids = tf.where(condition, random_other_next_tokens, input_ids) + + # append to next input_ids and attn_mask + next_input_ids = tf.concat([input_ids, next_tokens], axis=-1) + attn_mask = tf.concat( + [attn_mask, tf.ones((attn_mask.shape[0], 1), dtype=tf.int32)], + axis=1, + ) + + # get two different outputs + output_from_no_past = model(next_input_ids, attention_mask=attn_mask)[0] + output_from_past = model(next_tokens, past_key_values=outputs.past_key_values, attention_mask=attn_mask)[0] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).numpy().item() + output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx] + output_from_past_slice = output_from_past[:, 0, random_slice_idx] + + # test that outputs are equal for slice + tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3) + + def create_and_check_switchtransformers_decoder_model_past_large_inputs( + self, config, input_ids, decoder_input_ids, attention_mask + ): + model = TFSwitchTransformersModel(config=config).get_decoder() + + input_ids = input_ids[:1, :] + attention_mask = attention_mask[:1, :] + self.batch_size = 1 + + # first forward pass + outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + next_attn_mask = ids_tensor((self.batch_size, 3), 2) + + # append to next input_ids and + next_input_ids = tf.concat([input_ids, next_tokens], axis=-1) + next_attention_mask = tf.concat([attention_mask, next_attn_mask], axis=-1) + + output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)[0] + output_from_past = model( + next_tokens, attention_mask=next_attention_mask, past_key_values=outputs.past_key_values + )[0] + + self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1]) + + # select random slice + random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1])) + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx] + output_from_past_slice = output_from_past[:, :, random_slice_idx] + + # test that outputs are equal for slice + tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + (config, input_ids, input_mask, token_labels) = config_and_inputs + inputs_dict = { + "input_ids": input_ids, + "decoder_input_ids": input_ids, + "decoder_attention_mask": input_mask, + } + return config, inputs_dict + + +@require_tf +class TFSwitchTransformersModelTest(TFModelTesterMixin, unittest.TestCase): + + is_encoder_decoder = True + all_model_classes = (TFSwitchTransformersModel, TFSwitchTransformersForConditionalGeneration) if is_tf_available() else () + all_generative_model_classes = (TFSwitchTransformersForConditionalGeneration,) if is_tf_available() else () + test_onnx = False + + def setUp(self): + self.model_tester = TFSwitchTransformersModelTester(self) + self.config_tester = ConfigTester(self, config_class=SwitchTransformersConfig, d_model=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_switchtransformers_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_switchtransformers_model(*config_and_inputs) + + def test_switchtransformers_model_v1_1(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + config = config_and_inputs[0] + config.tie_word_embeddings = False + config.feed_forward_proj = "gated-gelu" + self.model_tester.create_and_check_switchtransformers_model(config, *config_and_inputs[1:]) + + def test_with_lm_head(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_switchtransformers_with_lm_head(*config_and_inputs) + + def test_switchtransformers_decoder_model_past(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_switchtransformers_decoder_model_past(*config_and_inputs) + + def test_switchtransformers_decoder_model_past_with_attn_mask(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_switchtransformers_decoder_model_attention_mask_past(*config_and_inputs) + + def test_switchtransformers_decoder_model_past_large_inputs(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + + # `create_and_check_switchtransformers_decoder_model_past_large_inputs` has special inputs: + # (config, input_ids, decoder_input_ids, attention_mask) + # and we have to prepare it correctly here. + config, input_ids, input_mask, token_labels = config_and_inputs + config_and_inputs = (config, input_ids, None, input_mask) + + self.model_tester.create_and_check_switchtransformers_decoder_model_past_large_inputs(*config_and_inputs) + + def test_model_common_attributes(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer) + + if model_class in self.all_generative_model_classes: + x = model.get_output_embeddings() + assert isinstance(x, tf.keras.layers.Layer) + name = model.get_bias() + assert name is None + else: + x = model.get_output_embeddings() + assert x is None + name = model.get_bias() + assert name is None + + @tooslow + def test_saved_model_creation(self): + pass + + @slow + def test_model_from_pretrained(self): + model = TFSwitchTransformersModel.from_pretrained("ybelkada/switchtransformers-base") + self.assertIsNotNone(model) + + def test_generate_with_headmasking(self): + # TODO: Fix head-masking according to PyTorch SwitchTransformers model + pass + + @slow + def test_resize_embeddings(self): + model = TFSwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switchtransformers-base") + original_vocab_size = model.get_input_embeddings().weight.shape[0] + # the vocab size is defined in the model config + self.assertEqual(original_vocab_size, model.config.vocab_size) + + tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switchtransformers-base") + tokenizer.add_special_tokens({"bos_token": "", "eos_token": ""}) + model._resize_token_embeddings(len(tokenizer)) + # the vocab size is now resized to the length of the tokenizer, which is different from the original size + self.assertEqual(model.get_input_embeddings().weight.shape[0], len(tokenizer)) + self.assertNotEqual(model.get_input_embeddings().weight.shape[0], original_vocab_size) + + # This test is run in `TFSwitchTransformersEncoderOnlyModelTest`, where the main layer has the same inputs as the model + @unittest.skip(reason="The inputs of the Main Layer are different.") + def test_keras_save_load(self): + pass + + +class TFSwitchTransformersEncoderOnlyModelTester: + def __init__( + self, + parent, + vocab_size=99, + batch_size=13, + encoder_seq_length=7, + # For common tests + use_attention_mask=True, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + d_ff=37, + relative_attention_num_buckets=8, + is_training=False, + dropout_rate=0.1, + initializer_factor=0.002, + is_encoder_decoder=False, + eos_token_id=1, + pad_token_id=0, + scope=None, + ): + + self.parent = parent + self.batch_size = batch_size + self.encoder_seq_length = encoder_seq_length + # For common tests + self.seq_length = self.encoder_seq_length + self.use_attention_mask = use_attention_mask + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.d_ff = d_ff + self.relative_attention_num_buckets = relative_attention_num_buckets + self.dropout_rate = dropout_rate + self.initializer_factor = initializer_factor + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.is_encoder_decoder = is_encoder_decoder + self.scope = None + self.is_training = is_training + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size) + + attention_mask = None + if self.use_attention_mask: + attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2) + + config = SwitchTransformersConfig( + vocab_size=self.vocab_size, + d_model=self.hidden_size, + d_ff=self.d_ff, + d_kv=self.hidden_size // self.num_attention_heads, + num_layers=self.num_hidden_layers, + num_heads=self.num_attention_heads, + relative_attention_num_buckets=self.relative_attention_num_buckets, + dropout_rate=self.dropout_rate, + initializer_factor=self.initializer_factor, + eos_token_id=self.eos_token_id, + bos_token_id=self.pad_token_id, + pad_token_id=self.pad_token_id, + is_encoder_decoder=self.is_encoder_decoder, + ) + + return ( + config, + input_ids, + attention_mask, + ) + + def create_and_check_model( + self, + config, + input_ids, + attention_mask, + ): + model = TFSwitchTransformersEncoderModel(config=config) + result = model( + input_ids=input_ids, + attention_mask=attention_mask, + ) + result = model(input_ids=input_ids) + encoder_output = result.last_hidden_state + + self.parent.assertEqual(encoder_output.shape, (self.batch_size, self.encoder_seq_length, self.hidden_size)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + attention_mask, + ) = config_and_inputs + + inputs_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + return config, inputs_dict + + +class TFSwitchTransformersEncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase): + is_encoder_decoder = False + all_model_classes = (TFSwitchTransformersEncoderModel,) if is_tf_available() else () + test_onnx = False + + def setUp(self): + self.model_tester = TFSwitchTransformersEncoderOnlyModelTester(self) + self.config_tester = ConfigTester(self, config_class=SwitchTransformersConfig, d_model=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + # is not able to be part of a pipeline + def test_train_pipeline_custom_model(self): + pass + + +@require_tf +@require_sentencepiece +@require_tokenizers +class TFSwitchTransformersGenerationIntegrationTests(unittest.TestCase): + @slow + def test_greedy_xla_generate_simple(self): + model = TFSwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switchtransformers-base") + tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switchtransformers-base") + + # two examples with different lengths to confirm that attention masks are operational in XLA + sentences = [ + "Translate English to German: Today is a beautiful day.", + "Translate English to German: I have four cats, three dogs, two birds, and a horse.", + ] + input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids + + xla_generate = tf.function(model.generate, jit_compile=True) + + output_ids = model.generate(input_ids) + output_ids_xla = xla_generate(input_ids) + + output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) + output_strings_xla = tokenizer.batch_decode(output_ids_xla, skip_special_tokens=True) + + expected_output_string = [ + "Heute ist ein schöner Tag.", + "Ich habe vier Katzen, drei Hunde, zwei Vögel und ein Pferd.", + ] + + self.assertListEqual(expected_output_string, output_strings) + self.assertListEqual(expected_output_string, output_strings_xla) + + @slow + def test_greedy_generate(self): + model = TFSwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switchtransformers-base") + tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switchtransformers-base") + + sentences = ["Yesterday, my name was", "Today is a beautiful day and"] + input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids + + generation_kwargs = { + "bad_words_ids": [tokenizer("my").input_ids, tokenizer("ein schöner").input_ids], + "no_repeat_ngram_size": 3, + "do_sample": False, + "repetition_penalty": 2.2, + } + + output_ids = model.generate(input_ids, **generation_kwargs) + + output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) + + expected_output_string = ["Yesterday, my name was", "Heute ist ein schöne Tag und"] + + self.assertListEqual(expected_output_string, output_strings) + + @slow + def test_sample_xla_generate_simple(self): + # NOTE: due to the small numerical differences that are natural when we compile to XLA, sampling the same + # output out of the same seed is far from guaranteed. We can, however, confirm that the results are sensible + # and that we can seed both versions. + + # forces the generation to happen on CPU, to avoid GPU-related quirks + with tf.device(":/CPU:0"): + model = TFSwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switchtransformers-base") + tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switchtransformers-base") + + sentence = "Translate English to German: I have two bananas" + input_ids = tokenizer(sentence, return_tensors="tf", padding=True).input_ids + expected_output_string = ["Ich habe zwei Bananen"] + expected_output_string_xla = ["Ich habe 2 Bananen"] + + # seed set -> deterministic sampling sequence -> deterministic generation + output_ids = model.generate(input_ids, do_sample=True, seed=[42, 0]) + output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) + self.assertListEqual(expected_output_string, output_strings) + + xla_generate = tf.function(model.generate, jit_compile=True) + # seed set -> deterministic sampling sequence -> deterministic generation + output_ids_xla = xla_generate(input_ids, do_sample=True, seed=[42, 0]) + output_strings_xla = tokenizer.batch_decode(output_ids_xla, skip_special_tokens=True) + self.assertListEqual(expected_output_string_xla, output_strings_xla) + + @slow + def test_sample_generate(self): + model = TFSwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switchtransformers-base") + tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switchtransformers-base") + + sentences = ["I really love my", "Translate English to German: the transformers are truly amazing"] + input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids + + generation_kwargs = { + "do_sample": True, + "bad_words_ids": [tokenizer("my").input_ids, tokenizer("ein schöner").input_ids], + "no_repeat_ngram_size": 3, + "repetition_penalty": 2.2, + "temperature": 0.8, + "top_k": 500, + "top_p": 0.9, + "seed": [20, 0], # seed set -> deterministic sampling sequence -> deterministic generation + } + + # forces the generation to happen on CPU, to avoid GPU-related quirks + with tf.device(":/CPU:0"): + output_ids = model.generate(input_ids, **generation_kwargs) + + output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) + + expected_output_string = ["- I really love my way of this.", "die Transformatoren sind wirklich erstaunlich"] + + self.assertListEqual(expected_output_string, output_strings) + + @slow + def test_beam_search_xla_generate_simple(self): + model = TFSwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switchtransformers-base") + tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switchtransformers-base") + + # tests XLA with task specific arguments + task_specific_config = getattr(model.config, "task_specific_params", {}) + translation_config = task_specific_config.get("translation_en_to_fr", {}) + model.config.update(translation_config) + + # two examples with different lengths to confirm that attention masks are operational in XLA + sentences = [ + model.config.prefix + "Today is a beautiful day.", + model.config.prefix + "I have four cats, three dogs, two birds, and a horse.", + ] + input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids + + xla_generate = tf.function(model.generate, jit_compile=True) + + output_ids = model.generate(input_ids, num_beams=2) + output_ids_xla = xla_generate(input_ids, num_beams=2) + + output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) + output_strings_xla = tokenizer.batch_decode(output_ids_xla, skip_special_tokens=True) + + expected_output_string = [ + "Aujourd'hui est une belle journée.", + "J'ai quatre chats, trois chiens, deux oiseaux et un cheval.", + ] + + self.assertListEqual(expected_output_string, output_strings) + self.assertListEqual(expected_output_string, output_strings_xla) + + @slow + def test_beam_search_generate(self): + model = TFSwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switchtransformers-base") + tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switchtransformers-base") + + sentences = ["I really love my", "Translate English to German: the transformers are truly amazing"] + input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids + + generation_kwargs = { + "bad_words_ids": [tokenizer("my").input_ids, tokenizer("ein schöner").input_ids], + "no_repeat_ngram_size": 3, + "do_sample": False, + "repetition_penalty": 2.2, + "num_beams": 4, + } + + output_ids = model.generate(input_ids, **generation_kwargs) + + output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) + + expected_output_string = ["Ich liebe es so sehr!", "die Transformatoren sind wirklich erstaunlich"] + self.assertListEqual(expected_output_string, output_strings) + + +@require_tf +@require_sentencepiece +@require_tokenizers +class TFSwitchTransformersModelIntegrationTests(unittest.TestCase): + @cached_property + def model(self): + return TFSwitchTransformersForConditionalGeneration.from_pretrained("switchtransformers-base") + + @slow + def test_small_integration_test(self): + """ + For comparision run: + >>> import switchtransformers # pip install switchtransformers==0.7.1 + >>> from switchtransformers.data.sentencepiece_vocabulary import SentencePieceVocabulary + + >>> path_to_mtf_small_switchtransformers_checkpoint = '' + >>> path_to_mtf_small_spm_model_path = '' + >>> switchtransformers_model = switchtransformers.models.MtfModel(model_dir=path_to_mtf_small_switchtransformers_checkpoint, batch_size=1, tpu=None) + >>> vocab = SentencePieceVocabulary(path_to_mtf_small_spm_model_path, extra_ids=100) + >>> score = switchtransformers_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) + """ + + model = TFSwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switchtransformers-base") + tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switchtransformers-base") + + input_ids = tokenizer("Hello there", return_tensors="tf").input_ids + labels = tokenizer("Hi I am", return_tensors="tf").input_ids + + loss = model(input_ids, labels=labels).loss + mtf_score = -tf.math.reduce_mean(loss).numpy() + + EXPECTED_SCORE = -4.771147 + self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4) + + @slow + def test_small_v1_1_integration_test(self): + """ + For comparision run: + >>> import switchtransformers # pip install switchtransformers==0.7.1 + >>> from switchtransformers.data.sentencepiece_vocabulary import SentencePieceVocabulary + + >>> path_to_mtf_small_switchtransformers_v1.1_checkpoint = '' + >>> path_to_mtf_small_spm_model_path = '' + >>> switchtransformers_model = switchtransformers.models.MtfModel(model_dir=path_to_mtf_small_switchtransformers_v1.1_checkpoint, batch_size=1, tpu=None) + >>> vocab = SentencePieceVocabulary(path_to_mtf_small_spm_model_path, extra_ids=100) + >>> score = switchtransformers_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) + """ + + model = TFSwitchTransformersForConditionalGeneration.from_pretrained("google/switchtransformers-v1_1-small") + tokenizer = SwitchTransformersTokenizer.from_pretrained("google/switchtransformers-v1_1-small") + + input_ids = tokenizer("Hello there", return_tensors="tf").input_ids + labels = tokenizer("Hi I am", return_tensors="tf").input_ids + + loss = model(input_ids, labels=labels).loss + mtf_score = -tf.math.reduce_mean(loss).numpy() + + EXPECTED_SCORE = -14.757326 + self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4) + + @slow + def test_small_byswitchtransformers_integration_test(self): + """ + For comparision run: + >>> import switchtransformers # pip install switchtransformers==0.9.1 + + >>> path_to_byswitchtransformers_small_checkpoint = '' + >>> switchtransformers_model = switchtransformers.models.MtfModel(model_dir=path_to_tf_checkpoint, batch_size=1, tpu=None) + >>> vocab = switchtransformers.data.ByteVocabulary() + >>> score = switchtransformers_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) + """ + + model = TFSwitchTransformersForConditionalGeneration.from_pretrained("google/byybelkada/switchtransformers-base") + tokenizer = BySwitchTransformersTokenizer.from_pretrained("google/byybelkada/switchtransformers-base") + + input_ids = tokenizer("Hello there", return_tensors="tf").input_ids + labels = tokenizer("Hi I am", return_tensors="tf").input_ids + + loss = model(input_ids, labels=labels).loss + mtf_score = -tf.math.reduce_mean(loss).numpy() + + EXPECTED_SCORE = -7.592465 + self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4) + + @slow + def test_summarization(self): + model = self.model + tok = SwitchTransformersTokenizer.from_pretrained("switchtransformers-base") + + FRANCE_ARTICLE = ( # @noqa + "Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings" + " Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane." + ' Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation."' + ' He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s' + " comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video" + " showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French" + " Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a" + " phone at the wreckage site. The two publications described the supposed video, but did not post it on" + " their websites. The publications said that they watched the video, which was found by a source close to" + " the investigation. \"One can hear cries of 'My God' in several languages,\" Paris Match reported." + ' "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the' + " cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the" + ' screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt,' + " editor-in-chief of Bild online. An official with France's accident investigation agency, the BEA, said" + " the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie spokesman" + " in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the" + ' reports were "completely wrong" and "unwarranted." Cell phones have been collected at the site, he said,' + ' but that they "hadn\'t been exploited yet." Menichini said he believed the cell phones would need to be' + " sent to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by" + " specialized technicians working hand-in-hand with investigators. But none of the cell phones found so" + " far have been sent to the institute, Menichini said. Asked whether staff involved in the search could" + ' have leaked a memory card to the media, Menichini answered with a categorical "no." Reichelt told "Erin' + ' Burnett: Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match' + ' are "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered' + ' cell phones from the crash site after Bild and Paris Match published their reports. "That is something' + " we did not know before. ... Overall we can say many things of the investigation weren't revealed by the" + ' investigation at the beginning," he said. What was mental state of Germanwings co-pilot? German airline' + " Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled depression years before he took the" + " controls of Germanwings Flight 9525, which he's accused of deliberately crashing last week in the" + ' French Alps. Lubitz told his Lufthansa flight training school in 2009 that he had a "previous episode of' + ' severe depression," the airline said Tuesday. Email correspondence between Lubitz and the school' + " discovered in an internal investigation, Lufthansa said, included medical documents he submitted in" + " connection with resuming his flight training. The announcement indicates that Lufthansa, the parent" + " company of Germanwings, knew of Lubitz's battle with depression, allowed him to continue training and" + " ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously said Lubitz was 100%" + ' fit to fly, described its statement Tuesday as a "swift and seamless clarification" and said it was' + " sharing the information and documents -- including training and medical records -- with public" + " prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the" + " past week to recover human remains and plane debris scattered across a steep mountainside. He saw the" + " crisis center set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash" + " site, where grieving families have left flowers at a simple stone memorial. Menichini told CNN late" + " Tuesday that no visible human remains were left at the site but recovery teams would keep searching." + " French President Francois Hollande, speaking Tuesday, said that it should be possible to identify all" + " the victims using DNA analysis by the end of the week, sooner than authorities had previously suggested." + " In the meantime, the recovery of the victims' personal belongings will start Wednesday, Menichini said." + " Among those personal belongings could be more cell phones belonging to the 144 passengers and six crew" + " on board. Check out the latest from our correspondents . The details about Lubitz's correspondence with" + " the flight school during his training were among several developments as investigators continued to" + " delve into what caused the crash and Lubitz's possible motive for downing the jet. A Lufthansa" + " spokesperson told CNN on Tuesday that Lubitz had a valid medical certificate, had passed all his" + ' examinations and "held all the licenses required." Earlier, a spokesman for the prosecutor\'s office in' + " Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz suffered from suicidal tendencies at" + " some point before his aviation career and underwent psychotherapy before he got his pilot's license." + " Kumpa emphasized there's no evidence suggesting Lubitz was suicidal or acting aggressively before the" + " crash. Investigators are looking into whether Lubitz feared his medical condition would cause him to" + " lose his pilot's license, a European government official briefed on the investigation told CNN on" + ' Tuesday. While flying was "a big part of his life," the source said, it\'s only one theory being' + " considered. Another source, a law enforcement official briefed on the investigation, also told CNN that" + " authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would" + " not be allowed to fly because of his medical problems. Lubitz's girlfriend told investigators he had" + " seen an eye doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded" + " he had psychological issues, the European government official said. But no matter what details emerge" + " about his previous mental health struggles, there's more to the story, said Brian Russell, a forensic" + ' psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the fact' + " that maybe they weren't going to keep doing their job and they're upset about that and so they're" + ' suicidal," he said. "But there is no mental illness that explains why somebody then feels entitled to' + " also take that rage and turn it outward on 149 other people who had nothing to do with the person's" + ' problems." Germanwings crash compensation: What we know . Who was the captain of Germanwings Flight' + " 9525? CNN's Margot Haddad reported from Marseille and Pamela Brown from Dusseldorf, while Laura" + " Smith-Spark wrote from London. CNN's Frederik Pleitgen, Pamela Boykoff, Antonia Mortensen, Sandrine" + " Amiel and Anna-Maja Rappard contributed to this report." + ) + + SHORTER_ARTICLE = ( + "(CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on" + " Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The" + " formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based." + " The Palestinians signed the ICC's founding Rome Statute in January, when they also accepted its" + ' jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East' + ' Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the' + " situation in Palestinian territories, paving the way for possible war crimes investigations against" + " Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and" + " the United States, neither of which is an ICC member, opposed the Palestinians' efforts to join the" + " body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday's ceremony, said it was a" + ' move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the' + ' world is also a step closer to ending a long era of impunity and injustice," he said, according to an' + ' ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge' + " Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the" + ' Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine' + " acquires all the rights as well as responsibilities that come with being a State Party to the Statute." + ' These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights' + ' Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should' + " immediately end their pressure, and countries that support universal acceptance of the court's treaty" + ' should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the' + " group. \"What's objectionable is the attempts to undermine international justice, not Palestine's" + ' decision to join a treaty to which over 100 countries around the world are members." In January, when' + " the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an" + ' outrage, saying the court was overstepping its boundaries. The United States also said it "strongly"' + " disagreed with the court's decision. \"As we have said repeatedly, we do not believe that Palestine is a" + ' state and therefore we do not believe that it is eligible to join the ICC," the State Department said in' + ' a statement. It urged the warring sides to resolve their differences through direct negotiations. "We' + ' will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace,"' + " it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the" + ' territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the' + " court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou" + ' Bensouda said her office would "conduct its analysis in full independence and impartiality." The war' + " between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry" + " will include alleged war crimes committed since June. The International Criminal Court was set up in" + " 2002 to prosecute genocide, crimes against humanity and war crimes. CNN's Vasco Cotovio, Kareem Khadder" + " and Faith Karimi contributed to this report." + ) + + IRAN_ARTICLE = ( + "(CNN)The United States and its negotiating partners reached a very strong framework agreement with Iran" + " in Lausanne, Switzerland, on Thursday that limits Iran's nuclear program in such a way as to effectively" + " block it from building a nuclear weapon. Expect pushback anyway, if the recent past is any harbinger." + " Just last month, in an attempt to head off such an agreement, House Speaker John Boehner invited Israeli" + " Prime Minister Benjamin Netanyahu to preemptively blast it before Congress, and 47 senators sent a" + " letter to the Iranian leadership warning them away from a deal. The debate that has already begun since" + " the announcement of the new framework will likely result in more heat than light. It will not be helped" + " by the gathering swirl of dubious assumptions and doubtful assertions. Let us address some of these: ." + " The most misleading assertion, despite universal rejection by experts, is that the negotiations'" + " objective at the outset was the total elimination of any nuclear program in Iran. That is the position" + " of Netanyahu and his acolytes in the U.S. Congress. But that is not and never was the objective. If it" + " had been, there would have been no Iranian team at the negotiating table. Rather, the objective has" + " always been to structure an agreement or series of agreements so that Iran could not covertly develop a" + " nuclear arsenal before the United States and its allies could respond. The new framework has exceeded" + " expectations in achieving that goal. It would reduce Iran's low-enriched uranium stockpile, cut by" + " two-thirds its number of installed centrifuges and implement a rigorous inspection regime. Another" + " dubious assumption of opponents is that the Iranian nuclear program is a covert weapons program. Despite" + " sharp accusations by some in the United States and its allies, Iran denies having such a program, and" + " U.S. intelligence contends that Iran has not yet made the decision to build a nuclear weapon. Iran's" + " continued cooperation with International Atomic Energy Agency inspections is further evidence on this" + " point, and we'll know even more about Iran's program in the coming months and years because of the deal." + " In fact, the inspections provisions that are part of this agreement are designed to protect against any" + " covert action by the Iranians. What's more, the rhetoric of some members of Congress has implied that" + " the negotiations have been between only the United States and Iran (i.e., the 47 senators' letter" + " warning that a deal might be killed by Congress or a future president). This of course is not the case." + " The talks were between Iran and the five permanent members of the U.N. Security Council (United States," + " United Kingdom, France, China and Russia) plus Germany, dubbed the P5+1. While the United States has" + " played a leading role in the effort, it negotiated the terms alongside its partners. If the agreement" + " reached by the P5+1 is rejected by Congress, it could result in an unraveling of the sanctions on Iran" + " and threaten NATO cohesion in other areas. Another questionable assertion is that this agreement" + " contains a sunset clause, after which Iran will be free to do as it pleases. Again, this is not the" + " case. Some of the restrictions on Iran's nuclear activities, such as uranium enrichment, will be eased" + " or eliminated over time, as long as 15 years. But most importantly, the framework agreement includes" + " Iran's ratification of the Additional Protocol, which allows IAEA inspectors expanded access to nuclear" + " sites both declared and nondeclared. This provision will be permanent. It does not sunset. Thus, going" + " forward, if Iran decides to enrich uranium to weapons-grade levels, monitors will be able to detect such" + " a move in a matter of days and alert the U.N. Security Council. Many in Congress have said that the" + ' agreement should be a formal treaty requiring the Senate to "advise and consent." But the issue is not' + " suited for a treaty. Treaties impose equivalent obligations on all signatories. For example, the New" + " START treaty limits Russia and the United States to 1,550 deployed strategic warheads. But any agreement" + " with Iran will not be so balanced. The restrictions and obligations in the final framework agreement" + " will be imposed almost exclusively on Iran. The P5+1 are obligated only to ease and eventually remove" + " most but not all economic sanctions, which were imposed as leverage to gain this final deal. Finally" + " some insist that any agreement must address Iranian missile programs, human rights violations or support" + " for Hamas or Hezbollah. As important as these issues are, and they must indeed be addressed, they are" + " unrelated to the most important aim of a nuclear deal: preventing a nuclear Iran. To include them in" + " the negotiations would be a poison pill. This agreement should be judged on its merits and on how it" + " affects the security of our negotiating partners and allies, including Israel. Those judgments should be" + " fact-based, not based on questionable assertions or dubious assumptions." + ) + + ARTICLE_SUBWAY = ( + "New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A" + " year later, she got married again in Westchester County, but to a different man and without divorcing" + " her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos" + ' declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married' + " once more, this time in the Bronx. In an application for a marriage license, she stated it was her" + ' "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false' + ' instrument for filing in the first degree," referring to her false statements on the 2010 marriage' + " license application, according to court documents. Prosecutors said the marriages were part of an" + " immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to" + " her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was" + " arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New" + " York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total," + " Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All" + " occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be" + " married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors" + " said the immigration scam involved some of her husbands, who filed for permanent residence status" + " shortly after the marriages. Any divorces happened only after such filings were approved. It was" + " unclear whether any of the men will be prosecuted. The case was referred to the Bronx District" + " Attorney's Office by Immigration and Customs Enforcement and the Department of Homeland Security's" + ' Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt,' + " Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his" + " native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces" + " up to four years in prison. Her next court appearance is scheduled for May 18." + ) + + expected_summaries = [ + 'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a' + " cell phone video of the final seconds . \"one can hear cries of 'My God' in several languages,\" one" + " magazine says .", + "the formal accession was marked by a ceremony at The Hague, in the Netherlands . the ICC opened a" + " preliminary examination into the situation in the occupied Palestinian territory . as members of the" + " court, Palestinians may be subject to counter-charges as well .", + "the u.s. and its negotiating partners reached a very strong framework agreement with Iran . aaron miller:" + " the debate that has already begun since the announcement of the new framework will likely result in more" + " heat than light . the deal would reduce Iran's low-enriched uranium stockpile, cut centrifuges and" + " implement a rigorous inspection regime .", + "prosecutors say the marriages were part of an immigration scam . if convicted, barrientos faces two" + ' criminal counts of "offering a false instrument for filing in the first degree" she has been married 10' + " times, with nine of her marriages occurring between 1999 and 2002 .", + ] + + task_specific_config = getattr(model.config, "task_specific_params", {}) + summarization_config = task_specific_config.get("summarization", {}) + model.config.update(summarization_config) + + dct = tok( + [model.config.prefix + x for x in [FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY]], + max_length=512, + padding="max_length", + truncation=True, + return_tensors="tf", + ) + self.assertEqual(512, dct["input_ids"].shape[1]) + + hypotheses_batch = model.generate( + input_ids=dct["input_ids"], + attention_mask=dct["attention_mask"], + num_beams=4, + length_penalty=2.0, + max_length=142, + min_length=56, + no_repeat_ngram_size=3, + do_sample=False, + early_stopping=True, + ) + + decoded = [ + tok.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in hypotheses_batch + ] + + self.assertListEqual( + expected_summaries, + decoded, + ) + + @slow + def test_translation_en_to_de(self): + tok = SwitchTransformersTokenizer.from_pretrained("switchtransformers-base") + model = self.model + + task_specific_config = getattr(model.config, "task_specific_params", {}) + translation_config = task_specific_config.get("translation_en_to_de", {}) + self.model.config.update(translation_config) + + original_input = '"Luigi often said to me that he never wanted the brothers to end up in court", she wrote.' + expected_translation = ( + '"Luigi sagte mir oft, dass er nie wollte, dass die Brüder am Gericht sitzen", schrieb sie.' + ) + + input_ids = tok.encode(model.config.prefix + original_input, return_tensors="tf") + + output = model.generate( + input_ids=input_ids, + num_beams=4, + length_penalty=2.0, + max_length=50, + no_repeat_ngram_size=3, + do_sample=False, + early_stopping=True, + ) + translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) + + self.assertEqual(translation, expected_translation) + + @slow + def test_translation_en_to_fr(self): + model = self.model + tok = SwitchTransformersTokenizer.from_pretrained("switchtransformers-base") + + task_specific_config = getattr(model.config, "task_specific_params", {}) + translation_config = task_specific_config.get("translation_en_to_fr", {}) + model.config.update(translation_config) + + en_text = ( + ' This image section from an infrared recording by the Spitzer telescope shows a "family portrait" of' + " countless generations of stars: the oldest stars are seen as blue dots. " + ) + + new_truncated_translation = ( + "Cette section d'images provenant de l'enregistrement infrarouge effectué par le télescope Spitzer montre " + "un " + "« portrait familial » de générations innombrables d’étoiles : les plus anciennes sont observées " + "sous forme " + "de points bleus." + ) + + input_ids = tok(model.config.prefix + en_text, return_tensors="tf").input_ids + + output = model.generate( + input_ids=input_ids, + num_beams=4, + length_penalty=2.0, + max_length=100, + no_repeat_ngram_size=3, + do_sample=False, + early_stopping=True, + ) + translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) + + self.assertEqual(translation, new_truncated_translation) + + @slow + def test_translation_en_to_ro(self): + model = self.model + tok = SwitchTransformersTokenizer.from_pretrained("switchtransformers-base") + + task_specific_config = getattr(model.config, "task_specific_params", {}) + translation_config = task_specific_config.get("translation_en_to_ro", {}) + model.config.update(translation_config) + + original_input = "Taco Bell said it plans to add 2,000 locations in the US by 2022." + expected_translation = "Taco Bell a declarat că intenţionează să adauge 2 000 de locaţii în SUA până în 2022." + + input_ids = tok.encode(model.config.prefix + original_input, return_tensors="tf") + + output = model.generate( + input_ids=input_ids, + num_beams=4, + length_penalty=2.0, + max_length=50, + no_repeat_ngram_size=3, + do_sample=False, + early_stopping=True, + ) + translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) + + self.assertEqual(translation, expected_translation) diff --git a/tests/models/switchtransformers/test_tokenization_switchtransformers.py b/tests/models/switchtransformers/test_tokenization_switchtransformers.py new file mode 100644 index 0000000000000..8ed4c6f80d78a --- /dev/null +++ b/tests/models/switchtransformers/test_tokenization_switchtransformers.py @@ -0,0 +1,381 @@ +# coding=utf-8 +# Copyright 2022 Google SwitchTransformers Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import os +import tempfile +import unittest + +from transformers import SPIECE_UNDERLINE, AddedToken, BatchEncoding, SwitchTransformersTokenizer, SwitchTransformersTokenizerFast +from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers, slow +from transformers.utils import cached_property, is_tf_available, is_torch_available + +from ...test_tokenization_common import TokenizerTesterMixin + + +SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model") + +if is_torch_available(): + FRAMEWORK = "pt" +elif is_tf_available(): + FRAMEWORK = "tf" +else: + FRAMEWORK = "jax" + + +@require_sentencepiece +@require_tokenizers +class SwitchTransformersTokenizationTest(TokenizerTesterMixin, unittest.TestCase): + + tokenizer_class = SwitchTransformersTokenizer + rust_tokenizer_class = SwitchTransformersTokenizerFast + test_rust_tokenizer = True + test_sentencepiece = True + + def setUp(self): + super().setUp() + + # We have a SentencePiece fixture for testing + tokenizer = SwitchTransformersTokenizer(SAMPLE_VOCAB) + tokenizer.save_pretrained(self.tmpdirname) + + def test_convert_token_and_id(self): + """Test ``_convert_token_to_id`` and ``_convert_id_to_token``.""" + token = "" + token_id = 1 + + self.assertEqual(self.get_tokenizer()._convert_token_to_id(token), token_id) + self.assertEqual(self.get_tokenizer()._convert_id_to_token(token_id), token) + + def test_get_vocab(self): + vocab_keys = list(self.get_tokenizer().get_vocab().keys()) + + self.assertEqual(vocab_keys[0], "") + self.assertEqual(vocab_keys[1], "") + self.assertEqual(vocab_keys[-1], "") + self.assertEqual(len(vocab_keys), 1_101) + + def test_vocab_size(self): + self.assertEqual(self.get_tokenizer().vocab_size, 1_100) + + def test_full_tokenizer(self): + tokenizer = SwitchTransformersTokenizer(SAMPLE_VOCAB) + + tokens = tokenizer.tokenize("This is a test") + self.assertListEqual(tokens, ["▁This", "▁is", "▁a", "▁t", "est"]) + + self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382]) + + tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.") + self.assertListEqual( + tokens, + [ + SPIECE_UNDERLINE + "I", + SPIECE_UNDERLINE + "was", + SPIECE_UNDERLINE + "b", + "or", + "n", + SPIECE_UNDERLINE + "in", + SPIECE_UNDERLINE + "", + "9", + "2", + "0", + "0", + "0", + ",", + SPIECE_UNDERLINE + "and", + SPIECE_UNDERLINE + "this", + SPIECE_UNDERLINE + "is", + SPIECE_UNDERLINE + "f", + "al", + "s", + "é", + ".", + ], + ) + ids = tokenizer.convert_tokens_to_ids(tokens) + self.assertListEqual(ids, [8, 21, 84, 55, 24, 19, 7, 0, 602, 347, 347, 347, 3, 12, 66, 46, 72, 80, 6, 0, 4]) + + back_tokens = tokenizer.convert_ids_to_tokens(ids) + self.assertListEqual( + back_tokens, + [ + SPIECE_UNDERLINE + "I", + SPIECE_UNDERLINE + "was", + SPIECE_UNDERLINE + "b", + "or", + "n", + SPIECE_UNDERLINE + "in", + SPIECE_UNDERLINE + "", + "", + "2", + "0", + "0", + "0", + ",", + SPIECE_UNDERLINE + "and", + SPIECE_UNDERLINE + "this", + SPIECE_UNDERLINE + "is", + SPIECE_UNDERLINE + "f", + "al", + "s", + "", + ".", + ], + ) + + @cached_property + def switchtransformers_base_tokenizer(self): + return SwitchTransformersTokenizer.from_pretrained("switchtransformers-base") + + @cached_property + def switchtransformers_base_tokenizer_fast(self): + return SwitchTransformersTokenizerFast.from_pretrained("switchtransformers-base") + + def get_tokenizer(self, **kwargs) -> SwitchTransformersTokenizer: + return self.tokenizer_class.from_pretrained(self.tmpdirname, pad_token=None, **kwargs) + + def get_rust_tokenizer(self, **kwargs) -> SwitchTransformersTokenizerFast: + return self.rust_tokenizer_class.from_pretrained(self.tmpdirname, pad_token=None, **kwargs) + + def test_rust_and_python_full_tokenizers(self): + if not self.test_rust_tokenizer: + return + + tokenizer = self.get_tokenizer() + rust_tokenizer = self.get_rust_tokenizer() + + sequence = "I was born in 92000, and this is falsé." + + tokens = tokenizer.tokenize(sequence) + rust_tokens = rust_tokenizer.tokenize(sequence) + self.assertListEqual(tokens, rust_tokens) + + ids = tokenizer.encode(sequence, add_special_tokens=False) + rust_ids = rust_tokenizer.encode(sequence, add_special_tokens=False) + self.assertListEqual(ids, rust_ids) + + rust_tokenizer = self.get_rust_tokenizer() + ids = tokenizer.encode(sequence) + rust_ids = rust_tokenizer.encode(sequence) + self.assertListEqual(ids, rust_ids) + + def test_eos_treatment(self): + tokenizer = self.switchtransformers_base_tokenizer + batch_with_eos_added = tokenizer(["hi", "I went to the gym", ""]) + batch_without_eos_added = tokenizer(["hi", "I went to the gym", ""]) + self.assertListEqual(batch_with_eos_added["input_ids"], batch_without_eos_added["input_ids"]) + + def test_prepare_batch(self): + tokenizer = self.switchtransformers_base_tokenizer + src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."] + expected_src_tokens = [71, 307, 8986, 21, 4505, 1635, 1707, 5, tokenizer.eos_token_id] + batch = tokenizer(src_text, padding=True, return_tensors=FRAMEWORK) + self.assertIsInstance(batch, BatchEncoding) + + if FRAMEWORK != "jax": + result = list(batch.input_ids.numpy()[0]) + else: + result = list(batch.input_ids.tolist()[0]) + + self.assertListEqual(expected_src_tokens, result) + + self.assertEqual((2, 9), batch.input_ids.shape) + self.assertEqual((2, 9), batch.attention_mask.shape) + + def test_empty_target_text(self): + tokenizer = self.switchtransformers_base_tokenizer + src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."] + batch = tokenizer(src_text, padding=True, return_tensors=FRAMEWORK) + # check if input_ids are returned and no decoder_input_ids + self.assertIn("input_ids", batch) + self.assertIn("attention_mask", batch) + self.assertNotIn("decoder_input_ids", batch) + self.assertNotIn("decoder_attention_mask", batch) + + def test_max_length(self): + tokenizer = self.switchtransformers_base_tokenizer + tgt_text = [ + "Summary of the text.", + "Another summary.", + ] + targets = tokenizer( + text_target=tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors=FRAMEWORK + ) + self.assertEqual(32, targets["input_ids"].shape[1]) + + def test_outputs_not_longer_than_maxlen(self): + tokenizer = self.switchtransformers_base_tokenizer + + batch = tokenizer( + ["I am a small frog" * 1000, "I am a small frog"], padding=True, truncation=True, return_tensors=FRAMEWORK + ) + self.assertIsInstance(batch, BatchEncoding) + # Since SwitchTransformers does NOT have a max input length, + # this test should be changed to the following in Transformers v5: + # self.assertEqual(batch.input_ids.shape, (2, 8001)) + self.assertEqual(batch.input_ids.shape, (2, 512)) + + def test_eos_in_input(self): + tokenizer = self.switchtransformers_base_tokenizer + src_text = ["A long paragraph for summarization. "] + tgt_text = ["Summary of the text. "] + expected_src_tokens = [71, 307, 8986, 21, 4505, 1635, 1707, 5, 1] + expected_tgt_tokens = [20698, 13, 8, 1499, 5, 1] + + batch = tokenizer(src_text, text_target=tgt_text) + + self.assertEqual(expected_src_tokens, batch["input_ids"][0]) + self.assertEqual(expected_tgt_tokens, batch["labels"][0]) + + def test_token_type_ids(self): + src_text_1 = ["A first paragraph for summarization."] + src_text_2 = ["A second paragraph for summarization."] + + fast_token_type_ids = self.switchtransformers_base_tokenizer_fast( + src_text_1, src_text_2, add_special_tokens=True, return_token_type_ids=True + ).token_type_ids + slow_token_type_ids = self.switchtransformers_base_tokenizer( + src_text_1, src_text_2, add_special_tokens=True, return_token_type_ids=True + ).token_type_ids + + self.assertEqual(slow_token_type_ids, fast_token_type_ids) + self.assertEqual(len(slow_token_type_ids[0]), 18) + + def test_fast_and_slow_same_result(self): + src_text = " Today is nice day " + tgt_ids = [0, 1960, 19, 2, 1245, 239, 1] + tgt_text = " Today is nice day" + + fast_ids = self.switchtransformers_base_tokenizer_fast(src_text, add_special_tokens=False).input_ids + slow_ids = self.switchtransformers_base_tokenizer(src_text, add_special_tokens=False).input_ids + self.assertEqual(tgt_ids, fast_ids) + self.assertEqual(tgt_ids, slow_ids) + + fast_text = self.switchtransformers_base_tokenizer_fast.decode(fast_ids) + slow_text = self.switchtransformers_base_tokenizer.decode(fast_ids) + self.assertEqual(tgt_text, fast_text) + self.assertEqual(tgt_text, slow_text) + + def test_special_tokens_initialization(self): + for tokenizer, pretrained_name, kwargs in self.tokenizers_list: + with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): + + added_tokens = [f"" for i in range(100)] + [AddedToken("", lstrip=True)] + + tokenizer_r = self.rust_tokenizer_class.from_pretrained( + pretrained_name, additional_special_tokens=added_tokens, **kwargs + ) + tokenizer_cr = self.rust_tokenizer_class.from_pretrained( + pretrained_name, additional_special_tokens=added_tokens, **kwargs, from_slow=True + ) + tokenizer_p = self.tokenizer_class.from_pretrained( + pretrained_name, additional_special_tokens=added_tokens, **kwargs + ) + + p_output = tokenizer_p.encode("Hey this is a token") + r_output = tokenizer_r.encode("Hey this is a token") + cr_output = tokenizer_cr.encode("Hey this is a token") + + special_token_id = tokenizer_r.encode("", add_special_tokens=False)[0] + + self.assertEqual(p_output, r_output) + self.assertEqual(cr_output, r_output) + self.assertTrue(special_token_id in p_output) + self.assertTrue(special_token_id in r_output) + self.assertTrue(special_token_id in cr_output) + + def test_special_tokens_initialization_with_non_empty_additional_special_tokens(self): + tokenizer_list = [] + if self.test_slow_tokenizer: + tokenizer_list.append((self.tokenizer_class, self.get_tokenizer())) + + if self.test_rust_tokenizer: + tokenizer_list.append((self.rust_tokenizer_class, self.get_rust_tokenizer())) + + for tokenizer_class, tokenizer_utils in tokenizer_list: + + with tempfile.TemporaryDirectory() as tmp_dir: + tokenizer_utils.save_pretrained(tmp_dir) + + with open(os.path.join(tmp_dir, "special_tokens_map.json"), encoding="utf-8") as json_file: + special_tokens_map = json.load(json_file) + + with open(os.path.join(tmp_dir, "tokenizer_config.json"), encoding="utf-8") as json_file: + tokenizer_config = json.load(json_file) + + added_tokens_extra_ids = [f"" for i in range(100)] + + special_tokens_map["additional_special_tokens"] = added_tokens_extra_ids + [ + "an_additional_special_token" + ] + tokenizer_config["additional_special_tokens"] = added_tokens_extra_ids + [ + "an_additional_special_token" + ] + + with open(os.path.join(tmp_dir, "special_tokens_map.json"), "w", encoding="utf-8") as outfile: + json.dump(special_tokens_map, outfile) + with open(os.path.join(tmp_dir, "tokenizer_config.json"), "w", encoding="utf-8") as outfile: + json.dump(tokenizer_config, outfile) + + # the following checks allow us to verify that our test works as expected, i.e. that the tokenizer takes + # into account the new value of additional_special_tokens given in the "tokenizer_config.json" and + # "special_tokens_map.json" files + tokenizer_without_change_in_init = tokenizer_class.from_pretrained( + tmp_dir, + ) + self.assertIn( + "an_additional_special_token", tokenizer_without_change_in_init.additional_special_tokens + ) + # self.assertIn("an_additional_special_token",tokenizer_without_change_in_init.get_vocab()) # BySwitchTransformersTokenization no vocab + self.assertEqual( + ["an_additional_special_token"], + tokenizer_without_change_in_init.convert_ids_to_tokens( + tokenizer_without_change_in_init.convert_tokens_to_ids(["an_additional_special_token"]) + ), + ) + + # Now we test that we can change the value of additional_special_tokens in the from_pretrained + new_added_tokens = added_tokens_extra_ids + [AddedToken("a_new_additional_special_token", lstrip=True)] + tokenizer = tokenizer_class.from_pretrained( + tmp_dir, + additional_special_tokens=new_added_tokens, + ) + + self.assertIn("a_new_additional_special_token", tokenizer.additional_special_tokens) + self.assertEqual( + ["a_new_additional_special_token"], + tokenizer.convert_ids_to_tokens( + tokenizer.convert_tokens_to_ids(["a_new_additional_special_token"]) + ), + ) + + # overwritten from `test_tokenization_common` since SwitchTransformers has no max length + def test_pretrained_model_lists(self): + # We should have at least one default checkpoint for each tokenizer + # We should specify the max input length as well (used in some part to list the pretrained checkpoints) + self.assertGreaterEqual(len(self.tokenizer_class.pretrained_vocab_files_map), 1) + self.assertGreaterEqual(len(list(self.tokenizer_class.pretrained_vocab_files_map.values())[0]), 1) + + @slow + def test_tokenizer_integration(self): + # fmt: off + expected_encoding = {'input_ids': [[31220, 7, 41, 14034, 801, 38, 3, 102, 63, 17, 127, 524, 18, 7031, 2032, 277, 11, 3, 102, 63, 17, 127, 524, 18, 2026, 17, 10761, 18, 7041, 61, 795, 879, 18, 19681, 4648, 7, 41, 12920, 382, 6, 350, 6383, 4949, 6, 2158, 12920, 382, 9, 6, 3, 4, 11160, 6, 2043, 17153, 279, 49, 17, 6, 3, 4, 434, 9688, 11439, 21, 6869, 10509, 17725, 41, 567, 9138, 61, 11, 6869, 10509, 11946, 41, 18207, 517, 61, 28, 147, 3538, 1220, 7140, 10761, 2250, 16, 910, 1220, 8024, 11, 1659, 1413, 32, 883, 2020, 344, 2215, 226, 6, 12901, 382, 127, 524, 11, 4738, 7, 127, 15390, 5, 1], [272, 24203, 19, 876, 12, 554, 18, 9719, 1659, 2647, 26352, 6497, 7, 45, 73, 9339, 400, 26, 1499, 57, 22801, 10760, 30, 321, 646, 11, 269, 2625, 16, 66, 7500, 5, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [37, 1704, 4216, 3, 20400, 4418, 7, 147, 8, 19743, 1782, 5, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]} # noqa: E501 + # fmt: on + + self.tokenizer_integration_test_util( + expected_encoding=expected_encoding, + model_name="switchtransformers-base", + revision="5a7ff2d8f5117c194c7e32ec1ccbf04642cca99b", + ) From 59c6512eee5ad9a3850391a0a7c52cf64458f254 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 4 Oct 2022 16:04:07 +0100 Subject: [PATCH 002/102] add more comments --- .../models/switchtransformers/modeling_switchtransformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/switchtransformers/modeling_switchtransformers.py b/src/transformers/models/switchtransformers/modeling_switchtransformers.py index f64b1ea60d37d..c2324c0b78774 100644 --- a/src/transformers/models/switchtransformers/modeling_switchtransformers.py +++ b/src/transformers/models/switchtransformers/modeling_switchtransformers.py @@ -275,7 +275,7 @@ def forward(self, hidden_states): pass ALL_LAYERNORM_LAYERS.append(SwitchTransformersLayerNorm) - +# TODO: this has to be changed with the experts # Copied from transformers.models.t5.modeling_t5.T5DenseActDense with T5->SwitchTransformers class SwitchTransformersDenseActDense(nn.Module): From 09068706a7c611ca26eb18cb8ca877be50d7a497 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 4 Oct 2022 17:46:35 +0100 Subject: [PATCH 003/102] add router v1 --- .../models/switchtransformers/__init__.py | 14 +- .../configuration_switchtransformers.py | 5 +- ...rmers_original_tf_checkpoint_to_pytorch.py | 9 +- ..._switchtransformersx_checkpoint_to_flax.py | 122 ++- .../modeling_flax_switchtransformers.py | 26 +- .../modeling_switchtransformers.py | 17 +- .../modeling_tf_switchtransformers.py | 7 +- .../models/switchtransformers/router.py | 215 +++++ .../models/switchtransformers/router_flax.py | 759 ++++++++++++++++++ .../tokenization_switchtransformers.py | 16 +- .../tokenization_switchtransformers_fast.py | 20 +- .../test_modeling_flax_switchtransformers.py | 15 +- .../test_modeling_switchtransformers.py | 485 ++--------- .../test_modeling_tf_switchtransformers.py | 20 +- .../test_tokenization_switchtransformers.py | 8 +- 15 files changed, 1249 insertions(+), 489 deletions(-) create mode 100644 src/transformers/models/switchtransformers/router.py create mode 100644 src/transformers/models/switchtransformers/router_flax.py diff --git a/src/transformers/models/switchtransformers/__init__.py b/src/transformers/models/switchtransformers/__init__.py index 0f656af9dfdfb..615827cb82a32 100644 --- a/src/transformers/models/switchtransformers/__init__.py +++ b/src/transformers/models/switchtransformers/__init__.py @@ -29,7 +29,13 @@ ) -_import_structure = {"configuration_switchtransformers": ["SWITCHTRANSFORMERS_PRETRAINED_CONFIG_ARCHIVE_MAP", "SwitchTransformersConfig", "SwitchTransformersOnnxConfig"]} +_import_structure = { + "configuration_switchtransformers": [ + "SWITCHTRANSFORMERS_PRETRAINED_CONFIG_ARCHIVE_MAP", + "SwitchTransformersConfig", + "SwitchTransformersOnnxConfig", + ] +} try: if not is_sentencepiece_available(): @@ -91,7 +97,11 @@ if TYPE_CHECKING: - from .configuration_switchtransformers import SWITCHTRANSFORMERS_PRETRAINED_CONFIG_ARCHIVE_MAP, SwitchTransformersConfig, SwitchTransformersOnnxConfig + from .configuration_switchtransformers import ( + SWITCHTRANSFORMERS_PRETRAINED_CONFIG_ARCHIVE_MAP, + SwitchTransformersConfig, + SwitchTransformersOnnxConfig, + ) try: if not is_sentencepiece_available(): diff --git a/src/transformers/models/switchtransformers/configuration_switchtransformers.py b/src/transformers/models/switchtransformers/configuration_switchtransformers.py index 0056bc7758a96..2acd5a177bfd5 100644 --- a/src/transformers/models/switchtransformers/configuration_switchtransformers.py +++ b/src/transformers/models/switchtransformers/configuration_switchtransformers.py @@ -23,11 +23,12 @@ logger = logging.get_logger(__name__) SWITCHTRANSFORMERS_PRETRAINED_CONFIG_ARCHIVE_MAP = { - "ybelkada/switchtransformers-base": "https://huggingface.co/ybelkada/switchtransformers-base/resolve/main/config.json", + "ybelkada/switchtransformers-base": ( + "https://huggingface.co/ybelkada/switchtransformers-base/resolve/main/config.json" + ), } - class SwitchTransformersConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`SwitchTransformersModel`] or a diff --git a/src/transformers/models/switchtransformers/convert_switchtransformers_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/switchtransformers/convert_switchtransformers_original_tf_checkpoint_to_pytorch.py index fbb0edd2ce29e..47081a2a0c578 100644 --- a/src/transformers/models/switchtransformers/convert_switchtransformers_original_tf_checkpoint_to_pytorch.py +++ b/src/transformers/models/switchtransformers/convert_switchtransformers_original_tf_checkpoint_to_pytorch.py @@ -17,7 +17,11 @@ import argparse -from transformers import SwitchTransformersConfig, SwitchTransformersForConditionalGeneration, load_tf_weights_in_switchtransformers +from transformers import ( + SwitchTransformersConfig, + SwitchTransformersForConditionalGeneration, + load_tf_weights_in_switchtransformers, +) from transformers.utils import logging @@ -50,7 +54,8 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du type=str, required=True, help=( - "The config json file corresponding to the pre-trained SwitchTransformers model. \nThis specifies the model architecture." + "The config json file corresponding to the pre-trained SwitchTransformers model. \nThis specifies the" + " model architecture." ), ) parser.add_argument( diff --git a/src/transformers/models/switchtransformers/convert_switchtransformersx_checkpoint_to_flax.py b/src/transformers/models/switchtransformers/convert_switchtransformersx_checkpoint_to_flax.py index 90d923d623025..0413f20b476af 100644 --- a/src/transformers/models/switchtransformers/convert_switchtransformersx_checkpoint_to_flax.py +++ b/src/transformers/models/switchtransformers/convert_switchtransformersx_checkpoint_to_flax.py @@ -21,7 +21,9 @@ from transformers import FlaxSwitchTransformersForConditionalGeneration, SwitchTransformersConfig -def convert_switchtransformersx_checkpoint_to_flax(switchtransformersx_checkpoint_path, config_name, flax_dump_folder_path): +def convert_switchtransformersx_checkpoint_to_flax( + switchtransformersx_checkpoint_path, config_name, flax_dump_folder_path +): config = SwitchTransformersConfig.from_pretrained(config_name) flax_model = FlaxSwitchTransformersForConditionalGeneration(config=config) switchtransformersx_model = checkpoints.load_switchtransformersx_checkpoint(switchtransformersx_checkpoint_path) @@ -33,24 +35,42 @@ def convert_switchtransformersx_checkpoint_to_flax(switchtransformersx_checkpoin layer_name = f"layers_{str(layer_index)}" # Self-Attention - switchtransformersx_attention_key = switchtransformersx_model["target"]["encoder"][layer_name]["attention"]["key"]["kernel"] - switchtransformersx_attention_out = switchtransformersx_model["target"]["encoder"][layer_name]["attention"]["out"]["kernel"] - switchtransformersx_attention_query = switchtransformersx_model["target"]["encoder"][layer_name]["attention"]["query"]["kernel"] - switchtransformersx_attention_value = switchtransformersx_model["target"]["encoder"][layer_name]["attention"]["value"]["kernel"] + switchtransformersx_attention_key = switchtransformersx_model["target"]["encoder"][layer_name]["attention"][ + "key" + ]["kernel"] + switchtransformersx_attention_out = switchtransformersx_model["target"]["encoder"][layer_name]["attention"][ + "out" + ]["kernel"] + switchtransformersx_attention_query = switchtransformersx_model["target"]["encoder"][layer_name]["attention"][ + "query" + ]["kernel"] + switchtransformersx_attention_value = switchtransformersx_model["target"]["encoder"][layer_name]["attention"][ + "value" + ]["kernel"] # Layer Normalization - switchtransformersx_attention_layer_norm = switchtransformersx_model["target"]["encoder"][layer_name]["pre_attention_layer_norm"]["scale"] + switchtransformersx_attention_layer_norm = switchtransformersx_model["target"]["encoder"][layer_name][ + "pre_attention_layer_norm" + ]["scale"] if split_mlp_wi: - switchtransformersx_mlp_wi_0 = switchtransformersx_model["target"]["encoder"][layer_name]["mlp"]["wi_0"]["kernel"] - switchtransformersx_mlp_wi_1 = switchtransformersx_model["target"]["encoder"][layer_name]["mlp"]["wi_1"]["kernel"] + switchtransformersx_mlp_wi_0 = switchtransformersx_model["target"]["encoder"][layer_name]["mlp"]["wi_0"][ + "kernel" + ] + switchtransformersx_mlp_wi_1 = switchtransformersx_model["target"]["encoder"][layer_name]["mlp"]["wi_1"][ + "kernel" + ] else: - switchtransformersx_mlp_wi = switchtransformersx_model["target"]["encoder"][layer_name]["mlp"]["wi"]["kernel"] + switchtransformersx_mlp_wi = switchtransformersx_model["target"]["encoder"][layer_name]["mlp"]["wi"][ + "kernel" + ] switchtransformersx_mlp_wo = switchtransformersx_model["target"]["encoder"][layer_name]["mlp"]["wo"]["kernel"] # Layer Normalization - switchtransformersx_mlp_layer_norm = switchtransformersx_model["target"]["encoder"][layer_name]["pre_mlp_layer_norm"]["scale"] + switchtransformersx_mlp_layer_norm = switchtransformersx_model["target"]["encoder"][layer_name][ + "pre_mlp_layer_norm" + ]["scale"] # Assigning flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["k"][ @@ -90,7 +110,9 @@ def convert_switchtransformersx_checkpoint_to_flax(switchtransformersx_checkpoin ] = switchtransformersx_mlp_layer_norm # Only for layer 0: - switchtransformersx_encoder_rel_embedding = switchtransformersx_model["target"]["encoder"]["relpos_bias"]["rel_embedding"].T + switchtransformersx_encoder_rel_embedding = switchtransformersx_model["target"]["encoder"]["relpos_bias"][ + "rel_embedding" + ].T flax_model.params["encoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"][ "embedding" ] = switchtransformersx_encoder_rel_embedding @@ -104,39 +126,55 @@ def convert_switchtransformersx_checkpoint_to_flax(switchtransformersx_checkpoin layer_name = f"layers_{str(layer_index)}" # Self-Attention - switchtransformersx_attention_key = switchtransformersx_model["target"]["decoder"][layer_name]["self_attention"]["key"]["kernel"] - switchtransformersx_attention_out = switchtransformersx_model["target"]["decoder"][layer_name]["self_attention"]["out"]["kernel"] - switchtransformersx_attention_query = switchtransformersx_model["target"]["decoder"][layer_name]["self_attention"]["query"]["kernel"] - switchtransformersx_attention_value = switchtransformersx_model["target"]["decoder"][layer_name]["self_attention"]["value"]["kernel"] + switchtransformersx_attention_key = switchtransformersx_model["target"]["decoder"][layer_name][ + "self_attention" + ]["key"]["kernel"] + switchtransformersx_attention_out = switchtransformersx_model["target"]["decoder"][layer_name][ + "self_attention" + ]["out"]["kernel"] + switchtransformersx_attention_query = switchtransformersx_model["target"]["decoder"][layer_name][ + "self_attention" + ]["query"]["kernel"] + switchtransformersx_attention_value = switchtransformersx_model["target"]["decoder"][layer_name][ + "self_attention" + ]["value"]["kernel"] # Layer Normalization - switchtransformersx_pre_attention_layer_norm = switchtransformersx_model["target"]["decoder"][layer_name]["pre_self_attention_layer_norm"][ - "scale" - ] + switchtransformersx_pre_attention_layer_norm = switchtransformersx_model["target"]["decoder"][layer_name][ + "pre_self_attention_layer_norm" + ]["scale"] # Encoder-Decoder-Attention - switchtransformersx_enc_dec_attention_key = switchtransformersx_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["key"][ - "kernel" - ] - switchtransformersx_enc_dec_attention_out = switchtransformersx_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["out"][ - "kernel" - ] - switchtransformersx_enc_dec_attention_query = switchtransformersx_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["query"][ - "kernel" - ] - switchtransformersx_enc_dec_attention_value = switchtransformersx_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["value"][ - "kernel" - ] + switchtransformersx_enc_dec_attention_key = switchtransformersx_model["target"]["decoder"][layer_name][ + "encoder_decoder_attention" + ]["key"]["kernel"] + switchtransformersx_enc_dec_attention_out = switchtransformersx_model["target"]["decoder"][layer_name][ + "encoder_decoder_attention" + ]["out"]["kernel"] + switchtransformersx_enc_dec_attention_query = switchtransformersx_model["target"]["decoder"][layer_name][ + "encoder_decoder_attention" + ]["query"]["kernel"] + switchtransformersx_enc_dec_attention_value = switchtransformersx_model["target"]["decoder"][layer_name][ + "encoder_decoder_attention" + ]["value"]["kernel"] # Layer Normalization - switchtransformersx_cross_layer_norm = switchtransformersx_model["target"]["decoder"][layer_name]["pre_cross_attention_layer_norm"]["scale"] + switchtransformersx_cross_layer_norm = switchtransformersx_model["target"]["decoder"][layer_name][ + "pre_cross_attention_layer_norm" + ]["scale"] # MLP if split_mlp_wi: - switchtransformersx_mlp_wi_0 = switchtransformersx_model["target"]["decoder"][layer_name]["mlp"]["wi_0"]["kernel"] - switchtransformersx_mlp_wi_1 = switchtransformersx_model["target"]["decoder"][layer_name]["mlp"]["wi_1"]["kernel"] + switchtransformersx_mlp_wi_0 = switchtransformersx_model["target"]["decoder"][layer_name]["mlp"]["wi_0"][ + "kernel" + ] + switchtransformersx_mlp_wi_1 = switchtransformersx_model["target"]["decoder"][layer_name]["mlp"]["wi_1"][ + "kernel" + ] else: - switchtransformersx_mlp_wi = switchtransformersx_model["target"]["decoder"][layer_name]["mlp"]["wi"]["kernel"] + switchtransformersx_mlp_wi = switchtransformersx_model["target"]["decoder"][layer_name]["mlp"]["wi"][ + "kernel" + ] switchtransformersx_mlp_wo = switchtransformersx_model["target"]["decoder"][layer_name]["mlp"]["wo"]["kernel"] @@ -203,7 +241,9 @@ def convert_switchtransformersx_checkpoint_to_flax(switchtransformersx_checkpoin flax_model.params["decoder"]["final_layer_norm"]["weight"] = tx5_decoder_norm # Only for layer 0: - switchtransformersx_decoder_rel_embedding = switchtransformersx_model["target"]["decoder"]["relpos_bias"]["rel_embedding"].T + switchtransformersx_decoder_rel_embedding = switchtransformersx_model["target"]["decoder"]["relpos_bias"][ + "rel_embedding" + ].T flax_model.params["decoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"][ "embedding" ] = switchtransformersx_decoder_rel_embedding @@ -214,7 +254,9 @@ def convert_switchtransformersx_checkpoint_to_flax(switchtransformersx_checkpoin # LM Head (only in v1.1 checkpoints) if "logits_dense" in switchtransformersx_model["target"]["decoder"]: - flax_model.params["lm_head"]["kernel"] = switchtransformersx_model["target"]["decoder"]["logits_dense"]["kernel"] + flax_model.params["lm_head"]["kernel"] = switchtransformersx_model["target"]["decoder"]["logits_dense"][ + "kernel" + ] flax_model.save_pretrained(flax_dump_folder_path) print("SwitchTransformersX Model was sucessfully converted!") @@ -226,9 +268,13 @@ def convert_switchtransformersx_checkpoint_to_flax(switchtransformersx_checkpoin parser.add_argument( "--switchtransformersx_checkpoint_path", default=None, type=str, required=True, help="Path the TX5 checkpoint." ) - parser.add_argument("--config_name", default=None, type=str, required=True, help="Config name of SwitchTransformers model.") + parser.add_argument( + "--config_name", default=None, type=str, required=True, help="Config name of SwitchTransformers model." + ) parser.add_argument( "--flax_dump_folder_path", default=None, type=str, required=True, help="Path to the output FLAX model." ) args = parser.parse_args() - convert_switchtransformersx_checkpoint_to_flax(args.switchtransformersx_checkpoint_path, args.config_name, args.flax_dump_folder_path) + convert_switchtransformersx_checkpoint_to_flax( + args.switchtransformersx_checkpoint_path, args.config_name, args.flax_dump_folder_path + ) diff --git a/src/transformers/models/switchtransformers/modeling_flax_switchtransformers.py b/src/transformers/models/switchtransformers/modeling_flax_switchtransformers.py index de48b929ad5ac..c9e3442fd687f 100644 --- a/src/transformers/models/switchtransformers/modeling_flax_switchtransformers.py +++ b/src/transformers/models/switchtransformers/modeling_flax_switchtransformers.py @@ -1151,7 +1151,9 @@ def _encoder_forward(module, input_ids, attention_mask, **kwargs): ) @add_start_docstrings(SWITCHTRANSFORMERS_DECODE_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=SwitchTransformersConfig) + @replace_return_docstrings( + output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=SwitchTransformersConfig + ) def decode( self, decoder_input_ids, @@ -1423,12 +1425,17 @@ class FlaxSwitchTransformersModel(FlaxSwitchTransformersPreTrainedModel): """ -overwrite_call_docstring(FlaxSwitchTransformersModel, SWITCHTRANSFORMERS_INPUTS_DOCSTRING + FLAX_SWITCHTRANSFORMERS_MODEL_DOCSTRING) -append_replace_return_docstrings(FlaxSwitchTransformersModel, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) +overwrite_call_docstring( + FlaxSwitchTransformersModel, SWITCHTRANSFORMERS_INPUTS_DOCSTRING + FLAX_SWITCHTRANSFORMERS_MODEL_DOCSTRING +) +append_replace_return_docstrings( + FlaxSwitchTransformersModel, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC +) @add_start_docstrings( - "The bare SWITCHTRANSFORMERS Model transformer outputting encoder's raw hidden-states without any specific head on top.", + "The bare SWITCHTRANSFORMERS Model transformer outputting encoder's raw hidden-states without any specific head on" + " top.", SWITCHTRANSFORMERS_START_DOCSTRING, ) # Copied from transformers.models.t5.modeling_flax_t5.FlaxT5EncoderModule with T5->SwitchTransformers @@ -1518,7 +1525,9 @@ def __call__( ) -@add_start_docstrings("""SWITCHTRANSFORMERS Model with a `language modeling` head on top.""", SWITCHTRANSFORMERS_START_DOCSTRING) +@add_start_docstrings( + """SWITCHTRANSFORMERS Model with a `language modeling` head on top.""", SWITCHTRANSFORMERS_START_DOCSTRING +) # Copied from transformers.models.t5.modeling_flax_t5.FlaxT5ForConditionalGenerationModule with T5->SwitchTransformers class FlaxSwitchTransformersForConditionalGenerationModule(nn.Module): config: SwitchTransformersConfig @@ -1633,7 +1642,9 @@ class FlaxSwitchTransformersForConditionalGeneration(FlaxSwitchTransformersPreTr module_class = FlaxSwitchTransformersForConditionalGenerationModule @add_start_docstrings(SWITCHTRANSFORMERS_DECODE_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=SwitchTransformersConfig) + @replace_return_docstrings( + output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=SwitchTransformersConfig + ) def decode( self, decoder_input_ids, @@ -1819,7 +1830,8 @@ def update_inputs_for_generation(self, model_outputs, model_kwargs): overwrite_call_docstring( - FlaxSwitchTransformersForConditionalGeneration, SWITCHTRANSFORMERS_INPUTS_DOCSTRING + FLAX_SWITCHTRANSFORMERS_CONDITIONAL_GENERATION_DOCSTRING + FlaxSwitchTransformersForConditionalGeneration, + SWITCHTRANSFORMERS_INPUTS_DOCSTRING + FLAX_SWITCHTRANSFORMERS_CONDITIONAL_GENERATION_DOCSTRING, ) append_replace_return_docstrings( FlaxSwitchTransformersForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC diff --git a/src/transformers/models/switchtransformers/modeling_switchtransformers.py b/src/transformers/models/switchtransformers/modeling_switchtransformers.py index c2324c0b78774..78c55f539f330 100644 --- a/src/transformers/models/switchtransformers/modeling_switchtransformers.py +++ b/src/transformers/models/switchtransformers/modeling_switchtransformers.py @@ -64,7 +64,6 @@ ] - #################################################### # This is a conversion method from TF 1.0 to PyTorch # More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28 @@ -275,8 +274,9 @@ def forward(self, hidden_states): pass ALL_LAYERNORM_LAYERS.append(SwitchTransformersLayerNorm) -# TODO: this has to be changed with the experts + +# TODO: this has to be changed with the experts # Copied from transformers.models.t5.modeling_t5.T5DenseActDense with T5->SwitchTransformers class SwitchTransformersDenseActDense(nn.Module): def __init__(self, config: SwitchTransformersConfig): @@ -314,9 +314,9 @@ def forward(self, hidden_states): # TODO: Change it here to adapt it from the paper, the FF layer contains experts -# an expert is a FF layer with multiple sub-FF layers inside. -# This class should also contain a router class -# check flaxformer/architecture/moe/router.py : https://github.com/google/flaxformer/blob/main/flaxformer/architectures/moe/routing.py +# an expert is a FF layer with multiple sub-FF layers inside. +# This class should also contain a router class +# check flaxformer/architecture/moe/router.py : https://github.com/google/flaxformer/blob/main/flaxformer/architectures/moe/routing.py class SwitchTransformersLayerFF(nn.Module): def __init__(self, config: SwitchTransformersConfig): super().__init__() @@ -1482,7 +1482,9 @@ def forward( ) -@add_start_docstrings("""SWITCHTRANSFORMERS Model with a `language modeling` head on top.""", SWITCHTRANSFORMERS_START_DOCSTRING) +@add_start_docstrings( + """SWITCHTRANSFORMERS Model with a `language modeling` head on top.""", SWITCHTRANSFORMERS_START_DOCSTRING +) class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedModel): _keys_to_ignore_on_load_missing = [ r"encoder.embed_tokens.weight", @@ -1774,7 +1776,8 @@ def _reorder_cache(self, past, beam_idx): @add_start_docstrings( - "The bare SWITCHTRANSFORMERS Model transformer outputting encoder's raw hidden-states without any specific head on top.", + "The bare SWITCHTRANSFORMERS Model transformer outputting encoder's raw hidden-states without any specific head on" + " top.", SWITCHTRANSFORMERS_START_DOCSTRING, ) class SwitchTransformersEncoderModel(SwitchTransformersPreTrainedModel): diff --git a/src/transformers/models/switchtransformers/modeling_tf_switchtransformers.py b/src/transformers/models/switchtransformers/modeling_tf_switchtransformers.py index 4e31c6319726d..01faf07112a76 100644 --- a/src/transformers/models/switchtransformers/modeling_tf_switchtransformers.py +++ b/src/transformers/models/switchtransformers/modeling_tf_switchtransformers.py @@ -1292,7 +1292,9 @@ def serving_output(self, output): ) -@add_start_docstrings("""SWITCHTRANSFORMERS Model with a `language modeling` head on top.""", SWITCHTRANSFORMERS_START_DOCSTRING) +@add_start_docstrings( + """SWITCHTRANSFORMERS Model with a `language modeling` head on top.""", SWITCHTRANSFORMERS_START_DOCSTRING +) class TFSwitchTransformersForConditionalGeneration(TFSwitchTransformersPreTrainedModel, TFCausalLanguageModelingLoss): def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) @@ -1569,7 +1571,8 @@ def _reorder_cache(self, past, beam_idx): @add_start_docstrings( - "The bare SWITCHTRANSFORMERS Model transformer outputting encoder's raw hidden-stateswithout any specific head on top.", + "The bare SWITCHTRANSFORMERS Model transformer outputting encoder's raw hidden-stateswithout any specific head on" + " top.", SWITCHTRANSFORMERS_START_DOCSTRING, ) class TFSwitchTransformersEncoderModel(TFSwitchTransformersPreTrainedModel): diff --git a/src/transformers/models/switchtransformers/router.py b/src/transformers/models/switchtransformers/router.py new file mode 100644 index 0000000000000..ebddebc6bc2f8 --- /dev/null +++ b/src/transformers/models/switchtransformers/router.py @@ -0,0 +1,215 @@ +# coding=utf-8 +# Copyright 2022 Mesh TensorFlow authors, SwitchTransformers Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Tuple + +import torch +import torch.nn as nn + +# Output classes + +RouterOutput = Any + +@dataclass +class RouterIndices: + r""" + Dispatch indices and combine weights for scatter/gather-based routing. + + Attributes: + dispatch_indices: [num_groups, tokens_per_group, + num_selected_experts, 2] dispatch indices indicating, for each token, its preferred expert and its priority in + that expert's buffer. + combine_weights: [num_groups, tokens_per_group, num_selected_experts] + combine weights used for scaling expert outputs with the router's dispatch probability/confidence. + auxiliary_loss: Load balancing loss for router. + router_z_loss: Router z-loss. Encourages router logits to remain small in an + effort to improve stability. + """ + dispatch_indices: torch.Tensor + combine_weights: torch.Tensor + auxiliary_loss: float + router_z_loss: float = 0. + +@dataclass +class RouterMask: + r""" + Dispatch and combine torch.Tensors for expert routing with masked matmuls. + + Attributes: + dispatch_mask: [num_groups, tokens_per_group, num_experts, + expert_capacity] dispatch torch.Tensor that is 1 if the token gets routed to the corresponding expert, and 0 + otherwise. + combine_torch.Tensor: [num_groups, tokens_per_group, num_experts, + expert_capacity] combine torch.Tensor used for combining expert outputs and scaling with router probability. + auxiliary_loss: Load balancing loss for router. + router_z_loss: Router z-loss. Encourages router logits to remain small in an + effort to improve stability. + """ + dispatch_mask: torch.Tensor + combine_array: torch.Tensor + auxiliary_loss: float + router_z_loss: float = 0. + +# Router loss + +def _router_z_loss(router_logits: torch.Tensor) -> float: + r""" + Compute router z-loss implemented in PyTorch. + + The router z-loss was introduced in Designing Effective Sparse Expert Models (https://arxiv.org/abs/2202.08906). It + encourages router logits to remain small in an effort to improve stability. + + Args: + router_logits: [num_groups, tokens_per_group, num_experts] router + logits. + + Returns: + Scalar router z-loss. + """ + num_groups, tokens_per_group, _ = router_logits.shape + log_z = torch.logsumexp(router_logits, dim=-1) + z_loss = log_z**2 + return torch.sum(z_loss) / (num_groups * tokens_per_group) + + +def _load_balancing_loss(router_probs: torch.Tensor, expert_indices: torch.Tensor) -> float: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://arxiv.org/abs/2101.03961). This function implements the loss function presented in + equations (4) - (6). It aims to penalize those cases where the routing between experts is unbalanced. + + Args: + router_probs: Probability assigned to each expert per token. Shape: + [num_groups, tokens_per_group, num_experts]. + expert_indices: [num_groups, tokens_per_group, num_selected_experts] + indices identifying the top num_selected_experts for a given token. + + Returns: + The auxiliary loss. + """ + num_experts = router_probs.shape[-1] + + # Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts]. + # cast the expert indices to int64, otherwise one-hot encoding will fail + if expert_indices.dtype != torch.int64: + expert_indices = expert_indices.to(torch.int64) + expert_mask = torch.nn.functional.one_hot(expert_indices, num_experts) + + # For a given token, determine if it was routed to a given expert. + # Shape: [num_groups, tokens_per_group, num_experts] + expert_mask = torch.max(expert_mask, axis=-2).values + + # cast to float32 otherwise mean will fail + expert_mask = expert_mask.to(torch.float32) + tokens_per_group_and_expert = torch.mean(expert_mask, axis=-2) + + router_prob_per_group_and_expert = torch.mean(router_probs, axis=-2) + return torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert) * (num_experts**2) + +# Router classes + +class Router(nn.Module): + """ + Abstract base router class, defining router API and inner workings. + + Attributes: + router_weights: Configurable module used to compute router logits from token + inputs. + jitter_noise: Amplitude of jitter noise applied to router logits. + dtype: Numeric float type for returned combine torch.Tensor. All actual + computations are performed in float32 of the input for stability. + ignore_padding_tokens: Whether to ignore padding tokens during routing. Note + that some routers (e.g. TokensChooseMaskedRouter) will completely ignore padding tokens, while others (e.g. + TokensChooseScatterRouter and ExpertsChooseMaskedRouter) will simply down-weight the probability of selecting + padding tokens. + """ + + def __init__(self, config, **kwargs): + super().__init__() + self.num_experts = config.num_experts + self.router_weights = nn.Linear(config.hidden_size, self.num_experts, bias=config.router_bias) + self.jitter_noise = config.router_jitter_noise + self.ignore_padding_tokens = config.router_ignore_padding_tokens + + def _compute_router_probabilities(self, token_inputs: torch.Tensor, num_experts: int, apply_jitter: bool) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Computes router probabilities from input tokens. + + Args: + token_inputs: [num_groups, tokens_per_group, hidden_dim] from which + router probabilities are computed. + num_experts: Number of experts. + apply_jitter: If true, apply jitter noise. + + Returns: + - [num_groups, tokens_per_group, num_experts] probabilities for + each token and expert. Used for routing tokens to experts. + - [num_groups, tokens_per_group, num_experts] raw router logits. + Used for computing router z-loss. + """ + # For remainder of routing computation we use float32 to ensure stability. + # See the discussion of "selective precision" in + # https://arxiv.org/abs/2101.03961. + token_inputs = token_inputs.to(torch.float32) + + if apply_jitter and self.jitter_noise > 0: + token_inputs *= torch.random.uniform( + token_inputs.shape, + token_inputs.dtype, + minval=1.0 - self.jitter_noise, + maxval=1.0 + self.jitter_noise) + + # Shape: [num_groups, tokens_per_group, num_experts] + router_logits = self.router_weights(token_inputs, num_experts) + + router_probabilities = torch.nn.softmax(router_logits, axis=-1) + + return router_probabilities, router_logits + + + def forward(self, token_inputs: torch.Tensor, expert_capacity: int, apply_jitter: bool = True) -> RouterOutput: + r""" + Args: + Computes dispatch and combine torch.Tensors for routing to experts. + token_inputs: [num_groups, tokens_per_group, hidden_dim] inputs to + send to experts. + num_experts: Number of experts. + expert_capacity: Each group will send this many tokens to each expert. + apply_jitter: If true, apply jitter noise during routing. + Returns: + Router indices or mask torch.Tensors (depending on router type). + """ + router_probs, router_logits = self._compute_router_probabilities(token_inputs, self.num_experts, apply_jitter) + + + if self.ignore_padding_tokens: + # To identify non-padding tokens, we rely on the fact that padding tokens + # in the inputs have already been masked in the default T5 architecture. + # See + # https://github.com/google/flaxformer/blob/9712a16/flaxformer/architectures/t5/t5_architecture.py#L315 + # and + # https://github.com/google/flaxformer/blob/9712a16/flaxformer/architectures/t5/t5_architecture.py#L603. + padding_mask = jnp.torch.Tensor((jnp.sum(jnp.abs(token_inputs), axis=-1) > 0), + dtype=token_inputs.dtype) + router_logits *= jnp.expand_dims(padding_mask, axis=-1) + else: + padding_mask = None + + instructions = self._compute_routing_instructions(router_probs, + padding_mask, + expert_capacity) + + return instructions.replace(router_z_loss=_router_z_loss(router_logits)) \ No newline at end of file diff --git a/src/transformers/models/switchtransformers/router_flax.py b/src/transformers/models/switchtransformers/router_flax.py new file mode 100644 index 0000000000000..fbe446a1319cc --- /dev/null +++ b/src/transformers/models/switchtransformers/router_flax.py @@ -0,0 +1,759 @@ +# Copyright 2022 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Mixture of Experts routing mechanisms.""" + +from typing import Any, Iterable, Optional, Sequence, Tuple, Union + +import flax +from flax import linen as nn +from flax.linen import partitioning as flax_partitioning +import jax +import jax.numpy as jnp + +# from flaxformer.components import dense +# from flaxformer.types import Array +# from flaxformer.types import DType +# from flaxformer.types import Initializer + +RouterOutput = Any +Array = Any +DType = Any +Initializer = Any + +# Switch Transformer (https://arxiv.org/abs/2101.03961) suggests using +# nn.initializers.variance_scaling(0.1, "fan_in", "truncated_normal") +# scaling throughout MoE models, but we find slightly better results adopting +# typical normally-distributed scaling for the router specifically. +default_kernel_init = nn.initializers.normal(stddev=2e-2) +default_bias_init = nn.initializers.zeros + + +@flax.struct.dataclass +class RouterIndices: + """Dispatch indices and combine weights for scatter/gather-based routing. + + Attributes: + dispatch_indices: [num_groups, tokens_per_group, + num_selected_experts, 2] dispatch indices indicating, for each token, its preferred expert and its priority in + that expert's buffer. + combine_weights: [num_groups, tokens_per_group, num_selected_experts] + combine weights used for scaling expert outputs with the router's dispatch probability/confidence. + auxiliary_loss: Load balancing loss for router. + router_z_loss: Router z-loss. Encourages router logits to remain small in an + effort to improve stability. + """ + dispatch_indices: Array + combine_weights: Array + auxiliary_loss: float + router_z_loss: float = 0. + + +@flax.struct.dataclass +class RouterMask: + """Dispatch and combine arrays for expert routing with masked matmuls. + + Attributes: + dispatch_mask: [num_groups, tokens_per_group, num_experts, + expert_capacity] dispatch array that is 1 if the token gets routed to the corresponding expert, and 0 otherwise. + combine_array: [num_groups, tokens_per_group, num_experts, + expert_capacity] combine array used for combining expert outputs and scaling with router probability. + auxiliary_loss: Load balancing loss for router. + router_z_loss: Router z-loss. Encourages router logits to remain small in an + effort to improve stability. + """ + dispatch_mask: Array + combine_array: Array + auxiliary_loss: float + router_z_loss: float = 0. + + +def _favor_one_hot_slices() -> bool: + """Returns true iff running on TPUs.""" + return jax.default_backend() == 'tpu' or jax.devices()[0].platform == 'tpu' + + +def _take_along_axis(array: Array, indices: Array, axis: int) -> Array: + """Takes values from the input array by matching 1D index and data slices. + + This function serves the same purpose as jax.numpy.take_along_axis, except that it uses one-hot matrix + multiplications under the hood on TPUs: (1) On TPUs, we use one-hot matrix multiplications to select elements from + the + array; this is particularly helpful for avoiding erroneous all-gather ops when running under pjit. + (2) Otherwise, we fall back to jax.numpy.take_along_axis. + + Notes: + - To simplify matters in case (1), we only support slices along the second or last dimensions. + - We may wish to revisit (1) for very large arrays. + + Args: + array: Source array. + indices: Indices to take along each 1D slice of array. + axis: Axis along which to take 1D slices. + + Returns: + The indexed result. + """ + if array.ndim != indices.ndim: + raise ValueError( + 'indices and array must have the same number of dimensions; ' + f'{indices.ndim} vs. {array.ndim}.') + + if (axis != -1 and axis != array.ndim - 1 and # Not last dimension + axis != 1 and axis != -array.ndim + 1): # Not second dimension + raise ValueError( + 'Only slices along the second or last dimension are supported; ' + f'array.ndim = {array.ndim}, while axis = {axis}.') + + if _favor_one_hot_slices(): + one_hot_length = array.shape[axis] + one_hot_indices = jax.nn.one_hot(indices, one_hot_length, axis=axis) + + if axis == -1 or array.ndim == 1: + # Take i elements from last dimension (s). + # We must use HIGHEST precision to accurately reproduce indexing + # operations with matrix multiplications. + result = jnp.einsum( + '...s,...is->...i', + array, + one_hot_indices, + precision=jax.lax.Precision.HIGHEST) + else: + # Take i elements from second dimension (s). We assume here that we always + # want to slice along the second dimension. + # We must use HIGHEST precision to accurately reproduce indexing + # operations with matrix multiplications. + result = jnp.einsum( + 'ns...,nis...->ni...', + array, + one_hot_indices, + precision=jax.lax.Precision.HIGHEST) + return jax.lax.convert_element_type(result, array.dtype) + else: + return jnp.take_along_axis(array, indices, axis=axis) + + +def _top_k(array: Array, k: int) -> Tuple[Array, Array]: + """Returns top k values and their indices along the last axis of the array. + + This function serves the same purpose as jax.lax.top_k, but in a more XLA friendly manner for TPUs: (1) On TPUs, we + use one-hot matrix multiplications to select the top k values. + This convoluted way of obtaining the top k values is generally faster on TPUs, and, for pjit in particular, + avoids adding extra all-gather ops during backpropagation. + (2) Otherwise, we fall back to jax.lax.top_k (and its underlying scatter op). + + Args: + array: Source array. + k: Number of top values to select. + + Returns: + - Top k values + - Associated top k indices. + """ + if _favor_one_hot_slices(): + top_k_indices = jax.lax.top_k(array, k)[-1] + top_k_values = _take_along_axis(array, top_k_indices, axis=-1) + return top_k_values, top_k_indices + else: + return jax.lax.top_k(array, k) + + +class RouterWeights(nn.Module): + """Router module converting token inputs to router logits. + + Attributes: + use_bias: Whether or not to use the bias term in computing the logits. + dtype: Numerical float type for router logit computation. + kernel_init: Initialization scheme for kernel. + bias_init: Initialization scheme for bias. + precision: XLA precision for array computations. + axis: Axes along which to apply the dense router weights transformation. + Defaults to final axis (typically the "hidden dimension"). + kernel_axis_names: Logical axis names to use for kernel sharding. + reshape_kernel: Whether to reshape the kernel parameter to 2D for Adafactor. + """ + use_bias: bool = True + dtype: DType = jnp.bfloat16 + kernel_init: Initializer = default_kernel_init + bias_init: Initializer = default_bias_init + precision: jax.lax.Precision = jax.lax.Precision.DEFAULT + axis: Union[Iterable[int], int] = -1 + kernel_axis_names: Sequence[str] = ('embed', 'unmodeled') + reshape_kernel: bool = True + + @nn.compact + def __call__(self, token_inputs: Array, num_experts: int) -> Array: + """Applies RouterWeights module. + + Args: + token_inputs: Flattened batch of tokens with shape [num_groups, + group_size, hidden_dim]. + num_experts: Number of experts. + + Returns: + Router logits with shape [num_groups, group_size, num_experts]. + """ + return dense.DenseGeneral( + features=num_experts, + axis=self.axis, + use_bias=self.use_bias, + dtype=self.dtype, + kernel_init=self.kernel_init, + bias_init=self.bias_init, + precision=self.precision, + kernel_axis_names=self.kernel_axis_names, + reshape_kernel=self.reshape_kernel, + name='w')( + token_inputs) + + +class Router(nn.Module): + """Abstract base router class, defining router API and inner workings. + + Attributes: + router_weights: Configurable module used to compute router logits from token + inputs. + jitter_noise: Amplitude of jitter noise applied to router logits. + dtype: Numeric float type for returned combine array. All actual + computations are performed in float32 of the input for stability. + ignore_padding_tokens: Whether to ignore padding tokens during routing. Note + that some routers (e.g. TokensChooseMaskedRouter) will completely ignore padding tokens, while others (e.g. + TokensChooseScatterRouter and ExpertsChooseMaskedRouter) will simply down-weight the probability of selecting + padding tokens. + """ + router_weights: RouterWeights + jitter_noise: float + dtype: jnp.dtype + ignore_padding_tokens: bool + + def __call__(self, + token_inputs: Array, + num_experts: int, + expert_capacity: int, + apply_jitter: bool = True) -> RouterOutput: + """Computes dispatch and combine arrays for routing to experts. + + Args: + token_inputs: [num_groups, tokens_per_group, hidden_dim] inputs to + send to experts. + num_experts: Number of experts. + expert_capacity: Each group will send this many tokens to each expert. + apply_jitter: If true, apply jitter noise during routing. + + Returns: + Router indices or mask arrays (depending on router type). + """ + token_inputs = flax_partitioning.with_sharding_constraint( + token_inputs, ('batch', 'length', 'embed')) + router_probs, router_logits = self._compute_router_probabilities( + token_inputs, num_experts, apply_jitter) + router_probs = flax_partitioning.with_sharding_constraint( + router_probs, ('batch', 'length', 'unmodeled')) + router_logits = flax_partitioning.with_sharding_constraint( + router_logits, ('batch', 'length', 'unmodeled')) + + if self.ignore_padding_tokens: + # To identify non-padding tokens, we rely on the fact that padding tokens + # in the inputs have already been masked in the default T5 architecture. + # See + # https://github.com/google/flaxformer/blob/9712a16/flaxformer/architectures/t5/t5_architecture.py#L315 + # and + # https://github.com/google/flaxformer/blob/9712a16/flaxformer/architectures/t5/t5_architecture.py#L603. + padding_mask = jnp.array((jnp.sum(jnp.abs(token_inputs), axis=-1) > 0), + dtype=token_inputs.dtype) + router_logits *= jnp.expand_dims(padding_mask, axis=-1) + else: + padding_mask = None + + instructions = self._compute_routing_instructions(router_probs, + padding_mask, + expert_capacity) + + return instructions.replace(router_z_loss=_router_z_loss(router_logits)) + + def _compute_router_probabilities(self, token_inputs: Array, num_experts: int, + apply_jitter: bool) -> Tuple[Array, Array]: + """Computes router probabilities from input tokens. + + Args: + token_inputs: [num_groups, tokens_per_group, hidden_dim] from which + router probabilities are computed. + num_experts: Number of experts. + apply_jitter: If true, apply jitter noise. + + Returns: + - [num_groups, tokens_per_group, num_experts] probabilities for each token and expert. Used for routing + tokens to experts. + - [num_groups, tokens_per_group, num_experts] raw router logits. Used for computing router z-loss. + """ + # For remainder of routing computation we use float32 to ensure stability. + # See the discussion of "selective precision" in + # https://arxiv.org/abs/2101.03961. + token_inputs = jax.lax.convert_element_type(token_inputs, jnp.float32) + + if apply_jitter and self.jitter_noise > 0: + token_inputs *= jax.random.uniform( + self.make_rng('jitter'), + token_inputs.shape, + token_inputs.dtype, + minval=1.0 - self.jitter_noise, + maxval=1.0 + self.jitter_noise) + + # Shape: [num_groups, tokens_per_group, num_experts] + router_logits = self.router_weights(token_inputs, num_experts) + + router_probabilities = jax.nn.softmax(router_logits, axis=-1) + + return router_probabilities, router_logits + + def _compute_routing_instructions(self, router_probs: Array, + padding_mask: Optional[Array], + expert_capacity: int) -> RouterOutput: + """Computes instructions for routing inputs to experts.""" + raise NotImplementedError( + 'Router is an abstract class that should be subclassed.') + + +class ScatterRouter(Router): + """Abstract base router class for scatter dispatch routers. + + ScatterRouter(s) return RouterIndices containing dispatch indices and combine weights for sending token inputs (via + scatter) and receiving outputs (via gather) to and from experts. + + Scatter-based routing is generally faster than masked matmul routing on CPUs and GPUs. + """ + + def _compute_routing_instructions(self, router_probs: Array, + padding_mask: Optional[Array], + expert_capacity: int) -> RouterIndices: + """Computes instructions for routing inputs to experts. + + Args: + router_probs: [num_groups, tokens_per_group, num_experts] + probabilities used to determine the routing of tokens to the experts. + padding_mask: [num_groups, tokens_per_group] padding logit mask + used to identify padding tokens that should be ignored by the router. + expert_capacity: Each group will send this many tokens to each expert. + + Returns: + Router indices containing dispatch indices and combine weights. + """ + raise NotImplementedError( + 'ScatterRouter is an abstract class that should be subclassed.') + + +class MaskedRouter(Router): + """Abstract base router class for masked matmul dispatch routers. + + MaskedRouter(s) return RouterMask(s) containing a dispatch mask and combine array for sending and receiving (via + masked matmuls) inputs and outputs to and from experts. + + Routing using masked matmuls is generally faster than scatter-based routing on TPUs. + """ + + def _compute_routing_instructions(self, router_probs: Array, + padding_mask: Optional[Array], + expert_capacity: int) -> RouterMask: + """Computes masks for the top-k experts per token. + + Args: + router_probs: [num_groups, tokens_per_group, num_experts] + probabilities used to determine the routing of tokens to the experts. + padding_mask: [num_groups, tokens_per_group] padding logit mask + used to identify padding tokens that should be ignored by the router. + expert_capacity: Each group will send this many tokens to each expert. + + Returns: + Router mask arrays. + """ + raise NotImplementedError( + 'MaskedRouter is an abstract class that should be subclassed.') + + +class TokensChooseScatterRouter(ScatterRouter): + """Scatter router using tokens choose top-k experts assignment. + + This router uses the same mechanism as in Switch Transformer (https://arxiv.org/abs/2101.03961) and V-MoE + (https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are sorted by router_probs and then routed + to their choice of expert until the expert's expert_capacity is reached. There is no guarantee that each token is + processed by an expert, or that each expert receives at least one token. + + Attributes: + num_selected_experts: Maximum number of experts to which each token is + routed. Tokens may be routed to fewer experts if particular experts are oversubscribed / reach capacity. + batch_prioritized_routing: Whether or not to use Batch Prioritized Routing + (BPR), originally introduced in V-MoE (https://arxiv.org/abs/2106.05974). + With BPR, we prioritize routing those top-k tokens with the highest router probability, rather than simply + using each tokens left-to-right ordering in the batch. This prioritization is important because the expert's + have limited capacity. + """ + num_selected_experts: int + batch_prioritized_routing: bool + + def _compute_routing_instructions(self, router_probs: Array, + padding_mask: Optional[Array], + expert_capacity: int) -> RouterIndices: + """Computes dispatch indices and combine weights for the top-k experts. + + Args: + router_probs: [num_groups, tokens_per_group, num_experts] + probabilities used to determine the routing of tokens to the experts. + padding_mask: [num_groups, tokens_per_group] padding logit mask + used to identify padding tokens that should be down-weighted by the router. + expert_capacity: Each group will send this many tokens to each expert. + + Returns: + Dispatch indices and combine weights for scatter/gather-based routing. + """ + num_groups, tokens_per_group, num_experts = router_probs.shape + + if padding_mask is not None: + # Because `expert_indices` are directly used for scatter-based routing, we + # mask probabilities corresponding to tokens before the top-k operation. + # Note that, unlike for mask-based tokens-choose routing, the + # (down-weighted) padding tokens may still be selected. + router_probs *= jnp.expand_dims(padding_mask, axis=-1) + + # Top-k router probability and corresponding expert indices for each token. + # Shape: [num_groups, tokens_per_group, num_selected_experts]. + combine_weights, expert_indices = _top_k( + router_probs, k=self.num_selected_experts) + + auxiliary_loss = _load_balancing_loss(router_probs, expert_indices) + + if self.batch_prioritized_routing: + # Sort tokens according to their routing probability per token group, so + # that the highest probability tokens are routed first. + token_ordering = jnp.argsort(-combine_weights[..., 0], axis=-1) + expert_indices = _take_along_axis( + expert_indices, jnp.expand_dims(token_ordering, axis=-1), axis=-2) + + # Identify each token's preferred expert. + # Make num_selected_experts the leading axis to ensure that top-1 choices + # have priority over top-2 choices, which have priority over top-3 + # choices... + preferred_experts = jnp.swapaxes(expert_indices, 1, 2) + # Shape: [num_groups, num_selected_experts * tokens_per_group] + preferred_experts = preferred_experts.reshape(num_groups, -1) + + # Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts]. + expert_mask = jax.nn.one_hot( + preferred_experts, num_experts, dtype=jnp.int32) + + # Experts have a fixed capacity that we cannot exceed. A token's priority + # within the expert's buffer is given by the masked, cumulative capacity of + # its target expert. + # Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts]. + token_priority = jnp.cumsum(expert_mask, axis=1) * expert_mask - 1.0 + # Shape: [num_groups, num_selected_experts, tokens_per_group, num_experts]. + token_priority = token_priority.reshape( + (num_groups, self.num_selected_experts, -1, num_experts)) + # Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts]. + token_priority = jnp.swapaxes(token_priority, 1, 2) + # For each token, across all experts, select the only non-negative + # (unmasked) priority. Shape: [num_groups, tokens_per_group, + # num_selected_experts]. + token_priority = jnp.max(token_priority, axis=-1) + + # Return to original index shape. + preferred_experts = preferred_experts.reshape(num_groups, + self.num_selected_experts, + tokens_per_group) + # Shape: [num_groups, tokens_per_group, num_selected_experts] + preferred_experts = jnp.swapaxes(preferred_experts, 1, 2) + + if self.batch_prioritized_routing: + # Place tokens in their original ordering. + inverse_token_ordering = jnp.argsort(token_ordering, axis=-1) + preferred_experts = _take_along_axis( + preferred_experts, + jnp.expand_dims(inverse_token_ordering, axis=-1), + axis=-2) + token_priority = _take_along_axis( + token_priority, + jnp.expand_dims(inverse_token_ordering, axis=-1), + axis=-2) + + # Mask out tokens that overflow the maximum expert capacities. + # Shape: [num_groups, tokens_per_group, num_selected_experts]. + combine_weights *= token_priority < expert_capacity + + # Expert index and priority within the expert capacity buffer. + # Shape: [num_groups, tokens_per_group, num_selected_experts, 2]. + dispatch_indices = jnp.stack([preferred_experts, token_priority], axis=-1) + + # Return to default dtype now that router computation is complete. + combine_weights = jax.lax.convert_element_type(combine_weights, self.dtype) + dispatch_indices = jax.lax.convert_element_type(dispatch_indices, jnp.int32) + + return RouterIndices(dispatch_indices, combine_weights, auxiliary_loss) + + +class TokensChooseMaskedRouter(MaskedRouter): + """Masked matmul router using tokens choose top-k experts assignment. + + This router uses the same mechanism as in Switch Transformer (https://arxiv.org/abs/2101.03961) and V-MoE + (https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are sorted by router_probs and then routed + to their choice of expert until the expert's expert_capacity is reached. There is no guarantee that each token is + processed by an expert, or that each expert receives at least one token. + + Attributes: + num_selected_experts: Maximum number of experts to which each token is + routed. Tokens may be routed to fewer experts if particular experts are oversubscribed / reach capacity. + batch_prioritized_routing: Whether or not to use Batch Prioritized Routing + (BPR), originally introduced in V-MoE (https://arxiv.org/abs/2106.05974). + With BPR, we prioritize routing those top-k tokens with the highest router probability, rather than simply + using each tokens left-to-right ordering in the batch. This prioritization is important because the experts + have limited capacity. + """ + num_selected_experts: int + batch_prioritized_routing: bool + + def _compute_routing_instructions(self, router_probs: Array, + padding_mask: Optional[Array], + expert_capacity: int) -> RouterMask: + """Computes masks for the top-k experts per token. + + Args: + router_probs: [num_groups, tokens_per_group, num_experts] + probabilities used to determine the routing of tokens to the experts. + padding_mask: [num_groups, tokens_per_group] padding logit mask + used to identify padding tokens that should be ignored by the router. + expert_capacity: Each group will send this many tokens to each expert. + + Returns: + Dispatch and combine arrays for routing with masked matmuls. + """ + num_groups, _, num_experts = router_probs.shape + + # Top-k router probability and corresponding expert indices for each token. + # Shape: [num_groups, tokens_per_group, num_selected_experts]. + expert_gate, expert_index = _top_k( + router_probs, k=self.num_selected_experts) + + if padding_mask is not None: + # Mask applied to gate. Exclude choices corresponding to padding tokens. + gate_mask = jnp.expand_dims(padding_mask, axis=-1) + expert_gate *= gate_mask + + # Set `expert_index` elements corresponding to padding to negative + # numbers. Negative `expert_index` elements will ultimately be dropped in + # the one_hot conversion to the `expert_mask`. + # First convert nonzero padding elements to negative values. + expert_index *= 2 * gate_mask - 1. + # Handle zero padding elements by negatively shifting all padding. + expert_index += jnp.repeat( + gate_mask - 1., self.num_selected_experts, axis=-1) + + # To correctly compute load balancing loss, we also mask out probs. + router_probs *= gate_mask + + auxiliary_loss = _load_balancing_loss(router_probs, expert_index) + + if self.batch_prioritized_routing: + # Sort tokens according to their routing probability per group, so that + # the highest probability tokens are routed first. + permutation = jnp.argsort(-expert_gate[..., 0], axis=-1) + # Shape: [num_groups, tokens_per_group, num_selected_experts] + expert_index = _take_along_axis( + expert_index, jnp.expand_dims(permutation, axis=-1), axis=-2) + + # Make num_selected_experts the leading axis to ensure that top-1 choices + # have priority over top-2 choices, which have priority over top-3 choices, + # etc. + expert_index = jnp.swapaxes(expert_index, 1, 2) + # Shape: [num_groups, num_selected_experts * tokens_per_group] + expert_index = expert_index.reshape(num_groups, -1) + + # Create mask out of indices. + # Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts]. + expert_mask = jax.nn.one_hot(expert_index, num_experts, dtype=jnp.int32) + + # Experts have a fixed capacity that we cannot exceed. A token's priority + # within the expert's buffer is given by the masked, cumulative capacity of + # its target expert. + # Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts]. + token_priority = jnp.cumsum(expert_mask, axis=1) * expert_mask - 1.0 + # Shape: [num_groups, num_selected_experts, tokens_per_group, num_experts]. + token_priority = token_priority.reshape( + (num_groups, self.num_selected_experts, -1, num_experts)) + # Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts]. + token_priority = jnp.swapaxes(token_priority, 1, 2) + # For each token, across all selected experts, select the only non-negative + # (unmasked) priority. Now, for group G routing to expert E, token T has + # non-negative priority (i.e. token_priority[G,T,E] >= 0) if and only if E + # is its targeted expert. + # Shape: [num_groups, tokens_per_group, num_experts]. + token_priority = jnp.max(token_priority, axis=2) + + if self.batch_prioritized_routing: + # Place token priorities in original ordering of tokens. + inv_permutation = jnp.argsort(permutation, axis=-1) + token_priority = _take_along_axis( + token_priority, jnp.expand_dims(inv_permutation, axis=-1), axis=-2) + + # Token T can only be routed to expert E if its priority is positive and + # less than the expert capacity. One-hot matrix will ignore indices outside + # the range [0, expert_capacity). + # Shape: [num_groups, tokens_per_group, num_experts, expert_capacity]. + dispatch_mask = jax.nn.one_hot( + token_priority, expert_capacity, dtype=jnp.bool_) + + # The combine array will be used for combining expert outputs, scaled by the + # router probabilities. Shape: [num_groups, tokens_per_group, num_experts, + # expert_capacity]. + combine_array = jnp.einsum( + '...te,...tec->...tec', + router_probs, + dispatch_mask, + precision=jax.lax.Precision.DEFAULT) + + # Return to default dtype now that router computation is complete. + combine_array = jax.lax.convert_element_type(combine_array, self.dtype) + + return RouterMask(dispatch_mask, combine_array, auxiliary_loss) + + +class ExpertsChooseMaskedRouter(MaskedRouter): + """Masked matmul router using experts choose tokens assignment. + + This router uses the same mechanism as in Mixture-of-Experts with Expert Choice (https://arxiv.org/abs/2202.09368): + each expert selects its top expert_capacity tokens. An individual token may be processed by multiple experts or none + at all. + + Note: "experts choose routing" should not be used in decoder blocks because it breaks the autoregressive behavior -- + the model will learn to cheat by using future token information to improve current token predictions. + """ + + def _compute_routing_instructions(self, router_probs: Array, + padding_mask: Optional[Array], + expert_capacity: int) -> RouterMask: + """Computes masks for the highest probability token per expert. + + Args: + router_probs: [num_groups, tokens_per_group, num_experts] + probabilities used to determine the routing of tokens to the experts. + padding_mask: [num_groups, tokens_per_group] padding logit mask + used to identify padding tokens that should be down-weighted by the router. + expert_capacity: Each group will send this many tokens to each expert. + + Returns: + Dispatch and combine arrays for routing with masked matmuls. + """ + tokens_per_group = router_probs.shape[1] + + if padding_mask is not None: + # Because experts choose tokens, we mask probabilities corresponding to + # tokens before the top-k operation. Note that, unlike for masked-based + # tokens-choose routing, the experts here may still choose to select the + # (down-weighted) padding tokens. + router_probs *= jnp.expand_dims(padding_mask, axis=-1) + + # vmap over group dimension. + router_probs_t = jax.vmap(lambda m: m.transpose())(router_probs) + + # Top expert_capacity router probability and corresponding token indices for + # each expert. Shapes: [num_groups, num_experts, expert_capacity]. + expert_gate, expert_index = _top_k(router_probs_t, k=expert_capacity) + + # Convert to one-hot mask of expert indices for each token in each group. + # Shape: [num_groups, num_experts, expert_capacity, tokens_per_group]. + dispatch_mask = jax.nn.one_hot( + expert_index, tokens_per_group, dtype=jnp.int32) + + # Move axes to conform with shape expected by MoeLayer API. + # Shape: [num_groups, tokens_per_group, num_experts, expert_capacity] + dispatch_mask = jnp.moveaxis(dispatch_mask, 3, 1) + + # The combine array will be used for combining expert outputs, scaled by the + # router probabilities. Shape: [num_groups, num_experts, tokens_per_group, + # expert_capacity]. + combine_array = jnp.einsum( + '...ec,...tec->...tec', + expert_gate, + dispatch_mask, + precision=jax.lax.Precision.DEFAULT) + + # Return to default dtype now that router computation is complete. + combine_array = jax.lax.convert_element_type(combine_array, self.dtype) + + # Each expert is choosing tokens until it reaches full capacity, so we don't + # need an auxiliary loading balancing loss for expert choice routing. + auxiliary_loss = 0.0 + + return RouterMask(dispatch_mask, combine_array, auxiliary_loss) + + +def _load_balancing_loss(router_probs: Array, expert_indices: Array) -> float: + """Computes auxiliary load balancing loss as in Switch Transformer. + + See Switch Transformer (https://arxiv.org/abs/2101.03961). This function implements the loss function presented in + equations (4) - (6). It aims to penalize those cases where the routing between experts is unbalanced. + + Args: + router_probs: Probability assigned to each expert per token. Shape: + [num_groups, tokens_per_group, num_experts]. + expert_indices: [num_groups, tokens_per_group, num_selected_experts] + indices identifying the top num_selected_experts for a given token. + + Returns: + The auxiliary loss. + """ + num_experts = router_probs.shape[-1] + + # Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts]. + expert_mask = jax.nn.one_hot(expert_indices, num_experts, dtype=jnp.int32) + # For a given token, determine if it was routed to a given expert. + # Shape: [num_groups, tokens_per_group, num_experts] + expert_mask = jnp.max(expert_mask, axis=-2) + + tokens_per_group_and_expert = jnp.mean( + expert_mask, dtype=jnp.float32, axis=-2) + router_prob_per_group_and_expert = jnp.mean( + router_probs, dtype=jnp.float32, axis=-2) + return jnp.mean( + tokens_per_group_and_expert * router_prob_per_group_and_expert, + dtype=jnp.float32) * num_experts**2 + + +def _router_z_loss(router_logits: Array) -> float: + """Compute router z-loss. + + The router z-loss was introduced in Designing Effective Sparse Expert Models (https://arxiv.org/abs/2202.08906). It + encourages router logits to remain small in an effort to improve stability. + + Args: + router_logits: [num_groups, tokens_per_group, num_experts] router + logits. + + Returns: + Scalar router z-loss. + """ + num_groups, tokens_per_group, _ = router_logits.shape + log_z = jax.nn.logsumexp(router_logits, axis=-1) + z_loss = log_z**2 + return jnp.sum(z_loss, dtype=jnp.float32) / (num_groups * tokens_per_group) + + +num_tokens = 5 +num_experts = 2 +num_selected_experts = 1 +rng = jax.random.PRNGKey(0) + +router_probs = jax.random.uniform( + rng, (num_tokens, num_experts), minval=0, maxval=1) +expert_indices = jax.random.randint( + rng, (num_tokens, num_selected_experts), minval=0, maxval=2) + +loss = _load_balancing_loss(router_probs, expert_indices) \ No newline at end of file diff --git a/src/transformers/models/switchtransformers/tokenization_switchtransformers.py b/src/transformers/models/switchtransformers/tokenization_switchtransformers.py index d90721d76166a..d1235520a12da 100644 --- a/src/transformers/models/switchtransformers/tokenization_switchtransformers.py +++ b/src/transformers/models/switchtransformers/tokenization_switchtransformers.py @@ -33,7 +33,9 @@ PRETRAINED_VOCAB_FILES_MAP = { "vocab_file": { - "ybelkada/switchtransformers-base": "https://huggingface.co/ybelkada/switchtransformers-base/resolve/main/spiece.model", + "ybelkada/switchtransformers-base": ( + "https://huggingface.co/ybelkada/switchtransformers-base/resolve/main/spiece.model" + ), "switchtransformers-base": "https://huggingface.co/switchtransformers-base/resolve/main/spiece.model", "switchtransformers-large": "https://huggingface.co/switchtransformers-large/resolve/main/spiece.model", "switchtransformers-3b": "https://huggingface.co/switchtransformers-3b/resolve/main/spiece.model", @@ -132,8 +134,8 @@ def __init__( if extra_tokens != extra_ids: raise ValueError( f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are" - " provided to SwitchTransformersTokenizer. In this case the additional_special_tokens must include the extra_ids" - " tokens" + " provided to SwitchTransformersTokenizer. In this case the additional_special_tokens must" + " include the extra_ids tokens" ) self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs @@ -155,9 +157,13 @@ def __init__( self.sp_model.Load(vocab_file) @staticmethod - def _eventually_correct_switchtransformers_max_length(pretrained_model_name_or_path, max_model_length, init_max_model_length): + def _eventually_correct_switchtransformers_max_length( + pretrained_model_name_or_path, max_model_length, init_max_model_length + ): if pretrained_model_name_or_path in SwitchTransformersTokenizer.max_model_input_sizes: - deprecated_max_model_length = SwitchTransformersTokenizer.max_model_input_sizes[pretrained_model_name_or_path] + deprecated_max_model_length = SwitchTransformersTokenizer.max_model_input_sizes[ + pretrained_model_name_or_path + ] if init_max_model_length is not None and init_max_model_length != max_model_length: return init_max_model_length elif init_max_model_length is None: diff --git a/src/transformers/models/switchtransformers/tokenization_switchtransformers_fast.py b/src/transformers/models/switchtransformers/tokenization_switchtransformers_fast.py index e9f0302f48700..0edf71fa3f285 100644 --- a/src/transformers/models/switchtransformers/tokenization_switchtransformers_fast.py +++ b/src/transformers/models/switchtransformers/tokenization_switchtransformers_fast.py @@ -36,14 +36,18 @@ PRETRAINED_VOCAB_FILES_MAP = { "vocab_file": { - "ybelkada/switchtransformers-base": "https://huggingface.co/ybelkada/switchtransformers-base/resolve/main/spiece.model", + "ybelkada/switchtransformers-base": ( + "https://huggingface.co/ybelkada/switchtransformers-base/resolve/main/spiece.model" + ), "switchtransformers-base": "https://huggingface.co/switchtransformers-base/resolve/main/spiece.model", "switchtransformers-large": "https://huggingface.co/switchtransformers-large/resolve/main/spiece.model", "switchtransformers-3b": "https://huggingface.co/switchtransformers-3b/resolve/main/spiece.model", "switchtransformers-11b": "https://huggingface.co/switchtransformers-11b/resolve/main/spiece.model", }, "tokenizer_file": { - "ybelkada/switchtransformers-base": "https://huggingface.co/ybelkada/switchtransformers-base/resolve/main/tokenizer.json", + "ybelkada/switchtransformers-base": ( + "https://huggingface.co/ybelkada/switchtransformers-base/resolve/main/tokenizer.json" + ), "switchtransformers-base": "https://huggingface.co/switchtransformers-base/resolve/main/tokenizer.json", "switchtransformers-large": "https://huggingface.co/switchtransformers-large/resolve/main/tokenizer.json", "switchtransformers-3b": "https://huggingface.co/switchtransformers-3b/resolve/main/tokenizer.json", @@ -127,8 +131,8 @@ def __init__( if extra_tokens != extra_ids: raise ValueError( f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are" - " provided to SwitchTransformersTokenizer. In this case the additional_special_tokens must include the extra_ids" - " tokens" + " provided to SwitchTransformersTokenizer. In this case the additional_special_tokens must" + " include the extra_ids tokens" ) super().__init__( @@ -147,9 +151,13 @@ def __init__( self._extra_ids = extra_ids @staticmethod - def _eventually_correct_switchtransformers_max_length(pretrained_model_name_or_path, max_model_length, init_max_model_length): + def _eventually_correct_switchtransformers_max_length( + pretrained_model_name_or_path, max_model_length, init_max_model_length + ): if pretrained_model_name_or_path in SwitchTransformersTokenizerFast.max_model_input_sizes: - deprecated_max_model_length = SwitchTransformersTokenizerFast.max_model_input_sizes[pretrained_model_name_or_path] + deprecated_max_model_length = SwitchTransformersTokenizerFast.max_model_input_sizes[ + pretrained_model_name_or_path + ] if init_max_model_length is not None and init_max_model_length != max_model_length: return init_max_model_length elif init_max_model_length is None: diff --git a/tests/models/switchtransformers/test_modeling_flax_switchtransformers.py b/tests/models/switchtransformers/test_modeling_flax_switchtransformers.py index ba512e626ccc9..6966a376faa40 100644 --- a/tests/models/switchtransformers/test_modeling_flax_switchtransformers.py +++ b/tests/models/switchtransformers/test_modeling_flax_switchtransformers.py @@ -46,7 +46,12 @@ from flax.core.frozen_dict import unfreeze from flax.training.common_utils import onehot from flax.traverse_util import flatten_dict - from transformers import FLAX_MODEL_MAPPING, BySwitchTransformersTokenizer, SwitchTransformersConfig, SwitchTransformersTokenizer + from transformers import ( + FLAX_MODEL_MAPPING, + BySwitchTransformersTokenizer, + SwitchTransformersConfig, + SwitchTransformersTokenizer, + ) from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model from transformers.models.switchtransformers.modeling_flax_switchtransformers import ( FlaxSwitchTransformersEncoderModel, @@ -229,7 +234,9 @@ def prepare_config_and_inputs_for_common(self): @require_flax class FlaxSwitchTransformersModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase): - all_model_classes = (FlaxSwitchTransformersModel, FlaxSwitchTransformersForConditionalGeneration) if is_flax_available() else () + all_model_classes = ( + (FlaxSwitchTransformersModel, FlaxSwitchTransformersForConditionalGeneration) if is_flax_available() else () + ) all_generative_model_classes = (FlaxSwitchTransformersForConditionalGeneration,) if is_flax_available() else () is_encoder_decoder = True @@ -834,7 +841,9 @@ def test_small_byswitchtransformers_integration_test(self): >>> score = switchtransformers_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) """ - model = FlaxSwitchTransformersForConditionalGeneration.from_pretrained("google/byybelkada/switchtransformers-base") + model = FlaxSwitchTransformersForConditionalGeneration.from_pretrained( + "google/byybelkada/switchtransformers-base" + ) tokenizer = BySwitchTransformersTokenizer.from_pretrained("google/byybelkada/switchtransformers-base") input_ids = tokenizer("Hello there", return_tensors="np").input_ids diff --git a/tests/models/switchtransformers/test_modeling_switchtransformers.py b/tests/models/switchtransformers/test_modeling_switchtransformers.py index 31447a1d74fe6..5e58f9d07f943 100644 --- a/tests/models/switchtransformers/test_modeling_switchtransformers.py +++ b/tests/models/switchtransformers/test_modeling_switchtransformers.py @@ -19,8 +19,7 @@ import unittest from transformers import SwitchTransformersConfig, is_torch_available -from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device -from transformers.utils import cached_property +from transformers.testing_utils import require_torch, slow, torch_device from ...generation.test_generation_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -30,8 +29,15 @@ if is_torch_available(): import torch - from transformers import BySwitchTransformersTokenizer, SwitchTransformersEncoderModel, SwitchTransformersForConditionalGeneration, SwitchTransformersModel, SwitchTransformersTokenizer - from transformers.models.switchtransformers.modeling_switchtransformers import SWITCHTRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST + from transformers import ( + SwitchTransformersEncoderModel, + SwitchTransformersForConditionalGeneration, + SwitchTransformersModel, + ) + from transformers.models.switchtransformers.modeling_switchtransformers import ( + SWITCHTRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST, + ) + from transformers.models.switchtransformers.router import _load_balancing_loss, _router_z_loss class SwitchTransformersModelTester: @@ -507,9 +513,13 @@ def prepare_config_and_inputs_for_common(self): @require_torch class SwitchTransformersModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): - all_model_classes = (SwitchTransformersModel, SwitchTransformersForConditionalGeneration) if is_torch_available() else () + all_model_classes = ( + (SwitchTransformersModel, SwitchTransformersForConditionalGeneration) if is_torch_available() else () + ) all_generative_model_classes = (SwitchTransformersForConditionalGeneration,) if is_torch_available() else () - all_parallelizable_model_classes = (SwitchTransformersModel, SwitchTransformersForConditionalGeneration) if is_torch_available() else () + all_parallelizable_model_classes = ( + (SwitchTransformersModel, SwitchTransformersForConditionalGeneration) if is_torch_available() else () + ) fx_compatible = False test_pruning = False test_resize_embeddings = True @@ -813,410 +823,6 @@ def use_task_specific_params(model, task): model.config.update(model.config.task_specific_params[task]) -@require_torch -@require_sentencepiece -@require_tokenizers -class SwitchTransformersModelIntegrationTests(unittest.TestCase): - @cached_property - def model(self): - return SwitchTransformersForConditionalGeneration.from_pretrained("switchtransformers-base").to(torch_device) - - @cached_property - def tokenizer(self): - return SwitchTransformersTokenizer.from_pretrained("switchtransformers-base") - - @slow - def test_small_generation(self): - model = SwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switchtransformers-base").to(torch_device) - model.config.max_length = 8 - model.config.num_beams = 1 - model.config.do_sample = False - tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switchtransformers-base") - - input_ids = tokenizer("summarize: Hello there", return_tensors="pt").input_ids.to(torch_device) - - sequences = model.generate(input_ids) - - output_str = tokenizer.batch_decode(sequences, skip_special_tokens=True)[0] - self.assertTrue(output_str == "Hello there!") - - @slow - def test_small_integration_test(self): - """ - For comparision run: - >>> import switchtransformers # pip install switchtransformers==0.7.1 - >>> from switchtransformers.data.sentencepiece_vocabulary import SentencePieceVocabulary - - >>> path_to_mtf_small_switchtransformers_checkpoint = '' - >>> path_to_mtf_small_spm_model_path = '' - >>> switchtransformers_model = switchtransformers.models.MtfModel(model_dir=path_to_mtf_small_switchtransformers_checkpoint, batch_size=1, tpu=None) - >>> vocab = SentencePieceVocabulary(path_to_mtf_small_spm_model_path, extra_ids=100) - >>> score = switchtransformers_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) - """ - - model = SwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switchtransformers-base").to(torch_device) - tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switchtransformers-base") - - input_ids = tokenizer("Hello there", return_tensors="pt").input_ids - labels = tokenizer("Hi I am", return_tensors="pt").input_ids - - loss = model(input_ids.to(torch_device), labels=labels.to(torch_device)).loss - mtf_score = -(labels.shape[-1] * loss.item()) - - EXPECTED_SCORE = -19.0845 - self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4) - - @slow - def test_small_v1_1_integration_test(self): - """ - For comparision run: - >>> import switchtransformers # pip install switchtransformers==0.7.1 - >>> from switchtransformers.data.sentencepiece_vocabulary import SentencePieceVocabulary - - >>> path_to_mtf_small_switchtransformers_v1_1_checkpoint = '' - >>> path_to_mtf_small_spm_model_path = '' - >>> switchtransformers_model = switchtransformers.models.MtfModel(model_dir=path_to_mtf_small_switchtransformers_v1_1_checkpoint, batch_size=1, tpu=None) - >>> vocab = SentencePieceVocabulary(path_to_mtf_small_spm_model_path, extra_ids=100) - >>> score = switchtransformers_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) - """ - - model = SwitchTransformersForConditionalGeneration.from_pretrained("google/switchtransformers-v1_1-small").to(torch_device) - tokenizer = SwitchTransformersTokenizer.from_pretrained("google/switchtransformers-v1_1-small") - - input_ids = tokenizer("Hello there", return_tensors="pt").input_ids - labels = tokenizer("Hi I am", return_tensors="pt").input_ids - - loss = model(input_ids.to(torch_device), labels=labels.to(torch_device)).loss - mtf_score = -(labels.shape[-1] * loss.item()) - - EXPECTED_SCORE = -59.0293 - self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4) - - @slow - def test_small_byswitchtransformers_integration_test(self): - """ - For comparision run: - >>> import switchtransformers # pip install switchtransformers==0.9.1 - - >>> path_to_byswitchtransformers_small_checkpoint = '' - >>> switchtransformers_model = switchtransformers.models.MtfModel(model_dir=path_to_tf_checkpoint, batch_size=1, tpu=None) - >>> vocab = switchtransformers.data.ByteVocabulary() - >>> score = switchtransformers_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) - """ - - model = SwitchTransformersForConditionalGeneration.from_pretrained("google/byybelkada/switchtransformers-base").to(torch_device) - tokenizer = BySwitchTransformersTokenizer.from_pretrained("google/byybelkada/switchtransformers-base") - - input_ids = tokenizer("Hello there", return_tensors="pt").input_ids - labels = tokenizer("Hi I am", return_tensors="pt").input_ids - - loss = model(input_ids.to(torch_device), labels=labels.to(torch_device)).loss - mtf_score = -(labels.shape[-1] * loss.item()) - - EXPECTED_SCORE = -60.7397 - self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4) - - @slow - def test_summarization(self): - model = self.model - tok = self.tokenizer - - FRANCE_ARTICLE = ( # @noqa - "Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings" - " Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane." - ' Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation."' - ' He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s' - " comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video" - " showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French" - " Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a" - " phone at the wreckage site. The two publications described the supposed video, but did not post it on" - " their websites. The publications said that they watched the video, which was found by a source close to" - " the investigation. \"One can hear cries of 'My God' in several languages,\" Paris Match reported." - ' "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the' - " cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the" - ' screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt,' - " editor-in-chief of Bild online. An official with France's accident investigation agency, the BEA, said" - " the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie spokesman" - " in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the" - ' reports were "completely wrong" and "unwarranted." Cell phones have been collected at the site, he said,' - ' but that they "hadn\'t been exploited yet." Menichini said he believed the cell phones would need to be' - " sent to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by" - " specialized technicians working hand-in-hand with investigators. But none of the cell phones found so" - " far have been sent to the institute, Menichini said. Asked whether staff involved in the search could" - ' have leaked a memory card to the media, Menichini answered with a categorical "no." Reichelt told "Erin' - ' Burnett: Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match' - ' are "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered' - ' cell phones from the crash site after Bild and Paris Match published their reports. "That is something' - " we did not know before. ... Overall we can say many things of the investigation weren't revealed by the" - ' investigation at the beginning," he said. What was mental state of Germanwings co-pilot? German airline' - " Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled depression years before he took the" - " controls of Germanwings Flight 9525, which he's accused of deliberately crashing last week in the" - ' French Alps. Lubitz told his Lufthansa flight training school in 2009 that he had a "previous episode of' - ' severe depression," the airline said Tuesday. Email correspondence between Lubitz and the school' - " discovered in an internal investigation, Lufthansa said, included medical documents he submitted in" - " connection with resuming his flight training. The announcement indicates that Lufthansa, the parent" - " company of Germanwings, knew of Lubitz's battle with depression, allowed him to continue training and" - " ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously said Lubitz was 100%" - ' fit to fly, described its statement Tuesday as a "swift and seamless clarification" and said it was' - " sharing the information and documents -- including training and medical records -- with public" - " prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the" - " past week to recover human remains and plane debris scattered across a steep mountainside. He saw the" - " crisis center set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash" - " site, where grieving families have left flowers at a simple stone memorial. Menichini told CNN late" - " Tuesday that no visible human remains were left at the site but recovery teams would keep searching." - " French President Francois Hollande, speaking Tuesday, said that it should be possible to identify all" - " the victims using DNA analysis by the end of the week, sooner than authorities had previously suggested." - " In the meantime, the recovery of the victims' personal belongings will start Wednesday, Menichini said." - " Among those personal belongings could be more cell phones belonging to the 144 passengers and six crew" - " on board. Check out the latest from our correspondents . The details about Lubitz's correspondence with" - " the flight school during his training were among several developments as investigators continued to" - " delve into what caused the crash and Lubitz's possible motive for downing the jet. A Lufthansa" - " spokesperson told CNN on Tuesday that Lubitz had a valid medical certificate, had passed all his" - ' examinations and "held all the licenses required." Earlier, a spokesman for the prosecutor\'s office in' - " Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz suffered from suicidal tendencies at" - " some point before his aviation career and underwent psychotherapy before he got his pilot's license." - " Kumpa emphasized there's no evidence suggesting Lubitz was suicidal or acting aggressively before the" - " crash. Investigators are looking into whether Lubitz feared his medical condition would cause him to" - " lose his pilot's license, a European government official briefed on the investigation told CNN on" - ' Tuesday. While flying was "a big part of his life," the source said, it\'s only one theory being' - " considered. Another source, a law enforcement official briefed on the investigation, also told CNN that" - " authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would" - " not be allowed to fly because of his medical problems. Lubitz's girlfriend told investigators he had" - " seen an eye doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded" - " he had psychological issues, the European government official said. But no matter what details emerge" - " about his previous mental health struggles, there's more to the story, said Brian Russell, a forensic" - ' psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the fact' - " that maybe they weren't going to keep doing their job and they're upset about that and so they're" - ' suicidal," he said. "But there is no mental illness that explains why somebody then feels entitled to' - " also take that rage and turn it outward on 149 other people who had nothing to do with the person's" - ' problems." Germanwings crash compensation: What we know . Who was the captain of Germanwings Flight' - " 9525? CNN's Margot Haddad reported from Marseille and Pamela Brown from Dusseldorf, while Laura" - " Smith-Spark wrote from London. CNN's Frederik Pleitgen, Pamela Boykoff, Antonia Mortensen, Sandrine" - " Amiel and Anna-Maja Rappard contributed to this report." - ) - SHORTER_ARTICLE = ( - "(CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on" - " Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The" - " formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based." - " The Palestinians signed the ICC's founding Rome Statute in January, when they also accepted its" - ' jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East' - ' Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the' - " situation in Palestinian territories, paving the way for possible war crimes investigations against" - " Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and" - " the United States, neither of which is an ICC member, opposed the Palestinians' efforts to join the" - " body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday's ceremony, said it was a" - ' move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the' - ' world is also a step closer to ending a long era of impunity and injustice," he said, according to an' - ' ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge' - " Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the" - ' Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine' - " acquires all the rights as well as responsibilities that come with being a State Party to the Statute." - ' These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights' - ' Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should' - " immediately end their pressure, and countries that support universal acceptance of the court's treaty" - ' should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the' - " group. \"What's objectionable is the attempts to undermine international justice, not Palestine's" - ' decision to join a treaty to which over 100 countries around the world are members." In January, when' - " the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an" - ' outrage, saying the court was overstepping its boundaries. The United States also said it "strongly"' - " disagreed with the court's decision. \"As we have said repeatedly, we do not believe that Palestine is a" - ' state and therefore we do not believe that it is eligible to join the ICC," the State Department said in' - ' a statement. It urged the warring sides to resolve their differences through direct negotiations. "We' - ' will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace,"' - " it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the" - ' territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the' - " court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou" - ' Bensouda said her office would "conduct its analysis in full independence and impartiality." The war' - " between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry" - " will include alleged war crimes committed since June. The International Criminal Court was set up in" - " 2002 to prosecute genocide, crimes against humanity and war crimes. CNN's Vasco Cotovio, Kareem Khadder" - " and Faith Karimi contributed to this report." - ) - IRAN_ARTICLE = ( - "(CNN)The United States and its negotiating partners reached a very strong framework agreement with Iran" - " in Lausanne, Switzerland, on Thursday that limits Iran's nuclear program in such a way as to effectively" - " block it from building a nuclear weapon. Expect pushback anyway, if the recent past is any harbinger." - " Just last month, in an attempt to head off such an agreement, House Speaker John Boehner invited Israeli" - " Prime Minister Benjamin Netanyahu to preemptively blast it before Congress, and 47 senators sent a" - " letter to the Iranian leadership warning them away from a deal. The debate that has already begun since" - " the announcement of the new framework will likely result in more heat than light. It will not be helped" - " by the gathering swirl of dubious assumptions and doubtful assertions. Let us address some of these: ." - " The most misleading assertion, despite universal rejection by experts, is that the negotiations'" - " objective at the outset was the total elimination of any nuclear program in Iran. That is the position" - " of Netanyahu and his acolytes in the U.S. Congress. But that is not and never was the objective. If it" - " had been, there would have been no Iranian team at the negotiating table. Rather, the objective has" - " always been to structure an agreement or series of agreements so that Iran could not covertly develop a" - " nuclear arsenal before the United States and its allies could respond. The new framework has exceeded" - " expectations in achieving that goal. It would reduce Iran's low-enriched uranium stockpile, cut by" - " two-thirds its number of installed centrifuges and implement a rigorous inspection regime. Another" - " dubious assumption of opponents is that the Iranian nuclear program is a covert weapons program. Despite" - " sharp accusations by some in the United States and its allies, Iran denies having such a program, and" - " U.S. intelligence contends that Iran has not yet made the decision to build a nuclear weapon. Iran's" - " continued cooperation with International Atomic Energy Agency inspections is further evidence on this" - " point, and we'll know even more about Iran's program in the coming months and years because of the deal." - " In fact, the inspections provisions that are part of this agreement are designed to protect against any" - " covert action by the Iranians. What's more, the rhetoric of some members of Congress has implied that" - " the negotiations have been between only the United States and Iran (i.e., the 47 senators' letter" - " warning that a deal might be killed by Congress or a future president). This of course is not the case." - " The talks were between Iran and the five permanent members of the U.N. Security Council (United States," - " United Kingdom, France, China and Russia) plus Germany, dubbed the P5+1. While the United States has" - " played a leading role in the effort, it negotiated the terms alongside its partners. If the agreement" - " reached by the P5+1 is rejected by Congress, it could result in an unraveling of the sanctions on Iran" - " and threaten NATO cohesion in other areas. Another questionable assertion is that this agreement" - " contains a sunset clause, after which Iran will be free to do as it pleases. Again, this is not the" - " case. Some of the restrictions on Iran's nuclear activities, such as uranium enrichment, will be eased" - " or eliminated over time, as long as 15 years. But most importantly, the framework agreement includes" - " Iran's ratification of the Additional Protocol, which allows IAEA inspectors expanded access to nuclear" - " sites both declared and nondeclared. This provision will be permanent. It does not sunset. Thus, going" - " forward, if Iran decides to enrich uranium to weapons-grade levels, monitors will be able to detect such" - " a move in a matter of days and alert the U.N. Security Council. Many in Congress have said that the" - ' agreement should be a formal treaty requiring the Senate to "advise and consent." But the issue is not' - " suited for a treaty. Treaties impose equivalent obligations on all signatories. For example, the New" - " START treaty limits Russia and the United States to 1,550 deployed strategic warheads. But any agreement" - " with Iran will not be so balanced. The restrictions and obligations in the final framework agreement" - " will be imposed almost exclusively on Iran. The P5+1 are obligated only to ease and eventually remove" - " most but not all economic sanctions, which were imposed as leverage to gain this final deal. Finally" - " some insist that any agreement must address Iranian missile programs, human rights violations or support" - " for Hamas or Hezbollah. As important as these issues are, and they must indeed be addressed, they are" - " unrelated to the most important aim of a nuclear deal: preventing a nuclear Iran. To include them in" - " the negotiations would be a poison pill. This agreement should be judged on its merits and on how it" - " affects the security of our negotiating partners and allies, including Israel. Those judgments should be" - " fact-based, not based on questionable assertions or dubious assumptions." - ) - ARTICLE_SUBWAY = ( - "New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A" - " year later, she got married again in Westchester County, but to a different man and without divorcing" - " her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos" - ' declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married' - " once more, this time in the Bronx. In an application for a marriage license, she stated it was her" - ' "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false' - ' instrument for filing in the first degree," referring to her false statements on the 2010 marriage' - " license application, according to court documents. Prosecutors said the marriages were part of an" - " immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to" - " her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was" - " arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New" - " York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total," - " Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All" - " occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be" - " married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors" - " said the immigration scam involved some of her husbands, who filed for permanent residence status" - " shortly after the marriages. Any divorces happened only after such filings were approved. It was" - " unclear whether any of the men will be prosecuted. The case was referred to the Bronx District" - " Attorney's Office by Immigration and Customs Enforcement and the Department of Homeland Security's" - ' Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt,' - " Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his" - " native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces" - " up to four years in prison. Her next court appearance is scheduled for May 18." - ) - - expected_summaries = [ - 'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a' - " cell phone video of the final seconds . \"one can hear cries of 'My God' in several languages,\" one" - " magazine says .", - "the formal accession was marked by a ceremony at The Hague, in the Netherlands . the ICC opened a" - " preliminary examination into the situation in the occupied Palestinian territory . as members of the" - " court, Palestinians may be subject to counter-charges as well .", - "the u.s. and its negotiating partners reached a very strong framework agreement with Iran . aaron miller:" - " the debate that has already begun since the announcement of the new framework will likely result in more" - " heat than light . the deal would reduce Iran's low-enriched uranium stockpile, cut centrifuges and" - " implement a rigorous inspection regime .", - "prosecutors say the marriages were part of an immigration scam . if convicted, barrientos faces two" - ' criminal counts of "offering a false instrument for filing in the first degree" she has been married 10' - " times, with nine of her marriages occurring between 1999 and 2002 .", - ] - - use_task_specific_params(model, "summarization") - - dct = tok( - [model.config.prefix + x for x in [FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY]], - padding="max_length", - truncation=True, - return_tensors="pt", - ).to(torch_device) - self.assertEqual(512, dct["input_ids"].shape[1]) - - hypotheses_batch = model.generate( - **dct, - num_beams=4, - length_penalty=2.0, - max_length=142, - min_length=56, - no_repeat_ngram_size=3, - do_sample=False, - early_stopping=True, - ) - - decoded = tok.batch_decode(hypotheses_batch, skip_special_tokens=True, clean_up_tokenization_spaces=False) - self.assertListEqual( - expected_summaries, - decoded, - ) - - @slow - def test_translation_en_to_de(self): - model = self.model - tok = self.tokenizer - use_task_specific_params(model, "translation_en_to_de") - - en_text = '"Luigi often said to me that he never wanted the brothers to end up in court", she wrote.' - expected_translation = ( - '"Luigi sagte mir oft, dass er nie wollte, dass die Brüder am Gericht sitzen", schrieb sie.' - ) - - input_ids = tok.encode(model.config.prefix + en_text, return_tensors="pt") - input_ids = input_ids.to(torch_device) - output = model.generate(input_ids) - translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) - self.assertEqual(translation, expected_translation) - - @slow - def test_translation_en_to_fr(self): - model = self.model # switchtransformers-base - tok = self.tokenizer - use_task_specific_params(model, "translation_en_to_fr") - - en_text = ( - ' This image section from an infrared recording by the Spitzer telescope shows a "family portrait" of' - " countless generations of stars: the oldest stars are seen as blue dots. " - ) - - input_ids = tok.encode(model.config.prefix + en_text, return_tensors="pt") - input_ids = input_ids.to(torch_device) - - output = model.generate( - input_ids=input_ids, - num_beams=4, - length_penalty=2.0, - max_length=100, - no_repeat_ngram_size=3, - do_sample=False, - early_stopping=True, - ) - translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) - new_truncated_translation = ( - "Cette section d'images provenant de l'enregistrement infrarouge effectué par le télescope Spitzer montre " - "un " - "« portrait familial » de générations innombrables d’étoiles : les plus anciennes sont observées " - "sous forme " - "de points bleus." - ) - - self.assertEqual(translation, new_truncated_translation) - - @slow - def test_translation_en_to_ro(self): - model = self.model - tok = self.tokenizer - use_task_specific_params(model, "translation_en_to_ro") - en_text = "Taco Bell said it plans to add 2,000 locations in the US by 2022." - expected_translation = "Taco Bell a declarat că intenţionează să adauge 2 000 de locaţii în SUA până în 2022." - - inputs = tok(model.config.prefix + en_text, return_tensors="pt").to(torch_device) - output = model.generate(**inputs) - translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) - self.assertEqual(translation, expected_translation) - - @require_torch class TestAsymmetricSwitchTransformers(unittest.TestCase): def build_model_and_check_forward_pass(self, **kwargs): @@ -1252,3 +858,62 @@ def test_defaulting_to_symmetry(self): # num_hidden_layers is passed to SwitchTransformersConfig as num_layers model = self.build_model_and_check_forward_pass(num_hidden_layers=2) assert len(model.decoder.block) == len(model.encoder.block) == 2 + + +class SwitchTransformerRouterTest(unittest.TestCase): + r""" + Switch Transformers has different blocks from classic transformer based models. + The Swift MLP contains a Router class, that has to be tested to check if it is correctly implemented + + Original implementation of the routers here: + + """ + + def test_equivalency_balancy_loss(self): + r""" + This test checks if the balancy loss is correctly implemented + as in the original implementation of the Switch Transformer . + """ + router_probs = torch.Tensor( + [ + [0.35490513, 0.60419905], + [0.4275843, 0.23061597], + [0.32985854, 0.43953657], + [0.25099766, 0.27730572], + [0.7678207, 0.71474564], + ] + ) + + expert_indices = torch.Tensor([[0], [1], [1], [0], [0]]).to(torch.int32) + + loss = _load_balancing_loss(router_probs, expert_indices) + self.assertAlmostEqual(loss.item(), 0.8741045, places=5) + + def test_equivalency_router_z_loss(self): + r""" + This test checks if the router z loss is correctly implemented + as in the original implementation of the Switch Transformer . + """ + logits = torch.Tensor( + [ + [ + [-4.2124424, 3.891939, -3.6481273, 1.8849981], + [0.32625437, 2.918651, 0.84758997, -4.556842], + [-3.32062, 4.6977115, -0.15439987, 0.44086337], + [3.4467149, 4.3436565, -4.7224274, -4.264637], + [-2.224406, -2.5318158, -1.3832569, 1.1891162], + [-2.320062, -0.44705987, 4.289819, -0.00662684], + ], + [ + [0.99470854, -0.6992364, 0.25503993, 4.2952085], + [3.5937333, -3.2408535, -4.298278, 4.426601], + [0.7669008, 2.6588762, 2.4505413, 4.6051874], + [0.23330331, -3.0845237, 0.6262374, -2.9865491], + [0.7595146, -2.1099675, -4.155346, -2.8326452], + [2.3771453, 1.004138, -3.1781673, 0.7581556], + ], + ] + ) + + loss = _router_z_loss(logits) + self.assertAlmostEqual(loss.item(), 13.786719, places=5) diff --git a/tests/models/switchtransformers/test_modeling_tf_switchtransformers.py b/tests/models/switchtransformers/test_modeling_tf_switchtransformers.py index b24b0dae2ea0e..048c1d7e0b4ce 100644 --- a/tests/models/switchtransformers/test_modeling_tf_switchtransformers.py +++ b/tests/models/switchtransformers/test_modeling_tf_switchtransformers.py @@ -26,7 +26,13 @@ if is_tf_available(): import tensorflow as tf - from transformers import BySwitchTransformersTokenizer, SwitchTransformersTokenizer, TFSwitchTransformersEncoderModel, TFSwitchTransformersForConditionalGeneration, TFSwitchTransformersModel + from transformers import ( + BySwitchTransformersTokenizer, + SwitchTransformersTokenizer, + TFSwitchTransformersEncoderModel, + TFSwitchTransformersForConditionalGeneration, + TFSwitchTransformersModel, + ) class TFSwitchTransformersModelTester: @@ -115,7 +121,9 @@ def create_and_check_switchtransformers_with_lm_head(self, config, input_ids, in self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) - def create_and_check_switchtransformers_decoder_model_past(self, config, input_ids, decoder_input_ids, attention_mask): + def create_and_check_switchtransformers_decoder_model_past( + self, config, input_ids, decoder_input_ids, attention_mask + ): model = TFSwitchTransformersModel(config=config).get_decoder() input_ids = input_ids[:1, :] @@ -242,7 +250,9 @@ def prepare_config_and_inputs_for_common(self): class TFSwitchTransformersModelTest(TFModelTesterMixin, unittest.TestCase): is_encoder_decoder = True - all_model_classes = (TFSwitchTransformersModel, TFSwitchTransformersForConditionalGeneration) if is_tf_available() else () + all_model_classes = ( + (TFSwitchTransformersModel, TFSwitchTransformersForConditionalGeneration) if is_tf_available() else () + ) all_generative_model_classes = (TFSwitchTransformersForConditionalGeneration,) if is_tf_available() else () test_onnx = False @@ -702,7 +712,9 @@ def test_small_byswitchtransformers_integration_test(self): >>> score = switchtransformers_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) """ - model = TFSwitchTransformersForConditionalGeneration.from_pretrained("google/byybelkada/switchtransformers-base") + model = TFSwitchTransformersForConditionalGeneration.from_pretrained( + "google/byybelkada/switchtransformers-base" + ) tokenizer = BySwitchTransformersTokenizer.from_pretrained("google/byybelkada/switchtransformers-base") input_ids = tokenizer("Hello there", return_tensors="tf").input_ids diff --git a/tests/models/switchtransformers/test_tokenization_switchtransformers.py b/tests/models/switchtransformers/test_tokenization_switchtransformers.py index 8ed4c6f80d78a..7ba607e27b5e9 100644 --- a/tests/models/switchtransformers/test_tokenization_switchtransformers.py +++ b/tests/models/switchtransformers/test_tokenization_switchtransformers.py @@ -17,7 +17,13 @@ import tempfile import unittest -from transformers import SPIECE_UNDERLINE, AddedToken, BatchEncoding, SwitchTransformersTokenizer, SwitchTransformersTokenizerFast +from transformers import ( + SPIECE_UNDERLINE, + AddedToken, + BatchEncoding, + SwitchTransformersTokenizer, + SwitchTransformersTokenizerFast, +) from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers, slow from transformers.utils import cached_property, is_tf_available, is_torch_available From 9c7643cc15b34d9a77c54cec46aa2f680f7a8f3a Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 4 Oct 2022 18:21:21 +0100 Subject: [PATCH 004/102] clean up - remove `tf` modeling files --- .../models/switchtransformers/__init__.py | 27 - .../modeling_switchtransformers.py | 285 --- .../modeling_tf_switchtransformers.py | 1674 ----------------- .../models/switchtransformers/router.py | 72 +- .../models/switchtransformers/router_flax.py | 1215 ++++++------ .../test_modeling_tf_switchtransformers.py | 1066 ----------- 6 files changed, 630 insertions(+), 3709 deletions(-) delete mode 100644 src/transformers/models/switchtransformers/modeling_tf_switchtransformers.py delete mode 100644 tests/models/switchtransformers/test_modeling_tf_switchtransformers.py diff --git a/src/transformers/models/switchtransformers/__init__.py b/src/transformers/models/switchtransformers/__init__.py index 615827cb82a32..44e99f74e80d3 100644 --- a/src/transformers/models/switchtransformers/__init__.py +++ b/src/transformers/models/switchtransformers/__init__.py @@ -68,19 +68,6 @@ "load_tf_weights_in_switchtransformers", ] -try: - if not is_tf_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - pass -else: - _import_structure["modeling_tf_switchtransformers"] = [ - "TF_SWITCHTRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST", - "TFSwitchTransformersEncoderModel", - "TFSwitchTransformersForConditionalGeneration", - "TFSwitchTransformersModel", - "TFSwitchTransformersPreTrainedModel", - ] try: if not is_flax_available(): @@ -134,20 +121,6 @@ load_tf_weights_in_switchtransformers, ) - try: - if not is_tf_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - pass - else: - from .modeling_tf_switchtransformers import ( - TF_SWITCHTRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST, - TFSwitchTransformersEncoderModel, - TFSwitchTransformersForConditionalGeneration, - TFSwitchTransformersModel, - TFSwitchTransformersPreTrainedModel, - ) - try: if not is_flax_available(): raise OptionalDependencyNotAvailable() diff --git a/src/transformers/models/switchtransformers/modeling_switchtransformers.py b/src/transformers/models/switchtransformers/modeling_switchtransformers.py index 78c55f539f330..15383d0ebc98a 100644 --- a/src/transformers/models/switchtransformers/modeling_switchtransformers.py +++ b/src/transformers/models/switchtransformers/modeling_switchtransformers.py @@ -17,7 +17,6 @@ import copy import math -import os import warnings from typing import Optional, Tuple, Union @@ -44,7 +43,6 @@ logging, replace_return_docstrings, ) -from ...utils.model_parallel_utils import assert_device_map, get_device_map from .configuration_switchtransformers import SwitchTransformersConfig @@ -64,175 +62,6 @@ ] -#################################################### -# This is a conversion method from TF 1.0 to PyTorch -# More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28 -#################################################### -# Copied from transformers.models.t5.modeling_t5.load_tf_weights_in_t5 with t5->switchtransformers -def load_tf_weights_in_switchtransformers(model, config, tf_checkpoint_path): - """Load tf checkpoints in a pytorch model.""" - try: - import re - - import numpy as np - import tensorflow as tf - except ImportError: - logger.error( - "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions." - ) - raise - tf_path = os.path.abspath(tf_checkpoint_path) - logger.info(f"Converting TensorFlow checkpoint from {tf_path}") - # Load weights from TF model - init_vars = tf.train.list_variables(tf_path) - names = [] - tf_weights = {} - for name, shape in init_vars: - logger.info(f"Loading TF weight {name} with shape {shape}") - array = tf.train.load_variable(tf_path, name) - names.append(name) - tf_weights[name] = array - - for txt_name in names: - name = txt_name.split("/") - # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v - # which are not required for using pretrained model - if any( - n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] - for n in name - ): - logger.info(f"Skipping {'/'.join(name)}") - tf_weights.pop(txt_name, None) - continue - if "_slot_" in name[-1]: - logger.info(f"Skipping {'/'.join(name)}") - tf_weights.pop(txt_name, None) - continue - pointer = model - array = tf_weights[txt_name] - - for m_name in name: - if re.fullmatch(r"[A-Za-z]+_\d+", m_name): - scope_names = re.split(r"_(\d+)", m_name) - else: - scope_names = [m_name] - if scope_names[0] in ["kernel", "scale", "embedding"]: - pointer = getattr(pointer, "weight") - elif scope_names[0] == "self_attention": - pointer = getattr(pointer, "layer") - pointer = pointer[0] - elif scope_names[0] == "enc_dec_attention": - pointer = getattr(pointer, "layer") - pointer = pointer[1] - elif scope_names[0] == "dense_relu_dense": - pointer = getattr(pointer, "layer") - pointer = pointer[2] - elif scope_names[0] == "rms_norm": - if hasattr(pointer, "layer_norm"): - pointer = getattr(pointer, "layer_norm") - elif hasattr(pointer, "final_layer_norm"): - pointer = getattr(pointer, "final_layer_norm") - elif scope_names[0] == "scale": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "output_bias" or scope_names[0] == "beta": - pointer = getattr(pointer, "bias") - elif scope_names[0] == "squad": - pointer = getattr(pointer, "classifier") - elif scope_names[0] == "decoder" and name[1] == "logits": - continue - elif scope_names[0] == "logits": - pointer = getattr(pointer, "lm_head") - elif scope_names[0] == "wi" and len(scope_names) > 1 and scope_names[1].isdigit(): - pointer = getattr(pointer, f"wi_{scope_names[1]}") - continue - else: - try: - pointer = getattr(pointer, scope_names[0]) - except AttributeError: - logger.info(f"Skipping {'/'.join(name)}") - continue - if len(scope_names) >= 2: - num = int(scope_names[1]) - pointer = pointer[num] - if scope_names[0] not in ["kernel", "scale", "embedding"]: - pointer = getattr(pointer, "weight") - if scope_names[0] != "embedding": - logger.info(f"Transposing numpy weight of shape {array.shape} for {name}") - array = np.transpose(array) - try: - assert ( - pointer.shape == array.shape - ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" - except AssertionError as e: - e.args += (pointer.shape, array.shape) - raise - logger.info(f"Initialize PyTorch weight {name}") - pointer.data = torch.from_numpy(array.astype(np.float32)) - tf_weights.pop(txt_name, None) - - logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.") - return model - - -#################################################### -# PyTorch Models are constructed by sub-classing -# - torch.nn.Module for the layers and -# - PreTrainedModel for the models (it-self a sub-class of nn.Module) -#################################################### -PARALLELIZE_DOCSTRING = r""" - This is an experimental feature and is a subject to change at a moment's notice. - - Uses a device map to distribute attention modules of the model across several devices. If no device map is given, - it will evenly distribute blocks across all devices. - - Args: - device_map (`Dict[int, list]`, optional, defaults to None): - A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always - automatically mapped to the first device (for esoteric reasons). That means that the first device should - have fewer attention modules mapped to it than other devices. For reference, the switchtransformers models - have the following number of attention modules: - - - ybelkada/switchtransformers-base: 6 - - switchtransformers-base: 12 - - switchtransformers-large: 24 - - switchtransformers-3b: 24 - - switchtransformers-11b: 24 - - Example: - - ```python - # Here is an example of a device map on a machine with 4 GPUs using switchtransformers-3b, which has a total of 24 attention modules: - model = SwitchTransformersForConditionalGeneration.from_pretrained("switchtransformers-3b") - device_map = { - 0: [0, 1, 2], - 1: [3, 4, 5, 6, 7, 8, 9], - 2: [10, 11, 12, 13, 14, 15, 16], - 3: [17, 18, 19, 20, 21, 22, 23], - } - model.parallelize(device_map) - ``` -""" -DEPARALLELIZE_DOCSTRING = r""" - Moves the model to cpu from a model parallel state. - - Example: - - ```python - # On a 4 GPU machine with switchtransformers-3b: - model = SwitchTransformersForConditionalGeneration.from_pretrained("switchtransformers-3b") - device_map = { - 0: [0, 1, 2], - 1: [3, 4, 5, 6, 7, 8, 9], - 2: [10, 11, 12, 13, 14, 15, 16], - 3: [17, 18, 19, 20, 21, 22, 23], - } - model.parallelize(device_map) # Splits the model across several devices - model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() - ``` -""" - - # Copied from transformers.models.t5.modeling_t5.T5LayerNorm with T5->SwitchTransformers class SwitchTransformersLayerNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): @@ -260,19 +89,6 @@ def forward(self, hidden_states): return self.weight * hidden_states -try: - from apex.normalization import FusedRMSNorm - - SwitchTransformersLayerNorm = FusedRMSNorm # noqa - - logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of SwitchTransformersLayerNorm") -except ImportError: - # using the normal SwitchTransformersLayerNorm - pass -except Exception: - logger.warning("discovered apex but it failed to load, falling back to SwitchTransformersLayerNorm") - pass - ALL_LAYERNORM_LAYERS.append(SwitchTransformersLayerNorm) @@ -748,7 +564,6 @@ def forward( return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) -# Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel with T5->SwitchTransformers,t5->switchtransformers class SwitchTransformersPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -756,7 +571,6 @@ class SwitchTransformersPreTrainedModel(PreTrainedModel): """ config_class = SwitchTransformersConfig - load_tf_weights = load_tf_weights_in_switchtransformers base_model_prefix = "transformer" is_parallelizable = True supports_gradient_checkpointing = True @@ -850,7 +664,6 @@ def _shift_right(self, input_ids): return shifted_input_ids -# Copied from transformers.models.t5.modeling_t5.T5Stack with T5->SwitchTransformers class SwitchTransformersStack(nn.Module): def __init__(self, config, embed_tokens=None): super().__init__(config) @@ -874,39 +687,6 @@ def __init__(self, config, embed_tokens=None): self.device_map = None self.gradient_checkpointing = False - @add_start_docstrings(PARALLELIZE_DOCSTRING) - def parallelize(self, device_map=None): - # Check validity of device_map - self.device_map = ( - get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map - ) - assert_device_map(self.device_map, len(self.block)) - self.model_parallel = True - self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) - self.last_device = "cuda:" + str(max(self.device_map.keys())) - # Load onto devices - for k, v in self.device_map.items(): - for layer in v: - cuda_device = "cuda:" + str(k) - self.block[layer] = self.block[layer].to(cuda_device) - - # Set embed_tokens to first layer - self.embed_tokens = self.embed_tokens.to(self.first_device) - # Set final layer norm to last device - self.final_layer_norm = self.final_layer_norm.to(self.last_device) - - @add_start_docstrings(PARALLELIZE_DOCSTRING) - def deparallelize(self): - self.model_parallel = False - self.device_map = None - self.first_device = "cpu" - self.last_device = "cpu" - for i in range(len(self.block)): - self.block[i] = self.block[i].to("cpu") - self.embed_tokens = self.embed_tokens.to("cpu") - self.final_layer_norm = self.final_layer_norm.to("cpu") - torch.cuda.empty_cache() - def get_input_embeddings(self): return self.embed_tokens @@ -1323,28 +1103,6 @@ def __init__(self, config: SwitchTransformersConfig): self.model_parallel = False self.device_map = None - @add_start_docstrings(PARALLELIZE_DOCSTRING) - def parallelize(self, device_map=None): - self.device_map = ( - get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) - if device_map is None - else device_map - ) - assert_device_map(self.device_map, len(self.encoder.block)) - self.encoder.parallelize(self.device_map) - self.decoder.parallelize(self.device_map) - self.model_parallel = True - - @add_start_docstrings(DEPARALLELIZE_DOCSTRING) - def deparallelize(self): - self.encoder.deparallelize() - self.decoder.deparallelize() - self.encoder = self.encoder.to("cpu") - self.decoder = self.decoder.to("cpu") - self.model_parallel = False - self.device_map = None - torch.cuda.empty_cache() - def get_input_embeddings(self): return self.shared @@ -1522,30 +1280,6 @@ def __init__(self, config: SwitchTransformersConfig): self.model_parallel = False self.device_map = None - @add_start_docstrings(PARALLELIZE_DOCSTRING) - def parallelize(self, device_map=None): - self.device_map = ( - get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) - if device_map is None - else device_map - ) - assert_device_map(self.device_map, len(self.encoder.block)) - self.encoder.parallelize(self.device_map) - self.decoder.parallelize(self.device_map) - self.lm_head = self.lm_head.to(self.decoder.first_device) - self.model_parallel = True - - @add_start_docstrings(DEPARALLELIZE_DOCSTRING) - def deparallelize(self): - self.encoder.deparallelize() - self.decoder.deparallelize() - self.encoder = self.encoder.to("cpu") - self.decoder = self.decoder.to("cpu") - self.lm_head = self.lm_head.to("cpu") - self.model_parallel = False - self.device_map = None - torch.cuda.empty_cache() - def get_input_embeddings(self): return self.shared @@ -1801,25 +1535,6 @@ def __init__(self, config: SwitchTransformersConfig): self.model_parallel = False self.device_map = None - @add_start_docstrings(PARALLELIZE_DOCSTRING) - def parallelize(self, device_map=None): - self.device_map = ( - get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) - if device_map is None - else device_map - ) - assert_device_map(self.device_map, len(self.encoder.block)) - self.encoder.parallelize(self.device_map) - self.model_parallel = True - - @add_start_docstrings(DEPARALLELIZE_DOCSTRING) - def deparallelize(self): - self.encoder.deparallelize() - self.encoder = self.encoder.to("cpu") - self.model_parallel = False - self.device_map = None - torch.cuda.empty_cache() - def get_input_embeddings(self): return self.shared diff --git a/src/transformers/models/switchtransformers/modeling_tf_switchtransformers.py b/src/transformers/models/switchtransformers/modeling_tf_switchtransformers.py deleted file mode 100644 index 01faf07112a76..0000000000000 --- a/src/transformers/models/switchtransformers/modeling_tf_switchtransformers.py +++ /dev/null @@ -1,1674 +0,0 @@ -# coding=utf-8 -# Copyright 2022 SwitchTransformers Authors and The HuggingFace Inc. team. -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" TF 2.0 SwitchTransformers model.""" - -import copy -import itertools -import math -import warnings -from typing import Optional, Tuple, Union - -import numpy as np -import tensorflow as tf -from tensorflow.compiler.tf2xla.python.xla import dynamic_slice - -from ...activations_tf import get_tf_activation -from ...modeling_tf_outputs import ( - TFBaseModelOutput, - TFBaseModelOutputWithPastAndCrossAttentions, - TFSeq2SeqLMOutput, - TFSeq2SeqModelOutput, -) -from ...modeling_tf_utils import ( - TFCausalLanguageModelingLoss, - TFModelInputType, - TFPreTrainedModel, - TFSharedEmbeddings, - TFWrappedEmbeddings, - keras_serializable, - unpack_inputs, -) -from ...tf_utils import shape_list, stable_softmax -from ...utils import ( - DUMMY_INPUTS, - DUMMY_MASK, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) -from .configuration_switchtransformers import SwitchTransformersConfig - - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "SwitchTransformersConfig" -_TOKENIZER_FOR_DOC = "SwitchTransformersTokenizer" - -TF_SWITCHTRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "ybelkada/switchtransformers-base", - # See all SwitchTransformers models at https://huggingface.co/models?filter=switchtransformers -] - - -#################################################### -# TF 2.0 Models are constructed using Keras imperative API by sub-classing -# - tf.keras.layers.Layer for the layers and -# - TFPreTrainedModel for the models (it-self a sub-class of tf.keras.Model) -#################################################### - - -# Copied from transformers.models.t5.modeling_tf_t5.TFT5LayerNorm with T5->SwitchTransformers -class TFSwitchTransformersLayerNorm(tf.keras.layers.Layer): - def __init__(self, epsilon=1e-6, **kwargs): - """ - Construct a layernorm module in the SwitchTransformers style No bias and no subtraction of mean. - """ - super().__init__(**kwargs) - self.variance_epsilon = epsilon - - def build(self, input_shape): - """Build shared word embedding layer""" - self.weight = self.add_weight("weight", shape=(input_shape[-1],), initializer="ones") - super().build(input_shape) - - def call(self, hidden_states): - variance = tf.math.reduce_mean(tf.math.square(hidden_states), axis=-1, keepdims=True) - hidden_states = hidden_states * tf.math.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states - - -# Copied from transformers.models.t5.modeling_tf_t5.TFT5DenseActDense with T5->SwitchTransformers -class TFSwitchTransformersDenseActDense(tf.keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - wi_initializer = tf.keras.initializers.RandomNormal( - mean=0, stddev=config.initializer_factor * (config.d_model**-0.5) - ) - wo_initializer = tf.keras.initializers.RandomNormal( - mean=0, stddev=config.initializer_factor * (config.d_ff**-0.5) - ) - self.wi = tf.keras.layers.Dense( - config.d_ff, use_bias=False, name="wi", kernel_initializer=wi_initializer - ) # Update init weights as in flax - self.wo = tf.keras.layers.Dense( - config.d_model, use_bias=False, name="wo", kernel_initializer=wo_initializer - ) # Update init weights as in flax - self.dropout = tf.keras.layers.Dropout(config.dropout_rate) - self.act = get_tf_activation(config.dense_act_fn) - - def call(self, hidden_states, training=False): - hidden_states = self.wi(hidden_states) - hidden_states = self.act(hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = self.wo(hidden_states) - return hidden_states - - -# Copied from transformers.models.t5.modeling_tf_t5.TFT5DenseGatedActDense with T5->SwitchTransformers -class TFSwitchTransformersDenseGatedActDense(tf.keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - wi_initializer = tf.keras.initializers.RandomNormal( - mean=0, stddev=config.initializer_factor * (config.d_model**-0.5) - ) - wo_initializer = tf.keras.initializers.RandomNormal( - mean=0, stddev=config.initializer_factor * (config.d_ff**-0.5) - ) - self.wi_0 = tf.keras.layers.Dense( - config.d_ff, use_bias=False, name="wi_0", kernel_initializer=wi_initializer - ) # Update init weights as in flax - self.wi_1 = tf.keras.layers.Dense( - config.d_ff, use_bias=False, name="wi_1", kernel_initializer=wi_initializer - ) # Update init weights as in flax - self.wo = tf.keras.layers.Dense( - config.d_model, use_bias=False, name="wo", kernel_initializer=wo_initializer - ) # Update init weights as in flax - self.dropout = tf.keras.layers.Dropout(config.dropout_rate) - self.act = get_tf_activation(config.dense_act_fn) - - def call(self, hidden_states, training=False): - hidden_gelu = self.act(self.wi_0(hidden_states)) - hidden_linear = self.wi_1(hidden_states) - hidden_states = hidden_gelu * hidden_linear - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = self.wo(hidden_states) - return hidden_states - - -# Copied from transformers.models.t5.modeling_tf_t5.TFT5LayerFF with T5->SwitchTransformers -class TFSwitchTransformersLayerFF(tf.keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - if config.is_gated_act: - self.DenseReluDense = TFSwitchTransformersDenseGatedActDense(config, name="DenseReluDense") - else: - self.DenseReluDense = TFSwitchTransformersDenseActDense(config, name="DenseReluDense") - - self.layer_norm = TFSwitchTransformersLayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm") - self.dropout = tf.keras.layers.Dropout(config.dropout_rate) - - def call(self, hidden_states, training=False): - normed_hidden_states = self.layer_norm(hidden_states) - dense_output = self.DenseReluDense(normed_hidden_states, training=training) - hidden_states = hidden_states + self.dropout(dense_output, training=training) - return hidden_states - - -# Copied from transformers.models.t5.modeling_tf_t5.TFT5Attention with T5->SwitchTransformers -class TFSwitchTransformersAttention(tf.keras.layers.Layer): - NEW_ID = itertools.count() - - def __init__(self, config, has_relative_attention_bias=False, **kwargs): - super().__init__(**kwargs) - self.layer_id = next(TFSwitchTransformersAttention.NEW_ID) - self.is_decoder = config.is_decoder - self.use_cache = config.use_cache - self.has_relative_attention_bias = has_relative_attention_bias - self.output_attentions = config.output_attentions - - self.relative_attention_num_buckets = config.relative_attention_num_buckets - self.relative_attention_max_distance = config.relative_attention_max_distance - self.d_model = config.d_model - self.key_value_proj_dim = config.d_kv - self.n_heads = config.num_heads - self.inner_dim = self.n_heads * self.key_value_proj_dim - - # Mesh TensorFlow initialization to avoid scaling before softmax - q_initializer = tf.keras.initializers.RandomNormal( - mean=0, stddev=config.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5) - ) - k_initializer = tf.keras.initializers.RandomNormal( - mean=0, stddev=config.initializer_factor * (self.inner_dim**-0.5) - ) - v_initializer = tf.keras.initializers.RandomNormal( - mean=0, stddev=config.initializer_factor * (self.inner_dim**-0.5) - ) - o_initializer = tf.keras.initializers.RandomNormal( - mean=0, stddev=config.initializer_factor * (self.inner_dim**-0.5) - ) - self.relative_attention_bias_initializer = tf.keras.initializers.RandomNormal( - mean=0, stddev=config.initializer_factor * (self.inner_dim**-0.5) - ) - - self.q = tf.keras.layers.Dense( - self.inner_dim, use_bias=False, name="q", kernel_initializer=q_initializer - ) # Update init weights as in flax - self.k = tf.keras.layers.Dense( - self.inner_dim, use_bias=False, name="k", kernel_initializer=k_initializer - ) # Update init weights as in flax - self.v = tf.keras.layers.Dense( - self.inner_dim, use_bias=False, name="v", kernel_initializer=v_initializer - ) # Update init weights as in flax - self.o = tf.keras.layers.Dense( - self.d_model, use_bias=False, name="o", kernel_initializer=o_initializer - ) # Update init weights as in flax - self.dropout = tf.keras.layers.Dropout(config.dropout_rate) - - self.pruned_heads = set() - - def build(self, input_shape): - if self.has_relative_attention_bias: - with tf.name_scope("relative_attention_bias"): - self.relative_attention_bias = self.add_weight( - name="embeddings", - shape=[self.relative_attention_num_buckets, self.n_heads], - initializer=self.relative_attention_bias_initializer, # Add initializer - ) - - return super().build(input_shape) - - def prune_heads(self, heads): - raise NotImplementedError - - @staticmethod - def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): - """ - Adapted from Mesh Tensorflow: - https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 - - Translate relative position to a bucket number for relative attention. The relative position is defined as - memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to - position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for - small absolute relative_position and larger buckets for larger absolute relative_positions. All relative - positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. - This should allow for more graceful generalization to longer sequences than the model has been trained on - - Args: - relative_position: an int32 Tensor - bidirectional: a boolean - whether the attention is bidirectional - num_buckets: an integer - max_distance: an integer - - Returns: - a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) - """ - relative_buckets = 0 - # n = -relative_position - if bidirectional: - num_buckets //= 2 - relative_buckets += ( - tf.cast(tf.math.greater(relative_position, 0), dtype=relative_position.dtype) * num_buckets - ) - relative_position = tf.math.abs(relative_position) - else: - relative_position = -tf.math.minimum(relative_position, 0) - # now n is in the range [0, inf) - max_exact = num_buckets // 2 - is_small = tf.math.less(relative_position, max_exact) - relative_position_if_large = max_exact + tf.cast( - tf.math.log(tf.cast(relative_position, tf.float32) / tf.cast(max_exact, tf.float32)) - / math.log(max_distance / max_exact) - * (num_buckets - max_exact), - dtype=relative_position.dtype, - ) - relative_position_if_large = tf.math.minimum(relative_position_if_large, num_buckets - 1) - relative_buckets += tf.where(is_small, relative_position, relative_position_if_large) - return relative_buckets - - def compute_bias(self, query_length, key_length): - """Compute binned relative position bias""" - context_position = tf.range(query_length)[:, None] - memory_position = tf.range(key_length)[None, :] - relative_position = memory_position - context_position # shape (query_length, key_length) - relative_position_bucket = self._relative_position_bucket( - relative_position, - bidirectional=(not self.is_decoder), - num_buckets=self.relative_attention_num_buckets, - max_distance=self.relative_attention_max_distance, - ) - values = tf.gather( - self.relative_attention_bias, relative_position_bucket - ) # shape (query_length, key_length, num_heads) - values = tf.expand_dims( - tf.transpose(values, [2, 0, 1]), axis=0 - ) # shape (1, num_heads, query_length, key_length) - return values - - def call( - self, - hidden_states, - mask=None, - key_value_states=None, - position_bias=None, - past_key_value=None, - layer_head_mask=None, - query_length=None, - use_cache=False, - training=False, - output_attentions=False, - ): - """ - Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). - """ - # Input is (batch_size, query_length, dim) - # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) - # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) - batch_size, seq_length = shape_list(hidden_states)[:2] - - real_seq_length = seq_length - - if past_key_value is not None: - assert ( - len(past_key_value) == 2 - ), f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states" - real_seq_length += shape_list(past_key_value[0])[2] if query_length is None else query_length - - key_length = real_seq_length if key_value_states is None else shape_list(key_value_states)[1] - - def shape(hidden_states): - """projection""" - return tf.transpose( - tf.reshape(hidden_states, (batch_size, -1, self.n_heads, self.key_value_proj_dim)), perm=(0, 2, 1, 3) - ) - - def unshape(hidden_states): - """compute context""" - return tf.reshape(tf.transpose(hidden_states, perm=(0, 2, 1, 3)), (batch_size, -1, self.inner_dim)) - - def project(hidden_states, proj_layer, key_value_states, past_key_value): - """projects hidden states correctly to key/query states""" - if key_value_states is None: - # self-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(hidden_states)) - elif past_key_value is None: - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - - if past_key_value is not None: - if key_value_states is None: - # self-attn - # (batch_size, n_heads, key_length, dim_per_head) - hidden_states = tf.concat([past_key_value, hidden_states], axis=2) - else: - # cross-attn - hidden_states = past_key_value - return hidden_states - - # get query - query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, query_length, dim_per_head) - - # get key/value - key_states = project( - hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None - ) - value_states = project( - hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None - ) - - # to cope with keras serialization - if self.is_decoder and use_cache: - present_key_value_state = (key_states, value_states) - else: - present_key_value_state = None - - scores = tf.einsum( - "bnqd,bnkd->bnqk", query_states, key_states - ) # (batch_size, n_heads, query_length, key_length) - - if position_bias is None: - if not self.has_relative_attention_bias: - position_bias = tf.zeros((1, self.n_heads, real_seq_length, key_length)) - else: - position_bias = self.compute_bias(real_seq_length, key_length) - - # if key and values are already calculated we want only the last query position bias - if past_key_value is not None: - if not self.has_relative_attention_bias: - position_bias = position_bias[:, :, -seq_length:, :] - else: - # we might have a padded past structure, in which case we want to fetch the position bias slice - # right after the most recently filled past index - most_recently_filled_past_index = tf.reduce_max(tf.where(past_key_value[0][0, 0, :, 0] != 0.0)) - position_bias = dynamic_slice( - position_bias, - (0, 0, most_recently_filled_past_index + 1, 0), - (1, self.n_heads, seq_length, real_seq_length), - ) - - if mask is not None: - position_bias = tf.cast(position_bias, dtype=mask.dtype) - position_bias = position_bias + mask # (batch_size, n_heads, query_length, key_length) - - scores += position_bias - weights = stable_softmax(scores, axis=-1) # (batch_size, n_heads, query_length, key_length) - weights = self.dropout(weights, training=training) # (batch_size, n_heads, query_length, key_length) - - # Mask heads if we want to - if layer_head_mask is not None: - tf.debugging.assert_equal( - shape_list(layer_head_mask), - [self.n_heads], - message=( - f"Head mask for a single layer should be of size {(self.n_heads)}, but is" - f" {shape_list(layer_head_mask)}" - ), - ) - weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * weights - - attn_output = tf.matmul(weights, value_states) # (batch_size, n_heads, query_length, dim_per_head) - - attn_output = self.o(unshape(attn_output)) - - outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) - - if output_attentions: - outputs = outputs + (weights,) - - return outputs - - -# Copied from transformers.models.t5.modeling_tf_t5.TFT5LayerSelfAttention with T5->SwitchTransformers -class TFSwitchTransformersLayerSelfAttention(tf.keras.layers.Layer): - def __init__(self, config, has_relative_attention_bias=False, **kwargs): - super().__init__(**kwargs) - self.SelfAttention = TFSwitchTransformersAttention( - config, - has_relative_attention_bias=has_relative_attention_bias, - name="SelfAttention", - ) - self.layer_norm = TFSwitchTransformersLayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm") - self.dropout = tf.keras.layers.Dropout(config.dropout_rate) - - def call( - self, - hidden_states, - attention_mask=None, - position_bias=None, - layer_head_mask=None, - past_key_value=None, - use_cache=False, - output_attentions=False, - training=False, - ): - normed_hidden_states = self.layer_norm(hidden_states) - attention_output = self.SelfAttention( - normed_hidden_states, - mask=attention_mask, - position_bias=position_bias, - layer_head_mask=layer_head_mask, - past_key_value=past_key_value, - use_cache=use_cache, - output_attentions=output_attentions, - training=training, - ) - hidden_states = hidden_states + self.dropout(attention_output[0], training=training) - outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them - return outputs - - -# Copied from transformers.models.t5.modeling_tf_t5.TFT5LayerCrossAttention with T5->SwitchTransformers -class TFSwitchTransformersLayerCrossAttention(tf.keras.layers.Layer): - def __init__(self, config, **kwargs): - super().__init__(**kwargs) - self.EncDecAttention = TFSwitchTransformersAttention( - config, - has_relative_attention_bias=False, - name="EncDecAttention", - ) - self.layer_norm = TFSwitchTransformersLayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm") - self.dropout = tf.keras.layers.Dropout(config.dropout_rate) - - def call( - self, - hidden_states, - key_value_states, - attention_mask=None, - position_bias=None, - layer_head_mask=None, - past_key_value=None, - query_length=None, - use_cache=False, - output_attentions=False, - training=False, - ): - normed_hidden_states = self.layer_norm(hidden_states) - attention_output = self.EncDecAttention( - normed_hidden_states, - mask=attention_mask, - key_value_states=key_value_states, - position_bias=position_bias, - layer_head_mask=layer_head_mask, - past_key_value=past_key_value, - query_length=query_length, - use_cache=use_cache, - output_attentions=output_attentions, - training=training, - ) - hidden_states = hidden_states + self.dropout(attention_output[0], training=training) - outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them - return outputs - - -# Copied from transformers.models.t5.modeling_tf_t5.TFT5Block with T5->SwitchTransformers -class TFSwitchTransformersBlock(tf.keras.layers.Layer): - def __init__(self, config, has_relative_attention_bias=False, **kwargs): - super().__init__(**kwargs) - self.is_decoder = config.is_decoder - self.layer = [] - self.layer.append( - TFSwitchTransformersLayerSelfAttention( - config, - has_relative_attention_bias=has_relative_attention_bias, - name="layer_._0", - ) - ) - if self.is_decoder: - self.layer.append( - TFSwitchTransformersLayerCrossAttention( - config, - name="layer_._1", - ) - ) - - self.layer.append(TFSwitchTransformersLayerFF(config, name=f"layer_._{len(self.layer)}")) - - def call( - self, - hidden_states, - attention_mask=None, - position_bias=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - encoder_decoder_position_bias=None, - layer_head_mask=None, - encoder_layer_head_mask=None, - past_key_value=None, - use_cache=False, - output_attentions=False, - training=False, - ): - - if past_key_value is not None: - assert self.is_decoder, "Only decoder can use `past_key_values`" - expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 - - if len(past_key_value) != expected_num_past_key_values: - raise ValueError( - f"There should be {expected_num_past_key_values} past states. " - f"{'2 (past / key) for cross attention' if expected_num_past_key_values == 4 else ''}." - f"Got {len(past_key_value)} past key / value states" - ) - - self_attn_past_key_value = past_key_value[:2] - cross_attn_past_key_value = past_key_value[2:] - else: - self_attn_past_key_value, cross_attn_past_key_value = None, None - - self_attention_outputs = self.layer[0]( - hidden_states, - attention_mask=attention_mask, - position_bias=position_bias, - layer_head_mask=layer_head_mask, - past_key_value=self_attn_past_key_value, - use_cache=use_cache, - output_attentions=output_attentions, - training=training, - ) - hidden_states, present_key_value_state = self_attention_outputs[:2] - attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights - - if self.is_decoder and encoder_hidden_states is not None: - # the actual query length is unknown for cross attention - # if using past key value states. Need to inject it here - if present_key_value_state is not None: - query_length = shape_list(present_key_value_state[0])[2] - else: - query_length = None - - cross_attention_outputs = self.layer[1]( - hidden_states, - key_value_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - position_bias=encoder_decoder_position_bias, - layer_head_mask=encoder_layer_head_mask, - past_key_value=cross_attn_past_key_value, - query_length=query_length, - use_cache=use_cache, - output_attentions=output_attentions, - training=training, - ) - hidden_states = cross_attention_outputs[0] - # Combine self attn and cross attn key value states - if present_key_value_state is not None: - present_key_value_state = present_key_value_state + cross_attention_outputs[1] - - # Keep cross-attention outputs and relative position weights - attention_outputs = attention_outputs + cross_attention_outputs[2:] - - # Apply Feed Forward layer - hidden_states = self.layer[-1](hidden_states, training=training) - outputs = (hidden_states,) - - # Add attentions if we output them - outputs = outputs + (present_key_value_state,) + attention_outputs - return outputs # hidden-states, present_key_value_states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) - - -#################################################### -# The full model without a specific pretrained or finetuning head is -# provided as a tf.keras.layers.Layer usually called "TFSwitchTransformersMainLayer" -#################################################### -@keras_serializable -# Copied from transformers.models.t5.modeling_tf_t5.TFT5MainLayer with T5->SwitchTransformers -class TFSwitchTransformersMainLayer(tf.keras.layers.Layer): - config_class = SwitchTransformersConfig - - def __init__(self, config, embed_tokens=None, **kwargs): - super().__init__(**kwargs) - - self.config = config - self.output_hidden_states = config.output_hidden_states - self.output_attentions = config.output_attentions - self.use_cache = config.use_cache - - self.embed_tokens = embed_tokens - self.is_decoder = config.is_decoder - - self.config = config - self.num_hidden_layers = config.num_layers - - self.block = [ - TFSwitchTransformersBlock(config, has_relative_attention_bias=bool(i == 0), name=f"block_._{i}") - for i in range(config.num_layers) - ] - self.final_layer_norm = TFSwitchTransformersLayerNorm( - epsilon=config.layer_norm_epsilon, name="final_layer_norm" - ) - self.dropout = tf.keras.layers.Dropout(config.dropout_rate) - - def _prune_heads(self, heads_to_prune): - raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models - - @unpack_inputs - def call( - self, - input_ids=None, - attention_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - inputs_embeds=None, - head_mask=None, - encoder_head_mask=None, - past_key_values=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - training=False, - ) -> Tuple: - - if input_ids is not None and inputs_embeds is not None: - err_msg_prefix = "decoder_" if self.is_decoder else "" - raise ValueError( - f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" - ) - elif input_ids is not None: - input_shape = shape_list(input_ids) - input_ids = tf.reshape(input_ids, (-1, input_shape[-1])) - elif inputs_embeds is not None: - input_shape = shape_list(inputs_embeds)[:-1] - else: - err_msg_prefix = "decoder_" if self.is_decoder else "" - raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") - - if inputs_embeds is None: - assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings" - # Note: tf.gather, on which the embedding layer is based, won't check positive out of bound - # indices on GPU, returning zeros instead. This is a dangerous silent behavior. - tf.debugging.assert_less( - input_ids, - tf.cast(self.embed_tokens.vocab_size, dtype=input_ids.dtype), - message=( - "input_ids must be smaller than the embedding layer's input dimension (got" - f" {tf.math.reduce_max(input_ids)} >= {self.embed_tokens.vocab_size})" - ), - ) - inputs_embeds = self.embed_tokens(input_ids) - - batch_size, seq_length = input_shape - - # required mask seq length can be calculated via length of past - mask_seq_length = ( - shape_list(past_key_values[0][0])[2] + seq_length if past_key_values is not None else seq_length - ) - - if attention_mask is None: - attention_mask = tf.fill((batch_size, mask_seq_length), 1) - if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: - encoder_seq_length = shape_list(encoder_hidden_states)[1] - encoder_attention_mask = tf.fill((batch_size, encoder_seq_length), 1) - - # initialize past_key_values with `None` if past does not exist - if past_key_values is None: - past_key_values = [None] * len(self.block) - - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - attention_mask = tf.cast(attention_mask, dtype=inputs_embeds.dtype) - num_dims_attention_mask = len(shape_list(attention_mask)) - if num_dims_attention_mask == 3: - extended_attention_mask = attention_mask[:, None, :, :] - elif num_dims_attention_mask == 2: - # Provided a padding mask of dimensions [batch_size, mask_seq_length] - # - if the model is a decoder, apply a causal mask in addition to the padding mask - # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] - if self.is_decoder: - seq_ids = tf.range(mask_seq_length) - causal_mask = tf.less_equal( - tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)), - seq_ids[None, :, None], - ) - causal_mask = tf.cast(causal_mask, dtype=attention_mask.dtype) - extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] - if past_key_values[0] is not None: - extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :] - else: - extended_attention_mask = attention_mask[:, None, None, :] - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -1e9 for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - - # SwitchTransformers has a mask that can compare sequence ids, we can simulate this here with this transposition - # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 - # extended_attention_mask = tf.math.equal(extended_attention_mask, - # tf.transpose(extended_attention_mask, perm=(-1, -2))) - - extended_attention_mask = (1.0 - extended_attention_mask) * -1e9 - - if self.is_decoder and encoder_attention_mask is not None: - # If a 2D ou 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=extended_attention_mask.dtype) - num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask)) - if num_dims_encoder_attention_mask == 3: - encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] - if num_dims_encoder_attention_mask == 2: - encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] - - # SwitchTransformers has a mask that can compare sequence ids, we can simulate this here with this transposition - # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 - # encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask, - # tf.transpose(encoder_extended_attention_mask, perm=(-1, -2))) - - encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9 - else: - encoder_extended_attention_mask = None - - present_key_value_states = () if use_cache and self.is_decoder else None - all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - all_cross_attentions = () if (output_attentions and self.is_decoder) else None - position_bias = None - encoder_decoder_position_bias = None - - hidden_states = self.dropout(inputs_embeds, training=training) - - for idx, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - layer_outputs = layer_module( - hidden_states, - attention_mask=extended_attention_mask, - position_bias=position_bias, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - encoder_decoder_position_bias=encoder_decoder_position_bias, - layer_head_mask=head_mask[idx] if head_mask is not None else None, - encoder_layer_head_mask=encoder_head_mask[idx] if encoder_head_mask is not None else None, - past_key_value=past_key_value, - use_cache=use_cache, - output_attentions=output_attentions, - training=training, - ) - - # layer_outputs is a tuple with: - # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) - hidden_states, present_key_value_state = layer_outputs[:2] - - # We share the position biases between the layers - the first layer store them - # layer_outputs = hidden-states, past_key_values, (self-attention weights), - # (self-attention position bias), (cross-attention position bias), (cross-attention weights), - position_bias = layer_outputs[2] - - if self.is_decoder and encoder_hidden_states is not None: - encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] - - # append next layer key value states - if present_key_value_state is not None and use_cache and self.is_decoder: - present_key_value_states = present_key_value_states + (present_key_value_state,) - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[3],) - if self.is_decoder: - all_cross_attentions = all_cross_attentions + (layer_outputs[5],) - - hidden_states = self.final_layer_norm(hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - - # Add last layer - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - outputs = (hidden_states,) - # need to check if is decoder here as well for special cases when using keras compile - if use_cache and self.is_decoder: - outputs = outputs + (present_key_value_states,) - if output_hidden_states: - outputs = outputs + (all_hidden_states,) - if output_attentions: - outputs = outputs + (all_attentions,) - if self.is_decoder: - outputs + (all_cross_attentions,) - return outputs # last-layer hidden state, (past_key_values), (all hidden states), (all attentions), (all_cross_attentions) - - if self.is_decoder: - return TFBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=present_key_value_states, - hidden_states=all_hidden_states, - attentions=all_attentions, - cross_attentions=all_cross_attentions, - ) - else: - return TFBaseModelOutput( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_attentions, - ) - - -#################################################### -# TFSwitchTransformersPreTrainedModel is a sub-class of tf.keras.Model -# which take care of loading and saving pretrained weights -# and various common utilities. -# Here you just need to specify a few (self-explanatory) -# pointers for your model. -#################################################### -# Copied from transformers.models.t5.modeling_tf_t5.TFT5PreTrainedModel with T5->SwitchTransformers -class TFSwitchTransformersPreTrainedModel(TFPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = SwitchTransformersConfig - base_model_prefix = "transformer" - # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model - _keys_to_ignore_on_load_unexpected = [r"decoder\Wblock[\W_0]+layer[\W_1]+EncDecAttention\Wrelative_attention_bias"] - - @property - def dummy_inputs(self): - inputs = tf.constant(DUMMY_INPUTS) - input_mask = tf.constant(DUMMY_MASK) - dummy_inputs = { - "input_ids": inputs, - "decoder_input_ids": inputs, - "decoder_attention_mask": input_mask, - } - return dummy_inputs - - @tf.function( - input_signature=[ - { - "input_ids": tf.TensorSpec((None, None), tf.int64, name="input_ids"), - "attention_mask": tf.TensorSpec((None, None), tf.int64, name="attention_mask"), - "decoder_input_ids": tf.TensorSpec((None, None), tf.int64, name="decoder_input_ids"), - "decoder_attention_mask": tf.TensorSpec((None, None), tf.int64, name="decoder_attention_mask"), - } - ] - ) - def serving(self, inputs): - output = self.call(inputs) - - return self.serving_output(output) - - def get_input_embeddings(self): - return self.shared - - def set_input_embeddings(self, value): - try: - self.shared.weight = value - except AttributeError: - self(self.dummy_inputs) - self.shared.weight = value - - self.shared.vocab_size = shape_list(value)[0] - # retrieve correct absolute scope for embed token wrapper - with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name: - pass - # Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope. - embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name) - self.encoder.embed_tokens = embed_tokens - if hasattr(self, "decoder"): - self.decoder.embed_tokens = embed_tokens - - def _shift_right(self, input_ids): - decoder_start_token_id = self.config.decoder_start_token_id - pad_token_id = self.config.pad_token_id - - assert decoder_start_token_id is not None, ( - "self.model.config.decoder_start_token_id has to be defined. In TF SwitchTransformers it is usually set to" - " the pad_token_id. See SwitchTransformers docs for more information" - ) - - start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id) - start_tokens = tf.cast(start_tokens, input_ids.dtype) # Ensure compatible dtypes for concatenation - shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1) - - assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." - # replace possible -100 values in labels by `pad_token_id` - shifted_input_ids = tf.where( - shifted_input_ids == -100, - tf.cast(tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids.dtype), - shifted_input_ids, - ) - - # "Verify that `labels` has only positive values and -100" - assert_gte0 = tf.debugging.assert_greater_equal( - shifted_input_ids, tf.constant(0, dtype=shifted_input_ids.dtype) - ) - - # Make sure the assertion op is called by wrapping the result in an identity no-op - with tf.control_dependencies([assert_gte0]): - shifted_input_ids = tf.identity(shifted_input_ids) - - return shifted_input_ids - - -SWITCHTRANSFORMERS_START_DOCSTRING = r""" - - The SWITCHTRANSFORMERS model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text - Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan - Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a - text-to-text denoising generative setting. - - This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a [tf.keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it - as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and - behavior. - - - - TensorFlow models and layers in `transformers` accept two formats as input: - - - having all inputs as keyword arguments (like PyTorch models), or - - having all inputs as a list, tuple or dict in the first positional argument. - - The reason the second format is supported is that Keras methods prefer this format when passing inputs to models - and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just - pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second - format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with - the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first - positional argument: - - - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` - - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: - `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` - - a dictionary with one or several input Tensors associated to the input names given in the docstring: - `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` - - Note that when creating models and layers with - [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry - about any of this, as you can just pass inputs like you would to any other Python function! - - - - Parameters: - config ([`SwitchTransformersConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -SWITCHTRANSFORMERS_INPUTS_DOCSTRING = r""" - Args: - input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. SWITCHTRANSFORMERS is a model with relative position - embeddings so you should be able to pad the inputs on the right or the left. - - Indices can be obtained using [`BertTokenizer`]. See [`PreTrainedTokenizer.__call__`] and - [`PreTrainedTokenizer.encode`] for details. - - [What are input IDs?](../glossary#input-ids) - - To know more on how to prepare `inputs` for pretraining take a look at [SWITCHTRANSFORMERS - Training](./switchtransformers#training). - decoder_input_ids (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): - Provide for sequence to sequence training. SWITCHTRANSFORMERS uses the `pad_token_id` as the starting token - for `decoder_input_ids` generation. If `past_key_values` is used, optionally only the last - `decoder_input_ids` have to be input (see `past_key_values`). - - To know more on how to prepare `decoder_input_ids` for pretraining take a look at [SWITCHTRANSFORMERS - Training](./switchtransformers#training). - attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - decoder_attention_mask (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): - Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also - be used by default. - head_mask (`tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0, - 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - decoder_head_mask (`tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0, - 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - encoder_outputs (`tuple(tuple(tf.FloatTensor)`, *optional*): - Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*) - `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at - the output of the last layer of the encoder. Used in the cross-attention of the decoder. - past_key_values (`tuple(tuple(tf.Tensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - decoder_inputs_embeds (`tf.Tensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded - representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be - input (see `past_key_values`). This is useful if you want more control over how to convert - `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. - - If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value - of `inputs_embeds`. - use_cache (`bool`, *optional*, defaults to `True`): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the - config will be used instead. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. This argument can be used only in eager mode, in graph mode the value in the config will be - used instead. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in - eager mode, in graph mode the value will always be set to True. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - -SWITCHTRANSFORMERS_ENCODER_INPUTS_DOCSTRING = r""" - Args: - inputs (`tf.Tensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. SWITCHTRANSFORMERS is a model with relative position - embeddings so you should be able to pad the inputs on the right or the left. - - Indices can be obtained using [`SwitchTransformersTokenizer`]. See [`PreTrainedTokenizer.__call__`] and - [`PreTrainedTokenizer.encode`] for details. - - To know more on how to prepare `inputs` for pre-training take a look at [SWITCHTRANSFORMERS - Training](./switchtransformers#training). - attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - head_mask (`tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - training (`bool`, *optional*, defaults to `False`): - Whether or not to use the model in training mode (some modules like dropout modules have different - behaviors between training and evaluation). -""" - -_HEAD_MASK_WARNING_MSG = """ -The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently, -`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions. -If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = tf.ones((num_layers, -num_heads))`. -""" - - -@add_start_docstrings( - "The bare SWITCHTRANSFORMERS Model transformer outputting raw hidden-stateswithout any specific head on top.", - SWITCHTRANSFORMERS_START_DOCSTRING, -) -class TFSwitchTransformersModel(TFSwitchTransformersPreTrainedModel): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.shared = TFSharedEmbeddings( - config.vocab_size, config.d_model, name="shared", initializer_range=self.config.initializer_factor - ) - - # retrieve correct absolute scope for embed token wrapper - with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name: - pass - # Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope. - embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name) - - encoder_config = copy.deepcopy(config) - encoder_config.use_cache = False - self.encoder = TFSwitchTransformersMainLayer(encoder_config, embed_tokens, name="encoder") - - decoder_config = copy.deepcopy(config) - decoder_config.is_decoder = True - decoder_config.num_layers = config.num_decoder_layers - self.decoder = TFSwitchTransformersMainLayer(decoder_config, embed_tokens, name="decoder") - - def get_encoder(self): - return self.encoder - - def get_decoder(self): - return self.decoder - - @unpack_inputs - @add_start_docstrings_to_model_forward(SWITCHTRANSFORMERS_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFSeq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) - def call( - self, - input_ids: Optional[TFModelInputType] = None, - attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, - decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, - decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, - head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, - decoder_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, - encoder_outputs: Optional[Union[np.ndarray, tf.Tensor]] = None, - past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, - inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, - decoder_inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - training: Optional[bool] = False, - ) -> Union[Tuple, TFSeq2SeqModelOutput]: - r""" - Returns: - - Examples: - - ```python - >>> from transformers import SwitchTransformersTokenizer, TFSwitchTransformersModel - - >>> tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switchtransformers-base") - >>> model = TFSwitchTransformersModel.from_pretrained("ybelkada/switchtransformers-base") - - >>> input_ids = tokenizer( - ... "Studies have been shown that owning a dog is good for you", return_tensors="tf" - ... ).input_ids # Batch size 1 - >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="tf").input_ids # Batch size 1 - - >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for SwitchTransformersModel. - >>> # This is not needed for torch's SwitchTransformersForConditionalGeneration as it does this internally using labels arg. - >>> decoder_input_ids = model._shift_right(decoder_input_ids) - - >>> # forward pass - >>> outputs = model(input_ids, decoder_input_ids=decoder_input_ids) - >>> last_hidden_states = outputs.last_hidden_state - ```""" - # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask - if head_mask is not None and decoder_head_mask is None: - warnings.warn(_HEAD_MASK_WARNING_MSG, FutureWarning) - decoder_head_mask = head_mask - - # Encode if needed (training, first prediction pass) - if encoder_outputs is None: - encoder_outputs = self.encoder( - input_ids, - attention_mask=attention_mask, - encoder_hidden_states=None, - encoder_attention_mask=None, - inputs_embeds=inputs_embeds, - head_mask=head_mask, - past_key_values=None, - use_cache=False, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - hidden_states = encoder_outputs[0] - - # Decode - decoder_outputs = self.decoder( - decoder_input_ids, - attention_mask=decoder_attention_mask, - encoder_hidden_states=hidden_states, - encoder_attention_mask=attention_mask, - inputs_embeds=decoder_inputs_embeds, - head_mask=decoder_head_mask, - encoder_head_mask=head_mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - past = decoder_outputs[1] if use_cache else None - - if not return_dict: - if past is not None: - decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:] - return decoder_outputs + encoder_outputs - - return TFSeq2SeqModelOutput( - last_hidden_state=decoder_outputs.last_hidden_state, - past_key_values=past, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ) - - def serving_output(self, output): - pkv = tf.convert_to_tensor(output.past_key_values[1:]) if self.config.use_cache else None - dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None - dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None - cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None - enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None - enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None - - return TFSeq2SeqModelOutput( - last_hidden_state=output.last_hidden_state, - past_key_values=pkv, - decoder_hidden_states=dec_hs, - decoder_attentions=dec_attns, - encoder_last_hidden_state=output.encoder_last_hidden_state, - cross_attentions=cross_attns, - encoder_hidden_states=enc_hs, - encoder_attentions=enc_attns, - ) - - -@add_start_docstrings( - """SWITCHTRANSFORMERS Model with a `language modeling` head on top.""", SWITCHTRANSFORMERS_START_DOCSTRING -) -class TFSwitchTransformersForConditionalGeneration(TFSwitchTransformersPreTrainedModel, TFCausalLanguageModelingLoss): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.model_dim = config.d_model - self.shared = TFSharedEmbeddings( - config.vocab_size, config.d_model, name="shared", initializer_range=self.config.initializer_factor - ) - - # retrieve correct absolute scope for embed token wrapper - with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name: - pass - # Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope. - embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name) - - encoder_config = copy.deepcopy(config) - encoder_config.use_cache = False - self.encoder = TFSwitchTransformersMainLayer(encoder_config, embed_tokens, name="encoder") - - decoder_config = copy.deepcopy(config) - decoder_config.is_decoder = True - decoder_config.num_layers = config.num_decoder_layers - self.decoder = TFSwitchTransformersMainLayer(decoder_config, embed_tokens, name="decoder") - - if not config.tie_word_embeddings: - lm_head_initializer = tf.keras.initializers.RandomNormal(mean=0, stddev=config.initializer_factor) - self.lm_head = tf.keras.layers.Dense( - config.vocab_size, use_bias=False, name="lm_head", kernel_initializer=lm_head_initializer - ) # Update init weights as in flax - - def get_output_embeddings(self): - if self.config.tie_word_embeddings: - return self.get_input_embeddings() - else: - # in a dense layer the kernel has a shape (last_dim, units), for us (dim, num_tokens) - # value has a shape (num_tokens, dim) then needs to be transposed - return tf.transpose(self.lm_head.kernel) - - def set_output_embeddings(self, value): - if self.config.tie_word_embeddings: - self.set_input_embeddings(value) - else: - lm_head_initializer = tf.keras.initializers.RandomNormal(mean=0, stddev=self.config.initializer_factor) - self.lm_head = tf.keras.layers.Dense( - shape_list(value)[0], use_bias=False, name="lm_head", kernel_initializer=lm_head_initializer - ) # Update init weights as in flax - # in a dense layer the kernel has a shape (last_dim, units), for us (dim, num_tokens) - # value has a shape (num_tokens, dim) then needs to be transposed - transposed_value = tf.transpose(value) - self.lm_head.kernel = transposed_value - - def get_encoder(self): - return self.encoder - - def get_decoder(self): - return self.decoder - - @unpack_inputs - @add_start_docstrings_to_model_forward(SWITCHTRANSFORMERS_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) - def call( - self, - input_ids: Optional[TFModelInputType] = None, - attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, - decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, - decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, - head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, - decoder_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, - encoder_outputs: Optional[Union[np.ndarray, tf.Tensor]] = None, - past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, - inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, - decoder_inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, - labels: Optional[Union[np.ndarray, tf.Tensor]] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - training: Optional[bool] = False, - ) -> Union[Tuple, TFSeq2SeqLMOutput]: - r""" - labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., - config.vocab_size - 1]`. - - Returns: - - Examples: - - ```python - >>> from transformers import SwitchTransformersTokenizer, TFSwitchTransformersForConditionalGeneration - - >>> tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switchtransformers-base") - >>> model = TFSwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switchtransformers-base") - - >>> # training - >>> inputs = tokenizer("The walks in park", return_tensors="tf").input_ids - >>> labels = tokenizer(" cute dog the ", return_tensors="tf").input_ids - >>> outputs = model(inputs, labels=labels) - >>> loss = outputs.loss - >>> logits = outputs.logits - - >>> # inference - >>> inputs = tokenizer( - ... "summarize: studies have shown that owning a dog is good for you", return_tensors="tf" - ... ).input_ids # Batch size 1 - >>> outputs = model.generate(inputs) - >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) - >>> # studies have shown that owning a dog is good for you - ```""" - # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask - if head_mask is not None and decoder_head_mask is None: - warnings.warn(_HEAD_MASK_WARNING_MSG, FutureWarning) - decoder_head_mask = head_mask - - # Encode if needed (training, first prediction pass) - if encoder_outputs is None: - encoder_outputs = self.encoder( - input_ids, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - hidden_states = encoder_outputs[0] - - if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: - # get decoder inputs from shifting lm labels to the right - decoder_input_ids = self._shift_right(labels) - - # Decode - decoder_outputs = self.decoder( - decoder_input_ids, - attention_mask=decoder_attention_mask, - encoder_hidden_states=hidden_states, - encoder_attention_mask=attention_mask, - inputs_embeds=decoder_inputs_embeds, - head_mask=decoder_head_mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - sequence_output = decoder_outputs[0] - - # SwitchTransformersv1.1 does not tie output word embeddings and thus does not require downscaling - if self.config.tie_word_embeddings: - sequence_output = sequence_output * (self.model_dim**-0.5) - logits = self.shared(sequence_output, mode="linear") - else: - logits = self.lm_head(sequence_output) - - logits = tf.cast(logits, tf.float32) - - loss = None if labels is None else self.hf_compute_loss(labels, logits) - - past = decoder_outputs[1] if use_cache else None - if not return_dict: - if past is not None: - decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:] - output = (logits,) + decoder_outputs[1:] + encoder_outputs - return ((loss,) + output) if loss is not None else output - - # If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True - elif isinstance(encoder_outputs, tuple): - last_hidden_state = encoder_outputs[0] - hidden_states = None - attentions = None - idx = 0 - if output_hidden_states: - idx += 1 - hidden_states = encoder_outputs[idx] - if output_attentions: - idx += 1 - attentions = encoder_outputs[idx] - - encoder_outputs = TFBaseModelOutput( - last_hidden_state=last_hidden_state, - hidden_states=hidden_states, - attentions=attentions, - ) - - return TFSeq2SeqLMOutput( - loss=loss, - logits=logits, - past_key_values=past, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ) - - def serving_output(self, output): - pkv = tf.convert_to_tensor(output.past_key_values[1:]) if self.config.use_cache else None - dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None - dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None - cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None - enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None - enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None - - return TFSeq2SeqLMOutput( - logits=output.logits, - past_key_values=pkv, - decoder_hidden_states=dec_hs, - decoder_attentions=dec_attns, - cross_attentions=cross_attns, - encoder_last_hidden_state=output.encoder_last_hidden_state, - encoder_hidden_states=enc_hs, - encoder_attentions=enc_attns, - ) - - def prepare_inputs_for_generation( - self, - input_ids, - past=None, - attention_mask=None, - decoder_attention_mask=None, - head_mask=None, - decoder_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs - ): - - # cut decoder_input_ids if past is used - if past is not None: - input_ids = input_ids[:, -1:] - - return { - "input_ids": None, # needs to be passed to make Keras.layer.__call__ happy - "decoder_input_ids": input_ids, - "past_key_values": past, - "encoder_outputs": encoder_outputs, - "attention_mask": attention_mask, - "decoder_attention_mask": decoder_attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "use_cache": use_cache, - } - - def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor): - return self._shift_right(labels) - - def _reorder_cache(self, past, beam_idx): - # if decoder past is not included in output - # speedy decoding is disabled and no need to reorder - if past is None: - logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") - return past - - reordered_decoder_past = () - for layer_past_states in past: - # get the correct batch idx from layer past batch dim - # batch dim of `past` is at 2nd position - reordered_layer_past_states = () - for layer_past_state in layer_past_states: - # need to set correct `past` for each of the four key / value states - reordered_layer_past_states = reordered_layer_past_states + ( - tf.gather(layer_past_state, beam_idx, axis=0), - ) - - assert reordered_layer_past_states[0].shape == layer_past_states[0].shape - assert len(reordered_layer_past_states) == len(layer_past_states) - - reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) - return reordered_decoder_past - - -@add_start_docstrings( - "The bare SWITCHTRANSFORMERS Model transformer outputting encoder's raw hidden-stateswithout any specific head on" - " top.", - SWITCHTRANSFORMERS_START_DOCSTRING, -) -class TFSwitchTransformersEncoderModel(TFSwitchTransformersPreTrainedModel): - def __init__(self, config, *inputs, **kwargs): - super().__init__(config, *inputs, **kwargs) - self.shared = TFSharedEmbeddings( - config.vocab_size, config.d_model, name="shared", initializer_range=self.config.initializer_factor - ) - - # retrieve correct absolute scope for embed token wrapper - with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name: - pass - # Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope. - embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name) - - encoder_config = copy.deepcopy(config) - encoder_config.use_cache = False - self.encoder = TFSwitchTransformersMainLayer(encoder_config, embed_tokens, name="encoder") - - @property - def dummy_inputs(self): - return {"input_ids": tf.constant(DUMMY_INPUTS)} - - def get_encoder(self): - return self.encoder - - @unpack_inputs - @add_start_docstrings_to_model_forward(SWITCHTRANSFORMERS_ENCODER_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFBaseModelOutput, config_class=_CONFIG_FOR_DOC) - def call( - self, - input_ids: Optional[TFModelInputType] = None, - attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, - head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, - inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - training: Optional[bool] = False, - ) -> Union[Tuple, TFBaseModelOutput]: - r""" - Returns: - - Examples: - - ```python - >>> from transformers import SwitchTransformersTokenizer, TFSwitchTransformersEncoderModel - - >>> tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switchtransformers-base") - >>> model = TFSwitchTransformersEncoderModel.from_pretrained("ybelkada/switchtransformers-base") - - >>> input_ids = tokenizer( - ... "Studies have been shown that owning a dog is good for you", return_tensors="tf" - ... ).input_ids # Batch size 1 - >>> outputs = model(input_ids) - ```""" - - encoder_outputs = self.encoder( - input_ids, - attention_mask=attention_mask, - encoder_hidden_states=None, - encoder_attention_mask=None, - inputs_embeds=inputs_embeds, - head_mask=head_mask, - past_key_values=None, - use_cache=False, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - ) - - if not return_dict: - return encoder_outputs - - return TFBaseModelOutput( - last_hidden_state=encoder_outputs.last_hidden_state, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - @tf.function( - input_signature=[ - { - "input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"), - "attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"), - } - ] - ) - def serving(self, inputs): - output = self.call(inputs) - - return self.serving_output(output) - - def serving_output(self, output): - hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None - attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None - - return TFBaseModelOutput(last_hidden_state=output.last_hidden_state, hidden_states=hs, attentions=attns) diff --git a/src/transformers/models/switchtransformers/router.py b/src/transformers/models/switchtransformers/router.py index ebddebc6bc2f8..23ca739b3f211 100644 --- a/src/transformers/models/switchtransformers/router.py +++ b/src/transformers/models/switchtransformers/router.py @@ -18,10 +18,12 @@ import torch import torch.nn as nn + # Output classes RouterOutput = Any + @dataclass class RouterIndices: r""" @@ -40,7 +42,8 @@ class RouterIndices: dispatch_indices: torch.Tensor combine_weights: torch.Tensor auxiliary_loss: float - router_z_loss: float = 0. + router_z_loss: float = 0.0 + @dataclass class RouterMask: @@ -60,10 +63,12 @@ class RouterMask: dispatch_mask: torch.Tensor combine_array: torch.Tensor auxiliary_loss: float - router_z_loss: float = 0. + router_z_loss: float = 0.0 + # Router loss + def _router_z_loss(router_logits: torch.Tensor) -> float: r""" Compute router z-loss implemented in PyTorch. @@ -107,7 +112,7 @@ def _load_balancing_loss(router_probs: torch.Tensor, expert_indices: torch.Tenso if expert_indices.dtype != torch.int64: expert_indices = expert_indices.to(torch.int64) expert_mask = torch.nn.functional.one_hot(expert_indices, num_experts) - + # For a given token, determine if it was routed to a given expert. # Shape: [num_groups, tokens_per_group, num_experts] expert_mask = torch.max(expert_mask, axis=-2).values @@ -115,12 +120,14 @@ def _load_balancing_loss(router_probs: torch.Tensor, expert_indices: torch.Tenso # cast to float32 otherwise mean will fail expert_mask = expert_mask.to(torch.float32) tokens_per_group_and_expert = torch.mean(expert_mask, axis=-2) - + router_prob_per_group_and_expert = torch.mean(router_probs, axis=-2) return torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert) * (num_experts**2) + # Router classes + class Router(nn.Module): """ Abstract base router class, defining router API and inner workings. @@ -143,8 +150,10 @@ def __init__(self, config, **kwargs): self.router_weights = nn.Linear(config.hidden_size, self.num_experts, bias=config.router_bias) self.jitter_noise = config.router_jitter_noise self.ignore_padding_tokens = config.router_ignore_padding_tokens - - def _compute_router_probabilities(self, token_inputs: torch.Tensor, num_experts: int, apply_jitter: bool) -> Tuple[torch.Tensor, torch.Tensor]: + + def _compute_router_probabilities( + self, token_inputs: torch.Tensor, num_experts: int, apply_jitter: bool + ) -> Tuple[torch.Tensor, torch.Tensor]: r""" Computes router probabilities from input tokens. @@ -167,10 +176,8 @@ def _compute_router_probabilities(self, token_inputs: torch.Tensor, num_experts: if apply_jitter and self.jitter_noise > 0: token_inputs *= torch.random.uniform( - token_inputs.shape, - token_inputs.dtype, - minval=1.0 - self.jitter_noise, - maxval=1.0 + self.jitter_noise) + token_inputs.shape, token_inputs.dtype, minval=1.0 - self.jitter_noise, maxval=1.0 + self.jitter_noise + ) # Shape: [num_groups, tokens_per_group, num_experts] router_logits = self.router_weights(token_inputs, num_experts) @@ -179,37 +186,32 @@ def _compute_router_probabilities(self, token_inputs: torch.Tensor, num_experts: return router_probabilities, router_logits - def forward(self, token_inputs: torch.Tensor, expert_capacity: int, apply_jitter: bool = True) -> RouterOutput: r""" Args: Computes dispatch and combine torch.Tensors for routing to experts. - token_inputs: [num_groups, tokens_per_group, hidden_dim] inputs to - send to experts. - num_experts: Number of experts. - expert_capacity: Each group will send this many tokens to each expert. - apply_jitter: If true, apply jitter noise during routing. + token_inputs: [num_groups, tokens_per_group, hidden_dim] inputs to send to experts. num_experts: + Number of experts. expert_capacity: Each group will send this many tokens to each expert. apply_jitter: If + true, apply jitter noise during routing. Returns: Router indices or mask torch.Tensors (depending on router type). """ router_probs, router_logits = self._compute_router_probabilities(token_inputs, self.num_experts, apply_jitter) - - if self.ignore_padding_tokens: - # To identify non-padding tokens, we rely on the fact that padding tokens - # in the inputs have already been masked in the default T5 architecture. - # See - # https://github.com/google/flaxformer/blob/9712a16/flaxformer/architectures/t5/t5_architecture.py#L315 - # and - # https://github.com/google/flaxformer/blob/9712a16/flaxformer/architectures/t5/t5_architecture.py#L603. - padding_mask = jnp.torch.Tensor((jnp.sum(jnp.abs(token_inputs), axis=-1) > 0), - dtype=token_inputs.dtype) - router_logits *= jnp.expand_dims(padding_mask, axis=-1) - else: - padding_mask = None - - instructions = self._compute_routing_instructions(router_probs, - padding_mask, - expert_capacity) - - return instructions.replace(router_z_loss=_router_z_loss(router_logits)) \ No newline at end of file + # Flax code for reference + # if self.ignore_padding_tokens: + # # To identify non-padding tokens, we rely on the fact that padding tokens + # # in the inputs have already been masked in the default T5 architecture. + # # See + # # https://github.com/google/flaxformer/blob/9712a16/flaxformer/architectures/t5/t5_architecture.py#L315 + # # and + # # https://github.com/google/flaxformer/blob/9712a16/flaxformer/architectures/t5/t5_architecture.py#L603. + # padding_mask = jnp.torch.Tensor((jnp.sum(jnp.abs(token_inputs), axis=-1) > 0), dtype=token_inputs.dtype) + # router_logits *= jnp.expand_dims(padding_mask, axis=-1) + # else: + # padding_mask = None + + # instructions = self._compute_routing_instructions(router_probs, padding_mask, expert_capacity) + + # return instructions.replace(router_z_loss=_router_z_loss(router_logits)) + pass diff --git a/src/transformers/models/switchtransformers/router_flax.py b/src/transformers/models/switchtransformers/router_flax.py index fbe446a1319cc..d1060279716e0 100644 --- a/src/transformers/models/switchtransformers/router_flax.py +++ b/src/transformers/models/switchtransformers/router_flax.py @@ -17,10 +17,11 @@ from typing import Any, Iterable, Optional, Sequence, Tuple, Union import flax -from flax import linen as nn -from flax.linen import partitioning as flax_partitioning import jax import jax.numpy as jnp +from flax import linen as nn +from flax.linen import partitioning as flax_partitioning + # from flaxformer.components import dense # from flaxformer.types import Array @@ -42,708 +43,680 @@ @flax.struct.dataclass class RouterIndices: - """Dispatch indices and combine weights for scatter/gather-based routing. - - Attributes: - dispatch_indices: [num_groups, tokens_per_group, - num_selected_experts, 2] dispatch indices indicating, for each token, its preferred expert and its priority in - that expert's buffer. - combine_weights: [num_groups, tokens_per_group, num_selected_experts] - combine weights used for scaling expert outputs with the router's dispatch probability/confidence. - auxiliary_loss: Load balancing loss for router. - router_z_loss: Router z-loss. Encourages router logits to remain small in an - effort to improve stability. - """ - dispatch_indices: Array - combine_weights: Array - auxiliary_loss: float - router_z_loss: float = 0. + """Dispatch indices and combine weights for scatter/gather-based routing. + + Attributes: + dispatch_indices: [num_groups, tokens_per_group, + num_selected_experts, 2] dispatch indices indicating, for each token, its preferred expert and its priority in + that expert's buffer. + combine_weights: [num_groups, tokens_per_group, num_selected_experts] + combine weights used for scaling expert outputs with the router's dispatch probability/confidence. + auxiliary_loss: Load balancing loss for router. + router_z_loss: Router z-loss. Encourages router logits to remain small in an + effort to improve stability. + """ + + dispatch_indices: Array + combine_weights: Array + auxiliary_loss: float + router_z_loss: float = 0.0 @flax.struct.dataclass class RouterMask: - """Dispatch and combine arrays for expert routing with masked matmuls. - - Attributes: - dispatch_mask: [num_groups, tokens_per_group, num_experts, - expert_capacity] dispatch array that is 1 if the token gets routed to the corresponding expert, and 0 otherwise. - combine_array: [num_groups, tokens_per_group, num_experts, - expert_capacity] combine array used for combining expert outputs and scaling with router probability. - auxiliary_loss: Load balancing loss for router. - router_z_loss: Router z-loss. Encourages router logits to remain small in an - effort to improve stability. - """ - dispatch_mask: Array - combine_array: Array - auxiliary_loss: float - router_z_loss: float = 0. + """Dispatch and combine arrays for expert routing with masked matmuls. + + Attributes: + dispatch_mask: [num_groups, tokens_per_group, num_experts, + expert_capacity] dispatch array that is 1 if the token gets routed to the corresponding expert, and 0 + otherwise. + combine_array: [num_groups, tokens_per_group, num_experts, + expert_capacity] combine array used for combining expert outputs and scaling with router probability. + auxiliary_loss: Load balancing loss for router. + router_z_loss: Router z-loss. Encourages router logits to remain small in an + effort to improve stability. + """ + + dispatch_mask: Array + combine_array: Array + auxiliary_loss: float + router_z_loss: float = 0.0 def _favor_one_hot_slices() -> bool: - """Returns true iff running on TPUs.""" - return jax.default_backend() == 'tpu' or jax.devices()[0].platform == 'tpu' + """Returns true iff running on TPUs.""" + return jax.default_backend() == "tpu" or jax.devices()[0].platform == "tpu" def _take_along_axis(array: Array, indices: Array, axis: int) -> Array: - """Takes values from the input array by matching 1D index and data slices. - - This function serves the same purpose as jax.numpy.take_along_axis, except that it uses one-hot matrix - multiplications under the hood on TPUs: (1) On TPUs, we use one-hot matrix multiplications to select elements from - the - array; this is particularly helpful for avoiding erroneous all-gather ops when running under pjit. - (2) Otherwise, we fall back to jax.numpy.take_along_axis. - - Notes: - - To simplify matters in case (1), we only support slices along the second or last dimensions. - - We may wish to revisit (1) for very large arrays. - - Args: - array: Source array. - indices: Indices to take along each 1D slice of array. - axis: Axis along which to take 1D slices. - - Returns: - The indexed result. - """ - if array.ndim != indices.ndim: - raise ValueError( - 'indices and array must have the same number of dimensions; ' - f'{indices.ndim} vs. {array.ndim}.') - - if (axis != -1 and axis != array.ndim - 1 and # Not last dimension - axis != 1 and axis != -array.ndim + 1): # Not second dimension - raise ValueError( - 'Only slices along the second or last dimension are supported; ' - f'array.ndim = {array.ndim}, while axis = {axis}.') - - if _favor_one_hot_slices(): - one_hot_length = array.shape[axis] - one_hot_indices = jax.nn.one_hot(indices, one_hot_length, axis=axis) - - if axis == -1 or array.ndim == 1: - # Take i elements from last dimension (s). - # We must use HIGHEST precision to accurately reproduce indexing - # operations with matrix multiplications. - result = jnp.einsum( - '...s,...is->...i', - array, - one_hot_indices, - precision=jax.lax.Precision.HIGHEST) - else: - # Take i elements from second dimension (s). We assume here that we always - # want to slice along the second dimension. - # We must use HIGHEST precision to accurately reproduce indexing - # operations with matrix multiplications. - result = jnp.einsum( - 'ns...,nis...->ni...', - array, - one_hot_indices, - precision=jax.lax.Precision.HIGHEST) - return jax.lax.convert_element_type(result, array.dtype) - else: - return jnp.take_along_axis(array, indices, axis=axis) + """Takes values from the input array by matching 1D index and data slices. + This function serves the same purpose as jax.numpy.take_along_axis, except that it uses one-hot matrix + multiplications under the hood on TPUs: (1) On TPUs, we use one-hot matrix multiplications to select elements from + the + array; this is particularly helpful for avoiding erroneous all-gather ops when running under pjit. + (2) Otherwise, we fall back to jax.numpy.take_along_axis. -def _top_k(array: Array, k: int) -> Tuple[Array, Array]: - """Returns top k values and their indices along the last axis of the array. - - This function serves the same purpose as jax.lax.top_k, but in a more XLA friendly manner for TPUs: (1) On TPUs, we - use one-hot matrix multiplications to select the top k values. - This convoluted way of obtaining the top k values is generally faster on TPUs, and, for pjit in particular, - avoids adding extra all-gather ops during backpropagation. - (2) Otherwise, we fall back to jax.lax.top_k (and its underlying scatter op). - - Args: - array: Source array. - k: Number of top values to select. - - Returns: - - Top k values - - Associated top k indices. - """ - if _favor_one_hot_slices(): - top_k_indices = jax.lax.top_k(array, k)[-1] - top_k_values = _take_along_axis(array, top_k_indices, axis=-1) - return top_k_values, top_k_indices - else: - return jax.lax.top_k(array, k) - - -class RouterWeights(nn.Module): - """Router module converting token inputs to router logits. - - Attributes: - use_bias: Whether or not to use the bias term in computing the logits. - dtype: Numerical float type for router logit computation. - kernel_init: Initialization scheme for kernel. - bias_init: Initialization scheme for bias. - precision: XLA precision for array computations. - axis: Axes along which to apply the dense router weights transformation. - Defaults to final axis (typically the "hidden dimension"). - kernel_axis_names: Logical axis names to use for kernel sharding. - reshape_kernel: Whether to reshape the kernel parameter to 2D for Adafactor. - """ - use_bias: bool = True - dtype: DType = jnp.bfloat16 - kernel_init: Initializer = default_kernel_init - bias_init: Initializer = default_bias_init - precision: jax.lax.Precision = jax.lax.Precision.DEFAULT - axis: Union[Iterable[int], int] = -1 - kernel_axis_names: Sequence[str] = ('embed', 'unmodeled') - reshape_kernel: bool = True - - @nn.compact - def __call__(self, token_inputs: Array, num_experts: int) -> Array: - """Applies RouterWeights module. + Notes: + - To simplify matters in case (1), we only support slices along the second or last dimensions. + - We may wish to revisit (1) for very large arrays. Args: - token_inputs: Flattened batch of tokens with shape [num_groups, - group_size, hidden_dim]. - num_experts: Number of experts. + array: Source array. + indices: Indices to take along each 1D slice of array. + axis: Axis along which to take 1D slices. Returns: - Router logits with shape [num_groups, group_size, num_experts]. + The indexed result. """ - return dense.DenseGeneral( - features=num_experts, - axis=self.axis, - use_bias=self.use_bias, - dtype=self.dtype, - kernel_init=self.kernel_init, - bias_init=self.bias_init, - precision=self.precision, - kernel_axis_names=self.kernel_axis_names, - reshape_kernel=self.reshape_kernel, - name='w')( - token_inputs) + if array.ndim != indices.ndim: + raise ValueError( + f"indices and array must have the same number of dimensions; {indices.ndim} vs. {array.ndim}." + ) + + if ( + axis != -1 and axis != array.ndim - 1 and axis != 1 and axis != -array.ndim + 1 # Not last dimension + ): # Not second dimension + raise ValueError( + "Only slices along the second or last dimension are supported; " + f"array.ndim = {array.ndim}, while axis = {axis}." + ) + + if _favor_one_hot_slices(): + one_hot_length = array.shape[axis] + one_hot_indices = jax.nn.one_hot(indices, one_hot_length, axis=axis) + + if axis == -1 or array.ndim == 1: + # Take i elements from last dimension (s). + # We must use HIGHEST precision to accurately reproduce indexing + # operations with matrix multiplications. + result = jnp.einsum("...s,...is->...i", array, one_hot_indices, precision=jax.lax.Precision.HIGHEST) + else: + # Take i elements from second dimension (s). We assume here that we always + # want to slice along the second dimension. + # We must use HIGHEST precision to accurately reproduce indexing + # operations with matrix multiplications. + result = jnp.einsum("ns...,nis...->ni...", array, one_hot_indices, precision=jax.lax.Precision.HIGHEST) + return jax.lax.convert_element_type(result, array.dtype) + else: + return jnp.take_along_axis(array, indices, axis=axis) -class Router(nn.Module): - """Abstract base router class, defining router API and inner workings. - - Attributes: - router_weights: Configurable module used to compute router logits from token - inputs. - jitter_noise: Amplitude of jitter noise applied to router logits. - dtype: Numeric float type for returned combine array. All actual - computations are performed in float32 of the input for stability. - ignore_padding_tokens: Whether to ignore padding tokens during routing. Note - that some routers (e.g. TokensChooseMaskedRouter) will completely ignore padding tokens, while others (e.g. - TokensChooseScatterRouter and ExpertsChooseMaskedRouter) will simply down-weight the probability of selecting - padding tokens. - """ - router_weights: RouterWeights - jitter_noise: float - dtype: jnp.dtype - ignore_padding_tokens: bool - - def __call__(self, - token_inputs: Array, - num_experts: int, - expert_capacity: int, - apply_jitter: bool = True) -> RouterOutput: - """Computes dispatch and combine arrays for routing to experts. +def _top_k(array: Array, k: int) -> Tuple[Array, Array]: + """Returns top k values and their indices along the last axis of the array. + + This function serves the same purpose as jax.lax.top_k, but in a more XLA friendly manner for TPUs: (1) On TPUs, we + use one-hot matrix multiplications to select the top k values. + This convoluted way of obtaining the top k values is generally faster on TPUs, and, for pjit in particular, + avoids adding extra all-gather ops during backpropagation. + (2) Otherwise, we fall back to jax.lax.top_k (and its underlying scatter op). Args: - token_inputs: [num_groups, tokens_per_group, hidden_dim] inputs to - send to experts. - num_experts: Number of experts. - expert_capacity: Each group will send this many tokens to each expert. - apply_jitter: If true, apply jitter noise during routing. + array: Source array. + k: Number of top values to select. Returns: - Router indices or mask arrays (depending on router type). + - Top k values + - Associated top k indices. """ - token_inputs = flax_partitioning.with_sharding_constraint( - token_inputs, ('batch', 'length', 'embed')) - router_probs, router_logits = self._compute_router_probabilities( - token_inputs, num_experts, apply_jitter) - router_probs = flax_partitioning.with_sharding_constraint( - router_probs, ('batch', 'length', 'unmodeled')) - router_logits = flax_partitioning.with_sharding_constraint( - router_logits, ('batch', 'length', 'unmodeled')) - - if self.ignore_padding_tokens: - # To identify non-padding tokens, we rely on the fact that padding tokens - # in the inputs have already been masked in the default T5 architecture. - # See - # https://github.com/google/flaxformer/blob/9712a16/flaxformer/architectures/t5/t5_architecture.py#L315 - # and - # https://github.com/google/flaxformer/blob/9712a16/flaxformer/architectures/t5/t5_architecture.py#L603. - padding_mask = jnp.array((jnp.sum(jnp.abs(token_inputs), axis=-1) > 0), - dtype=token_inputs.dtype) - router_logits *= jnp.expand_dims(padding_mask, axis=-1) + if _favor_one_hot_slices(): + top_k_indices = jax.lax.top_k(array, k)[-1] + top_k_values = _take_along_axis(array, top_k_indices, axis=-1) + return top_k_values, top_k_indices else: - padding_mask = None - - instructions = self._compute_routing_instructions(router_probs, - padding_mask, - expert_capacity) + return jax.lax.top_k(array, k) - return instructions.replace(router_z_loss=_router_z_loss(router_logits)) - def _compute_router_probabilities(self, token_inputs: Array, num_experts: int, - apply_jitter: bool) -> Tuple[Array, Array]: - """Computes router probabilities from input tokens. - - Args: - token_inputs: [num_groups, tokens_per_group, hidden_dim] from which - router probabilities are computed. - num_experts: Number of experts. - apply_jitter: If true, apply jitter noise. - - Returns: - - [num_groups, tokens_per_group, num_experts] probabilities for each token and expert. Used for routing - tokens to experts. - - [num_groups, tokens_per_group, num_experts] raw router logits. Used for computing router z-loss. +class RouterWeights(nn.Module): + """Router module converting token inputs to router logits. + + Attributes: + use_bias: Whether or not to use the bias term in computing the logits. + dtype: Numerical float type for router logit computation. + kernel_init: Initialization scheme for kernel. + bias_init: Initialization scheme for bias. + precision: XLA precision for array computations. + axis: Axes along which to apply the dense router weights transformation. + Defaults to final axis (typically the "hidden dimension"). + kernel_axis_names: Logical axis names to use for kernel sharding. + reshape_kernel: Whether to reshape the kernel parameter to 2D for Adafactor. """ - # For remainder of routing computation we use float32 to ensure stability. - # See the discussion of "selective precision" in - # https://arxiv.org/abs/2101.03961. - token_inputs = jax.lax.convert_element_type(token_inputs, jnp.float32) - - if apply_jitter and self.jitter_noise > 0: - token_inputs *= jax.random.uniform( - self.make_rng('jitter'), - token_inputs.shape, - token_inputs.dtype, - minval=1.0 - self.jitter_noise, - maxval=1.0 + self.jitter_noise) - # Shape: [num_groups, tokens_per_group, num_experts] - router_logits = self.router_weights(token_inputs, num_experts) + use_bias: bool = True + dtype: DType = jnp.bfloat16 + kernel_init: Initializer = default_kernel_init + bias_init: Initializer = default_bias_init + precision: jax.lax.Precision = jax.lax.Precision.DEFAULT + axis: Union[Iterable[int], int] = -1 + kernel_axis_names: Sequence[str] = ("embed", "unmodeled") + reshape_kernel: bool = True + + @nn.compact + def __call__(self, token_inputs: Array, num_experts: int) -> Array: + """Applies RouterWeights module. + + Args: + token_inputs: Flattened batch of tokens with shape [num_groups, + group_size, hidden_dim]. + num_experts: Number of experts. + + Returns: + Router logits with shape [num_groups, group_size, num_experts]. + """ + # Flax code for reference + # return dense.DenseGeneral( + # features=num_experts, + # axis=self.axis, + # use_bias=self.use_bias, + # dtype=self.dtype, + # kernel_init=self.kernel_init, + # bias_init=self.bias_init, + # precision=self.precision, + # kernel_axis_names=self.kernel_axis_names, + # reshape_kernel=self.reshape_kernel, + # name="w", + # )(token_inputs) + pass - router_probabilities = jax.nn.softmax(router_logits, axis=-1) - return router_probabilities, router_logits +class Router(nn.Module): + """Abstract base router class, defining router API and inner workings. + + Attributes: + router_weights: Configurable module used to compute router logits from token + inputs. + jitter_noise: Amplitude of jitter noise applied to router logits. + dtype: Numeric float type for returned combine array. All actual + computations are performed in float32 of the input for stability. + ignore_padding_tokens: Whether to ignore padding tokens during routing. Note + that some routers (e.g. TokensChooseMaskedRouter) will completely ignore padding tokens, while others (e.g. + TokensChooseScatterRouter and ExpertsChooseMaskedRouter) will simply down-weight the probability of selecting + padding tokens. + """ - def _compute_routing_instructions(self, router_probs: Array, - padding_mask: Optional[Array], - expert_capacity: int) -> RouterOutput: - """Computes instructions for routing inputs to experts.""" - raise NotImplementedError( - 'Router is an abstract class that should be subclassed.') + router_weights: RouterWeights + jitter_noise: float + dtype: jnp.dtype + ignore_padding_tokens: bool + + def __call__( + self, token_inputs: Array, num_experts: int, expert_capacity: int, apply_jitter: bool = True + ) -> RouterOutput: + """Computes dispatch and combine arrays for routing to experts. + + Args: + token_inputs: [num_groups, tokens_per_group, hidden_dim] inputs to + send to experts. + num_experts: Number of experts. + expert_capacity: Each group will send this many tokens to each expert. + apply_jitter: If true, apply jitter noise during routing. + + Returns: + Router indices or mask arrays (depending on router type). + """ + token_inputs = flax_partitioning.with_sharding_constraint(token_inputs, ("batch", "length", "embed")) + router_probs, router_logits = self._compute_router_probabilities(token_inputs, num_experts, apply_jitter) + router_probs = flax_partitioning.with_sharding_constraint(router_probs, ("batch", "length", "unmodeled")) + router_logits = flax_partitioning.with_sharding_constraint(router_logits, ("batch", "length", "unmodeled")) + + if self.ignore_padding_tokens: + # To identify non-padding tokens, we rely on the fact that padding tokens + # in the inputs have already been masked in the default T5 architecture. + # See + # https://github.com/google/flaxformer/blob/9712a16/flaxformer/architectures/t5/t5_architecture.py#L315 + # and + # https://github.com/google/flaxformer/blob/9712a16/flaxformer/architectures/t5/t5_architecture.py#L603. + padding_mask = jnp.array((jnp.sum(jnp.abs(token_inputs), axis=-1) > 0), dtype=token_inputs.dtype) + router_logits *= jnp.expand_dims(padding_mask, axis=-1) + else: + padding_mask = None + + instructions = self._compute_routing_instructions(router_probs, padding_mask, expert_capacity) + + return instructions.replace(router_z_loss=_router_z_loss(router_logits)) + + def _compute_router_probabilities( + self, token_inputs: Array, num_experts: int, apply_jitter: bool + ) -> Tuple[Array, Array]: + """Computes router probabilities from input tokens. + + Args: + token_inputs: [num_groups, tokens_per_group, hidden_dim] from which + router probabilities are computed. + num_experts: Number of experts. + apply_jitter: If true, apply jitter noise. + + Returns: + - [num_groups, tokens_per_group, num_experts] probabilities for each token and expert. Used for + routing tokens to experts. + - [num_groups, tokens_per_group, num_experts] raw router logits. Used for computing router z-loss. + """ + # For remainder of routing computation we use float32 to ensure stability. + # See the discussion of "selective precision" in + # https://arxiv.org/abs/2101.03961. + token_inputs = jax.lax.convert_element_type(token_inputs, jnp.float32) + + if apply_jitter and self.jitter_noise > 0: + token_inputs *= jax.random.uniform( + self.make_rng("jitter"), + token_inputs.shape, + token_inputs.dtype, + minval=1.0 - self.jitter_noise, + maxval=1.0 + self.jitter_noise, + ) + + # Shape: [num_groups, tokens_per_group, num_experts] + router_logits = self.router_weights(token_inputs, num_experts) + + router_probabilities = jax.nn.softmax(router_logits, axis=-1) + + return router_probabilities, router_logits + + def _compute_routing_instructions( + self, router_probs: Array, padding_mask: Optional[Array], expert_capacity: int + ) -> RouterOutput: + """Computes instructions for routing inputs to experts.""" + raise NotImplementedError("Router is an abstract class that should be subclassed.") class ScatterRouter(Router): - """Abstract base router class for scatter dispatch routers. + """Abstract base router class for scatter dispatch routers. - ScatterRouter(s) return RouterIndices containing dispatch indices and combine weights for sending token inputs (via - scatter) and receiving outputs (via gather) to and from experts. + ScatterRouter(s) return RouterIndices containing dispatch indices and combine weights for sending token inputs (via + scatter) and receiving outputs (via gather) to and from experts. - Scatter-based routing is generally faster than masked matmul routing on CPUs and GPUs. - """ + Scatter-based routing is generally faster than masked matmul routing on CPUs and GPUs. + """ - def _compute_routing_instructions(self, router_probs: Array, - padding_mask: Optional[Array], - expert_capacity: int) -> RouterIndices: - """Computes instructions for routing inputs to experts. + def _compute_routing_instructions( + self, router_probs: Array, padding_mask: Optional[Array], expert_capacity: int + ) -> RouterIndices: + """Computes instructions for routing inputs to experts. - Args: - router_probs: [num_groups, tokens_per_group, num_experts] - probabilities used to determine the routing of tokens to the experts. - padding_mask: [num_groups, tokens_per_group] padding logit mask - used to identify padding tokens that should be ignored by the router. - expert_capacity: Each group will send this many tokens to each expert. + Args: + router_probs: [num_groups, tokens_per_group, num_experts] + probabilities used to determine the routing of tokens to the experts. + padding_mask: [num_groups, tokens_per_group] padding logit mask + used to identify padding tokens that should be ignored by the router. + expert_capacity: Each group will send this many tokens to each expert. - Returns: - Router indices containing dispatch indices and combine weights. - """ - raise NotImplementedError( - 'ScatterRouter is an abstract class that should be subclassed.') + Returns: + Router indices containing dispatch indices and combine weights. + """ + raise NotImplementedError("ScatterRouter is an abstract class that should be subclassed.") class MaskedRouter(Router): - """Abstract base router class for masked matmul dispatch routers. + """Abstract base router class for masked matmul dispatch routers. - MaskedRouter(s) return RouterMask(s) containing a dispatch mask and combine array for sending and receiving (via - masked matmuls) inputs and outputs to and from experts. + MaskedRouter(s) return RouterMask(s) containing a dispatch mask and combine array for sending and receiving (via + masked matmuls) inputs and outputs to and from experts. - Routing using masked matmuls is generally faster than scatter-based routing on TPUs. - """ + Routing using masked matmuls is generally faster than scatter-based routing on TPUs. + """ - def _compute_routing_instructions(self, router_probs: Array, - padding_mask: Optional[Array], - expert_capacity: int) -> RouterMask: - """Computes masks for the top-k experts per token. + def _compute_routing_instructions( + self, router_probs: Array, padding_mask: Optional[Array], expert_capacity: int + ) -> RouterMask: + """Computes masks for the top-k experts per token. - Args: - router_probs: [num_groups, tokens_per_group, num_experts] - probabilities used to determine the routing of tokens to the experts. - padding_mask: [num_groups, tokens_per_group] padding logit mask - used to identify padding tokens that should be ignored by the router. - expert_capacity: Each group will send this many tokens to each expert. + Args: + router_probs: [num_groups, tokens_per_group, num_experts] + probabilities used to determine the routing of tokens to the experts. + padding_mask: [num_groups, tokens_per_group] padding logit mask + used to identify padding tokens that should be ignored by the router. + expert_capacity: Each group will send this many tokens to each expert. - Returns: - Router mask arrays. - """ - raise NotImplementedError( - 'MaskedRouter is an abstract class that should be subclassed.') + Returns: + Router mask arrays. + """ + raise NotImplementedError("MaskedRouter is an abstract class that should be subclassed.") class TokensChooseScatterRouter(ScatterRouter): - """Scatter router using tokens choose top-k experts assignment. - - This router uses the same mechanism as in Switch Transformer (https://arxiv.org/abs/2101.03961) and V-MoE - (https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are sorted by router_probs and then routed - to their choice of expert until the expert's expert_capacity is reached. There is no guarantee that each token is - processed by an expert, or that each expert receives at least one token. - - Attributes: - num_selected_experts: Maximum number of experts to which each token is - routed. Tokens may be routed to fewer experts if particular experts are oversubscribed / reach capacity. - batch_prioritized_routing: Whether or not to use Batch Prioritized Routing - (BPR), originally introduced in V-MoE (https://arxiv.org/abs/2106.05974). - With BPR, we prioritize routing those top-k tokens with the highest router probability, rather than simply - using each tokens left-to-right ordering in the batch. This prioritization is important because the expert's - have limited capacity. - """ - num_selected_experts: int - batch_prioritized_routing: bool - - def _compute_routing_instructions(self, router_probs: Array, - padding_mask: Optional[Array], - expert_capacity: int) -> RouterIndices: - """Computes dispatch indices and combine weights for the top-k experts. - - Args: - router_probs: [num_groups, tokens_per_group, num_experts] - probabilities used to determine the routing of tokens to the experts. - padding_mask: [num_groups, tokens_per_group] padding logit mask - used to identify padding tokens that should be down-weighted by the router. - expert_capacity: Each group will send this many tokens to each expert. - - Returns: - Dispatch indices and combine weights for scatter/gather-based routing. + """Scatter router using tokens choose top-k experts assignment. + + This router uses the same mechanism as in Switch Transformer (https://arxiv.org/abs/2101.03961) and V-MoE + (https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are sorted by router_probs and then + routed to their choice of expert until the expert's expert_capacity is reached. There is no guarantee that each + token is processed by an expert, or that each expert receives at least one token. + + Attributes: + num_selected_experts: Maximum number of experts to which each token is + routed. Tokens may be routed to fewer experts if particular experts are oversubscribed / reach capacity. + batch_prioritized_routing: Whether or not to use Batch Prioritized Routing + (BPR), originally introduced in V-MoE (https://arxiv.org/abs/2106.05974). + With BPR, we prioritize routing those top-k tokens with the highest router probability, rather than simply + using each tokens left-to-right ordering in the batch. This prioritization is important because the expert's + have limited capacity. """ - num_groups, tokens_per_group, num_experts = router_probs.shape - - if padding_mask is not None: - # Because `expert_indices` are directly used for scatter-based routing, we - # mask probabilities corresponding to tokens before the top-k operation. - # Note that, unlike for mask-based tokens-choose routing, the - # (down-weighted) padding tokens may still be selected. - router_probs *= jnp.expand_dims(padding_mask, axis=-1) - - # Top-k router probability and corresponding expert indices for each token. - # Shape: [num_groups, tokens_per_group, num_selected_experts]. - combine_weights, expert_indices = _top_k( - router_probs, k=self.num_selected_experts) - - auxiliary_loss = _load_balancing_loss(router_probs, expert_indices) - - if self.batch_prioritized_routing: - # Sort tokens according to their routing probability per token group, so - # that the highest probability tokens are routed first. - token_ordering = jnp.argsort(-combine_weights[..., 0], axis=-1) - expert_indices = _take_along_axis( - expert_indices, jnp.expand_dims(token_ordering, axis=-1), axis=-2) - - # Identify each token's preferred expert. - # Make num_selected_experts the leading axis to ensure that top-1 choices - # have priority over top-2 choices, which have priority over top-3 - # choices... - preferred_experts = jnp.swapaxes(expert_indices, 1, 2) - # Shape: [num_groups, num_selected_experts * tokens_per_group] - preferred_experts = preferred_experts.reshape(num_groups, -1) - - # Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts]. - expert_mask = jax.nn.one_hot( - preferred_experts, num_experts, dtype=jnp.int32) - - # Experts have a fixed capacity that we cannot exceed. A token's priority - # within the expert's buffer is given by the masked, cumulative capacity of - # its target expert. - # Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts]. - token_priority = jnp.cumsum(expert_mask, axis=1) * expert_mask - 1.0 - # Shape: [num_groups, num_selected_experts, tokens_per_group, num_experts]. - token_priority = token_priority.reshape( - (num_groups, self.num_selected_experts, -1, num_experts)) - # Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts]. - token_priority = jnp.swapaxes(token_priority, 1, 2) - # For each token, across all experts, select the only non-negative - # (unmasked) priority. Shape: [num_groups, tokens_per_group, - # num_selected_experts]. - token_priority = jnp.max(token_priority, axis=-1) - - # Return to original index shape. - preferred_experts = preferred_experts.reshape(num_groups, - self.num_selected_experts, - tokens_per_group) - # Shape: [num_groups, tokens_per_group, num_selected_experts] - preferred_experts = jnp.swapaxes(preferred_experts, 1, 2) - - if self.batch_prioritized_routing: - # Place tokens in their original ordering. - inverse_token_ordering = jnp.argsort(token_ordering, axis=-1) - preferred_experts = _take_along_axis( - preferred_experts, - jnp.expand_dims(inverse_token_ordering, axis=-1), - axis=-2) - token_priority = _take_along_axis( - token_priority, - jnp.expand_dims(inverse_token_ordering, axis=-1), - axis=-2) - - # Mask out tokens that overflow the maximum expert capacities. - # Shape: [num_groups, tokens_per_group, num_selected_experts]. - combine_weights *= token_priority < expert_capacity - - # Expert index and priority within the expert capacity buffer. - # Shape: [num_groups, tokens_per_group, num_selected_experts, 2]. - dispatch_indices = jnp.stack([preferred_experts, token_priority], axis=-1) - - # Return to default dtype now that router computation is complete. - combine_weights = jax.lax.convert_element_type(combine_weights, self.dtype) - dispatch_indices = jax.lax.convert_element_type(dispatch_indices, jnp.int32) - - return RouterIndices(dispatch_indices, combine_weights, auxiliary_loss) + num_selected_experts: int + batch_prioritized_routing: bool + + def _compute_routing_instructions( + self, router_probs: Array, padding_mask: Optional[Array], expert_capacity: int + ) -> RouterIndices: + """Computes dispatch indices and combine weights for the top-k experts. + + Args: + router_probs: [num_groups, tokens_per_group, num_experts] + probabilities used to determine the routing of tokens to the experts. + padding_mask: [num_groups, tokens_per_group] padding logit mask + used to identify padding tokens that should be down-weighted by the router. + expert_capacity: Each group will send this many tokens to each expert. + + Returns: + Dispatch indices and combine weights for scatter/gather-based routing. + """ + num_groups, tokens_per_group, num_experts = router_probs.shape + + if padding_mask is not None: + # Because `expert_indices` are directly used for scatter-based routing, we + # mask probabilities corresponding to tokens before the top-k operation. + # Note that, unlike for mask-based tokens-choose routing, the + # (down-weighted) padding tokens may still be selected. + router_probs *= jnp.expand_dims(padding_mask, axis=-1) + + # Top-k router probability and corresponding expert indices for each token. + # Shape: [num_groups, tokens_per_group, num_selected_experts]. + combine_weights, expert_indices = _top_k(router_probs, k=self.num_selected_experts) + + auxiliary_loss = _load_balancing_loss(router_probs, expert_indices) + + if self.batch_prioritized_routing: + # Sort tokens according to their routing probability per token group, so + # that the highest probability tokens are routed first. + token_ordering = jnp.argsort(-combine_weights[..., 0], axis=-1) + expert_indices = _take_along_axis(expert_indices, jnp.expand_dims(token_ordering, axis=-1), axis=-2) + + # Identify each token's preferred expert. + # Make num_selected_experts the leading axis to ensure that top-1 choices + # have priority over top-2 choices, which have priority over top-3 + # choices... + preferred_experts = jnp.swapaxes(expert_indices, 1, 2) + # Shape: [num_groups, num_selected_experts * tokens_per_group] + preferred_experts = preferred_experts.reshape(num_groups, -1) + + # Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts]. + expert_mask = jax.nn.one_hot(preferred_experts, num_experts, dtype=jnp.int32) + + # Experts have a fixed capacity that we cannot exceed. A token's priority + # within the expert's buffer is given by the masked, cumulative capacity of + # its target expert. + # Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts]. + token_priority = jnp.cumsum(expert_mask, axis=1) * expert_mask - 1.0 + # Shape: [num_groups, num_selected_experts, tokens_per_group, num_experts]. + token_priority = token_priority.reshape((num_groups, self.num_selected_experts, -1, num_experts)) + # Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts]. + token_priority = jnp.swapaxes(token_priority, 1, 2) + # For each token, across all experts, select the only non-negative + # (unmasked) priority. Shape: [num_groups, tokens_per_group, + # num_selected_experts]. + token_priority = jnp.max(token_priority, axis=-1) + + # Return to original index shape. + preferred_experts = preferred_experts.reshape(num_groups, self.num_selected_experts, tokens_per_group) + # Shape: [num_groups, tokens_per_group, num_selected_experts] + preferred_experts = jnp.swapaxes(preferred_experts, 1, 2) + + if self.batch_prioritized_routing: + # Place tokens in their original ordering. + inverse_token_ordering = jnp.argsort(token_ordering, axis=-1) + preferred_experts = _take_along_axis( + preferred_experts, jnp.expand_dims(inverse_token_ordering, axis=-1), axis=-2 + ) + token_priority = _take_along_axis( + token_priority, jnp.expand_dims(inverse_token_ordering, axis=-1), axis=-2 + ) + + # Mask out tokens that overflow the maximum expert capacities. + # Shape: [num_groups, tokens_per_group, num_selected_experts]. + combine_weights *= token_priority < expert_capacity + + # Expert index and priority within the expert capacity buffer. + # Shape: [num_groups, tokens_per_group, num_selected_experts, 2]. + dispatch_indices = jnp.stack([preferred_experts, token_priority], axis=-1) + + # Return to default dtype now that router computation is complete. + combine_weights = jax.lax.convert_element_type(combine_weights, self.dtype) + dispatch_indices = jax.lax.convert_element_type(dispatch_indices, jnp.int32) + + return RouterIndices(dispatch_indices, combine_weights, auxiliary_loss) -class TokensChooseMaskedRouter(MaskedRouter): - """Masked matmul router using tokens choose top-k experts assignment. - - This router uses the same mechanism as in Switch Transformer (https://arxiv.org/abs/2101.03961) and V-MoE - (https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are sorted by router_probs and then routed - to their choice of expert until the expert's expert_capacity is reached. There is no guarantee that each token is - processed by an expert, or that each expert receives at least one token. - - Attributes: - num_selected_experts: Maximum number of experts to which each token is - routed. Tokens may be routed to fewer experts if particular experts are oversubscribed / reach capacity. - batch_prioritized_routing: Whether or not to use Batch Prioritized Routing - (BPR), originally introduced in V-MoE (https://arxiv.org/abs/2106.05974). - With BPR, we prioritize routing those top-k tokens with the highest router probability, rather than simply - using each tokens left-to-right ordering in the batch. This prioritization is important because the experts - have limited capacity. - """ - num_selected_experts: int - batch_prioritized_routing: bool - - def _compute_routing_instructions(self, router_probs: Array, - padding_mask: Optional[Array], - expert_capacity: int) -> RouterMask: - """Computes masks for the top-k experts per token. - Args: - router_probs: [num_groups, tokens_per_group, num_experts] - probabilities used to determine the routing of tokens to the experts. - padding_mask: [num_groups, tokens_per_group] padding logit mask - used to identify padding tokens that should be ignored by the router. - expert_capacity: Each group will send this many tokens to each expert. - - Returns: - Dispatch and combine arrays for routing with masked matmuls. +class TokensChooseMaskedRouter(MaskedRouter): + """Masked matmul router using tokens choose top-k experts assignment. + + This router uses the same mechanism as in Switch Transformer (https://arxiv.org/abs/2101.03961) and V-MoE + (https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are sorted by router_probs and then + routed to their choice of expert until the expert's expert_capacity is reached. There is no guarantee that each + token is processed by an expert, or that each expert receives at least one token. + + Attributes: + num_selected_experts: Maximum number of experts to which each token is + routed. Tokens may be routed to fewer experts if particular experts are oversubscribed / reach capacity. + batch_prioritized_routing: Whether or not to use Batch Prioritized Routing + (BPR), originally introduced in V-MoE (https://arxiv.org/abs/2106.05974). + With BPR, we prioritize routing those top-k tokens with the highest router probability, rather than simply + using each tokens left-to-right ordering in the batch. This prioritization is important because the experts + have limited capacity. """ - num_groups, _, num_experts = router_probs.shape - - # Top-k router probability and corresponding expert indices for each token. - # Shape: [num_groups, tokens_per_group, num_selected_experts]. - expert_gate, expert_index = _top_k( - router_probs, k=self.num_selected_experts) - - if padding_mask is not None: - # Mask applied to gate. Exclude choices corresponding to padding tokens. - gate_mask = jnp.expand_dims(padding_mask, axis=-1) - expert_gate *= gate_mask - - # Set `expert_index` elements corresponding to padding to negative - # numbers. Negative `expert_index` elements will ultimately be dropped in - # the one_hot conversion to the `expert_mask`. - # First convert nonzero padding elements to negative values. - expert_index *= 2 * gate_mask - 1. - # Handle zero padding elements by negatively shifting all padding. - expert_index += jnp.repeat( - gate_mask - 1., self.num_selected_experts, axis=-1) - - # To correctly compute load balancing loss, we also mask out probs. - router_probs *= gate_mask - - auxiliary_loss = _load_balancing_loss(router_probs, expert_index) - - if self.batch_prioritized_routing: - # Sort tokens according to their routing probability per group, so that - # the highest probability tokens are routed first. - permutation = jnp.argsort(-expert_gate[..., 0], axis=-1) - # Shape: [num_groups, tokens_per_group, num_selected_experts] - expert_index = _take_along_axis( - expert_index, jnp.expand_dims(permutation, axis=-1), axis=-2) - - # Make num_selected_experts the leading axis to ensure that top-1 choices - # have priority over top-2 choices, which have priority over top-3 choices, - # etc. - expert_index = jnp.swapaxes(expert_index, 1, 2) - # Shape: [num_groups, num_selected_experts * tokens_per_group] - expert_index = expert_index.reshape(num_groups, -1) - - # Create mask out of indices. - # Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts]. - expert_mask = jax.nn.one_hot(expert_index, num_experts, dtype=jnp.int32) - - # Experts have a fixed capacity that we cannot exceed. A token's priority - # within the expert's buffer is given by the masked, cumulative capacity of - # its target expert. - # Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts]. - token_priority = jnp.cumsum(expert_mask, axis=1) * expert_mask - 1.0 - # Shape: [num_groups, num_selected_experts, tokens_per_group, num_experts]. - token_priority = token_priority.reshape( - (num_groups, self.num_selected_experts, -1, num_experts)) - # Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts]. - token_priority = jnp.swapaxes(token_priority, 1, 2) - # For each token, across all selected experts, select the only non-negative - # (unmasked) priority. Now, for group G routing to expert E, token T has - # non-negative priority (i.e. token_priority[G,T,E] >= 0) if and only if E - # is its targeted expert. - # Shape: [num_groups, tokens_per_group, num_experts]. - token_priority = jnp.max(token_priority, axis=2) - - if self.batch_prioritized_routing: - # Place token priorities in original ordering of tokens. - inv_permutation = jnp.argsort(permutation, axis=-1) - token_priority = _take_along_axis( - token_priority, jnp.expand_dims(inv_permutation, axis=-1), axis=-2) - - # Token T can only be routed to expert E if its priority is positive and - # less than the expert capacity. One-hot matrix will ignore indices outside - # the range [0, expert_capacity). - # Shape: [num_groups, tokens_per_group, num_experts, expert_capacity]. - dispatch_mask = jax.nn.one_hot( - token_priority, expert_capacity, dtype=jnp.bool_) - - # The combine array will be used for combining expert outputs, scaled by the - # router probabilities. Shape: [num_groups, tokens_per_group, num_experts, - # expert_capacity]. - combine_array = jnp.einsum( - '...te,...tec->...tec', - router_probs, - dispatch_mask, - precision=jax.lax.Precision.DEFAULT) - - # Return to default dtype now that router computation is complete. - combine_array = jax.lax.convert_element_type(combine_array, self.dtype) - - return RouterMask(dispatch_mask, combine_array, auxiliary_loss) + + num_selected_experts: int + batch_prioritized_routing: bool + + def _compute_routing_instructions( + self, router_probs: Array, padding_mask: Optional[Array], expert_capacity: int + ) -> RouterMask: + """Computes masks for the top-k experts per token. + + Args: + router_probs: [num_groups, tokens_per_group, num_experts] + probabilities used to determine the routing of tokens to the experts. + padding_mask: [num_groups, tokens_per_group] padding logit mask + used to identify padding tokens that should be ignored by the router. + expert_capacity: Each group will send this many tokens to each expert. + + Returns: + Dispatch and combine arrays for routing with masked matmuls. + """ + num_groups, _, num_experts = router_probs.shape + + # Top-k router probability and corresponding expert indices for each token. + # Shape: [num_groups, tokens_per_group, num_selected_experts]. + expert_gate, expert_index = _top_k(router_probs, k=self.num_selected_experts) + + if padding_mask is not None: + # Mask applied to gate. Exclude choices corresponding to padding tokens. + gate_mask = jnp.expand_dims(padding_mask, axis=-1) + expert_gate *= gate_mask + + # Set `expert_index` elements corresponding to padding to negative + # numbers. Negative `expert_index` elements will ultimately be dropped in + # the one_hot conversion to the `expert_mask`. + # First convert nonzero padding elements to negative values. + expert_index *= 2 * gate_mask - 1.0 + # Handle zero padding elements by negatively shifting all padding. + expert_index += jnp.repeat(gate_mask - 1.0, self.num_selected_experts, axis=-1) + + # To correctly compute load balancing loss, we also mask out probs. + router_probs *= gate_mask + + auxiliary_loss = _load_balancing_loss(router_probs, expert_index) + + if self.batch_prioritized_routing: + # Sort tokens according to their routing probability per group, so that + # the highest probability tokens are routed first. + permutation = jnp.argsort(-expert_gate[..., 0], axis=-1) + # Shape: [num_groups, tokens_per_group, num_selected_experts] + expert_index = _take_along_axis(expert_index, jnp.expand_dims(permutation, axis=-1), axis=-2) + + # Make num_selected_experts the leading axis to ensure that top-1 choices + # have priority over top-2 choices, which have priority over top-3 choices, + # etc. + expert_index = jnp.swapaxes(expert_index, 1, 2) + # Shape: [num_groups, num_selected_experts * tokens_per_group] + expert_index = expert_index.reshape(num_groups, -1) + + # Create mask out of indices. + # Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts]. + expert_mask = jax.nn.one_hot(expert_index, num_experts, dtype=jnp.int32) + + # Experts have a fixed capacity that we cannot exceed. A token's priority + # within the expert's buffer is given by the masked, cumulative capacity of + # its target expert. + # Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts]. + token_priority = jnp.cumsum(expert_mask, axis=1) * expert_mask - 1.0 + # Shape: [num_groups, num_selected_experts, tokens_per_group, num_experts]. + token_priority = token_priority.reshape((num_groups, self.num_selected_experts, -1, num_experts)) + # Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts]. + token_priority = jnp.swapaxes(token_priority, 1, 2) + # For each token, across all selected experts, select the only non-negative + # (unmasked) priority. Now, for group G routing to expert E, token T has + # non-negative priority (i.e. token_priority[G,T,E] >= 0) if and only if E + # is its targeted expert. + # Shape: [num_groups, tokens_per_group, num_experts]. + token_priority = jnp.max(token_priority, axis=2) + + if self.batch_prioritized_routing: + # Place token priorities in original ordering of tokens. + inv_permutation = jnp.argsort(permutation, axis=-1) + token_priority = _take_along_axis(token_priority, jnp.expand_dims(inv_permutation, axis=-1), axis=-2) + + # Token T can only be routed to expert E if its priority is positive and + # less than the expert capacity. One-hot matrix will ignore indices outside + # the range [0, expert_capacity). + # Shape: [num_groups, tokens_per_group, num_experts, expert_capacity]. + dispatch_mask = jax.nn.one_hot(token_priority, expert_capacity, dtype=jnp.bool_) + + # The combine array will be used for combining expert outputs, scaled by the + # router probabilities. Shape: [num_groups, tokens_per_group, num_experts, + # expert_capacity]. + combine_array = jnp.einsum( + "...te,...tec->...tec", router_probs, dispatch_mask, precision=jax.lax.Precision.DEFAULT + ) + + # Return to default dtype now that router computation is complete. + combine_array = jax.lax.convert_element_type(combine_array, self.dtype) + + return RouterMask(dispatch_mask, combine_array, auxiliary_loss) class ExpertsChooseMaskedRouter(MaskedRouter): - """Masked matmul router using experts choose tokens assignment. + """Masked matmul router using experts choose tokens assignment. - This router uses the same mechanism as in Mixture-of-Experts with Expert Choice (https://arxiv.org/abs/2202.09368): - each expert selects its top expert_capacity tokens. An individual token may be processed by multiple experts or none - at all. + This router uses the same mechanism as in Mixture-of-Experts with Expert Choice (https://arxiv.org/abs/2202.09368): + each expert selects its top expert_capacity tokens. An individual token may be processed by multiple experts or + none at all. - Note: "experts choose routing" should not be used in decoder blocks because it breaks the autoregressive behavior -- - the model will learn to cheat by using future token information to improve current token predictions. - """ + Note: "experts choose routing" should not be used in decoder blocks because it breaks the autoregressive behavior + -- the model will learn to cheat by using future token information to improve current token predictions. + """ - def _compute_routing_instructions(self, router_probs: Array, - padding_mask: Optional[Array], - expert_capacity: int) -> RouterMask: - """Computes masks for the highest probability token per expert. + def _compute_routing_instructions( + self, router_probs: Array, padding_mask: Optional[Array], expert_capacity: int + ) -> RouterMask: + """Computes masks for the highest probability token per expert. - Args: - router_probs: [num_groups, tokens_per_group, num_experts] - probabilities used to determine the routing of tokens to the experts. - padding_mask: [num_groups, tokens_per_group] padding logit mask - used to identify padding tokens that should be down-weighted by the router. - expert_capacity: Each group will send this many tokens to each expert. + Args: + router_probs: [num_groups, tokens_per_group, num_experts] + probabilities used to determine the routing of tokens to the experts. + padding_mask: [num_groups, tokens_per_group] padding logit mask + used to identify padding tokens that should be down-weighted by the router. + expert_capacity: Each group will send this many tokens to each expert. - Returns: - Dispatch and combine arrays for routing with masked matmuls. - """ - tokens_per_group = router_probs.shape[1] + Returns: + Dispatch and combine arrays for routing with masked matmuls. + """ + tokens_per_group = router_probs.shape[1] - if padding_mask is not None: - # Because experts choose tokens, we mask probabilities corresponding to - # tokens before the top-k operation. Note that, unlike for masked-based - # tokens-choose routing, the experts here may still choose to select the - # (down-weighted) padding tokens. - router_probs *= jnp.expand_dims(padding_mask, axis=-1) + if padding_mask is not None: + # Because experts choose tokens, we mask probabilities corresponding to + # tokens before the top-k operation. Note that, unlike for masked-based + # tokens-choose routing, the experts here may still choose to select the + # (down-weighted) padding tokens. + router_probs *= jnp.expand_dims(padding_mask, axis=-1) - # vmap over group dimension. - router_probs_t = jax.vmap(lambda m: m.transpose())(router_probs) + # vmap over group dimension. + router_probs_t = jax.vmap(lambda m: m.transpose())(router_probs) - # Top expert_capacity router probability and corresponding token indices for - # each expert. Shapes: [num_groups, num_experts, expert_capacity]. - expert_gate, expert_index = _top_k(router_probs_t, k=expert_capacity) + # Top expert_capacity router probability and corresponding token indices for + # each expert. Shapes: [num_groups, num_experts, expert_capacity]. + expert_gate, expert_index = _top_k(router_probs_t, k=expert_capacity) - # Convert to one-hot mask of expert indices for each token in each group. - # Shape: [num_groups, num_experts, expert_capacity, tokens_per_group]. - dispatch_mask = jax.nn.one_hot( - expert_index, tokens_per_group, dtype=jnp.int32) + # Convert to one-hot mask of expert indices for each token in each group. + # Shape: [num_groups, num_experts, expert_capacity, tokens_per_group]. + dispatch_mask = jax.nn.one_hot(expert_index, tokens_per_group, dtype=jnp.int32) - # Move axes to conform with shape expected by MoeLayer API. - # Shape: [num_groups, tokens_per_group, num_experts, expert_capacity] - dispatch_mask = jnp.moveaxis(dispatch_mask, 3, 1) + # Move axes to conform with shape expected by MoeLayer API. + # Shape: [num_groups, tokens_per_group, num_experts, expert_capacity] + dispatch_mask = jnp.moveaxis(dispatch_mask, 3, 1) - # The combine array will be used for combining expert outputs, scaled by the - # router probabilities. Shape: [num_groups, num_experts, tokens_per_group, - # expert_capacity]. - combine_array = jnp.einsum( - '...ec,...tec->...tec', - expert_gate, - dispatch_mask, - precision=jax.lax.Precision.DEFAULT) + # The combine array will be used for combining expert outputs, scaled by the + # router probabilities. Shape: [num_groups, num_experts, tokens_per_group, + # expert_capacity]. + combine_array = jnp.einsum( + "...ec,...tec->...tec", expert_gate, dispatch_mask, precision=jax.lax.Precision.DEFAULT + ) - # Return to default dtype now that router computation is complete. - combine_array = jax.lax.convert_element_type(combine_array, self.dtype) + # Return to default dtype now that router computation is complete. + combine_array = jax.lax.convert_element_type(combine_array, self.dtype) - # Each expert is choosing tokens until it reaches full capacity, so we don't - # need an auxiliary loading balancing loss for expert choice routing. - auxiliary_loss = 0.0 + # Each expert is choosing tokens until it reaches full capacity, so we don't + # need an auxiliary loading balancing loss for expert choice routing. + auxiliary_loss = 0.0 - return RouterMask(dispatch_mask, combine_array, auxiliary_loss) + return RouterMask(dispatch_mask, combine_array, auxiliary_loss) def _load_balancing_loss(router_probs: Array, expert_indices: Array) -> float: - """Computes auxiliary load balancing loss as in Switch Transformer. + """Computes auxiliary load balancing loss as in Switch Transformer. - See Switch Transformer (https://arxiv.org/abs/2101.03961). This function implements the loss function presented in - equations (4) - (6). It aims to penalize those cases where the routing between experts is unbalanced. + See Switch Transformer (https://arxiv.org/abs/2101.03961). This function implements the loss function presented in + equations (4) - (6). It aims to penalize those cases where the routing between experts is unbalanced. - Args: - router_probs: Probability assigned to each expert per token. Shape: - [num_groups, tokens_per_group, num_experts]. - expert_indices: [num_groups, tokens_per_group, num_selected_experts] - indices identifying the top num_selected_experts for a given token. + Args: + router_probs: Probability assigned to each expert per token. Shape: + [num_groups, tokens_per_group, num_experts]. + expert_indices: [num_groups, tokens_per_group, num_selected_experts] + indices identifying the top num_selected_experts for a given token. - Returns: - The auxiliary loss. - """ - num_experts = router_probs.shape[-1] + Returns: + The auxiliary loss. + """ + num_experts = router_probs.shape[-1] - # Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts]. - expert_mask = jax.nn.one_hot(expert_indices, num_experts, dtype=jnp.int32) - # For a given token, determine if it was routed to a given expert. - # Shape: [num_groups, tokens_per_group, num_experts] - expert_mask = jnp.max(expert_mask, axis=-2) + # Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts]. + expert_mask = jax.nn.one_hot(expert_indices, num_experts, dtype=jnp.int32) + # For a given token, determine if it was routed to a given expert. + # Shape: [num_groups, tokens_per_group, num_experts] + expert_mask = jnp.max(expert_mask, axis=-2) - tokens_per_group_and_expert = jnp.mean( - expert_mask, dtype=jnp.float32, axis=-2) - router_prob_per_group_and_expert = jnp.mean( - router_probs, dtype=jnp.float32, axis=-2) - return jnp.mean( - tokens_per_group_and_expert * router_prob_per_group_and_expert, - dtype=jnp.float32) * num_experts**2 + tokens_per_group_and_expert = jnp.mean(expert_mask, dtype=jnp.float32, axis=-2) + router_prob_per_group_and_expert = jnp.mean(router_probs, dtype=jnp.float32, axis=-2) + return ( + jnp.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert, dtype=jnp.float32) * num_experts**2 + ) def _router_z_loss(router_logits: Array) -> float: - """Compute router z-loss. + """Compute router z-loss. - The router z-loss was introduced in Designing Effective Sparse Expert Models (https://arxiv.org/abs/2202.08906). It - encourages router logits to remain small in an effort to improve stability. + The router z-loss was introduced in Designing Effective Sparse Expert Models (https://arxiv.org/abs/2202.08906). + It encourages router logits to remain small in an effort to improve stability. - Args: - router_logits: [num_groups, tokens_per_group, num_experts] router - logits. + Args: + router_logits: [num_groups, tokens_per_group, num_experts] router + logits. - Returns: - Scalar router z-loss. - """ - num_groups, tokens_per_group, _ = router_logits.shape - log_z = jax.nn.logsumexp(router_logits, axis=-1) - z_loss = log_z**2 - return jnp.sum(z_loss, dtype=jnp.float32) / (num_groups * tokens_per_group) + Returns: + Scalar router z-loss. + """ + num_groups, tokens_per_group, _ = router_logits.shape + log_z = jax.nn.logsumexp(router_logits, axis=-1) + z_loss = log_z**2 + return jnp.sum(z_loss, dtype=jnp.float32) / (num_groups * tokens_per_group) num_tokens = 5 @@ -751,9 +724,7 @@ def _router_z_loss(router_logits: Array) -> float: num_selected_experts = 1 rng = jax.random.PRNGKey(0) -router_probs = jax.random.uniform( - rng, (num_tokens, num_experts), minval=0, maxval=1) -expert_indices = jax.random.randint( - rng, (num_tokens, num_selected_experts), minval=0, maxval=2) +router_probs = jax.random.uniform(rng, (num_tokens, num_experts), minval=0, maxval=1) +expert_indices = jax.random.randint(rng, (num_tokens, num_selected_experts), minval=0, maxval=2) -loss = _load_balancing_loss(router_probs, expert_indices) \ No newline at end of file +loss = _load_balancing_loss(router_probs, expert_indices) diff --git a/tests/models/switchtransformers/test_modeling_tf_switchtransformers.py b/tests/models/switchtransformers/test_modeling_tf_switchtransformers.py deleted file mode 100644 index 048c1d7e0b4ce..0000000000000 --- a/tests/models/switchtransformers/test_modeling_tf_switchtransformers.py +++ /dev/null @@ -1,1066 +0,0 @@ -# coding=utf-8 -# Copyright 2022 Google SwitchTransformers Authors and HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -from transformers import SwitchTransformersConfig, is_tf_available -from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow, tooslow -from transformers.utils import cached_property - -from ...test_configuration_common import ConfigTester -from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask - - -if is_tf_available(): - import tensorflow as tf - - from transformers import ( - BySwitchTransformersTokenizer, - SwitchTransformersTokenizer, - TFSwitchTransformersEncoderModel, - TFSwitchTransformersForConditionalGeneration, - TFSwitchTransformersModel, - ) - - -class TFSwitchTransformersModelTester: - def __init__( - self, - parent, - ): - self.parent = parent - self.batch_size = 13 - self.seq_length = 7 - self.is_training = True - self.use_input_mask = True - self.use_labels = True - self.vocab_size = 99 - self.n_positions = 14 - self.hidden_size = 32 - self.num_hidden_layers = 5 - self.num_attention_heads = 4 - self.d_ff = 37 - self.relative_attention_num_buckets = 8 - self.dropout_rate = 0.1 - self.initializer_factor = 0.002 - self.eos_token_id = 1 - self.pad_token_id = 0 - self.scope = None - - def prepare_config_and_inputs(self): - input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) - - input_mask = None - if self.use_input_mask: - input_mask = random_attention_mask([self.batch_size, self.seq_length]) - - token_labels = None - if self.use_labels: - token_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) - - config = SwitchTransformersConfig( - vocab_size=self.vocab_size, - n_positions=self.n_positions, - d_model=self.hidden_size, - d_ff=self.d_ff, - d_kv=self.hidden_size // self.num_attention_heads, - num_layers=self.num_hidden_layers, - num_heads=self.num_attention_heads, - relative_attention_num_buckets=self.relative_attention_num_buckets, - dropout_rate=self.dropout_rate, - initializer_factor=self.initializer_factor, - eos_token_id=self.eos_token_id, - bos_token_id=self.pad_token_id, - pad_token_id=self.pad_token_id, - decoder_start_token_id=self.pad_token_id, - ) - - return (config, input_ids, input_mask, token_labels) - - def create_and_check_switchtransformers_model(self, config, input_ids, input_mask, token_labels): - model = TFSwitchTransformersModel(config=config) - inputs = { - "input_ids": input_ids, - "decoder_input_ids": input_ids, - "decoder_attention_mask": input_mask, - } - result = model(inputs) - - result = model(input_ids, decoder_attention_mask=input_mask, decoder_input_ids=input_ids) - decoder_output = result.last_hidden_state - decoder_past = result.past_key_values - encoder_output = result.encoder_last_hidden_state - self.parent.assertListEqual(list(encoder_output.shape), [self.batch_size, self.seq_length, self.hidden_size]) - self.parent.assertListEqual(list(decoder_output.shape), [self.batch_size, self.seq_length, self.hidden_size]) - # There should be `num_layers` key value embeddings stored in decoder_past[1] - self.parent.assertEqual(len(decoder_past), config.num_layers) - # There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past[1] tuple - self.parent.assertEqual(len(decoder_past[0]), 4) - - def create_and_check_switchtransformers_with_lm_head(self, config, input_ids, input_mask, token_labels): - model = TFSwitchTransformersForConditionalGeneration(config=config) - inputs_dict = { - "input_ids": input_ids, - "decoder_input_ids": input_ids, - "decoder_attention_mask": input_mask, - } - - result = model(inputs_dict) - - self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) - - def create_and_check_switchtransformers_decoder_model_past( - self, config, input_ids, decoder_input_ids, attention_mask - ): - model = TFSwitchTransformersModel(config=config).get_decoder() - - input_ids = input_ids[:1, :] - self.batch_size = 1 - - # first forward pass - outputs = model(input_ids, use_cache=True) - - outputs_use_cache_conf = model(input_ids) - outputs_no_past = model(input_ids, use_cache=False) - - self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf)) - self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1) - - # create hypothetical next token and extent to next_input_ids - next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) - - # append to next input_ids and - next_input_ids = tf.concat([input_ids, next_tokens], axis=-1) - - output_from_no_past = model(next_input_ids)[0] - output_from_past = model(next_tokens, past_key_values=outputs.past_key_values)[0] - - # select random slice - random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1])) - output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx] - output_from_past_slice = output_from_past[:, 0, random_slice_idx] - - # test that outputs are equal for slice - tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3) - - def create_and_check_switchtransformers_decoder_model_attention_mask_past( - self, config, input_ids, decoder_input_ids, attention_mask - ): - model = TFSwitchTransformersModel(config=config).get_decoder() - - # create attention mask - half_seq_length = self.seq_length // 2 - attn_mask_begin = tf.ones((self.batch_size, half_seq_length), dtype=tf.int32) - attn_mask_end = tf.zeros((self.batch_size, self.seq_length - half_seq_length), dtype=tf.int32) - attn_mask = tf.concat([attn_mask_begin, attn_mask_end], axis=1) - - # first forward pass - outputs = model(input_ids, attention_mask=attn_mask, use_cache=True) - - # create hypothetical next token and extent to next_input_ids - next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) - - # change a random masked slice from input_ids - random_seq_idx_to_change = ids_tensor((1,), half_seq_length).numpy() + 1 - random_other_next_tokens = ids_tensor((self.batch_size, self.seq_length), config.vocab_size) - vector_condition = tf.range(self.seq_length) == (self.seq_length - random_seq_idx_to_change) - condition = tf.transpose( - tf.broadcast_to(tf.expand_dims(vector_condition, -1), (self.seq_length, self.batch_size)) - ) - input_ids = tf.where(condition, random_other_next_tokens, input_ids) - - # append to next input_ids and attn_mask - next_input_ids = tf.concat([input_ids, next_tokens], axis=-1) - attn_mask = tf.concat( - [attn_mask, tf.ones((attn_mask.shape[0], 1), dtype=tf.int32)], - axis=1, - ) - - # get two different outputs - output_from_no_past = model(next_input_ids, attention_mask=attn_mask)[0] - output_from_past = model(next_tokens, past_key_values=outputs.past_key_values, attention_mask=attn_mask)[0] - - # select random slice - random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).numpy().item() - output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx] - output_from_past_slice = output_from_past[:, 0, random_slice_idx] - - # test that outputs are equal for slice - tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3) - - def create_and_check_switchtransformers_decoder_model_past_large_inputs( - self, config, input_ids, decoder_input_ids, attention_mask - ): - model = TFSwitchTransformersModel(config=config).get_decoder() - - input_ids = input_ids[:1, :] - attention_mask = attention_mask[:1, :] - self.batch_size = 1 - - # first forward pass - outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) - - # create hypothetical next token and extent to next_input_ids - next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) - next_attn_mask = ids_tensor((self.batch_size, 3), 2) - - # append to next input_ids and - next_input_ids = tf.concat([input_ids, next_tokens], axis=-1) - next_attention_mask = tf.concat([attention_mask, next_attn_mask], axis=-1) - - output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)[0] - output_from_past = model( - next_tokens, attention_mask=next_attention_mask, past_key_values=outputs.past_key_values - )[0] - - self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1]) - - # select random slice - random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1])) - output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx] - output_from_past_slice = output_from_past[:, :, random_slice_idx] - - # test that outputs are equal for slice - tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3) - - def prepare_config_and_inputs_for_common(self): - config_and_inputs = self.prepare_config_and_inputs() - (config, input_ids, input_mask, token_labels) = config_and_inputs - inputs_dict = { - "input_ids": input_ids, - "decoder_input_ids": input_ids, - "decoder_attention_mask": input_mask, - } - return config, inputs_dict - - -@require_tf -class TFSwitchTransformersModelTest(TFModelTesterMixin, unittest.TestCase): - - is_encoder_decoder = True - all_model_classes = ( - (TFSwitchTransformersModel, TFSwitchTransformersForConditionalGeneration) if is_tf_available() else () - ) - all_generative_model_classes = (TFSwitchTransformersForConditionalGeneration,) if is_tf_available() else () - test_onnx = False - - def setUp(self): - self.model_tester = TFSwitchTransformersModelTester(self) - self.config_tester = ConfigTester(self, config_class=SwitchTransformersConfig, d_model=37) - - def test_config(self): - self.config_tester.run_common_tests() - - def test_switchtransformers_model(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_switchtransformers_model(*config_and_inputs) - - def test_switchtransformers_model_v1_1(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - config = config_and_inputs[0] - config.tie_word_embeddings = False - config.feed_forward_proj = "gated-gelu" - self.model_tester.create_and_check_switchtransformers_model(config, *config_and_inputs[1:]) - - def test_with_lm_head(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_switchtransformers_with_lm_head(*config_and_inputs) - - def test_switchtransformers_decoder_model_past(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_switchtransformers_decoder_model_past(*config_and_inputs) - - def test_switchtransformers_decoder_model_past_with_attn_mask(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_switchtransformers_decoder_model_attention_mask_past(*config_and_inputs) - - def test_switchtransformers_decoder_model_past_large_inputs(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - - # `create_and_check_switchtransformers_decoder_model_past_large_inputs` has special inputs: - # (config, input_ids, decoder_input_ids, attention_mask) - # and we have to prepare it correctly here. - config, input_ids, input_mask, token_labels = config_and_inputs - config_and_inputs = (config, input_ids, None, input_mask) - - self.model_tester.create_and_check_switchtransformers_decoder_model_past_large_inputs(*config_and_inputs) - - def test_model_common_attributes(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - model = model_class(config) - assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer) - - if model_class in self.all_generative_model_classes: - x = model.get_output_embeddings() - assert isinstance(x, tf.keras.layers.Layer) - name = model.get_bias() - assert name is None - else: - x = model.get_output_embeddings() - assert x is None - name = model.get_bias() - assert name is None - - @tooslow - def test_saved_model_creation(self): - pass - - @slow - def test_model_from_pretrained(self): - model = TFSwitchTransformersModel.from_pretrained("ybelkada/switchtransformers-base") - self.assertIsNotNone(model) - - def test_generate_with_headmasking(self): - # TODO: Fix head-masking according to PyTorch SwitchTransformers model - pass - - @slow - def test_resize_embeddings(self): - model = TFSwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switchtransformers-base") - original_vocab_size = model.get_input_embeddings().weight.shape[0] - # the vocab size is defined in the model config - self.assertEqual(original_vocab_size, model.config.vocab_size) - - tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switchtransformers-base") - tokenizer.add_special_tokens({"bos_token": "", "eos_token": ""}) - model._resize_token_embeddings(len(tokenizer)) - # the vocab size is now resized to the length of the tokenizer, which is different from the original size - self.assertEqual(model.get_input_embeddings().weight.shape[0], len(tokenizer)) - self.assertNotEqual(model.get_input_embeddings().weight.shape[0], original_vocab_size) - - # This test is run in `TFSwitchTransformersEncoderOnlyModelTest`, where the main layer has the same inputs as the model - @unittest.skip(reason="The inputs of the Main Layer are different.") - def test_keras_save_load(self): - pass - - -class TFSwitchTransformersEncoderOnlyModelTester: - def __init__( - self, - parent, - vocab_size=99, - batch_size=13, - encoder_seq_length=7, - # For common tests - use_attention_mask=True, - hidden_size=32, - num_hidden_layers=5, - num_attention_heads=4, - d_ff=37, - relative_attention_num_buckets=8, - is_training=False, - dropout_rate=0.1, - initializer_factor=0.002, - is_encoder_decoder=False, - eos_token_id=1, - pad_token_id=0, - scope=None, - ): - - self.parent = parent - self.batch_size = batch_size - self.encoder_seq_length = encoder_seq_length - # For common tests - self.seq_length = self.encoder_seq_length - self.use_attention_mask = use_attention_mask - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.d_ff = d_ff - self.relative_attention_num_buckets = relative_attention_num_buckets - self.dropout_rate = dropout_rate - self.initializer_factor = initializer_factor - self.eos_token_id = eos_token_id - self.pad_token_id = pad_token_id - self.is_encoder_decoder = is_encoder_decoder - self.scope = None - self.is_training = is_training - - def prepare_config_and_inputs(self): - input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size) - - attention_mask = None - if self.use_attention_mask: - attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2) - - config = SwitchTransformersConfig( - vocab_size=self.vocab_size, - d_model=self.hidden_size, - d_ff=self.d_ff, - d_kv=self.hidden_size // self.num_attention_heads, - num_layers=self.num_hidden_layers, - num_heads=self.num_attention_heads, - relative_attention_num_buckets=self.relative_attention_num_buckets, - dropout_rate=self.dropout_rate, - initializer_factor=self.initializer_factor, - eos_token_id=self.eos_token_id, - bos_token_id=self.pad_token_id, - pad_token_id=self.pad_token_id, - is_encoder_decoder=self.is_encoder_decoder, - ) - - return ( - config, - input_ids, - attention_mask, - ) - - def create_and_check_model( - self, - config, - input_ids, - attention_mask, - ): - model = TFSwitchTransformersEncoderModel(config=config) - result = model( - input_ids=input_ids, - attention_mask=attention_mask, - ) - result = model(input_ids=input_ids) - encoder_output = result.last_hidden_state - - self.parent.assertEqual(encoder_output.shape, (self.batch_size, self.encoder_seq_length, self.hidden_size)) - - def prepare_config_and_inputs_for_common(self): - config_and_inputs = self.prepare_config_and_inputs() - ( - config, - input_ids, - attention_mask, - ) = config_and_inputs - - inputs_dict = { - "input_ids": input_ids, - "attention_mask": attention_mask, - } - return config, inputs_dict - - -class TFSwitchTransformersEncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase): - is_encoder_decoder = False - all_model_classes = (TFSwitchTransformersEncoderModel,) if is_tf_available() else () - test_onnx = False - - def setUp(self): - self.model_tester = TFSwitchTransformersEncoderOnlyModelTester(self) - self.config_tester = ConfigTester(self, config_class=SwitchTransformersConfig, d_model=37) - - def test_config(self): - self.config_tester.run_common_tests() - - def test_model(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_model(*config_and_inputs) - - # is not able to be part of a pipeline - def test_train_pipeline_custom_model(self): - pass - - -@require_tf -@require_sentencepiece -@require_tokenizers -class TFSwitchTransformersGenerationIntegrationTests(unittest.TestCase): - @slow - def test_greedy_xla_generate_simple(self): - model = TFSwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switchtransformers-base") - tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switchtransformers-base") - - # two examples with different lengths to confirm that attention masks are operational in XLA - sentences = [ - "Translate English to German: Today is a beautiful day.", - "Translate English to German: I have four cats, three dogs, two birds, and a horse.", - ] - input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids - - xla_generate = tf.function(model.generate, jit_compile=True) - - output_ids = model.generate(input_ids) - output_ids_xla = xla_generate(input_ids) - - output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) - output_strings_xla = tokenizer.batch_decode(output_ids_xla, skip_special_tokens=True) - - expected_output_string = [ - "Heute ist ein schöner Tag.", - "Ich habe vier Katzen, drei Hunde, zwei Vögel und ein Pferd.", - ] - - self.assertListEqual(expected_output_string, output_strings) - self.assertListEqual(expected_output_string, output_strings_xla) - - @slow - def test_greedy_generate(self): - model = TFSwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switchtransformers-base") - tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switchtransformers-base") - - sentences = ["Yesterday, my name was", "Today is a beautiful day and"] - input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids - - generation_kwargs = { - "bad_words_ids": [tokenizer("my").input_ids, tokenizer("ein schöner").input_ids], - "no_repeat_ngram_size": 3, - "do_sample": False, - "repetition_penalty": 2.2, - } - - output_ids = model.generate(input_ids, **generation_kwargs) - - output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) - - expected_output_string = ["Yesterday, my name was", "Heute ist ein schöne Tag und"] - - self.assertListEqual(expected_output_string, output_strings) - - @slow - def test_sample_xla_generate_simple(self): - # NOTE: due to the small numerical differences that are natural when we compile to XLA, sampling the same - # output out of the same seed is far from guaranteed. We can, however, confirm that the results are sensible - # and that we can seed both versions. - - # forces the generation to happen on CPU, to avoid GPU-related quirks - with tf.device(":/CPU:0"): - model = TFSwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switchtransformers-base") - tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switchtransformers-base") - - sentence = "Translate English to German: I have two bananas" - input_ids = tokenizer(sentence, return_tensors="tf", padding=True).input_ids - expected_output_string = ["Ich habe zwei Bananen"] - expected_output_string_xla = ["Ich habe 2 Bananen"] - - # seed set -> deterministic sampling sequence -> deterministic generation - output_ids = model.generate(input_ids, do_sample=True, seed=[42, 0]) - output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) - self.assertListEqual(expected_output_string, output_strings) - - xla_generate = tf.function(model.generate, jit_compile=True) - # seed set -> deterministic sampling sequence -> deterministic generation - output_ids_xla = xla_generate(input_ids, do_sample=True, seed=[42, 0]) - output_strings_xla = tokenizer.batch_decode(output_ids_xla, skip_special_tokens=True) - self.assertListEqual(expected_output_string_xla, output_strings_xla) - - @slow - def test_sample_generate(self): - model = TFSwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switchtransformers-base") - tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switchtransformers-base") - - sentences = ["I really love my", "Translate English to German: the transformers are truly amazing"] - input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids - - generation_kwargs = { - "do_sample": True, - "bad_words_ids": [tokenizer("my").input_ids, tokenizer("ein schöner").input_ids], - "no_repeat_ngram_size": 3, - "repetition_penalty": 2.2, - "temperature": 0.8, - "top_k": 500, - "top_p": 0.9, - "seed": [20, 0], # seed set -> deterministic sampling sequence -> deterministic generation - } - - # forces the generation to happen on CPU, to avoid GPU-related quirks - with tf.device(":/CPU:0"): - output_ids = model.generate(input_ids, **generation_kwargs) - - output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) - - expected_output_string = ["- I really love my way of this.", "die Transformatoren sind wirklich erstaunlich"] - - self.assertListEqual(expected_output_string, output_strings) - - @slow - def test_beam_search_xla_generate_simple(self): - model = TFSwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switchtransformers-base") - tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switchtransformers-base") - - # tests XLA with task specific arguments - task_specific_config = getattr(model.config, "task_specific_params", {}) - translation_config = task_specific_config.get("translation_en_to_fr", {}) - model.config.update(translation_config) - - # two examples with different lengths to confirm that attention masks are operational in XLA - sentences = [ - model.config.prefix + "Today is a beautiful day.", - model.config.prefix + "I have four cats, three dogs, two birds, and a horse.", - ] - input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids - - xla_generate = tf.function(model.generate, jit_compile=True) - - output_ids = model.generate(input_ids, num_beams=2) - output_ids_xla = xla_generate(input_ids, num_beams=2) - - output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) - output_strings_xla = tokenizer.batch_decode(output_ids_xla, skip_special_tokens=True) - - expected_output_string = [ - "Aujourd'hui est une belle journée.", - "J'ai quatre chats, trois chiens, deux oiseaux et un cheval.", - ] - - self.assertListEqual(expected_output_string, output_strings) - self.assertListEqual(expected_output_string, output_strings_xla) - - @slow - def test_beam_search_generate(self): - model = TFSwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switchtransformers-base") - tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switchtransformers-base") - - sentences = ["I really love my", "Translate English to German: the transformers are truly amazing"] - input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids - - generation_kwargs = { - "bad_words_ids": [tokenizer("my").input_ids, tokenizer("ein schöner").input_ids], - "no_repeat_ngram_size": 3, - "do_sample": False, - "repetition_penalty": 2.2, - "num_beams": 4, - } - - output_ids = model.generate(input_ids, **generation_kwargs) - - output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) - - expected_output_string = ["Ich liebe es so sehr!", "die Transformatoren sind wirklich erstaunlich"] - self.assertListEqual(expected_output_string, output_strings) - - -@require_tf -@require_sentencepiece -@require_tokenizers -class TFSwitchTransformersModelIntegrationTests(unittest.TestCase): - @cached_property - def model(self): - return TFSwitchTransformersForConditionalGeneration.from_pretrained("switchtransformers-base") - - @slow - def test_small_integration_test(self): - """ - For comparision run: - >>> import switchtransformers # pip install switchtransformers==0.7.1 - >>> from switchtransformers.data.sentencepiece_vocabulary import SentencePieceVocabulary - - >>> path_to_mtf_small_switchtransformers_checkpoint = '' - >>> path_to_mtf_small_spm_model_path = '' - >>> switchtransformers_model = switchtransformers.models.MtfModel(model_dir=path_to_mtf_small_switchtransformers_checkpoint, batch_size=1, tpu=None) - >>> vocab = SentencePieceVocabulary(path_to_mtf_small_spm_model_path, extra_ids=100) - >>> score = switchtransformers_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) - """ - - model = TFSwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switchtransformers-base") - tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switchtransformers-base") - - input_ids = tokenizer("Hello there", return_tensors="tf").input_ids - labels = tokenizer("Hi I am", return_tensors="tf").input_ids - - loss = model(input_ids, labels=labels).loss - mtf_score = -tf.math.reduce_mean(loss).numpy() - - EXPECTED_SCORE = -4.771147 - self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4) - - @slow - def test_small_v1_1_integration_test(self): - """ - For comparision run: - >>> import switchtransformers # pip install switchtransformers==0.7.1 - >>> from switchtransformers.data.sentencepiece_vocabulary import SentencePieceVocabulary - - >>> path_to_mtf_small_switchtransformers_v1.1_checkpoint = '' - >>> path_to_mtf_small_spm_model_path = '' - >>> switchtransformers_model = switchtransformers.models.MtfModel(model_dir=path_to_mtf_small_switchtransformers_v1.1_checkpoint, batch_size=1, tpu=None) - >>> vocab = SentencePieceVocabulary(path_to_mtf_small_spm_model_path, extra_ids=100) - >>> score = switchtransformers_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) - """ - - model = TFSwitchTransformersForConditionalGeneration.from_pretrained("google/switchtransformers-v1_1-small") - tokenizer = SwitchTransformersTokenizer.from_pretrained("google/switchtransformers-v1_1-small") - - input_ids = tokenizer("Hello there", return_tensors="tf").input_ids - labels = tokenizer("Hi I am", return_tensors="tf").input_ids - - loss = model(input_ids, labels=labels).loss - mtf_score = -tf.math.reduce_mean(loss).numpy() - - EXPECTED_SCORE = -14.757326 - self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4) - - @slow - def test_small_byswitchtransformers_integration_test(self): - """ - For comparision run: - >>> import switchtransformers # pip install switchtransformers==0.9.1 - - >>> path_to_byswitchtransformers_small_checkpoint = '' - >>> switchtransformers_model = switchtransformers.models.MtfModel(model_dir=path_to_tf_checkpoint, batch_size=1, tpu=None) - >>> vocab = switchtransformers.data.ByteVocabulary() - >>> score = switchtransformers_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) - """ - - model = TFSwitchTransformersForConditionalGeneration.from_pretrained( - "google/byybelkada/switchtransformers-base" - ) - tokenizer = BySwitchTransformersTokenizer.from_pretrained("google/byybelkada/switchtransformers-base") - - input_ids = tokenizer("Hello there", return_tensors="tf").input_ids - labels = tokenizer("Hi I am", return_tensors="tf").input_ids - - loss = model(input_ids, labels=labels).loss - mtf_score = -tf.math.reduce_mean(loss).numpy() - - EXPECTED_SCORE = -7.592465 - self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4) - - @slow - def test_summarization(self): - model = self.model - tok = SwitchTransformersTokenizer.from_pretrained("switchtransformers-base") - - FRANCE_ARTICLE = ( # @noqa - "Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings" - " Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane." - ' Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation."' - ' He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s' - " comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video" - " showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French" - " Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a" - " phone at the wreckage site. The two publications described the supposed video, but did not post it on" - " their websites. The publications said that they watched the video, which was found by a source close to" - " the investigation. \"One can hear cries of 'My God' in several languages,\" Paris Match reported." - ' "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the' - " cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the" - ' screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt,' - " editor-in-chief of Bild online. An official with France's accident investigation agency, the BEA, said" - " the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie spokesman" - " in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the" - ' reports were "completely wrong" and "unwarranted." Cell phones have been collected at the site, he said,' - ' but that they "hadn\'t been exploited yet." Menichini said he believed the cell phones would need to be' - " sent to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by" - " specialized technicians working hand-in-hand with investigators. But none of the cell phones found so" - " far have been sent to the institute, Menichini said. Asked whether staff involved in the search could" - ' have leaked a memory card to the media, Menichini answered with a categorical "no." Reichelt told "Erin' - ' Burnett: Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match' - ' are "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered' - ' cell phones from the crash site after Bild and Paris Match published their reports. "That is something' - " we did not know before. ... Overall we can say many things of the investigation weren't revealed by the" - ' investigation at the beginning," he said. What was mental state of Germanwings co-pilot? German airline' - " Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled depression years before he took the" - " controls of Germanwings Flight 9525, which he's accused of deliberately crashing last week in the" - ' French Alps. Lubitz told his Lufthansa flight training school in 2009 that he had a "previous episode of' - ' severe depression," the airline said Tuesday. Email correspondence between Lubitz and the school' - " discovered in an internal investigation, Lufthansa said, included medical documents he submitted in" - " connection with resuming his flight training. The announcement indicates that Lufthansa, the parent" - " company of Germanwings, knew of Lubitz's battle with depression, allowed him to continue training and" - " ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously said Lubitz was 100%" - ' fit to fly, described its statement Tuesday as a "swift and seamless clarification" and said it was' - " sharing the information and documents -- including training and medical records -- with public" - " prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the" - " past week to recover human remains and plane debris scattered across a steep mountainside. He saw the" - " crisis center set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash" - " site, where grieving families have left flowers at a simple stone memorial. Menichini told CNN late" - " Tuesday that no visible human remains were left at the site but recovery teams would keep searching." - " French President Francois Hollande, speaking Tuesday, said that it should be possible to identify all" - " the victims using DNA analysis by the end of the week, sooner than authorities had previously suggested." - " In the meantime, the recovery of the victims' personal belongings will start Wednesday, Menichini said." - " Among those personal belongings could be more cell phones belonging to the 144 passengers and six crew" - " on board. Check out the latest from our correspondents . The details about Lubitz's correspondence with" - " the flight school during his training were among several developments as investigators continued to" - " delve into what caused the crash and Lubitz's possible motive for downing the jet. A Lufthansa" - " spokesperson told CNN on Tuesday that Lubitz had a valid medical certificate, had passed all his" - ' examinations and "held all the licenses required." Earlier, a spokesman for the prosecutor\'s office in' - " Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz suffered from suicidal tendencies at" - " some point before his aviation career and underwent psychotherapy before he got his pilot's license." - " Kumpa emphasized there's no evidence suggesting Lubitz was suicidal or acting aggressively before the" - " crash. Investigators are looking into whether Lubitz feared his medical condition would cause him to" - " lose his pilot's license, a European government official briefed on the investigation told CNN on" - ' Tuesday. While flying was "a big part of his life," the source said, it\'s only one theory being' - " considered. Another source, a law enforcement official briefed on the investigation, also told CNN that" - " authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would" - " not be allowed to fly because of his medical problems. Lubitz's girlfriend told investigators he had" - " seen an eye doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded" - " he had psychological issues, the European government official said. But no matter what details emerge" - " about his previous mental health struggles, there's more to the story, said Brian Russell, a forensic" - ' psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the fact' - " that maybe they weren't going to keep doing their job and they're upset about that and so they're" - ' suicidal," he said. "But there is no mental illness that explains why somebody then feels entitled to' - " also take that rage and turn it outward on 149 other people who had nothing to do with the person's" - ' problems." Germanwings crash compensation: What we know . Who was the captain of Germanwings Flight' - " 9525? CNN's Margot Haddad reported from Marseille and Pamela Brown from Dusseldorf, while Laura" - " Smith-Spark wrote from London. CNN's Frederik Pleitgen, Pamela Boykoff, Antonia Mortensen, Sandrine" - " Amiel and Anna-Maja Rappard contributed to this report." - ) - - SHORTER_ARTICLE = ( - "(CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on" - " Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The" - " formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based." - " The Palestinians signed the ICC's founding Rome Statute in January, when they also accepted its" - ' jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East' - ' Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the' - " situation in Palestinian territories, paving the way for possible war crimes investigations against" - " Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and" - " the United States, neither of which is an ICC member, opposed the Palestinians' efforts to join the" - " body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday's ceremony, said it was a" - ' move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the' - ' world is also a step closer to ending a long era of impunity and injustice," he said, according to an' - ' ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge' - " Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the" - ' Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine' - " acquires all the rights as well as responsibilities that come with being a State Party to the Statute." - ' These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights' - ' Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should' - " immediately end their pressure, and countries that support universal acceptance of the court's treaty" - ' should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the' - " group. \"What's objectionable is the attempts to undermine international justice, not Palestine's" - ' decision to join a treaty to which over 100 countries around the world are members." In January, when' - " the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an" - ' outrage, saying the court was overstepping its boundaries. The United States also said it "strongly"' - " disagreed with the court's decision. \"As we have said repeatedly, we do not believe that Palestine is a" - ' state and therefore we do not believe that it is eligible to join the ICC," the State Department said in' - ' a statement. It urged the warring sides to resolve their differences through direct negotiations. "We' - ' will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace,"' - " it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the" - ' territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the' - " court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou" - ' Bensouda said her office would "conduct its analysis in full independence and impartiality." The war' - " between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry" - " will include alleged war crimes committed since June. The International Criminal Court was set up in" - " 2002 to prosecute genocide, crimes against humanity and war crimes. CNN's Vasco Cotovio, Kareem Khadder" - " and Faith Karimi contributed to this report." - ) - - IRAN_ARTICLE = ( - "(CNN)The United States and its negotiating partners reached a very strong framework agreement with Iran" - " in Lausanne, Switzerland, on Thursday that limits Iran's nuclear program in such a way as to effectively" - " block it from building a nuclear weapon. Expect pushback anyway, if the recent past is any harbinger." - " Just last month, in an attempt to head off such an agreement, House Speaker John Boehner invited Israeli" - " Prime Minister Benjamin Netanyahu to preemptively blast it before Congress, and 47 senators sent a" - " letter to the Iranian leadership warning them away from a deal. The debate that has already begun since" - " the announcement of the new framework will likely result in more heat than light. It will not be helped" - " by the gathering swirl of dubious assumptions and doubtful assertions. Let us address some of these: ." - " The most misleading assertion, despite universal rejection by experts, is that the negotiations'" - " objective at the outset was the total elimination of any nuclear program in Iran. That is the position" - " of Netanyahu and his acolytes in the U.S. Congress. But that is not and never was the objective. If it" - " had been, there would have been no Iranian team at the negotiating table. Rather, the objective has" - " always been to structure an agreement or series of agreements so that Iran could not covertly develop a" - " nuclear arsenal before the United States and its allies could respond. The new framework has exceeded" - " expectations in achieving that goal. It would reduce Iran's low-enriched uranium stockpile, cut by" - " two-thirds its number of installed centrifuges and implement a rigorous inspection regime. Another" - " dubious assumption of opponents is that the Iranian nuclear program is a covert weapons program. Despite" - " sharp accusations by some in the United States and its allies, Iran denies having such a program, and" - " U.S. intelligence contends that Iran has not yet made the decision to build a nuclear weapon. Iran's" - " continued cooperation with International Atomic Energy Agency inspections is further evidence on this" - " point, and we'll know even more about Iran's program in the coming months and years because of the deal." - " In fact, the inspections provisions that are part of this agreement are designed to protect against any" - " covert action by the Iranians. What's more, the rhetoric of some members of Congress has implied that" - " the negotiations have been between only the United States and Iran (i.e., the 47 senators' letter" - " warning that a deal might be killed by Congress or a future president). This of course is not the case." - " The talks were between Iran and the five permanent members of the U.N. Security Council (United States," - " United Kingdom, France, China and Russia) plus Germany, dubbed the P5+1. While the United States has" - " played a leading role in the effort, it negotiated the terms alongside its partners. If the agreement" - " reached by the P5+1 is rejected by Congress, it could result in an unraveling of the sanctions on Iran" - " and threaten NATO cohesion in other areas. Another questionable assertion is that this agreement" - " contains a sunset clause, after which Iran will be free to do as it pleases. Again, this is not the" - " case. Some of the restrictions on Iran's nuclear activities, such as uranium enrichment, will be eased" - " or eliminated over time, as long as 15 years. But most importantly, the framework agreement includes" - " Iran's ratification of the Additional Protocol, which allows IAEA inspectors expanded access to nuclear" - " sites both declared and nondeclared. This provision will be permanent. It does not sunset. Thus, going" - " forward, if Iran decides to enrich uranium to weapons-grade levels, monitors will be able to detect such" - " a move in a matter of days and alert the U.N. Security Council. Many in Congress have said that the" - ' agreement should be a formal treaty requiring the Senate to "advise and consent." But the issue is not' - " suited for a treaty. Treaties impose equivalent obligations on all signatories. For example, the New" - " START treaty limits Russia and the United States to 1,550 deployed strategic warheads. But any agreement" - " with Iran will not be so balanced. The restrictions and obligations in the final framework agreement" - " will be imposed almost exclusively on Iran. The P5+1 are obligated only to ease and eventually remove" - " most but not all economic sanctions, which were imposed as leverage to gain this final deal. Finally" - " some insist that any agreement must address Iranian missile programs, human rights violations or support" - " for Hamas or Hezbollah. As important as these issues are, and they must indeed be addressed, they are" - " unrelated to the most important aim of a nuclear deal: preventing a nuclear Iran. To include them in" - " the negotiations would be a poison pill. This agreement should be judged on its merits and on how it" - " affects the security of our negotiating partners and allies, including Israel. Those judgments should be" - " fact-based, not based on questionable assertions or dubious assumptions." - ) - - ARTICLE_SUBWAY = ( - "New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A" - " year later, she got married again in Westchester County, but to a different man and without divorcing" - " her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos" - ' declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married' - " once more, this time in the Bronx. In an application for a marriage license, she stated it was her" - ' "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false' - ' instrument for filing in the first degree," referring to her false statements on the 2010 marriage' - " license application, according to court documents. Prosecutors said the marriages were part of an" - " immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to" - " her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was" - " arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New" - " York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total," - " Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All" - " occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be" - " married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors" - " said the immigration scam involved some of her husbands, who filed for permanent residence status" - " shortly after the marriages. Any divorces happened only after such filings were approved. It was" - " unclear whether any of the men will be prosecuted. The case was referred to the Bronx District" - " Attorney's Office by Immigration and Customs Enforcement and the Department of Homeland Security's" - ' Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt,' - " Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his" - " native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces" - " up to four years in prison. Her next court appearance is scheduled for May 18." - ) - - expected_summaries = [ - 'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a' - " cell phone video of the final seconds . \"one can hear cries of 'My God' in several languages,\" one" - " magazine says .", - "the formal accession was marked by a ceremony at The Hague, in the Netherlands . the ICC opened a" - " preliminary examination into the situation in the occupied Palestinian territory . as members of the" - " court, Palestinians may be subject to counter-charges as well .", - "the u.s. and its negotiating partners reached a very strong framework agreement with Iran . aaron miller:" - " the debate that has already begun since the announcement of the new framework will likely result in more" - " heat than light . the deal would reduce Iran's low-enriched uranium stockpile, cut centrifuges and" - " implement a rigorous inspection regime .", - "prosecutors say the marriages were part of an immigration scam . if convicted, barrientos faces two" - ' criminal counts of "offering a false instrument for filing in the first degree" she has been married 10' - " times, with nine of her marriages occurring between 1999 and 2002 .", - ] - - task_specific_config = getattr(model.config, "task_specific_params", {}) - summarization_config = task_specific_config.get("summarization", {}) - model.config.update(summarization_config) - - dct = tok( - [model.config.prefix + x for x in [FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY]], - max_length=512, - padding="max_length", - truncation=True, - return_tensors="tf", - ) - self.assertEqual(512, dct["input_ids"].shape[1]) - - hypotheses_batch = model.generate( - input_ids=dct["input_ids"], - attention_mask=dct["attention_mask"], - num_beams=4, - length_penalty=2.0, - max_length=142, - min_length=56, - no_repeat_ngram_size=3, - do_sample=False, - early_stopping=True, - ) - - decoded = [ - tok.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in hypotheses_batch - ] - - self.assertListEqual( - expected_summaries, - decoded, - ) - - @slow - def test_translation_en_to_de(self): - tok = SwitchTransformersTokenizer.from_pretrained("switchtransformers-base") - model = self.model - - task_specific_config = getattr(model.config, "task_specific_params", {}) - translation_config = task_specific_config.get("translation_en_to_de", {}) - self.model.config.update(translation_config) - - original_input = '"Luigi often said to me that he never wanted the brothers to end up in court", she wrote.' - expected_translation = ( - '"Luigi sagte mir oft, dass er nie wollte, dass die Brüder am Gericht sitzen", schrieb sie.' - ) - - input_ids = tok.encode(model.config.prefix + original_input, return_tensors="tf") - - output = model.generate( - input_ids=input_ids, - num_beams=4, - length_penalty=2.0, - max_length=50, - no_repeat_ngram_size=3, - do_sample=False, - early_stopping=True, - ) - translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) - - self.assertEqual(translation, expected_translation) - - @slow - def test_translation_en_to_fr(self): - model = self.model - tok = SwitchTransformersTokenizer.from_pretrained("switchtransformers-base") - - task_specific_config = getattr(model.config, "task_specific_params", {}) - translation_config = task_specific_config.get("translation_en_to_fr", {}) - model.config.update(translation_config) - - en_text = ( - ' This image section from an infrared recording by the Spitzer telescope shows a "family portrait" of' - " countless generations of stars: the oldest stars are seen as blue dots. " - ) - - new_truncated_translation = ( - "Cette section d'images provenant de l'enregistrement infrarouge effectué par le télescope Spitzer montre " - "un " - "« portrait familial » de générations innombrables d’étoiles : les plus anciennes sont observées " - "sous forme " - "de points bleus." - ) - - input_ids = tok(model.config.prefix + en_text, return_tensors="tf").input_ids - - output = model.generate( - input_ids=input_ids, - num_beams=4, - length_penalty=2.0, - max_length=100, - no_repeat_ngram_size=3, - do_sample=False, - early_stopping=True, - ) - translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) - - self.assertEqual(translation, new_truncated_translation) - - @slow - def test_translation_en_to_ro(self): - model = self.model - tok = SwitchTransformersTokenizer.from_pretrained("switchtransformers-base") - - task_specific_config = getattr(model.config, "task_specific_params", {}) - translation_config = task_specific_config.get("translation_en_to_ro", {}) - model.config.update(translation_config) - - original_input = "Taco Bell said it plans to add 2,000 locations in the US by 2022." - expected_translation = "Taco Bell a declarat că intenţionează să adauge 2 000 de locaţii în SUA până în 2022." - - input_ids = tok.encode(model.config.prefix + original_input, return_tensors="tf") - - output = model.generate( - input_ids=input_ids, - num_beams=4, - length_penalty=2.0, - max_length=50, - no_repeat_ngram_size=3, - do_sample=False, - early_stopping=True, - ) - translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) - - self.assertEqual(translation, expected_translation) From cddfce7965fb6b133f7bc345227f7aff033dd40e Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 4 Oct 2022 18:23:56 +0100 Subject: [PATCH 005/102] clean up - remove `tf` modeling files --- docs/source/en/index.mdx | 2 +- .../en/model_doc/switchtransformers.mdx | 14 --------- src/transformers/__init__.py | 16 ---------- .../models/auto/modeling_tf_auto.py | 4 --- .../configuration_switchtransformers.py | 4 +-- src/transformers/utils/dummy_tf_objects.py | 31 ------------------- 6 files changed, 3 insertions(+), 68 deletions(-) diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index 82172cccb1c04..54d6f934ef209 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -310,7 +310,7 @@ Flax), PyTorch, and/or TensorFlow. | SqueezeBERT | ✅ | ✅ | ✅ | ❌ | ❌ | | Swin Transformer | ❌ | ❌ | ✅ | ✅ | ❌ | | Swin Transformer V2 | ❌ | ❌ | ✅ | ❌ | ❌ | -| SwitchTransformers | ✅ | ✅ | ✅ | ✅ | ✅ | +| SwitchTransformers | ✅ | ✅ | ✅ | ❌ | ✅ | | T5 | ✅ | ✅ | ✅ | ✅ | ✅ | | TAPAS | ✅ | ❌ | ✅ | ✅ | ❌ | | Time Series Transformer | ❌ | ❌ | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/switchtransformers.mdx b/docs/source/en/model_doc/switchtransformers.mdx index 0ba701599c7e9..e602041cb52eb 100644 --- a/docs/source/en/model_doc/switchtransformers.mdx +++ b/docs/source/en/model_doc/switchtransformers.mdx @@ -66,20 +66,6 @@ The original code can be found [here](). - parallelize - deparallelize -## TFSwitchTransformersModel - -[[autodoc]] TFSwitchTransformersModel - - call - -## TFSwitchTransformersForConditionalGeneration - -[[autodoc]] TFSwitchTransformersForConditionalGeneration - - call - -## TFSwitchTransformersEncoderModel - -[[autodoc]] TFSwitchTransformersEncoderModel - - call ## FlaxSwitchTransformersModel diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 23a34d2011927..ebc2792e6d1e2 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -2701,15 +2701,6 @@ "TFT5PreTrainedModel", ] ) - _import_structure["models.switchtransformers"].extend( - [ - "TF_SWITCHTRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST", - "TFSwitchTransformersEncoderModel", - "TFSwitchTransformersForConditionalGeneration", - "TFSwitchTransformersModel", - "TFSwitchTransformersPreTrainedModel", - ] - ) _import_structure["models.tapas"].extend( [ "TF_TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -5269,13 +5260,6 @@ TFSwinModel, TFSwinPreTrainedModel, ) - from .models.switchtransformers import ( - TF_SWITCHTRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST, - TFSwitchTransformersEncoderModel, - TFSwitchTransformersForConditionalGeneration, - TFSwitchTransformersModel, - TFSwitchTransformersPreTrainedModel, - ) from .models.t5 import ( TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST, TFT5EncoderModel, diff --git a/src/transformers/models/auto/modeling_tf_auto.py b/src/transformers/models/auto/modeling_tf_auto.py index 0d8c9280cdeb6..e13a0754b6926 100644 --- a/src/transformers/models/auto/modeling_tf_auto.py +++ b/src/transformers/models/auto/modeling_tf_auto.py @@ -74,7 +74,6 @@ ("segformer", "TFSegformerModel"), ("speech_to_text", "TFSpeech2TextModel"), ("swin", "TFSwinModel"), - ("switchtransformers", "TFSwitchTransformersModel"), ("t5", "TFT5Model"), ("tapas", "TFTapasModel"), ("transfo-xl", "TFTransfoXLModel"), @@ -107,7 +106,6 @@ ("mpnet", "TFMPNetForMaskedLM"), ("openai-gpt", "TFOpenAIGPTLMHeadModel"), ("roberta", "TFRobertaForMaskedLM"), - ("switchtransformers", "TFSwitchTransformersForConditionalGeneration"), ("t5", "TFT5ForConditionalGeneration"), ("tapas", "TFTapasForMaskedLM"), ("transfo-xl", "TFTransfoXLLMHeadModel"), @@ -144,7 +142,6 @@ ("roberta", "TFRobertaForMaskedLM"), ("roformer", "TFRoFormerForMaskedLM"), ("speech_to_text", "TFSpeech2TextForConditionalGeneration"), - ("switchtransformers", "TFSwitchTransformersForConditionalGeneration"), ("t5", "TFT5ForConditionalGeneration"), ("tapas", "TFTapasForMaskedLM"), ("transfo-xl", "TFTransfoXLLMHeadModel"), @@ -249,7 +246,6 @@ ("mbart", "TFMBartForConditionalGeneration"), ("mt5", "TFMT5ForConditionalGeneration"), ("pegasus", "TFPegasusForConditionalGeneration"), - ("switchtransformers", "TFSwitchTransformersForConditionalGeneration"), ("t5", "TFT5ForConditionalGeneration"), ] ) diff --git a/src/transformers/models/switchtransformers/configuration_switchtransformers.py b/src/transformers/models/switchtransformers/configuration_switchtransformers.py index 2acd5a177bfd5..c2c6f9c8aa32c 100644 --- a/src/transformers/models/switchtransformers/configuration_switchtransformers.py +++ b/src/transformers/models/switchtransformers/configuration_switchtransformers.py @@ -32,7 +32,7 @@ class SwitchTransformersConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`SwitchTransformersModel`] or a - [`TFSwitchTransformersModel`]. It is used to instantiate a SwitchTransformers model according to the specified + [`FlaxSwitchTransformersModel`]. It is used to instantiate a SwitchTransformers model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the SwitchTransformers [ybelkada/switchtransformers-base](https://huggingface.co/ybelkada/switchtransformers-base) architecture. @@ -44,7 +44,7 @@ class SwitchTransformersConfig(PretrainedConfig): vocab_size (`int`, *optional*, defaults to 32128): Vocabulary size of the SwitchTransformers model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`SwitchTransformersModel`] or - [`TFSwitchTransformersModel`]. + [`FlaxSwitchTransformersModel`]. d_model (`int`, *optional*, defaults to 512): Size of the encoder layers and the pooler layer. d_kv (`int`, *optional*, defaults to 64): diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index 06df4dbd872c4..3acc7804687df 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -2200,37 +2200,6 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["tf"]) -TF_SWITCHTRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST = None - - -class TFSwitchTransformersEncoderModel(metaclass=DummyObject): - _backends = ["tf"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["tf"]) - - -class TFSwitchTransformersForConditionalGeneration(metaclass=DummyObject): - _backends = ["tf"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["tf"]) - - -class TFSwitchTransformersModel(metaclass=DummyObject): - _backends = ["tf"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["tf"]) - - -class TFSwitchTransformersPreTrainedModel(metaclass=DummyObject): - _backends = ["tf"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["tf"]) - - TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST = None From 85c34e9c37768486ddcf375c7808065159fc6588 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 4 Oct 2022 18:49:16 +0100 Subject: [PATCH 006/102] clean up --- src/transformers/__init__.py | 2 -- src/transformers/models/switchtransformers/__init__.py | 2 -- ...witchtransformers_original_tf_checkpoint_to_pytorch.py | 8 ++------ src/transformers/utils/dummy_pt_objects.py | 4 ---- 4 files changed, 2 insertions(+), 14 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index ebc2792e6d1e2..1eedbf264995e 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1923,7 +1923,6 @@ "SwitchTransformersForConditionalGeneration", "SwitchTransformersModel", "SwitchTransformersPreTrainedModel", - "load_tf_weights_in_switchtransformers", ] ) _import_structure["models.trajectory_transformer"].extend( @@ -4616,7 +4615,6 @@ SwitchTransformersForConditionalGeneration, SwitchTransformersModel, SwitchTransformersPreTrainedModel, - load_tf_weights_in_switchtransformers, ) from .models.t5 import ( T5_PRETRAINED_MODEL_ARCHIVE_LIST, diff --git a/src/transformers/models/switchtransformers/__init__.py b/src/transformers/models/switchtransformers/__init__.py index 44e99f74e80d3..ccc257fc917bb 100644 --- a/src/transformers/models/switchtransformers/__init__.py +++ b/src/transformers/models/switchtransformers/__init__.py @@ -65,7 +65,6 @@ "SwitchTransformersForConditionalGeneration", "SwitchTransformersModel", "SwitchTransformersPreTrainedModel", - "load_tf_weights_in_switchtransformers", ] @@ -118,7 +117,6 @@ SwitchTransformersForConditionalGeneration, SwitchTransformersModel, SwitchTransformersPreTrainedModel, - load_tf_weights_in_switchtransformers, ) try: diff --git a/src/transformers/models/switchtransformers/convert_switchtransformers_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/switchtransformers/convert_switchtransformers_original_tf_checkpoint_to_pytorch.py index 47081a2a0c578..1bf81cf11de8c 100644 --- a/src/transformers/models/switchtransformers/convert_switchtransformers_original_tf_checkpoint_to_pytorch.py +++ b/src/transformers/models/switchtransformers/convert_switchtransformers_original_tf_checkpoint_to_pytorch.py @@ -17,11 +17,7 @@ import argparse -from transformers import ( - SwitchTransformersConfig, - SwitchTransformersForConditionalGeneration, - load_tf_weights_in_switchtransformers, -) +from transformers import SwitchTransformersConfig, SwitchTransformersForConditionalGeneration from transformers.utils import logging @@ -35,7 +31,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du model = SwitchTransformersForConditionalGeneration(config) # Load weights from tf checkpoint - load_tf_weights_in_switchtransformers(model, config, tf_checkpoint_path) + # load_tf_weights_in_switchtransformers(model, config, tf_checkpoint_path) # Save pytorch-model print(f"Save PyTorch model to {pytorch_dump_path}") diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 117974de31ba2..662a0bee6183c 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -4807,10 +4807,6 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) -def load_tf_weights_in_switchtransformers(*args, **kwargs): - requires_backends(load_tf_weights_in_switchtransformers, ["torch"]) - - T5_PRETRAINED_MODEL_ARCHIVE_LIST = None From d5d092ca838afc189cd82f3ad35e9316a46aadf4 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 5 Oct 2022 00:39:57 +0100 Subject: [PATCH 007/102] v0 routers --- .../configuration_switchtransformers.py | 23 ++ .../models/switchtransformers/router.py | 320 ++++++++++++++++-- .../models/switchtransformers/router_flax.py | 164 +++++++-- .../test_modeling_switchtransformers.py | 6 +- 4 files changed, 466 insertions(+), 47 deletions(-) diff --git a/src/transformers/models/switchtransformers/configuration_switchtransformers.py b/src/transformers/models/switchtransformers/configuration_switchtransformers.py index c2c6f9c8aa32c..8a58bfdf29341 100644 --- a/src/transformers/models/switchtransformers/configuration_switchtransformers.py +++ b/src/transformers/models/switchtransformers/configuration_switchtransformers.py @@ -60,6 +60,16 @@ class SwitchTransformersConfig(PretrainedConfig): Number of attention heads for each attention layer in the Transformer encoder. num_experts (`int`, *optional*, defaults to 8): Number of experts for each SwitchTransformer layer. + router_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the router. + router_jitter_noise (`float`, *optional*, defaults to 0.1): + Amount of noise to add to the router. + router_ignore_padding_tokens (`bool`, *optional*, defaults to `False`): + Whether to ignore padding tokens when routing. + batch_prioritized_routing (`bool`, *optional*, defaults to `False`): + Whether to use batch prioritized routing. + num_selected_experts (`int`, *optional*, defaults to 2): + Number of experts to select for each token. relative_attention_num_buckets (`int`, *optional*, defaults to 32): The number of buckets to use for each attention layer. relative_attention_max_distance (`int`, *optional*, defaults to 128): @@ -91,6 +101,11 @@ def __init__( num_decoder_layers=None, num_heads=8, num_experts=8, + router_bias=False, + router_jitter_noise=0.01, + num_selected_experts=2, + router_ignore_padding_tokens=False, + batch_prioritized_routing=False, relative_attention_num_buckets=32, relative_attention_max_distance=128, dropout_rate=0.1, @@ -113,8 +128,16 @@ def __init__( ) # default = symmetry self.num_heads = num_heads self.num_experts = num_experts + self.router_bias = router_bias + self.router_jitter_noise = router_jitter_noise + + self.router_ignore_padding_tokens = router_ignore_padding_tokens self.relative_attention_num_buckets = relative_attention_num_buckets self.relative_attention_max_distance = relative_attention_max_distance + self.batch_prioritized_routing = batch_prioritized_routing + + self.num_selected_experts = num_selected_experts + self.dropout_rate = dropout_rate self.layer_norm_epsilon = layer_norm_epsilon self.initializer_factor = initializer_factor diff --git a/src/transformers/models/switchtransformers/router.py b/src/transformers/models/switchtransformers/router.py index 23ca739b3f211..f1f35e0834167 100644 --- a/src/transformers/models/switchtransformers/router.py +++ b/src/transformers/models/switchtransformers/router.py @@ -13,11 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass -from typing import Any, Tuple +from typing import Any, Optional, Tuple import torch import torch.nn as nn +from transformers.models.switchtransformers.configuration_switchtransformers import SwitchTransformersConfig + # Output classes @@ -69,7 +71,7 @@ class RouterMask: # Router loss -def _router_z_loss(router_logits: torch.Tensor) -> float: +def router_z_loss_func(router_logits: torch.Tensor) -> float: r""" Compute router z-loss implemented in PyTorch. @@ -89,7 +91,7 @@ def _router_z_loss(router_logits: torch.Tensor) -> float: return torch.sum(z_loss) / (num_groups * tokens_per_group) -def _load_balancing_loss(router_probs: torch.Tensor, expert_indices: torch.Tensor) -> float: +def load_balancing_loss_fun(router_probs: torch.Tensor, expert_indices: torch.Tensor) -> float: r""" Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. @@ -175,14 +177,22 @@ def _compute_router_probabilities( token_inputs = token_inputs.to(torch.float32) if apply_jitter and self.jitter_noise > 0: - token_inputs *= torch.random.uniform( - token_inputs.shape, token_inputs.dtype, minval=1.0 - self.jitter_noise, maxval=1.0 + self.jitter_noise - ) + # Get the lower and upper bound of the uniform distribution + # Adapted from: https://stackoverflow.com/questions/44328530/how-to-get-a-uniform-distribution-in-a-range-r1-r2-in-pytorch + distrib_lower_bound = 1.0 - self.jitter_noise + distrib_upper_bound = 1.0 + self.jitter_noise + + uniform_distrib = ( + torch.rand(token_inputs.shape) * (distrib_lower_bound - distrib_upper_bound) + ) + distrib_upper_bound + + # Multiply the token inputs by the uniform distribution - adding some noise + token_inputs *= uniform_distrib # Shape: [num_groups, tokens_per_group, num_experts] - router_logits = self.router_weights(token_inputs, num_experts) + router_logits = self.router_weights(token_inputs) - router_probabilities = torch.nn.softmax(router_logits, axis=-1) + router_probabilities = torch.nn.Softmax(dim=-1)(router_logits) return router_probabilities, router_logits @@ -199,19 +209,281 @@ def forward(self, token_inputs: torch.Tensor, expert_capacity: int, apply_jitter router_probs, router_logits = self._compute_router_probabilities(token_inputs, self.num_experts, apply_jitter) # Flax code for reference - # if self.ignore_padding_tokens: - # # To identify non-padding tokens, we rely on the fact that padding tokens - # # in the inputs have already been masked in the default T5 architecture. - # # See - # # https://github.com/google/flaxformer/blob/9712a16/flaxformer/architectures/t5/t5_architecture.py#L315 - # # and - # # https://github.com/google/flaxformer/blob/9712a16/flaxformer/architectures/t5/t5_architecture.py#L603. - # padding_mask = jnp.torch.Tensor((jnp.sum(jnp.abs(token_inputs), axis=-1) > 0), dtype=token_inputs.dtype) - # router_logits *= jnp.expand_dims(padding_mask, axis=-1) - # else: - # padding_mask = None - - # instructions = self._compute_routing_instructions(router_probs, padding_mask, expert_capacity) - - # return instructions.replace(router_z_loss=_router_z_loss(router_logits)) - pass + if self.ignore_padding_tokens: + # To identify non-padding tokens, we rely on the fact that padding tokens + # in the inputs have already been masked in the default T5 architecture. + # See + # https://github.com/google/flaxformer/blob/9712a16/flaxformer/architectures/t5/t5_architecture.py#L315 + # and + # https://github.com/google/flaxformer/blob/9712a16/flaxformer/architectures/t5/t5_architecture.py#L603. + padding_mask = torch.Tensor((jnp.sum(jnp.abs(token_inputs), axis=-1) > 0), dtype=token_inputs.dtype) + router_logits *= jnp.expand_dims(padding_mask, axis=-1) + else: + padding_mask = None + + instructions = self._compute_routing_instructions(router_probs, padding_mask, expert_capacity) + + return instructions.replace(router_z_loss=router_z_loss_func(router_logits)) + + +class MaskedRouter(Router): + """Abstract base router class for masked matmul dispatch routers. + + MaskedRouter(s) return RouterMask(s) containing a dispatch mask and combine array for sending and receiving (via + masked matmuls) inputs and outputs to and from experts. + + Routing using masked matmuls is generally faster than scatter-based routing on TPUs. + """ + + def _compute_routing_instructions( + self, router_probs: torch.Tensor, padding_mask: Optional[torch.Tensor], expert_capacity: int + ) -> RouterMask: + """Computes masks for the top-k experts per token. + + Args: + router_probs: [num_groups, tokens_per_group, num_experts] + probabilities used to determine the routing of tokens to the experts. + padding_mask: [num_groups, tokens_per_group] padding logit mask + used to identify padding tokens that should be ignored by the router. + expert_capacity: Each group will send this many tokens to each expert. + + Returns: + Router mask arrays. + """ + raise NotImplementedError("MaskedRouter is an abstract class that should be subclassed.") + + +class ExpertsChooseMaskedRouter(MaskedRouter): + """Masked matmul router using experts choose tokens assignment. + + This router uses the same mechanism as in Mixture-of-Experts with Expert Choice (https://arxiv.org/abs/2202.09368): + each expert selects its top expert_capacity tokens. An individual token may be processed by multiple experts or + none at all. + + Note: "experts choose routing" should not be used in decoder blocks because it breaks the autoregressive behavior + -- the model will learn to cheat by using future token information to improve current token predictions. + """ + + def _compute_routing_instructions( + self, router_probs: torch.Tensor, padding_mask: Optional[torch.Tensor], expert_capacity: int + ) -> RouterMask: + """Computes masks for the highest probability token per expert. + + Args: + router_probs: [num_groups, tokens_per_group, num_experts] + probabilities used to determine the routing of tokens to the experts. + padding_mask: [num_groups, tokens_per_group] padding logit mask + used to identify padding tokens that should be down-weighted by the router. + expert_capacity: Each group will send this many tokens to each expert. + + Returns: + Dispatch and combine arrays for routing with masked matmuls. + """ + tokens_per_group = router_probs.shape[1] + + if padding_mask is not None: + # Because experts choose tokens, we mask probabilities corresponding to + # tokens before the top-k operation. Note that, unlike for masked-based + # tokens-choose routing, the experts here may still choose to select the + # (down-weighted) padding tokens. + router_probs *= jnp.expand_dims(padding_mask, axis=-1) + + # vmap over group dimension. + router_probs_t = jax.vmap(lambda m: m.transpose())(router_probs) + + # Top expert_capacity router probability and corresponding token indices for + # each expert. Shapes: [num_groups, num_experts, expert_capacity]. + expert_gate, expert_index = _top_k(router_probs_t, k=expert_capacity) + + # Convert to one-hot mask of expert indices for each token in each group. + # Shape: [num_groups, num_experts, expert_capacity, tokens_per_group]. + dispatch_mask = jax.nn.one_hot(expert_index, tokens_per_group, dtype=jnp.int32) + + # Move axes to conform with shape expected by MoeLayer API. + # Shape: [num_groups, tokens_per_group, num_experts, expert_capacity] + dispatch_mask = jnp.moveaxis(dispatch_mask, 3, 1) + + # The combine array will be used for combining expert outputs, scaled by the + # router probabilities. Shape: [num_groups, num_experts, tokens_per_group, + # expert_capacity]. + combine_array = jnp.einsum( + "...ec,...tec->...tec", expert_gate, dispatch_mask, precision=jax.lax.Precision.DEFAULT + ) + + # Return to default dtype now that router computation is complete. + combine_array = jax.lax.convert_element_type(combine_array, self.dtype) + + # Each expert is choosing tokens until it reaches full capacity, so we don't + # need an auxiliary loading balancing loss for expert choice routing. + auxiliary_loss = 0.0 + + return RouterMask(dispatch_mask, combine_array, auxiliary_loss) + + +class TokensChooseMaskedRouter(MaskedRouter): + """ + Masked matmul router using tokens choose top-k experts assignment. + + This router uses the same mechanism as in Switch Transformer + (https://arxiv.org/abs/2101.03961) and V-MoE + (https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are + sorted by router_probs and then routed to their choice of expert until the + expert's expert_capacity is reached. There is no guarantee that each token is + processed by an expert, or that each expert receives at least one token. + + Attributes: + num_selected_experts: Maximum number of experts to which each token is + routed. Tokens may be routed to fewer experts if particular experts are + oversubscribed / reach capacity. + batch_prioritized_routing: Whether or not to use Batch Prioritized Routing + (BPR), originally introduced in V-MoE (https://arxiv.org/abs/2106.05974). + With BPR, we prioritize routing those top-k tokens with the highest + router probability, rather than simply using each tokens left-to-right + ordering in the batch. This prioritization is important because the + experts have limited capacity. + """ + + def __init__(self, config, **kwargs): + super().__init__(config, **kwargs) + self.num_selected_experts = config.num_selected_experts + self.batch_prioritized_routing = config.batch_prioritized_routing + + def _compute_routing_instructions( + self, router_probs: torch.Tensor, padding_mask: Optional[torch.Tensor], expert_capacity: int + ) -> RouterMask: + """ + Computes masks for the top-k experts per token. + + Args: + router_probs: [num_groups, tokens_per_group, num_experts] + probabilities used to determine the routing of tokens to the experts. + padding_mask: [num_groups, tokens_per_group] padding logit mask + used to identify padding tokens that should be ignored by the router. + expert_capacity: Each group will send this many tokens to each expert. + + Returns: + Dispatch and combine arrays for routing with masked matmuls. + """ + num_groups, _, num_experts = router_probs.shape + + # Top-k router probability and corresponding expert indices for each token. + # Shape: [num_groups, tokens_per_group, num_selected_experts]. + expert_gate, expert_index = torch.topk(router_probs, k=self.num_selected_experts) + + if padding_mask is not None: + # Mask applied to gate. Exclude choices corresponding to padding tokens. + gate_mask = jnp.expand_dims(padding_mask, axis=-1) + expert_gate *= gate_mask + + # Set `expert_index` elements corresponding to padding to negative + # numbers. Negative `expert_index` elements will ultimately be dropped in + # the one_hot conversion to the `expert_mask`. + # First convert nonzero padding elements to negative values. + expert_index *= 2 * gate_mask - 1.0 + # Handle zero padding elements by negatively shifting all padding. + expert_index += jnp.repeat(gate_mask - 1.0, self.num_selected_experts, axis=-1) + + # To correctly compute load balancing loss, we also mask out probs. + router_probs *= gate_mask + + auxiliary_loss = load_balancing_loss_fun(router_probs, expert_index) + + if self.batch_prioritized_routing: + # Sort tokens according to their routing probability per group, so that + # the highest probability tokens are routed first. + permutation = jnp.argsort(-expert_gate[..., 0], axis=-1) + # Shape: [num_groups, tokens_per_group, num_selected_experts] + expert_index = _take_along_axis(expert_index, jnp.expand_dims(permutation, axis=-1), axis=-2) + + # Make num_selected_experts the leading axis to ensure that top-1 choices + # have priority over top-2 choices, which have priority over top-3 choices, + # etc. + expert_index = expert_index.permute((0, 2, 1)) + # Shape: [num_groups, num_selected_experts * tokens_per_group] + expert_index = expert_index.reshape(num_groups, -1) + + # Create mask out of indices. + # Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts]. + expert_mask = torch.nn.functional.one_hot(expert_index, num_experts) + + # Experts have a fixed capacity that we cannot exceed. A token's priority + # within the expert's buffer is given by the masked, cumulative capacity of + # its target expert. + # Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts]. + token_priority = torch.cumsum(expert_mask, axis=1) * expert_mask - 1.0 + # Shape: [num_groups, num_selected_experts, tokens_per_group, num_experts]. + token_priority = token_priority.reshape((num_groups, self.num_selected_experts, -1, num_experts)) + # Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts]. + token_priority = token_priority.permute((0, 2, 1, 3)) + # For each token, across all selected experts, select the only non-negative + # (unmasked) priority. Now, for group G routing to expert E, token T has + # non-negative priority (i.e. token_priority[G,T,E] >= 0) if and only if E + # is its targeted expert. + # Shape: [num_groups, tokens_per_group, num_experts]. + token_priority = torch.max(token_priority, axis=2).values + + if self.batch_prioritized_routing: + # Place token priorities in original ordering of tokens. + inv_permutation = jnp.argsort(permutation, axis=-1) + token_priority = _take_along_axis(token_priority, jnp.expand_dims(inv_permutation, axis=-1), axis=-2) + + # Token T can only be routed to expert E if its priority is positive and + # less than the expert capacity. One-hot matrix will ignore indices outside + # the range [0, expert_capacity). + # Shape: [num_groups, tokens_per_group, num_experts, expert_capacity]. + token_priority_mask = token_priority > 0 + token_priority = token_priority * token_priority_mask + + dispatch_mask = torch.nn.functional.one_hot(token_priority.long(), expert_capacity + 1)[..., 1:] + + # The combine array will be used for combining expert outputs, scaled by the + # router probabilities. Shape: [num_groups, tokens_per_group, num_experts, + # expert_capacity]. + combine_array = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask) + + # Return to default dtype now that router computation is complete. + combine_array = combine_array.to(torch.float32) + + return RouterMask(dispatch_mask, combine_array, auxiliary_loss) + + +num_groups = 2 +tokens_per_group = 3 +hidden_dim = 4 +num_experts = 2 +num_selected_experts = 1 # Switch routing case +expert_capacity = 1 # Total capacity = 2*2*1 = 4 < num_tokens +jitter_noise = 0.0 + +input_tokens = torch.Tensor( + [ + [ + [0.6433916, 0.18188512, 0.02240455, 0.563781], + [0.5526401, 0.0958724, 0.34253013, 0.03644359], + [0.08744538, 0.7909105, 0.35205448, 0.53364205], + ], + [ + [0.02900076, 0.4168595, 0.5802449, 0.91486526], + [0.27414513, 0.14991808, 0.9383501, 0.5209162], + [0.51207185, 0.90618336, 0.7309413, 0.95533276], + ], + ] +) + +config = SwitchTransformersConfig( + num_experts=num_experts, + hidden_size=hidden_dim, + num_selected_experts=num_selected_experts, + router_jitter_noise=jitter_noise, + expert_capacity=expert_capacity, + batch_prioritized_routing=False, +) +model = TokensChooseMaskedRouter(config) + +model.router_weights.weight = torch.nn.Parameter( + torch.Tensor( + [[0.02008116, 0.00620062], [-0.00811031, -0.00031623], [-0.03542127, 0.02703803], [0.02335377, -0.02971946]], + ).t() +) + +model(input_tokens, expert_capacity=expert_capacity) diff --git a/src/transformers/models/switchtransformers/router_flax.py b/src/transformers/models/switchtransformers/router_flax.py index d1060279716e0..3c37d3e938627 100644 --- a/src/transformers/models/switchtransformers/router_flax.py +++ b/src/transformers/models/switchtransformers/router_flax.py @@ -204,19 +204,19 @@ def __call__(self, token_inputs: Array, num_experts: int) -> Array: Router logits with shape [num_groups, group_size, num_experts]. """ # Flax code for reference - # return dense.DenseGeneral( - # features=num_experts, - # axis=self.axis, - # use_bias=self.use_bias, - # dtype=self.dtype, - # kernel_init=self.kernel_init, - # bias_init=self.bias_init, - # precision=self.precision, - # kernel_axis_names=self.kernel_axis_names, - # reshape_kernel=self.reshape_kernel, - # name="w", - # )(token_inputs) - pass + return nn.Dense( + features=num_experts, + axis=self.axis, + use_bias=self.use_bias, + dtype=self.dtype, + kernel_init=self.kernel_init, + bias_init=self.bias_init, + precision=self.precision, + kernel_axis_names=self.kernel_axis_names, + reshape_kernel=self.reshape_kernel, + name="w", + )(token_inputs) + # pass class Router(nn.Module): @@ -670,6 +670,130 @@ def _compute_routing_instructions( return RouterMask(dispatch_mask, combine_array, auxiliary_loss) +class TokensChooseMaskedRouter(MaskedRouter): + """ + Masked matmul router using tokens choose top-k experts assignment. + + This router uses the same mechanism as in Switch Transformer + (https://arxiv.org/abs/2101.03961) and V-MoE + (https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are + sorted by router_probs and then routed to their choice of expert until the + expert's expert_capacity is reached. There is no guarantee that each token is + processed by an expert, or that each expert receives at least one token. + + Attributes: + num_selected_experts: Maximum number of experts to which each token is + routed. Tokens may be routed to fewer experts if particular experts are + oversubscribed / reach capacity. + batch_prioritized_routing: Whether or not to use Batch Prioritized Routing + (BPR), originally introduced in V-MoE (https://arxiv.org/abs/2106.05974). + With BPR, we prioritize routing those top-k tokens with the highest + router probability, rather than simply using each tokens left-to-right + ordering in the batch. This prioritization is important because the + experts have limited capacity. + """ + + num_selected_experts: int + batch_prioritized_routing: bool + + def _compute_routing_instructions( + self, router_probs: Array, padding_mask: Optional[Array], expert_capacity: int + ) -> RouterMask: + """ + Computes masks for the top-k experts per token. + + Args: + router_probs: [num_groups, tokens_per_group, num_experts] + probabilities used to determine the routing of tokens to the experts. + padding_mask: [num_groups, tokens_per_group] padding logit mask + used to identify padding tokens that should be ignored by the router. + expert_capacity: Each group will send this many tokens to each expert. + + Returns: + Dispatch and combine arrays for routing with masked matmuls. + """ + num_groups, _, num_experts = router_probs.shape + + # Top-k router probability and corresponding expert indices for each token. + # Shape: [num_groups, tokens_per_group, num_selected_experts]. + expert_gate, expert_index = _top_k(router_probs, k=self.num_selected_experts) + + if padding_mask is not None: + # Mask applied to gate. Exclude choices corresponding to padding tokens. + gate_mask = jnp.expand_dims(padding_mask, axis=-1) + expert_gate *= gate_mask + + # Set `expert_index` elements corresponding to padding to negative + # numbers. Negative `expert_index` elements will ultimately be dropped in + # the one_hot conversion to the `expert_mask`. + # First convert nonzero padding elements to negative values. + expert_index *= 2 * gate_mask - 1.0 + # Handle zero padding elements by negatively shifting all padding. + expert_index += jnp.repeat(gate_mask - 1.0, self.num_selected_experts, axis=-1) + + # To correctly compute load balancing loss, we also mask out probs. + router_probs *= gate_mask + + auxiliary_loss = _load_balancing_loss(router_probs, expert_index) + + if self.batch_prioritized_routing: + # Sort tokens according to their routing probability per group, so that + # the highest probability tokens are routed first. + permutation = jnp.argsort(-expert_gate[..., 0], axis=-1) + # Shape: [num_groups, tokens_per_group, num_selected_experts] + expert_index = _take_along_axis(expert_index, jnp.expand_dims(permutation, axis=-1), axis=-2) + + # Make num_selected_experts the leading axis to ensure that top-1 choices + # have priority over top-2 choices, which have priority over top-3 choices, + # etc. + expert_index = jnp.swapaxes(expert_index, 1, 2) + # Shape: [num_groups, num_selected_experts * tokens_per_group] + expert_index = expert_index.reshape(num_groups, -1) + + # Create mask out of indices. + # Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts]. + expert_mask = jax.nn.one_hot(expert_index, num_experts, dtype=jnp.int32) + + # Experts have a fixed capacity that we cannot exceed. A token's priority + # within the expert's buffer is given by the masked, cumulative capacity of + # its target expert. + # Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts]. + token_priority = jnp.cumsum(expert_mask, axis=1) * expert_mask - 1.0 + # Shape: [num_groups, num_selected_experts, tokens_per_group, num_experts]. + token_priority = token_priority.reshape((num_groups, self.num_selected_experts, -1, num_experts)) + # Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts]. + token_priority = jnp.swapaxes(token_priority, 1, 2) + # For each token, across all selected experts, select the only non-negative + # (unmasked) priority. Now, for group G routing to expert E, token T has + # non-negative priority (i.e. token_priority[G,T,E] >= 0) if and only if E + # is its targeted expert. + # Shape: [num_groups, tokens_per_group, num_experts]. + token_priority = jnp.max(token_priority, axis=2) + + if self.batch_prioritized_routing: + # Place token priorities in original ordering of tokens. + inv_permutation = jnp.argsort(permutation, axis=-1) + token_priority = _take_along_axis(token_priority, jnp.expand_dims(inv_permutation, axis=-1), axis=-2) + + # Token T can only be routed to expert E if its priority is positive and + # less than the expert capacity. One-hot matrix will ignore indices outside + # the range [0, expert_capacity). + # Shape: [num_groups, tokens_per_group, num_experts, expert_capacity]. + dispatch_mask = jax.nn.one_hot(token_priority, expert_capacity, dtype=jnp.bool_) + + # The combine array will be used for combining expert outputs, scaled by the + # router probabilities. Shape: [num_groups, tokens_per_group, num_experts, + # expert_capacity]. + combine_array = jnp.einsum( + "...te,...tec->...tec", router_probs, dispatch_mask, precision=jax.lax.Precision.DEFAULT + ) + + # Return to default dtype now that router computation is complete. + combine_array = jax.lax.convert_element_type(combine_array, self.dtype) + + return RouterMask(dispatch_mask, combine_array, auxiliary_loss) + + def _load_balancing_loss(router_probs: Array, expert_indices: Array) -> float: """Computes auxiliary load balancing loss as in Switch Transformer. @@ -719,12 +843,12 @@ def _router_z_loss(router_logits: Array) -> float: return jnp.sum(z_loss, dtype=jnp.float32) / (num_groups * tokens_per_group) -num_tokens = 5 -num_experts = 2 -num_selected_experts = 1 -rng = jax.random.PRNGKey(0) +# num_tokens = 5 +# num_experts = 2 +# num_selected_experts = 1 +# rng = jax.random.PRNGKey(0) -router_probs = jax.random.uniform(rng, (num_tokens, num_experts), minval=0, maxval=1) -expert_indices = jax.random.randint(rng, (num_tokens, num_selected_experts), minval=0, maxval=2) +# router_probs = jax.random.uniform(rng, (num_tokens, num_experts), minval=0, maxval=1) +# expert_indices = jax.random.randint(rng, (num_tokens, num_selected_experts), minval=0, maxval=2) -loss = _load_balancing_loss(router_probs, expert_indices) +# loss = _load_balancing_loss(router_probs, expert_indices) diff --git a/tests/models/switchtransformers/test_modeling_switchtransformers.py b/tests/models/switchtransformers/test_modeling_switchtransformers.py index 5e58f9d07f943..74c306f5621a0 100644 --- a/tests/models/switchtransformers/test_modeling_switchtransformers.py +++ b/tests/models/switchtransformers/test_modeling_switchtransformers.py @@ -37,7 +37,7 @@ from transformers.models.switchtransformers.modeling_switchtransformers import ( SWITCHTRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST, ) - from transformers.models.switchtransformers.router import _load_balancing_loss, _router_z_loss + from transformers.models.switchtransformers.router import load_balancing_loss_func, router_z_loss_func class SwitchTransformersModelTester: @@ -886,7 +886,7 @@ def test_equivalency_balancy_loss(self): expert_indices = torch.Tensor([[0], [1], [1], [0], [0]]).to(torch.int32) - loss = _load_balancing_loss(router_probs, expert_indices) + loss = load_balancing_loss_func(router_probs, expert_indices) self.assertAlmostEqual(loss.item(), 0.8741045, places=5) def test_equivalency_router_z_loss(self): @@ -915,5 +915,5 @@ def test_equivalency_router_z_loss(self): ] ) - loss = _router_z_loss(logits) + loss = router_z_loss_func(logits) self.assertAlmostEqual(loss.item(), 13.786719, places=5) From 62d34bdef7ed0481bdf1a760b5e722cad9bddf59 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 6 Oct 2022 17:23:28 +0100 Subject: [PATCH 008/102] added more router - Implemented `ExpertsChooseMaskedRouter` - added tests - 2 more routers to implement --- .../models/switchtransformers/router.py | 227 ++++++++++++------ .../models/switchtransformers/router_flax.py | 139 +---------- .../test_modeling_switchtransformers.py | 150 +++++++++++- 3 files changed, 306 insertions(+), 210 deletions(-) diff --git a/src/transformers/models/switchtransformers/router.py b/src/transformers/models/switchtransformers/router.py index f1f35e0834167..f87273c6497b8 100644 --- a/src/transformers/models/switchtransformers/router.py +++ b/src/transformers/models/switchtransformers/router.py @@ -12,20 +12,54 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass +from dataclasses import dataclass, replace from typing import Any, Optional, Tuple import torch import torch.nn as nn -from transformers.models.switchtransformers.configuration_switchtransformers import SwitchTransformersConfig +# from transformers.models.switchtransformers.configuration_switchtransformers import SwitchTransformersConfig -# Output classes +# Output classes RouterOutput = Any +def _jax_one_hot(tensor, num_classes, axis=-1, dtype=torch.bool): + r""" + This function mimics the behavior of jax.nn.functional.one_hot in PyTorch. It takes a tensor of indices, the number + of desired classes and the axis to apply the one-hot encoding. If any value is outside the range [0, num_classes), + it will be set to zeros. + """ + if tensor.is_floating_point(): + raise "Input tensor for one hot encoding must be an `int32` or `int64`" + + if axis >= len(tensor.shape): + raise "Axis is out of bounds" + + if axis == -1: + axis = len(tensor.shape) + elif axis < -1: + raise "Axis must be greater than -1" + else: + axis = axis + 1 + + # Get the final output shape + output_shape = list(tensor.shape) + output_shape.insert(axis, num_classes) + + # Create an empty output of zeros + out = torch.zeros(tuple(output_shape), dtype=dtype) + + # Mask out the places where it is outside the range [0, num_classes) + # kudos to GitHub copilot for this line + mask = (tensor >= 0) & (tensor < num_classes) + out[mask, tensor[mask]] = 1 + + return out + + @dataclass class RouterIndices: r""" @@ -216,14 +250,14 @@ def forward(self, token_inputs: torch.Tensor, expert_capacity: int, apply_jitter # https://github.com/google/flaxformer/blob/9712a16/flaxformer/architectures/t5/t5_architecture.py#L315 # and # https://github.com/google/flaxformer/blob/9712a16/flaxformer/architectures/t5/t5_architecture.py#L603. - padding_mask = torch.Tensor((jnp.sum(jnp.abs(token_inputs), axis=-1) > 0), dtype=token_inputs.dtype) - router_logits *= jnp.expand_dims(padding_mask, axis=-1) + padding_mask = torch.Tensor((torch.sum(torch.abs(token_inputs), axis=-1) > 0)).to(token_inputs.dtype) + router_logits *= padding_mask.unsqueeze(-1) else: padding_mask = None instructions = self._compute_routing_instructions(router_probs, padding_mask, expert_capacity) - return instructions.replace(router_z_loss=router_z_loss_func(router_logits)) + return replace(instructions, router_z_loss=router_z_loss_func(router_logits)) class MaskedRouter(Router): @@ -280,38 +314,38 @@ def _compute_routing_instructions( Dispatch and combine arrays for routing with masked matmuls. """ tokens_per_group = router_probs.shape[1] + default_dtype = router_probs.dtype if padding_mask is not None: # Because experts choose tokens, we mask probabilities corresponding to # tokens before the top-k operation. Note that, unlike for masked-based # tokens-choose routing, the experts here may still choose to select the # (down-weighted) padding tokens. - router_probs *= jnp.expand_dims(padding_mask, axis=-1) + router_probs *= padding_mask.unsqueeze(-1) # vmap over group dimension. - router_probs_t = jax.vmap(lambda m: m.transpose())(router_probs) + # router_probs_t = router_probs.t() + router_probs_t = router_probs.permute(0, 2, 1) # Top expert_capacity router probability and corresponding token indices for # each expert. Shapes: [num_groups, num_experts, expert_capacity]. - expert_gate, expert_index = _top_k(router_probs_t, k=expert_capacity) + expert_gate, expert_index = torch.topk(router_probs_t, k=expert_capacity) # Convert to one-hot mask of expert indices for each token in each group. # Shape: [num_groups, num_experts, expert_capacity, tokens_per_group]. - dispatch_mask = jax.nn.one_hot(expert_index, tokens_per_group, dtype=jnp.int32) + dispatch_mask = _jax_one_hot(expert_index, tokens_per_group, dtype=torch.int32) # Move axes to conform with shape expected by MoeLayer API. # Shape: [num_groups, tokens_per_group, num_experts, expert_capacity] - dispatch_mask = jnp.moveaxis(dispatch_mask, 3, 1) + dispatch_mask = torch.moveaxis(dispatch_mask, 3, 1) # The combine array will be used for combining expert outputs, scaled by the # router probabilities. Shape: [num_groups, num_experts, tokens_per_group, # expert_capacity]. - combine_array = jnp.einsum( - "...ec,...tec->...tec", expert_gate, dispatch_mask, precision=jax.lax.Precision.DEFAULT - ) + combine_array = torch.einsum("...ec,...tec->...tec", expert_gate, dispatch_mask) # Return to default dtype now that router computation is complete. - combine_array = jax.lax.convert_element_type(combine_array, self.dtype) + combine_array = combine_array.to(default_dtype) # Each expert is choosing tokens until it reaches full capacity, so we don't # need an auxiliary loading balancing loss for expert choice routing. @@ -324,23 +358,18 @@ class TokensChooseMaskedRouter(MaskedRouter): """ Masked matmul router using tokens choose top-k experts assignment. - This router uses the same mechanism as in Switch Transformer - (https://arxiv.org/abs/2101.03961) and V-MoE - (https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are - sorted by router_probs and then routed to their choice of expert until the - expert's expert_capacity is reached. There is no guarantee that each token is - processed by an expert, or that each expert receives at least one token. + This router uses the same mechanism as in Switch Transformer (https://arxiv.org/abs/2101.03961) and V-MoE + (https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are sorted by router_probs and then + routed to their choice of expert until the expert's expert_capacity is reached. There is no guarantee that each + token is processed by an expert, or that each expert receives at least one token. Attributes: num_selected_experts: Maximum number of experts to which each token is - routed. Tokens may be routed to fewer experts if particular experts are - oversubscribed / reach capacity. + routed. Tokens may be routed to fewer experts if particular experts are oversubscribed / reach capacity. batch_prioritized_routing: Whether or not to use Batch Prioritized Routing - (BPR), originally introduced in V-MoE (https://arxiv.org/abs/2106.05974). - With BPR, we prioritize routing those top-k tokens with the highest - router probability, rather than simply using each tokens left-to-right - ordering in the batch. This prioritization is important because the - experts have limited capacity. + (BPR), originally introduced in V-MoE (https://arxiv.org/abs/2106.05974). With BPR, we prioritize routing those + top-k tokens with the highest router probability, rather than simply using each tokens left-to-right ordering + in the batch. This prioritization is important because the experts have limited capacity. """ def __init__(self, config, **kwargs): @@ -372,16 +401,16 @@ def _compute_routing_instructions( if padding_mask is not None: # Mask applied to gate. Exclude choices corresponding to padding tokens. - gate_mask = jnp.expand_dims(padding_mask, axis=-1) + gate_mask = padding_mask.unsqueeze(-1).to(expert_index.dtype) expert_gate *= gate_mask # Set `expert_index` elements corresponding to padding to negative # numbers. Negative `expert_index` elements will ultimately be dropped in # the one_hot conversion to the `expert_mask`. # First convert nonzero padding elements to negative values. - expert_index *= 2 * gate_mask - 1.0 + expert_index *= (2 * gate_mask) - 1 # Handle zero padding elements by negatively shifting all padding. - expert_index += jnp.repeat(gate_mask - 1.0, self.num_selected_experts, axis=-1) + expert_index += (gate_mask - 1).repeat(1, 1, self.num_selected_experts) # To correctly compute load balancing loss, we also mask out probs. router_probs *= gate_mask @@ -391,9 +420,9 @@ def _compute_routing_instructions( if self.batch_prioritized_routing: # Sort tokens according to their routing probability per group, so that # the highest probability tokens are routed first. - permutation = jnp.argsort(-expert_gate[..., 0], axis=-1) + permutation = torch.argsort(-expert_gate[..., 0], dim=-1) # Shape: [num_groups, tokens_per_group, num_selected_experts] - expert_index = _take_along_axis(expert_index, jnp.expand_dims(permutation, axis=-1), axis=-2) + expert_index = torch.take_along_dim(expert_index, permutation.unsqueeze(-1), dim=-2) # Make num_selected_experts the leading axis to ensure that top-1 choices # have priority over top-2 choices, which have priority over top-3 choices, @@ -424,22 +453,23 @@ def _compute_routing_instructions( if self.batch_prioritized_routing: # Place token priorities in original ordering of tokens. - inv_permutation = jnp.argsort(permutation, axis=-1) - token_priority = _take_along_axis(token_priority, jnp.expand_dims(inv_permutation, axis=-1), axis=-2) + inv_permutation = torch.argsort(permutation, dim=-1) + token_priority = torch.take_along_dim(token_priority, inv_permutation.unsqueeze(-1), dim=-2) # Token T can only be routed to expert E if its priority is positive and # less than the expert capacity. One-hot matrix will ignore indices outside # the range [0, expert_capacity). # Shape: [num_groups, tokens_per_group, num_experts, expert_capacity]. - token_priority_mask = token_priority > 0 - token_priority = token_priority * token_priority_mask + # token_priority = token_priority * (token_priority > 0) - dispatch_mask = torch.nn.functional.one_hot(token_priority.long(), expert_capacity + 1)[..., 1:] + # dispatch_mask = torch.nn.functional.one_hot(token_priority.long(), expert_capacity + 1)[..., 1:] + dispatch_mask = _jax_one_hot(token_priority.long(), expert_capacity, axis=-1) # The combine array will be used for combining expert outputs, scaled by the # router probabilities. Shape: [num_groups, tokens_per_group, num_experts, # expert_capacity]. combine_array = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask) + # combine_array = torch.einsum("...te,...te->...te", router_probs, dispatch_mask) # Return to default dtype now that router computation is complete. combine_array = combine_array.to(torch.float32) @@ -447,43 +477,84 @@ def _compute_routing_instructions( return RouterMask(dispatch_mask, combine_array, auxiliary_loss) -num_groups = 2 -tokens_per_group = 3 -hidden_dim = 4 -num_experts = 2 -num_selected_experts = 1 # Switch routing case -expert_capacity = 1 # Total capacity = 2*2*1 = 4 < num_tokens -jitter_noise = 0.0 - -input_tokens = torch.Tensor( - [ - [ - [0.6433916, 0.18188512, 0.02240455, 0.563781], - [0.5526401, 0.0958724, 0.34253013, 0.03644359], - [0.08744538, 0.7909105, 0.35205448, 0.53364205], - ], - [ - [0.02900076, 0.4168595, 0.5802449, 0.91486526], - [0.27414513, 0.14991808, 0.9383501, 0.5209162], - [0.51207185, 0.90618336, 0.7309413, 0.95533276], - ], - ] -) - -config = SwitchTransformersConfig( - num_experts=num_experts, - hidden_size=hidden_dim, - num_selected_experts=num_selected_experts, - router_jitter_noise=jitter_noise, - expert_capacity=expert_capacity, - batch_prioritized_routing=False, -) -model = TokensChooseMaskedRouter(config) - -model.router_weights.weight = torch.nn.Parameter( - torch.Tensor( - [[0.02008116, 0.00620062], [-0.00811031, -0.00031623], [-0.03542127, 0.02703803], [0.02335377, -0.02971946]], - ).t() -) - -model(input_tokens, expert_capacity=expert_capacity) +# num_groups = 2 +# tokens_per_group = 4 +# hidden_dim = 3 +# num_experts = 2 +# expert_capacity = 2 # Total capacity = 2*2*1 = 4 < num_tokens +# jitter_noise = 0.0 + +# input_tokens = torch.Tensor( +# [[[0.6433916 , 0.18188512, 0.02240455], +# [0.563781 , 0.5526401 , 0.0958724 ], +# [0.34253013, 0.03644359, 0.08744538], +# [0.7909105 , 0.35205448, 0.53364205]], + +# [[0.02900076, 0.4168595 , 0.5802449 ], +# [0.91486526, 0.27414513, 0.14991808], +# [0.9383501 , 0.5209162 , 0.51207185], +# [0.90618336, 0.7309413 , 0.95533276]]] +# ) + +# config = SwitchTransformersConfig( +# num_experts=num_experts, +# hidden_size=hidden_dim, +# router_jitter_noise=jitter_noise, +# expert_capacity=expert_capacity, +# batch_prioritized_routing=False, +# ) +# # model = TokensChooseMaskedRouter(config) +# model = ExpertsChooseMaskedRouter(config) + +# model.router_weights.weight = torch.nn.Parameter( +# torch.Tensor([[-0.00107201, 0.01544739], +# [-0.0087319 , 0.01314363], +# [ 0.03530733, 0.03709853]]).t() +# ) + +# model(input_tokens, expert_capacity=expert_capacity) + + +# hidden_dim = 4 +# num_experts = 2 +# num_selected_experts = 1 # Switch routing case +# expert_capacity = 1 # Total capacity = 2*2*1 = 4 < num_tokens +# jitter_noise = 0.0 + +# input_tokens = torch.Tensor( +# [ +# [ +# [0.6433916, 0.18188512, 0.02240455, 0.563781], +# [0.5526401, 0.0958724, 0.34253013, 0.03644359], +# [0.08744538, 0.7909105, 0.35205448, 0.53364205], +# ], +# [ +# [0.02900076, 0.4168595, 0.5802449, 0.91486526], +# [0.27414513, 0.14991808, 0.9383501, 0.5209162], +# [0.51207185, 0.90618336, 0.7309413, 0.95533276], +# ], +# ] +# ) + +# config = SwitchTransformersConfig( +# num_experts=num_experts, +# hidden_size=hidden_dim, +# num_selected_experts=num_selected_experts, +# router_jitter_noise=jitter_noise, +# expert_capacity=expert_capacity, +# batch_prioritized_routing=False, +# ) +# model = TokensChooseMaskedRouter(config) + +# model.router_weights.weight = torch.nn.Parameter( +# torch.Tensor( +# [ +# [0.02008116, 0.00620062], +# [-0.00811031, -0.00031623], +# [-0.03542127, 0.02703803], +# [0.02335377, -0.02971946], +# ], +# ).t() +# ) + +# output = model(input_tokens, expert_capacity=expert_capacity) diff --git a/src/transformers/models/switchtransformers/router_flax.py b/src/transformers/models/switchtransformers/router_flax.py index 3c37d3e938627..2a7480bb5ba35 100644 --- a/src/transformers/models/switchtransformers/router_flax.py +++ b/src/transformers/models/switchtransformers/router_flax.py @@ -485,124 +485,6 @@ def _compute_routing_instructions( return RouterIndices(dispatch_indices, combine_weights, auxiliary_loss) -class TokensChooseMaskedRouter(MaskedRouter): - """Masked matmul router using tokens choose top-k experts assignment. - - This router uses the same mechanism as in Switch Transformer (https://arxiv.org/abs/2101.03961) and V-MoE - (https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are sorted by router_probs and then - routed to their choice of expert until the expert's expert_capacity is reached. There is no guarantee that each - token is processed by an expert, or that each expert receives at least one token. - - Attributes: - num_selected_experts: Maximum number of experts to which each token is - routed. Tokens may be routed to fewer experts if particular experts are oversubscribed / reach capacity. - batch_prioritized_routing: Whether or not to use Batch Prioritized Routing - (BPR), originally introduced in V-MoE (https://arxiv.org/abs/2106.05974). - With BPR, we prioritize routing those top-k tokens with the highest router probability, rather than simply - using each tokens left-to-right ordering in the batch. This prioritization is important because the experts - have limited capacity. - """ - - num_selected_experts: int - batch_prioritized_routing: bool - - def _compute_routing_instructions( - self, router_probs: Array, padding_mask: Optional[Array], expert_capacity: int - ) -> RouterMask: - """Computes masks for the top-k experts per token. - - Args: - router_probs: [num_groups, tokens_per_group, num_experts] - probabilities used to determine the routing of tokens to the experts. - padding_mask: [num_groups, tokens_per_group] padding logit mask - used to identify padding tokens that should be ignored by the router. - expert_capacity: Each group will send this many tokens to each expert. - - Returns: - Dispatch and combine arrays for routing with masked matmuls. - """ - num_groups, _, num_experts = router_probs.shape - - # Top-k router probability and corresponding expert indices for each token. - # Shape: [num_groups, tokens_per_group, num_selected_experts]. - expert_gate, expert_index = _top_k(router_probs, k=self.num_selected_experts) - - if padding_mask is not None: - # Mask applied to gate. Exclude choices corresponding to padding tokens. - gate_mask = jnp.expand_dims(padding_mask, axis=-1) - expert_gate *= gate_mask - - # Set `expert_index` elements corresponding to padding to negative - # numbers. Negative `expert_index` elements will ultimately be dropped in - # the one_hot conversion to the `expert_mask`. - # First convert nonzero padding elements to negative values. - expert_index *= 2 * gate_mask - 1.0 - # Handle zero padding elements by negatively shifting all padding. - expert_index += jnp.repeat(gate_mask - 1.0, self.num_selected_experts, axis=-1) - - # To correctly compute load balancing loss, we also mask out probs. - router_probs *= gate_mask - - auxiliary_loss = _load_balancing_loss(router_probs, expert_index) - - if self.batch_prioritized_routing: - # Sort tokens according to their routing probability per group, so that - # the highest probability tokens are routed first. - permutation = jnp.argsort(-expert_gate[..., 0], axis=-1) - # Shape: [num_groups, tokens_per_group, num_selected_experts] - expert_index = _take_along_axis(expert_index, jnp.expand_dims(permutation, axis=-1), axis=-2) - - # Make num_selected_experts the leading axis to ensure that top-1 choices - # have priority over top-2 choices, which have priority over top-3 choices, - # etc. - expert_index = jnp.swapaxes(expert_index, 1, 2) - # Shape: [num_groups, num_selected_experts * tokens_per_group] - expert_index = expert_index.reshape(num_groups, -1) - - # Create mask out of indices. - # Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts]. - expert_mask = jax.nn.one_hot(expert_index, num_experts, dtype=jnp.int32) - - # Experts have a fixed capacity that we cannot exceed. A token's priority - # within the expert's buffer is given by the masked, cumulative capacity of - # its target expert. - # Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts]. - token_priority = jnp.cumsum(expert_mask, axis=1) * expert_mask - 1.0 - # Shape: [num_groups, num_selected_experts, tokens_per_group, num_experts]. - token_priority = token_priority.reshape((num_groups, self.num_selected_experts, -1, num_experts)) - # Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts]. - token_priority = jnp.swapaxes(token_priority, 1, 2) - # For each token, across all selected experts, select the only non-negative - # (unmasked) priority. Now, for group G routing to expert E, token T has - # non-negative priority (i.e. token_priority[G,T,E] >= 0) if and only if E - # is its targeted expert. - # Shape: [num_groups, tokens_per_group, num_experts]. - token_priority = jnp.max(token_priority, axis=2) - - if self.batch_prioritized_routing: - # Place token priorities in original ordering of tokens. - inv_permutation = jnp.argsort(permutation, axis=-1) - token_priority = _take_along_axis(token_priority, jnp.expand_dims(inv_permutation, axis=-1), axis=-2) - - # Token T can only be routed to expert E if its priority is positive and - # less than the expert capacity. One-hot matrix will ignore indices outside - # the range [0, expert_capacity). - # Shape: [num_groups, tokens_per_group, num_experts, expert_capacity]. - dispatch_mask = jax.nn.one_hot(token_priority, expert_capacity, dtype=jnp.bool_) - - # The combine array will be used for combining expert outputs, scaled by the - # router probabilities. Shape: [num_groups, tokens_per_group, num_experts, - # expert_capacity]. - combine_array = jnp.einsum( - "...te,...tec->...tec", router_probs, dispatch_mask, precision=jax.lax.Precision.DEFAULT - ) - - # Return to default dtype now that router computation is complete. - combine_array = jax.lax.convert_element_type(combine_array, self.dtype) - - return RouterMask(dispatch_mask, combine_array, auxiliary_loss) - - class ExpertsChooseMaskedRouter(MaskedRouter): """Masked matmul router using experts choose tokens assignment. @@ -674,23 +556,18 @@ class TokensChooseMaskedRouter(MaskedRouter): """ Masked matmul router using tokens choose top-k experts assignment. - This router uses the same mechanism as in Switch Transformer - (https://arxiv.org/abs/2101.03961) and V-MoE - (https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are - sorted by router_probs and then routed to their choice of expert until the - expert's expert_capacity is reached. There is no guarantee that each token is - processed by an expert, or that each expert receives at least one token. + This router uses the same mechanism as in Switch Transformer (https://arxiv.org/abs/2101.03961) and V-MoE + (https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are sorted by router_probs and then + routed to their choice of expert until the expert's expert_capacity is reached. There is no guarantee that each + token is processed by an expert, or that each expert receives at least one token. Attributes: num_selected_experts: Maximum number of experts to which each token is - routed. Tokens may be routed to fewer experts if particular experts are - oversubscribed / reach capacity. + routed. Tokens may be routed to fewer experts if particular experts are oversubscribed / reach capacity. batch_prioritized_routing: Whether or not to use Batch Prioritized Routing - (BPR), originally introduced in V-MoE (https://arxiv.org/abs/2106.05974). - With BPR, we prioritize routing those top-k tokens with the highest - router probability, rather than simply using each tokens left-to-right - ordering in the batch. This prioritization is important because the - experts have limited capacity. + (BPR), originally introduced in V-MoE (https://arxiv.org/abs/2106.05974). With BPR, we prioritize routing those + top-k tokens with the highest router probability, rather than simply using each tokens left-to-right ordering + in the batch. This prioritization is important because the experts have limited capacity. """ num_selected_experts: int diff --git a/tests/models/switchtransformers/test_modeling_switchtransformers.py b/tests/models/switchtransformers/test_modeling_switchtransformers.py index 74c306f5621a0..cba49ba4d123c 100644 --- a/tests/models/switchtransformers/test_modeling_switchtransformers.py +++ b/tests/models/switchtransformers/test_modeling_switchtransformers.py @@ -37,7 +37,12 @@ from transformers.models.switchtransformers.modeling_switchtransformers import ( SWITCHTRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST, ) - from transformers.models.switchtransformers.router import load_balancing_loss_func, router_z_loss_func + from transformers.models.switchtransformers.router import ( + ExpertsChooseMaskedRouter, + TokensChooseMaskedRouter, + load_balancing_loss_func, + router_z_loss_func, + ) class SwitchTransformersModelTester: @@ -917,3 +922,146 @@ def test_equivalency_router_z_loss(self): loss = router_z_loss_func(logits) self.assertAlmostEqual(loss.item(), 13.786719, places=5) + + def test_equivalency_token_chose_masked_router(self): + r""" + This test tests the equivalency between the `TokensChooseMaskedRouter` + originally implemented from here: TODO: provide link + """ + hidden_dim = 4 + num_experts = 2 + num_selected_experts = 1 # Switch routing case + expert_capacity = 1 # Total capacity = 2*2*1 = 4 < num_tokens + jitter_noise = 0.0 + + input_tokens = torch.Tensor( + [ + [ + [0.6433916, 0.18188512, 0.02240455, 0.563781], + [0.5526401, 0.0958724, 0.34253013, 0.03644359], + [0.08744538, 0.7909105, 0.35205448, 0.53364205], + ], + [ + [0.02900076, 0.4168595, 0.5802449, 0.91486526], + [0.27414513, 0.14991808, 0.9383501, 0.5209162], + [0.51207185, 0.90618336, 0.7309413, 0.95533276], + ], + ] + ) + + config = SwitchTransformersConfig( + num_experts=num_experts, + hidden_size=hidden_dim, + num_selected_experts=num_selected_experts, + router_jitter_noise=jitter_noise, + expert_capacity=expert_capacity, + batch_prioritized_routing=False, + ) + model = TokensChooseMaskedRouter(config) + + model.router_weights.weight = torch.nn.Parameter( + torch.Tensor( + [ + [0.02008116, 0.00620062], + [-0.00811031, -0.00031623], + [-0.03542127, 0.02703803], + [0.02335377, -0.02971946], + ], + ).t() + ) + + output = model(input_tokens, expert_capacity=expert_capacity) + + expected_dispatch_mask = torch.Tensor( + [ + [[[True], [False]], [[False], [True]], [[False], [False]]], + [[[True], [False]], [[False], [True]], [[False], [False]]], + ] + ) + + expected_combine_array = torch.Tensor( + [ + [[[0.5090], [0.0000]], [[0.0000], [0.5031]], [[0.0000], [0.0000]]], + [[[0.5024], [0.0000]], [[0.0000], [0.5071]], [[0.0000], [0.0000]]], + ] + ) + + self.assertAlmostEqual(output.auxiliary_loss.item(), 1.000308, places=5) + self.assertAlmostEqual(output.router_z_loss.item(), 0.4789799, places=5) + + self.assertTrue(torch.allclose(output.dispatch_mask, expected_dispatch_mask)) + self.assertTrue(torch.allclose(output.combine_array, expected_combine_array)) + + def test_equivalency_experts_chose_masked_router(self): + r""" + This test tests the equivalency between the `ExpertsChooseMaskedRouter` + originally implemented from here: TODO: provide link + """ + hidden_dim = 3 + num_experts = 2 + expert_capacity = 2 # Total capacity = 2*2*1 = 4 < num_tokens + jitter_noise = 0.0 + + input_tokens = torch.Tensor( + [ + [ + [0.6433916, 0.18188512, 0.02240455], + [0.563781, 0.5526401, 0.0958724], + [0.34253013, 0.03644359, 0.08744538], + [0.7909105, 0.35205448, 0.53364205], + ], + [ + [0.02900076, 0.4168595, 0.5802449], + [0.91486526, 0.27414513, 0.14991808], + [0.9383501, 0.5209162, 0.51207185], + [0.90618336, 0.7309413, 0.95533276], + ], + ] + ) + + config = SwitchTransformersConfig( + num_experts=num_experts, + hidden_size=hidden_dim, + router_jitter_noise=jitter_noise, + expert_capacity=expert_capacity, + batch_prioritized_routing=False, + ) + + model = ExpertsChooseMaskedRouter(config) + + model.router_weights.weight = torch.nn.Parameter( + torch.Tensor([[-0.00107201, 0.01544739], [-0.0087319, 0.01314363], [0.03530733, 0.03709853]]).t() + ) + + output = model(input_tokens, expert_capacity=expert_capacity) + + self.assertEqual(output.auxiliary_loss, 0.0) + self.assertAlmostEqual(output.router_z_loss.item(), 0.507016, places=5) + + expected_dispatch_mask = torch.Tensor( + [ + [[[0, 1], [0, 0]], [[0, 0], [0, 1]], [[1, 0], [0, 0]], [[0, 0], [1, 0]]], + [[[1, 0], [0, 0]], [[0, 1], [0, 0]], [[0, 0], [0, 1]], [[0, 0], [1, 0]]], + ] + ) + + self.assertTrue(torch.allclose(output.dispatch_mask, expected_dispatch_mask)) + + expected_combined_array = torch.Tensor( + [ + [ + [[0.0000, 0.4963], [0.0000, 0.0000]], + [[0.0000, 0.0000], [0.0000, 0.5054]], + [[0.4983, 0.0000], [0.0000, 0.0000]], + [[0.0000, 0.0000], [0.5054, 0.0000]], + ], + [ + [[0.4973, 0.0000], [0.0000, 0.0000]], + [[0.0000, 0.4947], [0.0000, 0.0000]], + [[0.0000, 0.0000], [0.0000, 0.5070]], + [[0.0000, 0.0000], [0.5082, 0.0000]], + ], + ] + ) + + self.assertTrue(torch.allclose(output.combine_array, expected_combined_array)) From 2eea820946338d4ffb65e733786270082e11d8e8 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 6 Oct 2022 18:47:54 +0100 Subject: [PATCH 009/102] last router --- .../models/switchtransformers/router.py | 188 +++++++++++++++++- 1 file changed, 186 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/switchtransformers/router.py b/src/transformers/models/switchtransformers/router.py index f87273c6497b8..c5fb4a78c4816 100644 --- a/src/transformers/models/switchtransformers/router.py +++ b/src/transformers/models/switchtransformers/router.py @@ -125,7 +125,7 @@ def router_z_loss_func(router_logits: torch.Tensor) -> float: return torch.sum(z_loss) / (num_groups * tokens_per_group) -def load_balancing_loss_fun(router_probs: torch.Tensor, expert_indices: torch.Tensor) -> float: +def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.Tensor) -> float: r""" Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. @@ -415,7 +415,7 @@ def _compute_routing_instructions( # To correctly compute load balancing loss, we also mask out probs. router_probs *= gate_mask - auxiliary_loss = load_balancing_loss_fun(router_probs, expert_index) + auxiliary_loss = load_balancing_loss_func(router_probs, expert_index) if self.batch_prioritized_routing: # Sort tokens according to their routing probability per group, so that @@ -477,6 +477,149 @@ def _compute_routing_instructions( return RouterMask(dispatch_mask, combine_array, auxiliary_loss) +class ScatterRouter(Router): + """Abstract base router class for scatter dispatch routers. + + ScatterRouter(s) return RouterIndices containing dispatch indices and combine weights for sending token inputs (via + scatter) and receiving outputs (via gather) to and from experts. + + Scatter-based routing is generally faster than masked matmul routing on CPUs and GPUs. + """ + + def _compute_routing_instructions( + self, router_probs: torch.Tensor, padding_mask: Optional[torch.Tensor], expert_capacity: int + ) -> RouterIndices: + """Computes instructions for routing inputs to experts. + + Args: + router_probs: [num_groups, tokens_per_group, num_experts] + probabilities used to determine the routing of tokens to the experts. + padding_mask: [num_groups, tokens_per_group] padding logit mask + used to identify padding tokens that should be ignored by the router. + expert_capacity: Each group will send this many tokens to each expert. + + Returns: + Router indices containing dispatch indices and combine weights. + """ + raise NotImplementedError("ScatterRouter is an abstract class that should be subclassed.") + + +class TokensChooseScatterRouter(ScatterRouter): + """Scatter router using tokens choose top-k experts assignment. + + This router uses the same mechanism as in Switch Transformer (https://arxiv.org/abs/2101.03961) and V-MoE + (https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are sorted by router_probs and then + routed to their choice of expert until the expert's expert_capacity is reached. There is no guarantee that each + token is processed by an expert, or that each expert receives at least one token. + + Attributes: + num_selected_experts: Maximum number of experts to which each token is + routed. Tokens may be routed to fewer experts if particular experts are oversubscribed / reach capacity. + batch_prioritized_routing: Whether or not to use Batch Prioritized Routing + (BPR), originally introduced in V-MoE (https://arxiv.org/abs/2106.05974). + With BPR, we prioritize routing those top-k tokens with the highest router probability, rather than simply + using each tokens left-to-right ordering in the batch. This prioritization is important because the expert's + have limited capacity. + """ + + def __init__(self, config, **kwargs): + super().__init__(config, **kwargs) + self.num_selected_experts = config.num_selected_experts + self.batch_prioritized_routing = config.batch_prioritized_routing + + def _compute_routing_instructions( + self, router_probs: torch.Tensor, padding_mask: Optional[torch.Tensor], expert_capacity: int + ) -> RouterIndices: + """Computes dispatch indices and combine weights for the top-k experts. + + Args: + router_probs: [num_groups, tokens_per_group, num_experts] + probabilities used to determine the routing of tokens to the experts. + padding_mask: [num_groups, tokens_per_group] padding logit mask + used to identify padding tokens that should be down-weighted by the router. + expert_capacity: Each group will send this many tokens to each expert. + + Returns: + Dispatch indices and combine weights for scatter/gather-based routing. + """ + original_dtype = router_probs.dtype + num_groups, tokens_per_group, num_experts = router_probs.shape + + if padding_mask is not None: + # Because experts choose tokens, we mask probabilities corresponding to + # tokens before the top-k operation. Note that, unlike for masked-based + # tokens-choose routing, the experts here may still choose to select the + # (down-weighted) padding tokens. + router_probs *= padding_mask.unsqueeze(-1) + + # Top-k router probability and corresponding expert indices for each token. + # Shape: [num_groups, tokens_per_group, num_selected_experts]. + combine_weights, expert_indices = torch.topk(router_probs, k=self.num_selected_experts) + + auxiliary_loss = load_balancing_loss_func(router_probs, expert_indices) + + if self.batch_prioritized_routing: + # Sort tokens according to their routing probability per group, so that + # the highest probability tokens are routed first. + token_ordering = torch.argsort(-combine_weights[..., 0], dim=-1) + # Shape: [num_groups, tokens_per_group, num_selected_experts] + expert_indices = torch.take_along_dim(expert_indices, token_ordering.unsqueeze(-1), dim=-2) + + # Identify each token's preferred expert. + # Make num_selected_experts the leading axis to ensure that top-1 choices + # have priority over top-2 choices, which have priority over top-3 + # choices... + preferred_experts = expert_indices.permute(0, 2, 1) + # Shape: [num_groups, num_selected_experts * tokens_per_group] + preferred_experts = preferred_experts.reshape(num_groups, -1) + + # Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts]. + expert_mask = _jax_one_hot(preferred_experts, num_experts, dtype=torch.int32) + + # Experts have a fixed capacity that we cannot exceed. A token's priority + # within the expert's buffer is given by the masked, cumulative capacity of + # its target expert. + # Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts]. + token_priority = torch.cumsum(expert_mask, axis=1) * expert_mask - 1.0 + # Shape: [num_groups, num_selected_experts, tokens_per_group, num_experts]. + token_priority = token_priority.reshape((num_groups, self.num_selected_experts, -1, num_experts)) + # Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts]. + token_priority = token_priority.permute((0, 2, 1, 3)) + # For each token, across all selected experts, select the only non-negative + # (unmasked) priority. Now, for group G routing to expert E, token T has + # non-negative priority (i.e. token_priority[G,T,E] >= 0) if and only if E + # is its targeted expert. + # Shape: [num_groups, tokens_per_group, num_experts]. + token_priority = torch.max(token_priority, axis=-1).values + + # Return to original index shape. + preferred_experts = preferred_experts.reshape(num_groups, self.num_selected_experts, tokens_per_group) + # Shape: [num_groups, tokens_per_group, num_selected_experts] + preferred_experts = preferred_experts.permute(0, 2, 1) + + if self.batch_prioritized_routing: + # Place token priorities in original ordering of tokens. + inv_permutation = torch.argsort(token_ordering, dim=-1) + preferred_experts = torch.take_along_dim( + preferred_experts.unsqueeze(-1), inv_permutation.unsqueeze(-1), dim=-2 + ) + token_priority = torch.take_along_dim(token_priority.unsqueeze(-1), inv_permutation.unsqueeze(-1), dim=-2) + + # Mask out tokens that overflow the maximum expert capacities. + # Shape: [num_groups, tokens_per_group, num_selected_experts]. + combine_weights *= token_priority < expert_capacity + + # Expert index and priority within the expert capacity buffer. + # Shape: [num_groups, tokens_per_group, num_selected_experts, 2]. + dispatch_indices = torch.stack([preferred_experts, token_priority], dim=-1) + + # Return to default dtype now that router computation is complete. + combine_weights = combine_weights.to(original_dtype) + dispatch_indices = dispatch_indices.to(torch.int32) + + return RouterIndices(dispatch_indices, combine_weights, auxiliary_loss) + + # num_groups = 2 # tokens_per_group = 4 # hidden_dim = 3 @@ -558,3 +701,44 @@ def _compute_routing_instructions( # ) # output = model(input_tokens, expert_capacity=expert_capacity) + + +# num_groups = 2 +# tokens_per_group = 4 +# hidden_dim = 3 +# num_experts = 3 +# num_selected_experts = 1 +# expert_capacity = 2 +# jitter_noise = 0.0 + +# input_tokens = torch.Tensor( +# [[[0.6433916 , 0.18188512, 0.02240455], +# [0.563781 , 0.5526401 , 0.0958724 ], +# [0.34253013, 0.03644359, 0.08744538], +# [0.7909105 , 0.35205448, 0.53364205]], + +# [[0.02900076, 0.4168595 , 0.5802449 ], +# [0.91486526, 0.27414513, 0.14991808], +# [0.9383501 , 0.5209162 , 0.51207185], +# [0.90618336, 0.7309413 , 0.95533276]]] +# ) + +# config = SwitchTransformersConfig( +# num_experts=num_experts, +# hidden_size=hidden_dim, +# num_selected_experts=num_selected_experts, +# router_jitter_noise=jitter_noise, +# expert_capacity=expert_capacity, +# batch_prioritized_routing=False, +# ) +# model = TokensChooseScatterRouter(config) + +# model.router_weights.weight = torch.nn.Parameter( +# torch.Tensor( +# [[ 0.02736656, -0.00253537, 0.04682618], +# [ 0.00928149, 0.04933621, -0.00275501], +# [ 0.00751786, 0.04295348, -0.00503795]], +# ).t() +# ) + +# output = model(input_tokens, expert_capacity=expert_capacity) From a65c7e47307e80e632443409cb9879f94da76f29 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Sat, 8 Oct 2022 19:58:52 +0100 Subject: [PATCH 010/102] improved docstring - completed the docstring in `router.py` - added more args in the config --- .../configuration_switchtransformers.py | 8 + .../models/switchtransformers/router.py | 250 +++++++++++------- 2 files changed, 164 insertions(+), 94 deletions(-) diff --git a/src/transformers/models/switchtransformers/configuration_switchtransformers.py b/src/transformers/models/switchtransformers/configuration_switchtransformers.py index 8a58bfdf29341..54953c033cfc9 100644 --- a/src/transformers/models/switchtransformers/configuration_switchtransformers.py +++ b/src/transformers/models/switchtransformers/configuration_switchtransformers.py @@ -66,6 +66,9 @@ class SwitchTransformersConfig(PretrainedConfig): Amount of noise to add to the router. router_ignore_padding_tokens (`bool`, *optional*, defaults to `False`): Whether to ignore padding tokens when routing. + router_dtype (`str`, *optional*, default to `float32`): + The `dtype` used for the routers. It is preferable to keep the `dtype` to `float32` as specified in the + "selective precision" discussion in https://arxiv.org/abs/2101.03961. batch_prioritized_routing (`bool`, *optional*, defaults to `False`): Whether to use batch prioritized routing. num_selected_experts (`int`, *optional*, defaults to 2): @@ -103,6 +106,7 @@ def __init__( num_experts=8, router_bias=False, router_jitter_noise=0.01, + router_dtype="float32", num_selected_experts=2, router_ignore_padding_tokens=False, batch_prioritized_routing=False, @@ -130,6 +134,10 @@ def __init__( self.num_experts = num_experts self.router_bias = router_bias self.router_jitter_noise = router_jitter_noise + self.router_dtype = router_dtype + + if router_dtype not in ["float16", "float32", "bfloat16"]: + raise ValueError("""Please select a correct `router_dtype` from ["float16", "float32", "bfloat16"].""") self.router_ignore_padding_tokens = router_ignore_padding_tokens self.relative_attention_num_buckets = relative_attention_num_buckets diff --git a/src/transformers/models/switchtransformers/router.py b/src/transformers/models/switchtransformers/router.py index c5fb4a78c4816..b08ea57238e85 100644 --- a/src/transformers/models/switchtransformers/router.py +++ b/src/transformers/models/switchtransformers/router.py @@ -31,6 +31,16 @@ def _jax_one_hot(tensor, num_classes, axis=-1, dtype=torch.bool): This function mimics the behavior of jax.nn.functional.one_hot in PyTorch. It takes a tensor of indices, the number of desired classes and the axis to apply the one-hot encoding. If any value is outside the range [0, num_classes), it will be set to zeros. + + Args: + tensor (`torch.Tensor`): + Input tensor + num_classes (`int`): + Number of classes to process for one hot encoding + axis (`int`, *optional*): + The lookup axis to check for one-hot encoding + dtype (`torch.dtype`, *optional*): + Output `dtype`. The one hot encoded vector will be casted to this dtype """ if tensor.is_floating_point(): raise "Input tensor for one hot encoding must be an `int32` or `int64`" @@ -63,17 +73,19 @@ def _jax_one_hot(tensor, num_classes, axis=-1, dtype=torch.bool): @dataclass class RouterIndices: r""" - Dispatch indices and combine weights for scatter/gather-based routing. + A dataclass wrapper to store the dispatch indices and combine weights for scatter/gather-based routing. Attributes: - dispatch_indices: [num_groups, tokens_per_group, - num_selected_experts, 2] dispatch indices indicating, for each token, its preferred expert and its priority in - that expert's buffer. - combine_weights: [num_groups, tokens_per_group, num_selected_experts] - combine weights used for scaling expert outputs with the router's dispatch probability/confidence. - auxiliary_loss: Load balancing loss for router. - router_z_loss: Router z-loss. Encourages router logits to remain small in an - effort to improve stability. + dispatch_indices (`torch.Tensor`): + A tensor of size [`num_groups`, `tokens_per_group`, `num_selected_experts`, 2] dispatch indices indicating, + for each token, its preferred expert and its priority in that expert's buffer. + combine_weights (`torch.Tensor`): + A tensor of size [num_groups, tokens_per_group, num_selected_experts] combine weights used for scaling + expert outputs with the router's dispatch probability/confidence. + auxiliary_loss (`float`): + Load balancing loss for router. + router_z_loss (`float`): + Router z-loss. Encourages router logits to remain small in an effort to improve stability. """ dispatch_indices: torch.Tensor combine_weights: torch.Tensor @@ -87,14 +99,16 @@ class RouterMask: Dispatch and combine torch.Tensors for expert routing with masked matmuls. Attributes: - dispatch_mask: [num_groups, tokens_per_group, num_experts, - expert_capacity] dispatch torch.Tensor that is 1 if the token gets routed to the corresponding expert, and 0 - otherwise. - combine_torch.Tensor: [num_groups, tokens_per_group, num_experts, - expert_capacity] combine torch.Tensor used for combining expert outputs and scaling with router probability. - auxiliary_loss: Load balancing loss for router. - router_z_loss: Router z-loss. Encourages router logits to remain small in an - effort to improve stability. + dispatch_mask (`torch.Tensor`): + A mask tensor of shape [num_groups, tokens_per_group, num_experts, expert_capacity] that is 1 if the token + gets routed to the corresponding expert, and 0 otherwise. + combine_array (`torch.Tensor`): + A tensor of shape [num_groups, tokens_per_group, num_experts, expert_capacity] combine torch.Tensor used + for combining expert outputs and scaling with router probability. + auxiliary_loss (`float`): + Load balancing loss for router. + router_z_loss (`float`): + Router z-loss. Encourages router logits to remain small in an effort to improve stability. """ dispatch_mask: torch.Tensor combine_array: torch.Tensor @@ -113,11 +127,11 @@ def router_z_loss_func(router_logits: torch.Tensor) -> float: encourages router logits to remain small in an effort to improve stability. Args: - router_logits: [num_groups, tokens_per_group, num_experts] router - logits. + router_logits (`float`): + Input logits of shape [num_groups, tokens_per_group, num_experts] Returns: - Scalar router z-loss. + Scalar router z-loss. """ num_groups, tokens_per_group, _ = router_logits.shape log_z = torch.logsumexp(router_logits, dim=-1) @@ -133,10 +147,11 @@ def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.T equations (4) - (6). It aims to penalize those cases where the routing between experts is unbalanced. Args: - router_probs: Probability assigned to each expert per token. Shape: - [num_groups, tokens_per_group, num_experts]. - expert_indices: [num_groups, tokens_per_group, num_selected_experts] - indices identifying the top num_selected_experts for a given token. + router_probs (`torch.Tensor`): + Probability assigned to each expert per token. Shape: [num_groups, tokens_per_group, num_experts]. + expert_indices (`torch.Tensor`): + Indices tensor of shape [num_groups, tokens_per_group, num_selected_experts] identifying the top + num_selected_experts for a given token. Returns: The auxiliary loss. @@ -169,15 +184,17 @@ class Router(nn.Module): Abstract base router class, defining router API and inner workings. Attributes: - router_weights: Configurable module used to compute router logits from token - inputs. - jitter_noise: Amplitude of jitter noise applied to router logits. - dtype: Numeric float type for returned combine torch.Tensor. All actual - computations are performed in float32 of the input for stability. - ignore_padding_tokens: Whether to ignore padding tokens during routing. Note - that some routers (e.g. TokensChooseMaskedRouter) will completely ignore padding tokens, while others (e.g. - TokensChooseScatterRouter and ExpertsChooseMaskedRouter) will simply down-weight the probability of selecting - padding tokens. + router_weights (`torch.nn.Module`): + Configurable module used to compute router logits from token inputs. + jitter_noise (`float`): + Amplitude of jitter noise applied to router logits. + dtype (`torch.dtype`): + Numeric float type for returned combine torch.Tensor. All actual computations are performed in float32 of + the input for stability. + ignore_padding_tokens (`bool`): + Whether to ignore padding tokens during routing. Note that some routers (e.g. `TokensChooseMaskedRouter`) + will completely ignore padding tokens, while others (e.g. `TokensChooseScatterRouter` and + `ExpertsChooseMaskedRouter`) will simply down-weight the probability of selecting padding tokens. """ def __init__(self, config, **kwargs): @@ -186,6 +203,7 @@ def __init__(self, config, **kwargs): self.router_weights = nn.Linear(config.hidden_size, self.num_experts, bias=config.router_bias) self.jitter_noise = config.router_jitter_noise self.ignore_padding_tokens = config.router_ignore_padding_tokens + self.dtype = getattr(torch, config.router_dtype) def _compute_router_probabilities( self, token_inputs: torch.Tensor, num_experts: int, apply_jitter: bool @@ -194,21 +212,27 @@ def _compute_router_probabilities( Computes router probabilities from input tokens. Args: - token_inputs: [num_groups, tokens_per_group, hidden_dim] from which - router probabilities are computed. - num_experts: Number of experts. - apply_jitter: If true, apply jitter noise. + token_inputs (`torch.Tensor`): + [num_groups, tokens_per_group, hidden_dim] from which router probabilities are computed. + num_experts (`int`): + Number of experts. + apply_jitter (`bool`): + If true, apply jitter noise. Returns: - - [num_groups, tokens_per_group, num_experts] probabilities for - each token and expert. Used for routing tokens to experts. - - [num_groups, tokens_per_group, num_experts] raw router logits. - Used for computing router z-loss. + router_probabilities (`torch.Tensor`): + Tensor of shape [num_groups, tokens_per_group, num_experts] corresponding to the probabilities for each + token and expert. Used for routing tokens to experts. + router_logits (`torch.Tensor`): + Logits tensor of shape [num_groups, tokens_per_group, num_experts] corresponding to raw router logits. + This is used later for computing router z-loss. """ # For remainder of routing computation we use float32 to ensure stability. # See the discussion of "selective precision" in # https://arxiv.org/abs/2101.03961. - token_inputs = token_inputs.to(torch.float32) + # We also store the previous dtype to cast back the output to the previous dtype + self.input_tokens_dtype = token_inputs.dtype + token_inputs = token_inputs.to(self.dtype) if apply_jitter and self.jitter_noise > 0: # Get the lower and upper bound of the uniform distribution @@ -230,8 +254,18 @@ def _compute_router_probabilities( return router_probabilities, router_logits - def forward(self, token_inputs: torch.Tensor, expert_capacity: int, apply_jitter: bool = True) -> RouterOutput: + def forward( + self, token_inputs: torch.Tensor, expert_capacity: int, apply_jitter: bool = True, **kwargs + ) -> RouterOutput: r""" + Generic forward function for every Router class. Each Router expects to have the same input hidden states + (`token_inputs`) corresponding to the hidden states for each token, the `expert_capacity` corresponding to the + number of tokens the Router will send to each expert, some Routers can send up to few tokens to each expert. + + Each Router works as the following: it expects the hidden states for each token, gets the `router_probs` and + `router_logits` from the `router_weights`. This will assign for each token, the raw probability to be assigned + to an expert. Then each Router class will have to define its own `_compute_routing_instructions`. + Args: Computes dispatch and combine torch.Tensors for routing to experts. token_inputs: [num_groups, tokens_per_group, hidden_dim] inputs to send to experts. num_experts: @@ -255,13 +289,24 @@ def forward(self, token_inputs: torch.Tensor, expert_capacity: int, apply_jitter else: padding_mask = None - instructions = self._compute_routing_instructions(router_probs, padding_mask, expert_capacity) + instructions = self._compute_routing_instructions(router_probs, padding_mask, expert_capacity, **kwargs) + # We cast back the output to the previous dtype + instructions = instructions.to(self.input_tokens_dtype) return replace(instructions, router_z_loss=router_z_loss_func(router_logits)) + def _compute_routing_instructions(self, router_probs, padding_mask, expert_capacity): + raise NotImplementedError( + """ + The forward function cannot be called from the `Router` super-class. Please call an appropriate Router + class that inherits from the `Router` class (for example `ExpertsChooseMaskedRouter`) + """ + ) + class MaskedRouter(Router): - """Abstract base router class for masked matmul dispatch routers. + """ + Abstract base router class for masked matmul dispatch routers. MaskedRouter(s) return RouterMask(s) containing a dispatch mask and combine array for sending and receiving (via masked matmuls) inputs and outputs to and from experts. @@ -272,14 +317,19 @@ class MaskedRouter(Router): def _compute_routing_instructions( self, router_probs: torch.Tensor, padding_mask: Optional[torch.Tensor], expert_capacity: int ) -> RouterMask: - """Computes masks for the top-k experts per token. + """ + Computes masks for the top-k experts per token. This has to be implemented for each subclass of MaskedRouter + routers. Args: - router_probs: [num_groups, tokens_per_group, num_experts] - probabilities used to determine the routing of tokens to the experts. - padding_mask: [num_groups, tokens_per_group] padding logit mask - used to identify padding tokens that should be ignored by the router. - expert_capacity: Each group will send this many tokens to each expert. + router_probs (`torch.Tensor`): + Input router probabilities of shape [num_groups, tokens_per_group, num_experts] this corresponds to the + probabilities used to determine the routing of tokens to the experts. + padding_mask (`torch.Tensor`): + Padding mask tensor of shape [num_groups, tokens_per_group] a mask used to identify padding tokens that + should be ignored by the router. + expert_capacity (`int`): + Each group will send this many tokens to each expert. Returns: Router mask arrays. @@ -288,7 +338,8 @@ def _compute_routing_instructions( class ExpertsChooseMaskedRouter(MaskedRouter): - """Masked matmul router using experts choose tokens assignment. + """ + Masked matmul router using experts choose tokens assignment. This router uses the same mechanism as in Mixture-of-Experts with Expert Choice (https://arxiv.org/abs/2202.09368): each expert selects its top expert_capacity tokens. An individual token may be processed by multiple experts or @@ -304,17 +355,19 @@ def _compute_routing_instructions( """Computes masks for the highest probability token per expert. Args: - router_probs: [num_groups, tokens_per_group, num_experts] - probabilities used to determine the routing of tokens to the experts. - padding_mask: [num_groups, tokens_per_group] padding logit mask - used to identify padding tokens that should be down-weighted by the router. - expert_capacity: Each group will send this many tokens to each expert. + router_probs (`torch.Tensor`): + Raw router probabilities of shape [num_groups, tokens_per_group, num_experts] used to determine the + routing of tokens to the experts. + padding_mask (`torch.Tensor`): + padding mask tensor of shape [num_groups, tokens_per_group] used to identify padding tokens that should + be down-weighted by the router. + expert_capacity (`int`): + Each group will send this many tokens to each expert. Returns: Dispatch and combine arrays for routing with masked matmuls. """ tokens_per_group = router_probs.shape[1] - default_dtype = router_probs.dtype if padding_mask is not None: # Because experts choose tokens, we mask probabilities corresponding to @@ -344,9 +397,6 @@ def _compute_routing_instructions( # expert_capacity]. combine_array = torch.einsum("...ec,...tec->...tec", expert_gate, dispatch_mask) - # Return to default dtype now that router computation is complete. - combine_array = combine_array.to(default_dtype) - # Each expert is choosing tokens until it reaches full capacity, so we don't # need an auxiliary loading balancing loss for expert choice routing. auxiliary_loss = 0.0 @@ -364,12 +414,14 @@ class TokensChooseMaskedRouter(MaskedRouter): token is processed by an expert, or that each expert receives at least one token. Attributes: - num_selected_experts: Maximum number of experts to which each token is - routed. Tokens may be routed to fewer experts if particular experts are oversubscribed / reach capacity. - batch_prioritized_routing: Whether or not to use Batch Prioritized Routing - (BPR), originally introduced in V-MoE (https://arxiv.org/abs/2106.05974). With BPR, we prioritize routing those - top-k tokens with the highest router probability, rather than simply using each tokens left-to-right ordering - in the batch. This prioritization is important because the experts have limited capacity. + num_selected_experts (`int`): + Maximum number of experts to which each token is routed. Tokens may be routed to fewer experts if particular + experts are oversubscribed / reach capacity. + batch_prioritized_routing (`bool`): + Whether or not to use Batch Prioritized Routing (BPR), originally introduced in V-MoE + (https://arxiv.org/abs/2106.05974). With BPR, we prioritize routing those top-k tokens with the highest router + probability, rather than simply using each tokens left-to-right ordering in the batch. This prioritization is + important because the experts have limited capacity. """ def __init__(self, config, **kwargs): @@ -384,11 +436,14 @@ def _compute_routing_instructions( Computes masks for the top-k experts per token. Args: - router_probs: [num_groups, tokens_per_group, num_experts] - probabilities used to determine the routing of tokens to the experts. - padding_mask: [num_groups, tokens_per_group] padding logit mask - used to identify padding tokens that should be ignored by the router. - expert_capacity: Each group will send this many tokens to each expert. + router_probs (`torch.Tensor`): + Router raw probabilities tensor of shape [num_groups, tokens_per_group, num_experts] used to determine + the routing of tokens to the experts. + padding_mask (`torch.Tensor`): + Padding mask of shape [num_groups, tokens_per_group] used to identify padding tokens that should be + ignored by the router. + expert_capacity (`int`): + Each group will send this many tokens to each expert. Returns: Dispatch and combine arrays for routing with masked matmuls. @@ -478,7 +533,8 @@ def _compute_routing_instructions( class ScatterRouter(Router): - """Abstract base router class for scatter dispatch routers. + """ + Abstract base router class for scatter dispatch routers. ScatterRouter(s) return RouterIndices containing dispatch indices and combine weights for sending token inputs (via scatter) and receiving outputs (via gather) to and from experts. @@ -492,20 +548,24 @@ def _compute_routing_instructions( """Computes instructions for routing inputs to experts. Args: - router_probs: [num_groups, tokens_per_group, num_experts] - probabilities used to determine the routing of tokens to the experts. - padding_mask: [num_groups, tokens_per_group] padding logit mask - used to identify padding tokens that should be ignored by the router. - expert_capacity: Each group will send this many tokens to each expert. + router_probs (`torch.Tensor`): + Router raw probabilities tensor of shape [num_groups, tokens_per_group, num_experts] used to determine + the routing of tokens to the experts. + padding_mask (`torch.Tensor`): + Padding mask of shape [num_groups, tokens_per_group] used to identify padding tokens that should be + ignored by the router. + expert_capacity (`int`): + Each group will send this many tokens to each expert. Returns: - Router indices containing dispatch indices and combine weights. + Router indices containing dispatch indices and combine weights. """ raise NotImplementedError("ScatterRouter is an abstract class that should be subclassed.") class TokensChooseScatterRouter(ScatterRouter): - """Scatter router using tokens choose top-k experts assignment. + """ + Scatter router using tokens choose top-k experts assignment. This router uses the same mechanism as in Switch Transformer (https://arxiv.org/abs/2101.03961) and V-MoE (https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are sorted by router_probs and then @@ -513,13 +573,14 @@ class TokensChooseScatterRouter(ScatterRouter): token is processed by an expert, or that each expert receives at least one token. Attributes: - num_selected_experts: Maximum number of experts to which each token is - routed. Tokens may be routed to fewer experts if particular experts are oversubscribed / reach capacity. - batch_prioritized_routing: Whether or not to use Batch Prioritized Routing - (BPR), originally introduced in V-MoE (https://arxiv.org/abs/2106.05974). - With BPR, we prioritize routing those top-k tokens with the highest router probability, rather than simply - using each tokens left-to-right ordering in the batch. This prioritization is important because the expert's - have limited capacity. + num_selected_experts (`int`): + Maximum number of experts to which each token is routed. Tokens may be routed to fewer experts if + particular experts are oversubscribed / reach capacity. + batch_prioritized_routing (`bool`): + Whether or not to use Batch Prioritized Routing BPR), originally introduced in V-MoE + (https://arxiv.org/abs/2106.05974). With BPR, we prioritize routing those top-k tokens with the highest + router probability, rather than simply using each tokens left-to-right ordering in the batch. This + prioritization is important because the expert's have limited capacity. """ def __init__(self, config, **kwargs): @@ -533,16 +594,18 @@ def _compute_routing_instructions( """Computes dispatch indices and combine weights for the top-k experts. Args: - router_probs: [num_groups, tokens_per_group, num_experts] - probabilities used to determine the routing of tokens to the experts. - padding_mask: [num_groups, tokens_per_group] padding logit mask - used to identify padding tokens that should be down-weighted by the router. - expert_capacity: Each group will send this many tokens to each expert. + router_probs (`torch.Tensor`): + Router raw probabilities tensor of shape [num_groups, tokens_per_group, num_experts] used to determine + the routing of tokens to the experts. + padding_mask (`torch.Tensor`): + Padding mask of shape [num_groups, tokens_per_group] used to identify padding tokens that should be + ignored by the router. + expert_capacity (`int`): + Each group will send this many tokens to each expert. Returns: Dispatch indices and combine weights for scatter/gather-based routing. """ - original_dtype = router_probs.dtype num_groups, tokens_per_group, num_experts = router_probs.shape if padding_mask is not None: @@ -614,7 +677,6 @@ def _compute_routing_instructions( dispatch_indices = torch.stack([preferred_experts, token_priority], dim=-1) # Return to default dtype now that router computation is complete. - combine_weights = combine_weights.to(original_dtype) dispatch_indices = dispatch_indices.to(torch.int32) return RouterIndices(dispatch_indices, combine_weights, auxiliary_loss) From 7f1026daf45fe0fb09ce588d362ced367e4c7519 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Sun, 9 Oct 2022 02:37:25 +0200 Subject: [PATCH 011/102] v0 sparse mlp --- .../configuration_switchtransformers.py | 36 ++++++++-- .../modeling_switchtransformers.py | 66 ++++++++++++++++--- 2 files changed, 86 insertions(+), 16 deletions(-) diff --git a/src/transformers/models/switchtransformers/configuration_switchtransformers.py b/src/transformers/models/switchtransformers/configuration_switchtransformers.py index 54953c033cfc9..fe0eff59898f5 100644 --- a/src/transformers/models/switchtransformers/configuration_switchtransformers.py +++ b/src/transformers/models/switchtransformers/configuration_switchtransformers.py @@ -52,10 +52,14 @@ class SwitchTransformersConfig(PretrainedConfig): num_heads`. d_ff (`int`, *optional*, defaults to 2048): Size of the intermediate feed forward layer in each `SwitchTransformersBlock`. - num_layers (`int`, *optional*, defaults to 6): - Number of hidden layers in the Transformer encoder. - num_decoder_layers (`int`, *optional*): + num_encoder_layers (`int`, *optional*, defaults to 12): + Number of dense hidden layers in the Transformer encoder layer. + num_sparse_encoder_layers (`int`, *optional*, defaults to 6): + Number of sparse (MoE) dense hidden layers in the Transformer encoder layer. + num_decoder_layers (`int`, *optional*, defaults to 12): Number of hidden layers in the Transformer decoder. Will use the same value as `num_layers` if not set. + num_sparse_decoder_layers (`int`, *optional*, defaults to 12): + Number of sparse (MoE) dense hidden layers in the Transformer decoder layer. num_heads (`int`, *optional*, defaults to 8): Number of attention heads for each attention layer in the Transformer encoder. num_experts (`int`, *optional*, defaults to 8): @@ -100,8 +104,10 @@ def __init__( d_model=512, d_kv=64, d_ff=2048, - num_layers=6, - num_decoder_layers=None, + num_encoder_layers=12, + num_sparse_encoder_layers=6, + num_decoder_layers=12, + num_sparse_decoder_layers=6, num_heads=8, num_experts=8, router_bias=False, @@ -126,10 +132,26 @@ def __init__( self.d_model = d_model self.d_kv = d_kv self.d_ff = d_ff - self.num_layers = num_layers + self.num_encoder_layers = num_encoder_layers + self.num_sparse_encoder_layers = num_sparse_encoder_layers + self.num_decoder_layers = ( - num_decoder_layers if num_decoder_layers is not None else self.num_layers + num_decoder_layers if num_decoder_layers is not None else self.num_encoder_layers ) # default = symmetry + self.num_sparse_decoder_layers = num_sparse_decoder_layers + + # This tells us, each how many encoder layer we'll have to set a sparse layer. + if self.num_sparse_encoder_layers > 0: + self.encoder_sparse_step = self.num_encoder_layer % self.num_sparse_encoder_layers + else: + self.encoder_sparse_step = self.num_encoder_layer # HACK: this will create 0 sparse layers + + # This tells us, each how many encoder layer we'll have to set a sparse layer. + if self.num_sparse_decoder_layers > 0: + self.decoder_sparse_step = self.num_decoder_layer % self.num_sparse_decoder_layers + else: + self.decoder_sparse_step = self.num_decoder_layer # HACK: this will create 0 sparse layers + self.num_heads = num_heads self.num_experts = num_experts self.router_bias = router_bias diff --git a/src/transformers/models/switchtransformers/modeling_switchtransformers.py b/src/transformers/models/switchtransformers/modeling_switchtransformers.py index 15383d0ebc98a..4444525128bd3 100644 --- a/src/transformers/models/switchtransformers/modeling_switchtransformers.py +++ b/src/transformers/models/switchtransformers/modeling_switchtransformers.py @@ -44,6 +44,7 @@ replace_return_docstrings, ) from .configuration_switchtransformers import SwitchTransformersConfig +from .router import ExpertsChooseMaskedRouter, TokensChooseMaskedRouter, TokensChooseScatterRouter logger = logging.get_logger(__name__) @@ -134,19 +135,63 @@ def forward(self, hidden_states): # This class should also contain a router class # check flaxformer/architecture/moe/router.py : https://github.com/google/flaxformer/blob/main/flaxformer/architectures/moe/routing.py class SwitchTransformersLayerFF(nn.Module): - def __init__(self, config: SwitchTransformersConfig): + def __init__(self, config: SwitchTransformersConfig, is_sparse=False): super().__init__() # TODO: check the comments above + self.is_sparse = is_sparse + if self.is_sparse: + self.mlp = SwitchTransformersDenseActDense(config) + else: + self.mlp = SwitchTransformersSparseMLP(config) self.layer_norm = SwitchTransformersLayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) def forward(self, hidden_states): forwarded_states = self.layer_norm(hidden_states) - forwarded_states = self.DenseReluDense(forwarded_states) + forwarded_states = self.mlp(forwarded_states) hidden_states = hidden_states + self.dropout(forwarded_states) return hidden_states +class SwitchTransformersSparseMLP(nn.Module): + r""" + Implementation of the Switch Transformers Sparse MLP module + + TODO: Add a LOT of details here + """ + + def __init__(self, config: SwitchTransformersConfig): + super().__init__() + # Step 1: Get the correct router + self.router = self._get_router(config) + + # Step 2: Get the experts + self.experts = None # TODO: figure out how this is done in t5x... + + def _get_router(self, config): + r""" + For now two types of Router are supported: + - Masked Routers + - Scatter Routers + In total the list of supported Routers are the following: + + """ + if config.router_type.lower() == "tokens_masked": + return TokensChooseMaskedRouter(config) + elif config.router_type.lower() == "tokens_scatter": + return TokensChooseScatterRouter(config) + elif config.router_type.lower() == "experts_masked": + return ExpertsChooseMaskedRouter(config) + else: + raise NotImplementedError( + f"{config.router_type.lower()} not implemented ! Please chose a router among [tokens_masked," + " tokens_scatter, experts_masked]" + ) + + def forward(self, hidden_states): + pass + + # Copied from transformers.models.t5.modeling_t5.T5Attention with T5->SwitchTransformers class SwitchTransformersAttention(nn.Module): def __init__(self, config: SwitchTransformersConfig, has_relative_attention_bias=False): @@ -448,11 +493,11 @@ def forward( return outputs -# Copied from transformers.models.t5.modeling_t5.T5Block with T5->SwitchTransformers class SwitchTransformersBlock(nn.Module): - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config, has_relative_attention_bias=False, is_sparse=False): super().__init__() self.is_decoder = config.is_decoder + self.is_sparse = is_sparse self.layer = nn.ModuleList() self.layer.append( SwitchTransformersLayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias) @@ -460,7 +505,7 @@ def __init__(self, config, has_relative_attention_bias=False): if self.is_decoder: self.layer.append(SwitchTransformersLayerCrossAttention(config)) - self.layer.append(SwitchTransformersLayerFF(config)) + self.layer.append(SwitchTransformersLayerFF(config, is_sparse=self.is_sparse)) def forward( self, @@ -665,15 +710,18 @@ def _shift_right(self, input_ids): class SwitchTransformersStack(nn.Module): - def __init__(self, config, embed_tokens=None): + def __init__(self, config, embed_tokens=None, sparse_step=1): super().__init__(config) self.embed_tokens = embed_tokens self.is_decoder = config.is_decoder + # TODO: change this, actually you can have a block full of sparse layers... self.block = nn.ModuleList( [ - SwitchTransformersBlock(config, has_relative_attention_bias=bool(i == 0)) + SwitchTransformersBlock( + config, has_relative_attention_bias=bool(i == 0), is_sparse=(i % sparse_step == 0) + ) for i in range(config.num_layers) ] ) @@ -1088,13 +1136,13 @@ def __init__(self, config: SwitchTransformersConfig): encoder_config.is_decoder = False encoder_config.use_cache = False encoder_config.is_encoder_decoder = False - self.encoder = SwitchTransformersStack(encoder_config, self.shared) + self.encoder = SwitchTransformersStack(encoder_config, self.shared, encoder_config.encoder_sparse_step) decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True decoder_config.is_encoder_decoder = False decoder_config.num_layers = config.num_decoder_layers - self.decoder = SwitchTransformersStack(decoder_config, self.shared) + self.decoder = SwitchTransformersStack(decoder_config, self.shared, decoder_config.decoder_sparse_step) # Initialize weights and apply final processing self.post_init() From 484767087ed83fae8970a3f692a4c3afdddeb34c Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 12 Oct 2022 14:44:15 +0000 Subject: [PATCH 012/102] replace wrong naming --- README.md | 2 +- README_es.md | 1 + README_ko.md | 2 +- README_zh-hans.md | 2 +- README_zh-hant.md | 2 +- docs/source/en/index.mdx | 2 +- docs/source/en/serialization.mdx | 2 +- src/transformers/__init__.py | 24 ++-- src/transformers/models/__init__.py | 2 +- .../models/auto/configuration_auto.py | 6 +- src/transformers/models/auto/modeling_auto.py | 8 +- .../models/auto/modeling_flax_auto.py | 6 +- .../models/auto/tokenization_auto.py | 2 +- .../__init__.py | 28 ++-- .../configuration_switch_transformers.py} | 12 +- ...mers_original_tf_checkpoint_to_pytorch.py} | 2 +- ...witch_transformersx_checkpoint_to_flax.py} | 134 +++++++++--------- .../modeling_flax_switch_transformers.py} | 80 +++++------ .../modeling_switch_transformers.py} | 64 ++++----- .../router.py | 2 +- .../router_flax.py | 0 .../tokenization_switch_transformers.py} | 26 ++-- .../tokenization_switch_transformers_fast.py} | 40 +++--- src/transformers/utils/dummy_pt_objects.py | 2 +- .../__init__.py | 0 ...test_modeling_flax_switch_transformers.py} | 54 +++---- .../test_modeling_switch_transformers.py} | 24 ++-- .../test_tokenization_switch_transformers.py} | 34 ++--- 28 files changed, 282 insertions(+), 281 deletions(-) rename src/transformers/models/{switchtransformers => switch_transformers}/__init__.py (78%) rename src/transformers/models/{switchtransformers/configuration_switchtransformers.py => switch_transformers/configuration_switch_transformers.py} (96%) rename src/transformers/models/{switchtransformers/convert_switchtransformers_original_tf_checkpoint_to_pytorch.py => switch_transformers/convert_switch_transformers_original_tf_checkpoint_to_pytorch.py} (96%) rename src/transformers/models/{switchtransformers/convert_switchtransformersx_checkpoint_to_flax.py => switch_transformers/convert_switch_transformersx_checkpoint_to_flax.py} (58%) rename src/transformers/models/{switchtransformers/modeling_flax_switchtransformers.py => switch_transformers/modeling_flax_switch_transformers.py} (96%) rename src/transformers/models/{switchtransformers/modeling_switchtransformers.py => switch_transformers/modeling_switch_transformers.py} (97%) rename src/transformers/models/{switchtransformers => switch_transformers}/router.py (99%) rename src/transformers/models/{switchtransformers => switch_transformers}/router_flax.py (100%) rename src/transformers/models/{switchtransformers/tokenization_switchtransformers.py => switch_transformers/tokenization_switch_transformers.py} (94%) rename src/transformers/models/{switchtransformers/tokenization_switchtransformers_fast.py => switch_transformers/tokenization_switch_transformers_fast.py} (85%) rename tests/models/{switchtransformers => switch_transformers}/__init__.py (100%) rename tests/models/{switchtransformers/test_modeling_flax_switchtransformers.py => switch_transformers/test_modeling_flax_switch_transformers.py} (96%) rename tests/models/{switchtransformers/test_modeling_switchtransformers.py => switch_transformers/test_modeling_switch_transformers.py} (97%) rename tests/models/{switchtransformers/test_tokenization_switchtransformers.py => switch_transformers/test_tokenization_switch_transformers.py} (94%) diff --git a/README.md b/README.md index a3ea3b10f70fb..cee378c477e56 100644 --- a/README.md +++ b/README.md @@ -373,7 +373,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h 1. **[SqueezeBERT](https://huggingface.co/docs/transformers/model_doc/squeezebert)** (from Berkeley) released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer. 1. **[Swin Transformer](https://huggingface.co/docs/transformers/model_doc/swin)** (from Microsoft) released with the paper [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030) by Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, Baining Guo. 1. **[Swin Transformer V2](https://huggingface.co/docs/transformers/model_doc/swinv2)** (from Microsoft) released with the paper [Swin Transformer V2: Scaling Up Capacity and Resolution](https://arxiv.org/abs/2111.09883) by Ze Liu, Han Hu, Yutong Lin, Zhuliang Yao, Zhenda Xie, Yixuan Wei, Jia Ning, Yue Cao, Zheng Zhang, Li Dong, Furu Wei, Baining Guo. -1. **[SwitchTransformers](https://huggingface.co/docs/transformers/main/model_doc/switchtransformers)** (from ) released with the paper []() by . +1. **[SwitchTransformers](https://huggingface.co/docs/transformers/main/model_doc/switch_transformers)** (from ) released with the paper []() by . 1. **[T5](https://huggingface.co/docs/transformers/model_doc/t5)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[T5v1.1](https://huggingface.co/docs/transformers/model_doc/t5v1.1)** (from Google AI) released in the repository [google-research/text-to-text-transfer-transformer](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[TAPAS](https://huggingface.co/docs/transformers/model_doc/tapas)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos. diff --git a/README_es.md b/README_es.md index 733b22aba55a5..ca38c4c082a7b 100644 --- a/README_es.md +++ b/README_es.md @@ -373,6 +373,7 @@ Número actual de puntos de control: ![](https://img.shields.io/endpoint?url=htt 1. **[SqueezeBERT](https://huggingface.co/docs/transformers/model_doc/squeezebert)** (from Berkeley) released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer. 1. **[Swin Transformer](https://huggingface.co/docs/transformers/model_doc/swin)** (from Microsoft) released with the paper [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030) by Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, Baining Guo. 1. **[Swin Transformer V2](https://huggingface.co/docs/transformers/model_doc/swinv2)** (from Microsoft) released with the paper [Swin Transformer V2: Scaling Up Capacity and Resolution](https://arxiv.org/abs/2111.09883) by Ze Liu, Han Hu, Yutong Lin, Zhuliang Yao, Zhenda Xie, Yixuan Wei, Jia Ning, Yue Cao, Zheng Zhang, Li Dong, Furu Wei, Baining Guo. +1. **[SwitchTransformers](https://huggingface.co/docs/transformers/main/model_doc/switch_transformers)** (from ) released with the paper []() by . 1. **[T5](https://huggingface.co/docs/transformers/model_doc/t5)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[T5v1.1](https://huggingface.co/docs/transformers/model_doc/t5v1.1)** (from Google AI) released in the repository [google-research/text-to-text-transfer-transformer](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[TAPAS](https://huggingface.co/docs/transformers/model_doc/tapas)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos. diff --git a/README_ko.md b/README_ko.md index e7d798bf49ee0..e1693faa9e660 100644 --- a/README_ko.md +++ b/README_ko.md @@ -323,7 +323,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는 1. **[SqueezeBERT](https://huggingface.co/docs/transformers/model_doc/squeezebert)** (from Berkeley) released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer. 1. **[Swin Transformer](https://huggingface.co/docs/transformers/model_doc/swin)** (from Microsoft) released with the paper [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030) by Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, Baining Guo. 1. **[Swin Transformer V2](https://huggingface.co/docs/transformers/model_doc/swinv2)** (from Microsoft) released with the paper [Swin Transformer V2: Scaling Up Capacity and Resolution](https://arxiv.org/abs/2111.09883) by Ze Liu, Han Hu, Yutong Lin, Zhuliang Yao, Zhenda Xie, Yixuan Wei, Jia Ning, Yue Cao, Zheng Zhang, Li Dong, Furu Wei, Baining Guo. -1. **[SwitchTransformers](https://huggingface.co/docs/transformers/main/model_doc/switchtransformers)** (from ) released with the paper []() by . +1. **[SwitchTransformers](https://huggingface.co/docs/transformers/main/model_doc/switch_transformers)** (from ) released with the paper []() by . 1. **[T5](https://huggingface.co/docs/transformers/model_doc/t5)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[T5v1.1](https://huggingface.co/docs/transformers/model_doc/t5v1.1)** (from Google AI) released in the repository [google-research/text-to-text-transfer-transformer](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[TAPAS](https://huggingface.co/docs/transformers/model_doc/tapas)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos. diff --git a/README_zh-hans.md b/README_zh-hans.md index 6e4faf72f05ba..923a018be2be8 100644 --- a/README_zh-hans.md +++ b/README_zh-hans.md @@ -347,7 +347,7 @@ conda install -c huggingface transformers 1. **[SqueezeBERT](https://huggingface.co/docs/transformers/model_doc/squeezebert)** (来自 Berkeley) 伴随论文 [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) 由 Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer 发布。 1. **[Swin Transformer](https://huggingface.co/docs/transformers/model_doc/swin)** (来自 Microsoft) 伴随论文 [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030) 由 Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, Baining Guo 发布。 1. **[Swin Transformer V2](https://huggingface.co/docs/transformers/model_doc/swinv2)** (来自 Microsoft) 伴随论文 [Swin Transformer V2: Scaling Up Capacity and Resolution](https://arxiv.org/abs/2111.09883) 由 Ze Liu, Han Hu, Yutong Lin, Zhuliang Yao, Zhenda Xie, Yixuan Wei, Jia Ning, Yue Cao, Zheng Zhang, Li Dong, Furu Wei, Baining Guo 发布。 -1. **[SwitchTransformers](https://huggingface.co/docs/transformers/main/model_doc/switchtransformers)** (from ) released with the paper []() by . +1. **[SwitchTransformers](https://huggingface.co/docs/transformers/main/model_doc/switch_transformers)** (from ) released with the paper []() by . 1. **[T5](https://huggingface.co/docs/transformers/model_doc/t5)** (来自 Google AI) 伴随论文 [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) 由 Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu 发布。 1. **[T5v1.1](https://huggingface.co/docs/transformers/model_doc/t5v1.1)** (来自 Google AI) 伴随论文 [google-research/text-to-text-transfer-transformer](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) 由 Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu 发布。 1. **[TAPAS](https://huggingface.co/docs/transformers/model_doc/tapas)** (来自 Google AI) 伴随论文 [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) 由 Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos 发布。 diff --git a/README_zh-hant.md b/README_zh-hant.md index 18efd0f91bdd7..6459d15509068 100644 --- a/README_zh-hant.md +++ b/README_zh-hant.md @@ -359,7 +359,7 @@ conda install -c huggingface transformers 1. **[SqueezeBERT](https://huggingface.co/docs/transformers/model_doc/squeezebert)** (from Berkeley) released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer. 1. **[Swin Transformer](https://huggingface.co/docs/transformers/model_doc/swin)** (from Microsoft) released with the paper [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030) by Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, Baining Guo. 1. **[Swin Transformer V2](https://huggingface.co/docs/transformers/model_doc/swinv2)** (from Microsoft) released with the paper [Swin Transformer V2: Scaling Up Capacity and Resolution](https://arxiv.org/abs/2111.09883) by Ze Liu, Han Hu, Yutong Lin, Zhuliang Yao, Zhenda Xie, Yixuan Wei, Jia Ning, Yue Cao, Zheng Zhang, Li Dong, Furu Wei, Baining Guo. -1. **[SwitchTransformers](https://huggingface.co/docs/transformers/main/model_doc/switchtransformers)** (from ) released with the paper []() by . +1. **[SwitchTransformers](https://huggingface.co/docs/transformers/main/model_doc/switch_transformers)** (from ) released with the paper []() by . 1. **[T5](https://huggingface.co/docs/transformers/model_doc/t5)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[T5v1.1](https://huggingface.co/docs/transformers/model_doc/t5v1.1)** (from Google AI) released with the paper [google-research/text-to-text-transfer-transformer](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[TAPAS](https://huggingface.co/docs/transformers/model_doc/tapas)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos. diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index aa5e062bcba49..c12d63addb12d 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -162,7 +162,7 @@ The documentation is organized into five sections: 1. **[SqueezeBERT](model_doc/squeezebert)** (from Berkeley) released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer. 1. **[Swin Transformer](model_doc/swin)** (from Microsoft) released with the paper [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030) by Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, Baining Guo. 1. **[Swin Transformer V2](model_doc/swinv2)** (from Microsoft) released with the paper [Swin Transformer V2: Scaling Up Capacity and Resolution](https://arxiv.org/abs/2111.09883) by Ze Liu, Han Hu, Yutong Lin, Zhuliang Yao, Zhenda Xie, Yixuan Wei, Jia Ning, Yue Cao, Zheng Zhang, Li Dong, Furu Wei, Baining Guo. -1. **[SwitchTransformers](model_doc/switchtransformers)** (from ) released with the paper []() by . +1. **[Switch Transformers](model_doc/switch_transformers)** (from ) released with the paper []() by . 1. **[T5](model_doc/t5)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[T5v1.1](model_doc/t5v1.1)** (from Google AI) released in the repository [google-research/text-to-text-transfer-transformer](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[TAPAS](model_doc/tapas)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos. diff --git a/docs/source/en/serialization.mdx b/docs/source/en/serialization.mdx index 837b09fdb836b..3f653c32bdd82 100644 --- a/docs/source/en/serialization.mdx +++ b/docs/source/en/serialization.mdx @@ -94,8 +94,8 @@ Ready-made configurations include the following architectures: - RoFormer - SegFormer - SqueezeBERT -- SwitchTransformers - Swin Transformer +- SwitchTransformers - T5 - Vision Encoder decoder - ViT diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index b6b9fde57c3f9..38ceb39cb348c 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -334,7 +334,7 @@ "models.squeezebert": ["SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "SqueezeBertConfig", "SqueezeBertTokenizer"], "models.swin": ["SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP", "SwinConfig"], "models.swinv2": ["SWINV2_PRETRAINED_CONFIG_ARCHIVE_MAP", "Swinv2Config"], - "models.switchtransformers": ["SWITCHTRANSFORMERS_PRETRAINED_CONFIG_ARCHIVE_MAP", "SwitchTransformersConfig"], + "models.switch_transformers": ["SWITCH_TRANSFORMERS_PRETRAINED_CONFIG_ARCHIVE_MAP", "SwitchTransformersConfig"], "models.t5": ["T5_PRETRAINED_CONFIG_ARCHIVE_MAP", "T5Config"], "models.tapas": ["TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP", "TapasConfig", "TapasTokenizer"], "models.tapex": ["TapexTokenizer"], @@ -544,7 +544,7 @@ _import_structure["models.rembert"].append("RemBertTokenizer") _import_structure["models.speech_to_text"].append("Speech2TextTokenizer") _import_structure["models.t5"].append("T5Tokenizer") - _import_structure["models.switchtransformers"].append("SwitchTransformersTokenizer") + _import_structure["models.switch_transformers"].append("SwitchTransformersTokenizer") _import_structure["models.xglm"].append("XGLMTokenizer") _import_structure["models.xlm_prophetnet"].append("XLMProphetNetTokenizer") _import_structure["models.xlm_roberta"].append("XLMRobertaTokenizer") @@ -613,7 +613,7 @@ _import_structure["models.roformer"].append("RoFormerTokenizerFast") _import_structure["models.splinter"].append("SplinterTokenizerFast") _import_structure["models.squeezebert"].append("SqueezeBertTokenizerFast") - _import_structure["models.switchtransformers"].append("SwitchTransformersTokenizerFast") + _import_structure["models.switch_transformers"].append("SwitchTransformersTokenizerFast") _import_structure["models.t5"].append("T5TokenizerFast") _import_structure["models.xglm"].append("XGLMTokenizerFast") _import_structure["models.xlm_roberta"].append("XLMRobertaTokenizerFast") @@ -1951,9 +1951,9 @@ "load_tf_weights_in_t5", ] ) - _import_structure["models.switchtransformers"].extend( + _import_structure["models.switch_transformers"].extend( [ - "SWITCHTRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST", + "SWITCH_TRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST", "SwitchTransformersEncoderModel", "SwitchTransformersForConditionalGeneration", "SwitchTransformersModel", @@ -3074,7 +3074,7 @@ _import_structure["models.t5"].extend( ["FlaxT5EncoderModel", "FlaxT5ForConditionalGeneration", "FlaxT5Model", "FlaxT5PreTrainedModel"] ) - _import_structure["models.switchtransformers"].extend( + _import_structure["models.switch_transformers"].extend( [ "FlaxSwitchTransformersEncoderModel", "FlaxSwitchTransformersForConditionalGeneration", @@ -3371,7 +3371,7 @@ from .models.squeezebert import SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, SqueezeBertConfig, SqueezeBertTokenizer from .models.swin import SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP, SwinConfig from .models.swinv2 import SWINV2_PRETRAINED_CONFIG_ARCHIVE_MAP, Swinv2Config - from .models.switchtransformers import SWITCHTRANSFORMERS_PRETRAINED_CONFIG_ARCHIVE_MAP, SwitchTransformersConfig + from .models.switch_transformers import SWITCH_TRANSFORMERS_PRETRAINED_CONFIG_ARCHIVE_MAP, SwitchTransformersConfig from .models.t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config from .models.tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig, TapasTokenizer from .models.tapex import TapexTokenizer @@ -3562,7 +3562,7 @@ from .models.reformer import ReformerTokenizer from .models.rembert import RemBertTokenizer from .models.speech_to_text import Speech2TextTokenizer - from .models.switchtransformers import SwitchTransformersTokenizer + from .models.switch_transformers import SwitchTransformersTokenizer from .models.t5 import T5Tokenizer from .models.xglm import XGLMTokenizer from .models.xlm_prophetnet import XLMProphetNetTokenizer @@ -3625,7 +3625,7 @@ from .models.roformer import RoFormerTokenizerFast from .models.splinter import SplinterTokenizerFast from .models.squeezebert import SqueezeBertTokenizerFast - from .models.switchtransformers import SwitchTransformersTokenizerFast + from .models.switch_transformers import SwitchTransformersTokenizerFast from .models.t5 import T5TokenizerFast from .models.xglm import XGLMTokenizerFast from .models.xlm_roberta import XLMRobertaTokenizerFast @@ -4688,8 +4688,8 @@ Swinv2Model, Swinv2PreTrainedModel, ) - from .models.switchtransformers import ( - SWITCHTRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST, + from .models.switch_transformers import ( + SWITCH_TRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST, SwitchTransformersEncoderModel, SwitchTransformersForConditionalGeneration, SwitchTransformersModel, @@ -5605,7 +5605,7 @@ FlaxRoFormerPreTrainedModel, ) from .models.speech_encoder_decoder import FlaxSpeechEncoderDecoderModel - from .models.switchtransformers import ( + from .models.switch_transformers import ( FlaxSwitchTransformersEncoderModel, FlaxSwitchTransformersForConditionalGeneration, FlaxSwitchTransformersModel, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 83d43ba95c229..e71891a657987 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -137,7 +137,7 @@ squeezebert, swin, swinv2, - switchtransformers, + switch_transformers, t5, tapas, tapex, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 74f949f34d050..7a74598b69a05 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -133,7 +133,7 @@ ("squeezebert", "SqueezeBertConfig"), ("swin", "SwinConfig"), ("swinv2", "Swinv2Config"), - ("switchtransformers", "SwitchTransformersConfig"), + ("switch_transformers", "SwitchTransformersConfig"), ("t5", "T5Config"), ("tapas", "TapasConfig"), ("time_series_transformer", "TimeSeriesTransformerConfig"), @@ -265,7 +265,7 @@ ("squeezebert", "SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("swin", "SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("swinv2", "SWINV2_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("switchtransformers", "SWITCHTRANSFORMERS_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("switch_transformers", "SWITCH_TRANSFORMERS_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("t5", "T5_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("tapas", "TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("time_series_transformer", "TIME_SERIES_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), @@ -417,7 +417,7 @@ ("squeezebert", "SqueezeBERT"), ("swin", "Swin Transformer"), ("swinv2", "Swin Transformer V2"), - ("switchtransformers", "SwitchTransformers"), + ("switch_transformers", "SwitchTransformers"), ("t5", "T5"), ("t5v1.1", "T5v1.1"), ("tapas", "TAPAS"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 8ed85b00199a3..b1bf115b85b55 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -129,7 +129,7 @@ ("squeezebert", "SqueezeBertModel"), ("swin", "SwinModel"), ("swinv2", "Swinv2Model"), - ("switchtransformers", "SwitchTransformersModel"), + ("switch_transformers", "SwitchTransformersModel"), ("t5", "T5Model"), ("tapas", "TapasModel"), ("time_series_transformer", "TimeSeriesTransformerModel"), @@ -198,7 +198,7 @@ ("roberta", "RobertaForMaskedLM"), ("splinter", "SplinterForPreTraining"), ("squeezebert", "SqueezeBertForMaskedLM"), - ("switchtransformers", "SwitchTransformersForConditionalGeneration"), + ("switch_transformers", "SwitchTransformersForConditionalGeneration"), ("t5", "T5ForConditionalGeneration"), ("tapas", "TapasForMaskedLM"), ("transfo-xl", "TransfoXLLMHeadModel"), @@ -273,7 +273,7 @@ ("roformer", "RoFormerForMaskedLM"), ("speech_to_text", "Speech2TextForConditionalGeneration"), ("squeezebert", "SqueezeBertForMaskedLM"), - ("switchtransformers", "SwitchTransformersForConditionalGeneration"), + ("switch_transformers", "SwitchTransformersForConditionalGeneration"), ("t5", "T5ForConditionalGeneration"), ("tapas", "TapasForMaskedLM"), ("transfo-xl", "TransfoXLLMHeadModel"), @@ -511,7 +511,7 @@ ("pegasus_x", "PegasusXForConditionalGeneration"), ("plbart", "PLBartForConditionalGeneration"), ("prophetnet", "ProphetNetForConditionalGeneration"), - ("switchtransformers", "SwitchTransformersForConditionalGeneration"), + ("switch_transformers", "SwitchTransformersForConditionalGeneration"), ("t5", "T5ForConditionalGeneration"), ("xlm-prophetnet", "XLMProphetNetForConditionalGeneration"), ] diff --git a/src/transformers/models/auto/modeling_flax_auto.py b/src/transformers/models/auto/modeling_flax_auto.py index 5850797ea769f..ea720ca6b0732 100644 --- a/src/transformers/models/auto/modeling_flax_auto.py +++ b/src/transformers/models/auto/modeling_flax_auto.py @@ -49,7 +49,7 @@ ("pegasus", "FlaxPegasusModel"), ("roberta", "FlaxRobertaModel"), ("roformer", "FlaxRoFormerModel"), - ("switchtransformers", "FlaxSwitchTransformersModel"), + ("switch_transformers", "FlaxSwitchTransformersModel"), ("t5", "FlaxT5Model"), ("vision-text-dual-encoder", "FlaxVisionTextDualEncoderModel"), ("vit", "FlaxViTModel"), @@ -72,7 +72,7 @@ ("mt5", "FlaxMT5ForConditionalGeneration"), ("roberta", "FlaxRobertaForMaskedLM"), ("roformer", "FlaxRoFormerForMaskedLM"), - ("switchtransformers", "FlaxSwitchTransformersForConditionalGeneration"), + ("switch_transformers", "FlaxSwitchTransformersForConditionalGeneration"), ("t5", "FlaxT5ForConditionalGeneration"), ("wav2vec2", "FlaxWav2Vec2ForPreTraining"), ("xlm-roberta", "FlaxXLMRobertaForMaskedLM"), @@ -107,7 +107,7 @@ ("mbart", "FlaxMBartForConditionalGeneration"), ("mt5", "FlaxMT5ForConditionalGeneration"), ("pegasus", "FlaxPegasusForConditionalGeneration"), - ("switchtransformers", "FlaxSwitchTransformersForConditionalGeneration"), + ("switch_transformers", "FlaxSwitchTransformersForConditionalGeneration"), ("t5", "FlaxT5ForConditionalGeneration"), ] ) diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index b0fa33c6d5415..13194b116dd4c 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -242,7 +242,7 @@ ("SqueezeBertTokenizer", "SqueezeBertTokenizerFast" if is_tokenizers_available() else None), ), ( - "switchtransformers", + "switch_transformers", ( "SwitchTransformersTokenizer" if is_sentencepiece_available() else None, "SwitchTransformersTokenizerFast" if is_tokenizers_available() else None, diff --git a/src/transformers/models/switchtransformers/__init__.py b/src/transformers/models/switch_transformers/__init__.py similarity index 78% rename from src/transformers/models/switchtransformers/__init__.py rename to src/transformers/models/switch_transformers/__init__.py index ccc257fc917bb..5f5ee32c89f9e 100644 --- a/src/transformers/models/switchtransformers/__init__.py +++ b/src/transformers/models/switch_transformers/__init__.py @@ -30,8 +30,8 @@ _import_structure = { - "configuration_switchtransformers": [ - "SWITCHTRANSFORMERS_PRETRAINED_CONFIG_ARCHIVE_MAP", + "configuration_switch_transformers": [ + "SWITCH_TRANSFORMERS_PRETRAINED_CONFIG_ARCHIVE_MAP", "SwitchTransformersConfig", "SwitchTransformersOnnxConfig", ] @@ -43,7 +43,7 @@ except OptionalDependencyNotAvailable: pass else: - _import_structure["tokenization_switchtransformers"] = ["SwitchTransformersTokenizer"] + _import_structure["tokenization_switch_transformers"] = ["SwitchTransformersTokenizer"] try: if not is_tokenizers_available(): @@ -51,7 +51,7 @@ except OptionalDependencyNotAvailable: pass else: - _import_structure["tokenization_switchtransformers_fast"] = ["SwitchTransformersTokenizerFast"] + _import_structure["tokenization_switch_transformers_fast"] = ["SwitchTransformersTokenizerFast"] try: if not is_torch_available(): @@ -59,8 +59,8 @@ except OptionalDependencyNotAvailable: pass else: - _import_structure["modeling_switchtransformers"] = [ - "SWITCHTRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST", + _import_structure["modeling_switch_transformers"] = [ + "SWITCH_TRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST", "SwitchTransformersEncoderModel", "SwitchTransformersForConditionalGeneration", "SwitchTransformersModel", @@ -74,7 +74,7 @@ except OptionalDependencyNotAvailable: pass else: - _import_structure["modeling_flax_switchtransformers"] = [ + _import_structure["modeling_flax_switch_transformers"] = [ "FlaxSwitchTransformersEncoderModel", "FlaxSwitchTransformersForConditionalGeneration", "FlaxSwitchTransformersModel", @@ -83,8 +83,8 @@ if TYPE_CHECKING: - from .configuration_switchtransformers import ( - SWITCHTRANSFORMERS_PRETRAINED_CONFIG_ARCHIVE_MAP, + from .configuration_switch_transformers import ( + SWITCH_TRANSFORMERS_PRETRAINED_CONFIG_ARCHIVE_MAP, SwitchTransformersConfig, SwitchTransformersOnnxConfig, ) @@ -95,7 +95,7 @@ except OptionalDependencyNotAvailable: pass else: - from .tokenization_switchtransformers import SwitchTransformersTokenizer + from .tokenization_switch_transformers import SwitchTransformersTokenizer try: if not is_tokenizers_available(): @@ -103,7 +103,7 @@ except OptionalDependencyNotAvailable: pass else: - from .tokenization_switchtransformers_fast import SwitchTransformersTokenizerFast + from .tokenization_switch_transformers_fast import SwitchTransformersTokenizerFast try: if not is_torch_available(): @@ -111,8 +111,8 @@ except OptionalDependencyNotAvailable: pass else: - from .modeling_switchtransformers import ( - SWITCHTRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST, + from .modeling_switch_transformers import ( + SWITCH_TRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST, SwitchTransformersEncoderModel, SwitchTransformersForConditionalGeneration, SwitchTransformersModel, @@ -125,7 +125,7 @@ except OptionalDependencyNotAvailable: pass else: - from .modeling_flax_switchtransformers import ( + from .modeling_flax_switch_transformers import ( FlaxSwitchTransformersEncoderModel, FlaxSwitchTransformersForConditionalGeneration, FlaxSwitchTransformersModel, diff --git a/src/transformers/models/switchtransformers/configuration_switchtransformers.py b/src/transformers/models/switch_transformers/configuration_switch_transformers.py similarity index 96% rename from src/transformers/models/switchtransformers/configuration_switchtransformers.py rename to src/transformers/models/switch_transformers/configuration_switch_transformers.py index fe0eff59898f5..f209cdd38edb2 100644 --- a/src/transformers/models/switchtransformers/configuration_switchtransformers.py +++ b/src/transformers/models/switch_transformers/configuration_switch_transformers.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" SwitchTransformers model configuration""" +""" Switch Transformers model configuration""" from typing import Mapping from ...configuration_utils import PretrainedConfig @@ -22,9 +22,9 @@ logger = logging.get_logger(__name__) -SWITCHTRANSFORMERS_PRETRAINED_CONFIG_ARCHIVE_MAP = { - "ybelkada/switchtransformers-base": ( - "https://huggingface.co/ybelkada/switchtransformers-base/resolve/main/config.json" +SWITCH_TRANSFORMERS_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "ybelkada/switch_transformers-base": ( + "https://huggingface.co/ybelkada/switch_transformers-base/resolve/main/config.json" ), } @@ -35,7 +35,7 @@ class SwitchTransformersConfig(PretrainedConfig): [`FlaxSwitchTransformersModel`]. It is used to instantiate a SwitchTransformers model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the SwitchTransformers - [ybelkada/switchtransformers-base](https://huggingface.co/ybelkada/switchtransformers-base) architecture. + [ybelkada/switch_transformers-base](https://huggingface.co/ybelkada/switch_transformers-base) architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. @@ -94,7 +94,7 @@ class SwitchTransformersConfig(PretrainedConfig): use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). """ - model_type = "switchtransformers" + model_type = "switch_transformers" keys_to_ignore_at_inference = ["past_key_values"] attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"} diff --git a/src/transformers/models/switchtransformers/convert_switchtransformers_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/switch_transformers/convert_switch_transformers_original_tf_checkpoint_to_pytorch.py similarity index 96% rename from src/transformers/models/switchtransformers/convert_switchtransformers_original_tf_checkpoint_to_pytorch.py rename to src/transformers/models/switch_transformers/convert_switch_transformers_original_tf_checkpoint_to_pytorch.py index 1bf81cf11de8c..509d622daea58 100644 --- a/src/transformers/models/switchtransformers/convert_switchtransformers_original_tf_checkpoint_to_pytorch.py +++ b/src/transformers/models/switch_transformers/convert_switch_transformers_original_tf_checkpoint_to_pytorch.py @@ -31,7 +31,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du model = SwitchTransformersForConditionalGeneration(config) # Load weights from tf checkpoint - # load_tf_weights_in_switchtransformers(model, config, tf_checkpoint_path) + # load_tf_weights_in_switch_transformers(model, config, tf_checkpoint_path) # Save pytorch-model print(f"Save PyTorch model to {pytorch_dump_path}") diff --git a/src/transformers/models/switchtransformers/convert_switchtransformersx_checkpoint_to_flax.py b/src/transformers/models/switch_transformers/convert_switch_transformersx_checkpoint_to_flax.py similarity index 58% rename from src/transformers/models/switchtransformers/convert_switchtransformersx_checkpoint_to_flax.py rename to src/transformers/models/switch_transformers/convert_switch_transformersx_checkpoint_to_flax.py index 0413f20b476af..d9e59f3ee91cc 100644 --- a/src/transformers/models/switchtransformers/convert_switchtransformersx_checkpoint_to_flax.py +++ b/src/transformers/models/switch_transformers/convert_switch_transformersx_checkpoint_to_flax.py @@ -17,244 +17,244 @@ import argparse -from switchtransformersx import checkpoints +from switch_transformersx import checkpoints from transformers import FlaxSwitchTransformersForConditionalGeneration, SwitchTransformersConfig -def convert_switchtransformersx_checkpoint_to_flax( - switchtransformersx_checkpoint_path, config_name, flax_dump_folder_path +def convert_switch_transformersx_checkpoint_to_flax( + switch_transformersx_checkpoint_path, config_name, flax_dump_folder_path ): config = SwitchTransformersConfig.from_pretrained(config_name) flax_model = FlaxSwitchTransformersForConditionalGeneration(config=config) - switchtransformersx_model = checkpoints.load_switchtransformersx_checkpoint(switchtransformersx_checkpoint_path) + switch_transformersx_model = checkpoints.load_switch_transformersx_checkpoint(switch_transformersx_checkpoint_path) - split_mlp_wi = "wi_0" in switchtransformersx_model["target"]["encoder"]["layers_0"]["mlp"] + split_mlp_wi = "wi_0" in switch_transformersx_model["target"]["encoder"]["layers_0"]["mlp"] # Encoder for layer_index in range(config.num_layers): layer_name = f"layers_{str(layer_index)}" # Self-Attention - switchtransformersx_attention_key = switchtransformersx_model["target"]["encoder"][layer_name]["attention"][ + switch_transformersx_attention_key = switch_transformersx_model["target"]["encoder"][layer_name]["attention"][ "key" ]["kernel"] - switchtransformersx_attention_out = switchtransformersx_model["target"]["encoder"][layer_name]["attention"][ + switch_transformersx_attention_out = switch_transformersx_model["target"]["encoder"][layer_name]["attention"][ "out" ]["kernel"] - switchtransformersx_attention_query = switchtransformersx_model["target"]["encoder"][layer_name]["attention"][ + switch_transformersx_attention_query = switch_transformersx_model["target"]["encoder"][layer_name]["attention"][ "query" ]["kernel"] - switchtransformersx_attention_value = switchtransformersx_model["target"]["encoder"][layer_name]["attention"][ + switch_transformersx_attention_value = switch_transformersx_model["target"]["encoder"][layer_name]["attention"][ "value" ]["kernel"] # Layer Normalization - switchtransformersx_attention_layer_norm = switchtransformersx_model["target"]["encoder"][layer_name][ + switch_transformersx_attention_layer_norm = switch_transformersx_model["target"]["encoder"][layer_name][ "pre_attention_layer_norm" ]["scale"] if split_mlp_wi: - switchtransformersx_mlp_wi_0 = switchtransformersx_model["target"]["encoder"][layer_name]["mlp"]["wi_0"][ + switch_transformersx_mlp_wi_0 = switch_transformersx_model["target"]["encoder"][layer_name]["mlp"]["wi_0"][ "kernel" ] - switchtransformersx_mlp_wi_1 = switchtransformersx_model["target"]["encoder"][layer_name]["mlp"]["wi_1"][ + switch_transformersx_mlp_wi_1 = switch_transformersx_model["target"]["encoder"][layer_name]["mlp"]["wi_1"][ "kernel" ] else: - switchtransformersx_mlp_wi = switchtransformersx_model["target"]["encoder"][layer_name]["mlp"]["wi"][ + switch_transformersx_mlp_wi = switch_transformersx_model["target"]["encoder"][layer_name]["mlp"]["wi"][ "kernel" ] - switchtransformersx_mlp_wo = switchtransformersx_model["target"]["encoder"][layer_name]["mlp"]["wo"]["kernel"] + switch_transformersx_mlp_wo = switch_transformersx_model["target"]["encoder"][layer_name]["mlp"]["wo"]["kernel"] # Layer Normalization - switchtransformersx_mlp_layer_norm = switchtransformersx_model["target"]["encoder"][layer_name][ + switch_transformersx_mlp_layer_norm = switch_transformersx_model["target"]["encoder"][layer_name][ "pre_mlp_layer_norm" ]["scale"] # Assigning flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["k"][ "kernel" - ] = switchtransformersx_attention_key + ] = switch_transformersx_attention_key flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["o"][ "kernel" - ] = switchtransformersx_attention_out + ] = switch_transformersx_attention_out flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["q"][ "kernel" - ] = switchtransformersx_attention_query + ] = switch_transformersx_attention_query flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["v"][ "kernel" - ] = switchtransformersx_attention_value + ] = switch_transformersx_attention_value flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["layer_norm"][ "weight" - ] = switchtransformersx_attention_layer_norm + ] = switch_transformersx_attention_layer_norm if split_mlp_wi: flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi_0"][ "kernel" - ] = switchtransformersx_mlp_wi_0 + ] = switch_transformersx_mlp_wi_0 flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi_1"][ "kernel" - ] = switchtransformersx_mlp_wi_1 + ] = switch_transformersx_mlp_wi_1 else: flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi"][ "kernel" - ] = switchtransformersx_mlp_wi + ] = switch_transformersx_mlp_wi flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wo"][ "kernel" - ] = switchtransformersx_mlp_wo + ] = switch_transformersx_mlp_wo flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["layer_norm"][ "weight" - ] = switchtransformersx_mlp_layer_norm + ] = switch_transformersx_mlp_layer_norm # Only for layer 0: - switchtransformersx_encoder_rel_embedding = switchtransformersx_model["target"]["encoder"]["relpos_bias"][ + switch_transformersx_encoder_rel_embedding = switch_transformersx_model["target"]["encoder"]["relpos_bias"][ "rel_embedding" ].T flax_model.params["encoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"][ "embedding" - ] = switchtransformersx_encoder_rel_embedding + ] = switch_transformersx_encoder_rel_embedding # Assigning - switchtransformersx_encoder_norm = switchtransformersx_model["target"]["encoder"]["encoder_norm"]["scale"] - flax_model.params["encoder"]["final_layer_norm"]["weight"] = switchtransformersx_encoder_norm + switch_transformersx_encoder_norm = switch_transformersx_model["target"]["encoder"]["encoder_norm"]["scale"] + flax_model.params["encoder"]["final_layer_norm"]["weight"] = switch_transformersx_encoder_norm # Decoder for layer_index in range(config.num_decoder_layers): layer_name = f"layers_{str(layer_index)}" # Self-Attention - switchtransformersx_attention_key = switchtransformersx_model["target"]["decoder"][layer_name][ + switch_transformersx_attention_key = switch_transformersx_model["target"]["decoder"][layer_name][ "self_attention" ]["key"]["kernel"] - switchtransformersx_attention_out = switchtransformersx_model["target"]["decoder"][layer_name][ + switch_transformersx_attention_out = switch_transformersx_model["target"]["decoder"][layer_name][ "self_attention" ]["out"]["kernel"] - switchtransformersx_attention_query = switchtransformersx_model["target"]["decoder"][layer_name][ + switch_transformersx_attention_query = switch_transformersx_model["target"]["decoder"][layer_name][ "self_attention" ]["query"]["kernel"] - switchtransformersx_attention_value = switchtransformersx_model["target"]["decoder"][layer_name][ + switch_transformersx_attention_value = switch_transformersx_model["target"]["decoder"][layer_name][ "self_attention" ]["value"]["kernel"] # Layer Normalization - switchtransformersx_pre_attention_layer_norm = switchtransformersx_model["target"]["decoder"][layer_name][ + switch_transformersx_pre_attention_layer_norm = switch_transformersx_model["target"]["decoder"][layer_name][ "pre_self_attention_layer_norm" ]["scale"] # Encoder-Decoder-Attention - switchtransformersx_enc_dec_attention_key = switchtransformersx_model["target"]["decoder"][layer_name][ + switch_transformersx_enc_dec_attention_key = switch_transformersx_model["target"]["decoder"][layer_name][ "encoder_decoder_attention" ]["key"]["kernel"] - switchtransformersx_enc_dec_attention_out = switchtransformersx_model["target"]["decoder"][layer_name][ + switch_transformersx_enc_dec_attention_out = switch_transformersx_model["target"]["decoder"][layer_name][ "encoder_decoder_attention" ]["out"]["kernel"] - switchtransformersx_enc_dec_attention_query = switchtransformersx_model["target"]["decoder"][layer_name][ + switch_transformersx_enc_dec_attention_query = switch_transformersx_model["target"]["decoder"][layer_name][ "encoder_decoder_attention" ]["query"]["kernel"] - switchtransformersx_enc_dec_attention_value = switchtransformersx_model["target"]["decoder"][layer_name][ + switch_transformersx_enc_dec_attention_value = switch_transformersx_model["target"]["decoder"][layer_name][ "encoder_decoder_attention" ]["value"]["kernel"] # Layer Normalization - switchtransformersx_cross_layer_norm = switchtransformersx_model["target"]["decoder"][layer_name][ + switch_transformersx_cross_layer_norm = switch_transformersx_model["target"]["decoder"][layer_name][ "pre_cross_attention_layer_norm" ]["scale"] # MLP if split_mlp_wi: - switchtransformersx_mlp_wi_0 = switchtransformersx_model["target"]["decoder"][layer_name]["mlp"]["wi_0"][ + switch_transformersx_mlp_wi_0 = switch_transformersx_model["target"]["decoder"][layer_name]["mlp"]["wi_0"][ "kernel" ] - switchtransformersx_mlp_wi_1 = switchtransformersx_model["target"]["decoder"][layer_name]["mlp"]["wi_1"][ + switch_transformersx_mlp_wi_1 = switch_transformersx_model["target"]["decoder"][layer_name]["mlp"]["wi_1"][ "kernel" ] else: - switchtransformersx_mlp_wi = switchtransformersx_model["target"]["decoder"][layer_name]["mlp"]["wi"][ + switch_transformersx_mlp_wi = switch_transformersx_model["target"]["decoder"][layer_name]["mlp"]["wi"][ "kernel" ] - switchtransformersx_mlp_wo = switchtransformersx_model["target"]["decoder"][layer_name]["mlp"]["wo"]["kernel"] + switch_transformersx_mlp_wo = switch_transformersx_model["target"]["decoder"][layer_name]["mlp"]["wo"]["kernel"] # Layer Normalization - tx5_mlp_layer_norm = switchtransformersx_model["target"]["decoder"][layer_name]["pre_mlp_layer_norm"]["scale"] + tx5_mlp_layer_norm = switch_transformersx_model["target"]["decoder"][layer_name]["pre_mlp_layer_norm"]["scale"] # Assigning flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["k"][ "kernel" - ] = switchtransformersx_attention_key + ] = switch_transformersx_attention_key flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["o"][ "kernel" - ] = switchtransformersx_attention_out + ] = switch_transformersx_attention_out flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["q"][ "kernel" - ] = switchtransformersx_attention_query + ] = switch_transformersx_attention_query flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["v"][ "kernel" - ] = switchtransformersx_attention_value + ] = switch_transformersx_attention_value flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["layer_norm"][ "weight" - ] = switchtransformersx_pre_attention_layer_norm + ] = switch_transformersx_pre_attention_layer_norm flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["k"][ "kernel" - ] = switchtransformersx_enc_dec_attention_key + ] = switch_transformersx_enc_dec_attention_key flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["o"][ "kernel" - ] = switchtransformersx_enc_dec_attention_out + ] = switch_transformersx_enc_dec_attention_out flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["q"][ "kernel" - ] = switchtransformersx_enc_dec_attention_query + ] = switch_transformersx_enc_dec_attention_query flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["v"][ "kernel" - ] = switchtransformersx_enc_dec_attention_value + ] = switch_transformersx_enc_dec_attention_value flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["layer_norm"][ "weight" - ] = switchtransformersx_cross_layer_norm + ] = switch_transformersx_cross_layer_norm if split_mlp_wi: flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi_0"][ "kernel" - ] = switchtransformersx_mlp_wi_0 + ] = switch_transformersx_mlp_wi_0 flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi_1"][ "kernel" - ] = switchtransformersx_mlp_wi_1 + ] = switch_transformersx_mlp_wi_1 else: flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi"][ "kernel" - ] = switchtransformersx_mlp_wi + ] = switch_transformersx_mlp_wi flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wo"][ "kernel" - ] = switchtransformersx_mlp_wo + ] = switch_transformersx_mlp_wo flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["layer_norm"][ "weight" ] = tx5_mlp_layer_norm # Decoder Normalization - tx5_decoder_norm = switchtransformersx_model["target"]["decoder"]["decoder_norm"]["scale"] + tx5_decoder_norm = switch_transformersx_model["target"]["decoder"]["decoder_norm"]["scale"] flax_model.params["decoder"]["final_layer_norm"]["weight"] = tx5_decoder_norm # Only for layer 0: - switchtransformersx_decoder_rel_embedding = switchtransformersx_model["target"]["decoder"]["relpos_bias"][ + switch_transformersx_decoder_rel_embedding = switch_transformersx_model["target"]["decoder"]["relpos_bias"][ "rel_embedding" ].T flax_model.params["decoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"][ "embedding" - ] = switchtransformersx_decoder_rel_embedding + ] = switch_transformersx_decoder_rel_embedding # Token Embeddings - tx5_token_embeddings = switchtransformersx_model["target"]["token_embedder"]["embedding"] + tx5_token_embeddings = switch_transformersx_model["target"]["token_embedder"]["embedding"] flax_model.params["shared"]["embedding"] = tx5_token_embeddings # LM Head (only in v1.1 checkpoints) - if "logits_dense" in switchtransformersx_model["target"]["decoder"]: - flax_model.params["lm_head"]["kernel"] = switchtransformersx_model["target"]["decoder"]["logits_dense"][ + if "logits_dense" in switch_transformersx_model["target"]["decoder"]: + flax_model.params["lm_head"]["kernel"] = switch_transformersx_model["target"]["decoder"]["logits_dense"][ "kernel" ] @@ -266,7 +266,7 @@ def convert_switchtransformersx_checkpoint_to_flax( parser = argparse.ArgumentParser() # Required parameters parser.add_argument( - "--switchtransformersx_checkpoint_path", default=None, type=str, required=True, help="Path the TX5 checkpoint." + "--switch_transformersx_checkpoint_path", default=None, type=str, required=True, help="Path the TX5 checkpoint." ) parser.add_argument( "--config_name", default=None, type=str, required=True, help="Config name of SwitchTransformers model." @@ -275,6 +275,6 @@ def convert_switchtransformersx_checkpoint_to_flax( "--flax_dump_folder_path", default=None, type=str, required=True, help="Path to the output FLAX model." ) args = parser.parse_args() - convert_switchtransformersx_checkpoint_to_flax( - args.switchtransformersx_checkpoint_path, args.config_name, args.flax_dump_folder_path + convert_switch_transformersx_checkpoint_to_flax( + args.switch_transformersx_checkpoint_path, args.config_name, args.flax_dump_folder_path ) diff --git a/src/transformers/models/switchtransformers/modeling_flax_switchtransformers.py b/src/transformers/models/switch_transformers/modeling_flax_switch_transformers.py similarity index 96% rename from src/transformers/models/switchtransformers/modeling_flax_switchtransformers.py rename to src/transformers/models/switch_transformers/modeling_flax_switch_transformers.py index c9e3442fd687f..130352c287b87 100644 --- a/src/transformers/models/switchtransformers/modeling_flax_switchtransformers.py +++ b/src/transformers/models/switch_transformers/modeling_flax_switch_transformers.py @@ -45,12 +45,12 @@ overwrite_call_docstring, ) from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings -from .configuration_switchtransformers import SwitchTransformersConfig +from .configuration_switch_transformers import SwitchTransformersConfig logger = logging.get_logger(__name__) -_CHECKPOINT_FOR_DOC = "ybelkada/switchtransformers-base" +_CHECKPOINT_FOR_DOC = "ybelkada/switch_transformers-base" _CONFIG_FOR_DOC = "SwitchTransformersConfig" _TOKENIZER_FOR_DOC = "SwitchTransformersTokenizer" @@ -817,17 +817,17 @@ def __call__( ) -SWITCHTRANSFORMERS_ENCODE_INPUTS_DOCSTRING = r""" +SWITCH_TRANSFORMERS_ENCODE_INPUTS_DOCSTRING = r""" Args: input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. SWITCHTRANSFORMERS is a model with relative position + Indices of input sequence tokens in the vocabulary. SWITCH_TRANSFORMERS is a model with relative position embeddings so you should be able to pad the inputs on both the right and the left. Indices can be obtained using [`SwitchTransformersTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for detail. - To know more on how to prepare `input_ids` for pretraining take a look a [SWITCHTRANSFORMERS - Training](./switchtransformers#training). + To know more on how to prepare `input_ids` for pretraining take a look a [SWITCH_TRANSFORMERS + Training](./switch_transformers#training). attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: @@ -845,7 +845,7 @@ def __call__( Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ -SWITCHTRANSFORMERS_DECODE_INPUTS_DOCSTRING = r""" +SWITCH_TRANSFORMERS_DECODE_INPUTS_DOCSTRING = r""" Args: decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`): Indices of decoder input sequence tokens in the vocabulary. @@ -887,10 +887,10 @@ def __call__( """ -SWITCHTRANSFORMERS_INPUTS_DOCSTRING = r""" +SWITCH_TRANSFORMERS_INPUTS_DOCSTRING = r""" Args: input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. SWITCHTRANSFORMERS is a model with relative position + Indices of input sequence tokens in the vocabulary. SWITCH_TRANSFORMERS is a model with relative position embeddings so you should be able to pad the inputs on both the right and the left. Indices can be obtained using [`SwitchTransformersTokenizer`]. See [`PreTrainedTokenizer.encode`] and @@ -898,8 +898,8 @@ def __call__( [What are input IDs?](../glossary#input-ids) - To know more on how to prepare `input_ids` for pretraining take a look a [SWITCHTRANSFORMERS - Training](./switchtransformers#training). + To know more on how to prepare `input_ids` for pretraining take a look a [SWITCH_TRANSFORMERS + Training](./switch_transformers#training). attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: @@ -915,12 +915,12 @@ def __call__( [What are decoder input IDs?](../glossary#decoder-input-ids) - SWITCHTRANSFORMERS uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If + SWITCH_TRANSFORMERS uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). - To know more on how to prepare `decoder_input_ids` for pretraining take a look at [SWITCHTRANSFORMERS - Training](./switchtransformers#training). + To know more on how to prepare `decoder_input_ids` for pretraining take a look at [SWITCH_TRANSFORMERS + Training](./switch_transformers#training). decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also be used by default. @@ -1006,7 +1006,7 @@ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: Froz else: return random_params - @add_start_docstrings_to_model_forward(SWITCHTRANSFORMERS_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(SWITCH_TRANSFORMERS_INPUTS_DOCSTRING) def __call__( self, input_ids: jnp.ndarray, @@ -1092,7 +1092,7 @@ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs ) return unfreeze(init_variables["cache"]) - @add_start_docstrings(SWITCHTRANSFORMERS_ENCODE_INPUTS_DOCSTRING) + @add_start_docstrings(SWITCH_TRANSFORMERS_ENCODE_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=SwitchTransformersConfig) def encode( self, @@ -1113,8 +1113,8 @@ def encode( ```python >>> from transformers import SwitchTransformersTokenizer, FlaxSwitchTransformersForConditionalGeneration - >>> tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switchtransformers-base") - >>> model = FlaxSwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switchtransformers-base") + >>> tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switch_transformers-base") + >>> model = FlaxSwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switch_transformers-base") >>> text = "My friends are cool but they eat too many carbs." >>> inputs = tokenizer(text, return_tensors="np") @@ -1150,7 +1150,7 @@ def _encoder_forward(module, input_ids, attention_mask, **kwargs): method=_encoder_forward, ) - @add_start_docstrings(SWITCHTRANSFORMERS_DECODE_INPUTS_DOCSTRING) + @add_start_docstrings(SWITCH_TRANSFORMERS_DECODE_INPUTS_DOCSTRING) @replace_return_docstrings( output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=SwitchTransformersConfig ) @@ -1177,8 +1177,8 @@ def decode( >>> from transformers import SwitchTransformersTokenizer, FlaxSwitchTransformersForConditionalGeneration >>> import jax.numpy as jnp - >>> tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switchtransformers-base") - >>> model = FlaxSwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switchtransformers-base") + >>> tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switch_transformers-base") + >>> model = FlaxSwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switch_transformers-base") >>> text = "My friends are cool but they eat too many carbs." >>> inputs = tokenizer(text, return_tensors="np") @@ -1256,8 +1256,8 @@ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs return outputs -SWITCHTRANSFORMERS_START_DOCSTRING = r""" - The SWITCHTRANSFORMERS model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text +SWITCH_TRANSFORMERS_START_DOCSTRING = r""" + The SWITCH_TRANSFORMERS model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a text-to-text denoising generative setting. @@ -1297,8 +1297,8 @@ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs @add_start_docstrings( - "The bare SWITCHTRANSFORMERS Model transformer outputting raw hidden-stateswithout any specific head on top.", - SWITCHTRANSFORMERS_START_DOCSTRING, + "The bare SWITCH_TRANSFORMERS Model transformer outputting raw hidden-stateswithout any specific head on top.", + SWITCH_TRANSFORMERS_START_DOCSTRING, ) # Copied from transformers.models.t5.modeling_flax_t5.FlaxT5Module with T5->SwitchTransformers class FlaxSwitchTransformersModule(nn.Module): @@ -1398,7 +1398,7 @@ class FlaxSwitchTransformersModel(FlaxSwitchTransformersPreTrainedModel): FlaxSwitchTransformersModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC ) -FLAX_SWITCHTRANSFORMERS_MODEL_DOCSTRING = """ +FLAX_SWITCH_TRANSFORMERS_MODEL_DOCSTRING = """ Returns: Example: @@ -1406,8 +1406,8 @@ class FlaxSwitchTransformersModel(FlaxSwitchTransformersPreTrainedModel): ```python >>> from transformers import SwitchTransformersTokenizer, FlaxSwitchTransformersModel - >>> tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switchtransformers-base") - >>> model = FlaxSwitchTransformersModel.from_pretrained("ybelkada/switchtransformers-base") + >>> tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switch_transformers-base") + >>> model = FlaxSwitchTransformersModel.from_pretrained("ybelkada/switch_transformers-base") >>> input_ids = tokenizer( ... "Studies have been shown that owning a dog is good for you", return_tensors="np" @@ -1426,7 +1426,7 @@ class FlaxSwitchTransformersModel(FlaxSwitchTransformersPreTrainedModel): overwrite_call_docstring( - FlaxSwitchTransformersModel, SWITCHTRANSFORMERS_INPUTS_DOCSTRING + FLAX_SWITCHTRANSFORMERS_MODEL_DOCSTRING + FlaxSwitchTransformersModel, SWITCH_TRANSFORMERS_INPUTS_DOCSTRING + FLAX_SWITCH_TRANSFORMERS_MODEL_DOCSTRING ) append_replace_return_docstrings( FlaxSwitchTransformersModel, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC @@ -1434,9 +1434,9 @@ class FlaxSwitchTransformersModel(FlaxSwitchTransformersPreTrainedModel): @add_start_docstrings( - "The bare SWITCHTRANSFORMERS Model transformer outputting encoder's raw hidden-states without any specific head on" + "The bare SWITCH_TRANSFORMERS Model transformer outputting encoder's raw hidden-states without any specific head on" " top.", - SWITCHTRANSFORMERS_START_DOCSTRING, + SWITCH_TRANSFORMERS_START_DOCSTRING, ) # Copied from transformers.models.t5.modeling_flax_t5.FlaxT5EncoderModule with T5->SwitchTransformers class FlaxSwitchTransformersEncoderModule(nn.Module): @@ -1488,7 +1488,7 @@ def __call__( class FlaxSwitchTransformersEncoderModel(FlaxSwitchTransformersPreTrainedModel): module_class = FlaxSwitchTransformersEncoderModule - @add_start_docstrings_to_model_forward(SWITCHTRANSFORMERS_ENCODE_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(SWITCH_TRANSFORMERS_ENCODE_INPUTS_DOCSTRING) def __call__( self, input_ids: jnp.ndarray, @@ -1526,7 +1526,7 @@ def __call__( @add_start_docstrings( - """SWITCHTRANSFORMERS Model with a `language modeling` head on top.""", SWITCHTRANSFORMERS_START_DOCSTRING + """SWITCH_TRANSFORMERS Model with a `language modeling` head on top.""", SWITCH_TRANSFORMERS_START_DOCSTRING ) # Copied from transformers.models.t5.modeling_flax_t5.FlaxT5ForConditionalGenerationModule with T5->SwitchTransformers class FlaxSwitchTransformersForConditionalGenerationModule(nn.Module): @@ -1641,7 +1641,7 @@ def __call__( class FlaxSwitchTransformersForConditionalGeneration(FlaxSwitchTransformersPreTrainedModel): module_class = FlaxSwitchTransformersForConditionalGenerationModule - @add_start_docstrings(SWITCHTRANSFORMERS_DECODE_INPUTS_DOCSTRING) + @add_start_docstrings(SWITCH_TRANSFORMERS_DECODE_INPUTS_DOCSTRING) @replace_return_docstrings( output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=SwitchTransformersConfig ) @@ -1668,8 +1668,8 @@ def decode( >>> from transformers import SwitchTransformersTokenizer, FlaxSwitchTransformersForConditionalGeneration >>> import jax.numpy as jnp - >>> tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switchtransformers-base") - >>> model = FlaxSwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switchtransformers-base") + >>> tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switch_transformers-base") + >>> model = FlaxSwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switch_transformers-base") >>> text = "summarize: My friends are cool but they eat too many carbs." >>> inputs = tokenizer(text, return_tensors="np") @@ -1808,7 +1808,7 @@ def update_inputs_for_generation(self, model_outputs, model_kwargs): return model_kwargs -FLAX_SWITCHTRANSFORMERS_CONDITIONAL_GENERATION_DOCSTRING = """ +FLAX_SWITCH_TRANSFORMERS_CONDITIONAL_GENERATION_DOCSTRING = """ Returns: Example: @@ -1816,8 +1816,8 @@ def update_inputs_for_generation(self, model_outputs, model_kwargs): ```python >>> from transformers import SwitchTransformersTokenizer, FlaxSwitchTransformersForConditionalGeneration - >>> tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switchtransformers-base") - >>> model = FlaxSwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switchtransformers-base") + >>> tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switch_transformers-base") + >>> model = FlaxSwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switch_transformers-base") >>> ARTICLE_TO_SUMMARIZE = "summarize: My friends are cool but they eat too many carbs." >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], return_tensors="np") @@ -1831,7 +1831,7 @@ def update_inputs_for_generation(self, model_outputs, model_kwargs): overwrite_call_docstring( FlaxSwitchTransformersForConditionalGeneration, - SWITCHTRANSFORMERS_INPUTS_DOCSTRING + FLAX_SWITCHTRANSFORMERS_CONDITIONAL_GENERATION_DOCSTRING, + SWITCH_TRANSFORMERS_INPUTS_DOCSTRING + FLAX_SWITCH_TRANSFORMERS_CONDITIONAL_GENERATION_DOCSTRING, ) append_replace_return_docstrings( FlaxSwitchTransformersForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC diff --git a/src/transformers/models/switchtransformers/modeling_switchtransformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py similarity index 97% rename from src/transformers/models/switchtransformers/modeling_switchtransformers.py rename to src/transformers/models/switch_transformers/modeling_switch_transformers.py index 4444525128bd3..6b4aa78be6b85 100644 --- a/src/transformers/models/switchtransformers/modeling_switchtransformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -43,7 +43,7 @@ logging, replace_return_docstrings, ) -from .configuration_switchtransformers import SwitchTransformersConfig +from .configuration_switch_transformers import SwitchTransformersConfig from .router import ExpertsChooseMaskedRouter, TokensChooseMaskedRouter, TokensChooseScatterRouter @@ -51,15 +51,15 @@ _CONFIG_FOR_DOC = "SwitchTransformersConfig" _TOKENIZER_FOR_DOC = "SwitchTransformersTokenizer" -_CHECKPOINT_FOR_DOC = "ybelkada/switchtransformers-base" +_CHECKPOINT_FOR_DOC = "ybelkada/switch_transformers-base" #################################################### # This dict contains ids and associated url # for the pretrained weights provided with the models #################################################### -SWITCHTRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "ybelkada/switchtransformers-base", - # See all SwitchTransformers models at https://huggingface.co/models?filter=switchtransformers +SWITCH_TRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "ybelkada/switch_transformers-base", + # See all SwitchTransformers models at https://huggingface.co/models?filter=switch_transformers ] @@ -952,9 +952,9 @@ def custom_forward(*inputs): ) -SWITCHTRANSFORMERS_START_DOCSTRING = r""" +SWITCH_TRANSFORMERS_START_DOCSTRING = r""" - The SWITCHTRANSFORMERS model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text + The SWITCH_TRANSFORMERS model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a text-to-text denoising generative setting. @@ -973,10 +973,10 @@ def custom_forward(*inputs): configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ -SWITCHTRANSFORMERS_INPUTS_DOCSTRING = r""" +SWITCH_TRANSFORMERS_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. SWITCHTRANSFORMERS is a model with relative position + Indices of input sequence tokens in the vocabulary. SWITCH_TRANSFORMERS is a model with relative position embeddings so you should be able to pad the inputs on both the right and the left. Indices can be obtained using [`SwitchTransformersTokenizer`]. See [`PreTrainedTokenizer.encode`] and @@ -984,8 +984,8 @@ def custom_forward(*inputs): [What are input IDs?](../glossary#input-ids) - To know more on how to prepare `input_ids` for pretraining take a look a [SWITCHTRANSFORMERS - Training](./switchtransformers#training). + To know more on how to prepare `input_ids` for pretraining take a look a [SWITCH_TRANSFORMERS + Training](./switch_transformers#training). attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: @@ -1001,12 +1001,12 @@ def custom_forward(*inputs): [What are decoder input IDs?](../glossary#decoder-input-ids) - SWITCHTRANSFORMERS uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If + SWITCH_TRANSFORMERS uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). - To know more on how to prepare `decoder_input_ids` for pretraining take a look at [SWITCHTRANSFORMERS - Training](./switchtransformers#training). + To know more on how to prepare `decoder_input_ids` for pretraining take a look at [SWITCH_TRANSFORMERS + Training](./switch_transformers#training). decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also be used by default. @@ -1068,17 +1068,17 @@ def custom_forward(*inputs): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ -SWITCHTRANSFORMERS_ENCODER_INPUTS_DOCSTRING = r""" +SWITCH_TRANSFORMERS_ENCODER_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. SWITCHTRANSFORMERS is a model with relative position + Indices of input sequence tokens in the vocabulary. SWITCH_TRANSFORMERS is a model with relative position embeddings so you should be able to pad the inputs on both the right and the left. Indices can be obtained using [`SwitchTransformersTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for detail. - To know more on how to prepare `input_ids` for pretraining take a look a [SWITCHTRANSFORMERS - Training](./switchtransformers#training). + To know more on how to prepare `input_ids` for pretraining take a look a [SWITCH_TRANSFORMERS + Training](./switch_transformers#training). attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: @@ -1116,8 +1116,8 @@ def custom_forward(*inputs): @add_start_docstrings( - "The bare SWITCHTRANSFORMERS Model transformer outputting raw hidden-states without any specific head on top.", - SWITCHTRANSFORMERS_START_DOCSTRING, + "The bare SWITCH_TRANSFORMERS Model transformer outputting raw hidden-states without any specific head on top.", + SWITCH_TRANSFORMERS_START_DOCSTRING, ) class SwitchTransformersModel(SwitchTransformersPreTrainedModel): _keys_to_ignore_on_load_missing = [ @@ -1173,7 +1173,7 @@ class PreTrainedModel for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) - @add_start_docstrings_to_model_forward(SWITCHTRANSFORMERS_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(SWITCH_TRANSFORMERS_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) def forward( self, @@ -1201,8 +1201,8 @@ def forward( ```python >>> from transformers import SwitchTransformersTokenizer, SwitchTransformersModel - >>> tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switchtransformers-base") - >>> model = SwitchTransformersModel.from_pretrained("ybelkada/switchtransformers-base") + >>> tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switch_transformers-base") + >>> model = SwitchTransformersModel.from_pretrained("ybelkada/switch_transformers-base") >>> input_ids = tokenizer( ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" @@ -1289,7 +1289,7 @@ def forward( @add_start_docstrings( - """SWITCHTRANSFORMERS Model with a `language modeling` head on top.""", SWITCHTRANSFORMERS_START_DOCSTRING + """SWITCH_TRANSFORMERS Model with a `language modeling` head on top.""", SWITCH_TRANSFORMERS_START_DOCSTRING ) class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedModel): _keys_to_ignore_on_load_missing = [ @@ -1348,7 +1348,7 @@ def get_encoder(self): def get_decoder(self): return self.decoder - @add_start_docstrings_to_model_forward(SWITCHTRANSFORMERS_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(SWITCH_TRANSFORMERS_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) def forward( self, @@ -1382,8 +1382,8 @@ def forward( ```python >>> from transformers import SwitchTransformersTokenizer, SwitchTransformersForConditionalGeneration - >>> tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switchtransformers-base") - >>> model = SwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switchtransformers-base") + >>> tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switch_transformers-base") + >>> model = SwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switch_transformers-base") >>> # training >>> input_ids = tokenizer("The walks in park", return_tensors="pt").input_ids @@ -1558,9 +1558,9 @@ def _reorder_cache(self, past, beam_idx): @add_start_docstrings( - "The bare SWITCHTRANSFORMERS Model transformer outputting encoder's raw hidden-states without any specific head on" + "The bare SWITCH_TRANSFORMERS Model transformer outputting encoder's raw hidden-states without any specific head on" " top.", - SWITCHTRANSFORMERS_START_DOCSTRING, + SWITCH_TRANSFORMERS_START_DOCSTRING, ) class SwitchTransformersEncoderModel(SwitchTransformersPreTrainedModel): authorized_missing_keys = [ @@ -1601,7 +1601,7 @@ class PreTrainedModel for layer, heads in heads_to_prune.items(): self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads) - @add_start_docstrings_to_model_forward(SWITCHTRANSFORMERS_ENCODER_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(SWITCH_TRANSFORMERS_ENCODER_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) def forward( self, @@ -1621,8 +1621,8 @@ def forward( ```python >>> from transformers import SwitchTransformersTokenizer, SwitchTransformersEncoderModel - >>> tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switchtransformers-base") - >>> model = SwitchTransformersEncoderModel.from_pretrained("ybelkada/switchtransformers-base") + >>> tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switch_transformers-base") + >>> model = SwitchTransformersEncoderModel.from_pretrained("ybelkada/switch_transformers-base") >>> input_ids = tokenizer( ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" ... ).input_ids # Batch size 1 diff --git a/src/transformers/models/switchtransformers/router.py b/src/transformers/models/switch_transformers/router.py similarity index 99% rename from src/transformers/models/switchtransformers/router.py rename to src/transformers/models/switch_transformers/router.py index b08ea57238e85..ee1b8d3c000ce 100644 --- a/src/transformers/models/switchtransformers/router.py +++ b/src/transformers/models/switch_transformers/router.py @@ -19,7 +19,7 @@ import torch.nn as nn -# from transformers.models.switchtransformers.configuration_switchtransformers import SwitchTransformersConfig +# from transformers.models.switch_transformers.configuration_switch_transformers import SwitchTransformersConfig # Output classes diff --git a/src/transformers/models/switchtransformers/router_flax.py b/src/transformers/models/switch_transformers/router_flax.py similarity index 100% rename from src/transformers/models/switchtransformers/router_flax.py rename to src/transformers/models/switch_transformers/router_flax.py diff --git a/src/transformers/models/switchtransformers/tokenization_switchtransformers.py b/src/transformers/models/switch_transformers/tokenization_switch_transformers.py similarity index 94% rename from src/transformers/models/switchtransformers/tokenization_switchtransformers.py rename to src/transformers/models/switch_transformers/tokenization_switch_transformers.py index d1235520a12da..a2652b5008435 100644 --- a/src/transformers/models/switchtransformers/tokenization_switchtransformers.py +++ b/src/transformers/models/switch_transformers/tokenization_switch_transformers.py @@ -33,24 +33,24 @@ PRETRAINED_VOCAB_FILES_MAP = { "vocab_file": { - "ybelkada/switchtransformers-base": ( - "https://huggingface.co/ybelkada/switchtransformers-base/resolve/main/spiece.model" + "ybelkada/switch_transformers-base": ( + "https://huggingface.co/ybelkada/switch_transformers-base/resolve/main/spiece.model" ), - "switchtransformers-base": "https://huggingface.co/switchtransformers-base/resolve/main/spiece.model", - "switchtransformers-large": "https://huggingface.co/switchtransformers-large/resolve/main/spiece.model", - "switchtransformers-3b": "https://huggingface.co/switchtransformers-3b/resolve/main/spiece.model", - "switchtransformers-11b": "https://huggingface.co/switchtransformers-11b/resolve/main/spiece.model", + "switch_transformers-base": "https://huggingface.co/switch_transformers-base/resolve/main/spiece.model", + "switch_transformers-large": "https://huggingface.co/switch_transformers-large/resolve/main/spiece.model", + "switch_transformers-3b": "https://huggingface.co/switch_transformers-3b/resolve/main/spiece.model", + "switch_transformers-11b": "https://huggingface.co/switch_transformers-11b/resolve/main/spiece.model", } } # TODO(PVP) - this should be removed in Transformers v5 PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { - "ybelkada/switchtransformers-base": 512, - "switchtransformers-base": 512, - "switchtransformers-large": 512, - "switchtransformers-3b": 512, - "switchtransformers-11b": 512, + "ybelkada/switch_transformers-base": 512, + "switch_transformers-base": 512, + "switch_transformers-large": 512, + "switch_transformers-3b": 512, + "switch_transformers-11b": 512, } @@ -85,7 +85,7 @@ class SwitchTransformersTokenizer(PreTrainedTokenizer): accessible as "" where "{%d}" is a number between 0 and extra_ids-1. Extra tokens are indexed from the end of the vocabulary up to beginning ("" is the last token in the vocabulary like in SwitchTransformers preprocessing see - [here](https://github.com/google-research/text-to-text-transfer-transformer/blob/9fd7b14a769417be33bc6c850f9598764913c833/switchtransformers/data/preprocessors.py#L2117)). + [here](https://github.com/google-research/text-to-text-transfer-transformer/blob/9fd7b14a769417be33bc6c850f9598764913c833/switch_transformers/data/preprocessors.py#L2117)). additional_special_tokens (`List[str]`, *optional*): Additional special tokens used by the tokenizer. sp_model_kwargs (`dict`, *optional*): @@ -157,7 +157,7 @@ def __init__( self.sp_model.Load(vocab_file) @staticmethod - def _eventually_correct_switchtransformers_max_length( + def _eventually_correct_switch_transformers_max_length( pretrained_model_name_or_path, max_model_length, init_max_model_length ): if pretrained_model_name_or_path in SwitchTransformersTokenizer.max_model_input_sizes: diff --git a/src/transformers/models/switchtransformers/tokenization_switchtransformers_fast.py b/src/transformers/models/switch_transformers/tokenization_switch_transformers_fast.py similarity index 85% rename from src/transformers/models/switchtransformers/tokenization_switchtransformers_fast.py rename to src/transformers/models/switch_transformers/tokenization_switch_transformers_fast.py index 0edf71fa3f285..758498ac2f632 100644 --- a/src/transformers/models/switchtransformers/tokenization_switchtransformers_fast.py +++ b/src/transformers/models/switch_transformers/tokenization_switch_transformers_fast.py @@ -25,7 +25,7 @@ if is_sentencepiece_available(): - from .tokenization_switchtransformers import SwitchTransformersTokenizer + from .tokenization_switch_transformers import SwitchTransformersTokenizer else: SwitchTransformersTokenizer = None @@ -36,33 +36,33 @@ PRETRAINED_VOCAB_FILES_MAP = { "vocab_file": { - "ybelkada/switchtransformers-base": ( - "https://huggingface.co/ybelkada/switchtransformers-base/resolve/main/spiece.model" + "ybelkada/switch_transformers-base": ( + "https://huggingface.co/ybelkada/switch_transformers-base/resolve/main/spiece.model" ), - "switchtransformers-base": "https://huggingface.co/switchtransformers-base/resolve/main/spiece.model", - "switchtransformers-large": "https://huggingface.co/switchtransformers-large/resolve/main/spiece.model", - "switchtransformers-3b": "https://huggingface.co/switchtransformers-3b/resolve/main/spiece.model", - "switchtransformers-11b": "https://huggingface.co/switchtransformers-11b/resolve/main/spiece.model", + "switch_transformers-base": "https://huggingface.co/switch_transformers-base/resolve/main/spiece.model", + "switch_transformers-large": "https://huggingface.co/switch_transformers-large/resolve/main/spiece.model", + "switch_transformers-3b": "https://huggingface.co/switch_transformers-3b/resolve/main/spiece.model", + "switch_transformers-11b": "https://huggingface.co/switch_transformers-11b/resolve/main/spiece.model", }, "tokenizer_file": { - "ybelkada/switchtransformers-base": ( - "https://huggingface.co/ybelkada/switchtransformers-base/resolve/main/tokenizer.json" + "ybelkada/switch_transformers-base": ( + "https://huggingface.co/ybelkada/switch_transformers-base/resolve/main/tokenizer.json" ), - "switchtransformers-base": "https://huggingface.co/switchtransformers-base/resolve/main/tokenizer.json", - "switchtransformers-large": "https://huggingface.co/switchtransformers-large/resolve/main/tokenizer.json", - "switchtransformers-3b": "https://huggingface.co/switchtransformers-3b/resolve/main/tokenizer.json", - "switchtransformers-11b": "https://huggingface.co/switchtransformers-11b/resolve/main/tokenizer.json", + "switch_transformers-base": "https://huggingface.co/switch_transformers-base/resolve/main/tokenizer.json", + "switch_transformers-large": "https://huggingface.co/switch_transformers-large/resolve/main/tokenizer.json", + "switch_transformers-3b": "https://huggingface.co/switch_transformers-3b/resolve/main/tokenizer.json", + "switch_transformers-11b": "https://huggingface.co/switch_transformers-11b/resolve/main/tokenizer.json", }, } # TODO(PVP) - this should be removed in Transformers v5 PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { - "ybelkada/switchtransformers-base": 512, - "switchtransformers-base": 512, - "switchtransformers-large": 512, - "switchtransformers-3b": 512, - "switchtransformers-11b": 512, + "ybelkada/switch_transformers-base": 512, + "switch_transformers-base": 512, + "switch_transformers-large": 512, + "switch_transformers-3b": 512, + "switch_transformers-11b": 512, } @@ -98,7 +98,7 @@ class SwitchTransformersTokenizerFast(PreTrainedTokenizerFast): accessible as "" where "{%d}" is a number between 0 and extra_ids-1. Extra tokens are indexed from the end of the vocabulary up to beginning ("" is the last token in the vocabulary like in SwitchTransformers preprocessing see - [here](https://github.com/google-research/text-to-text-transfer-transformer/blob/9fd7b14a769417be33bc6c850f9598764913c833/switchtransformers/data/preprocessors.py#L2117)). + [here](https://github.com/google-research/text-to-text-transfer-transformer/blob/9fd7b14a769417be33bc6c850f9598764913c833/switch_transformers/data/preprocessors.py#L2117)). additional_special_tokens (`List[str]`, *optional*): Additional special tokens used by the tokenizer. """ @@ -151,7 +151,7 @@ def __init__( self._extra_ids = extra_ids @staticmethod - def _eventually_correct_switchtransformers_max_length( + def _eventually_correct_switch_transformers_max_length( pretrained_model_name_or_path, max_model_length, init_max_model_length ): if pretrained_model_name_or_path in SwitchTransformersTokenizerFast.max_model_input_sizes: diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index d197e738b1b3c..e66cd1c34d4ba 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -4855,7 +4855,7 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) -SWITCHTRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST = None +SWITCH_TRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST = None class SwitchTransformersEncoderModel(metaclass=DummyObject): diff --git a/tests/models/switchtransformers/__init__.py b/tests/models/switch_transformers/__init__.py similarity index 100% rename from tests/models/switchtransformers/__init__.py rename to tests/models/switch_transformers/__init__.py diff --git a/tests/models/switchtransformers/test_modeling_flax_switchtransformers.py b/tests/models/switch_transformers/test_modeling_flax_switch_transformers.py similarity index 96% rename from tests/models/switchtransformers/test_modeling_flax_switchtransformers.py rename to tests/models/switch_transformers/test_modeling_flax_switch_transformers.py index 6966a376faa40..e714397770a36 100644 --- a/tests/models/switchtransformers/test_modeling_flax_switchtransformers.py +++ b/tests/models/switch_transformers/test_modeling_flax_switch_transformers.py @@ -53,7 +53,7 @@ SwitchTransformersTokenizer, ) from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model - from transformers.models.switchtransformers.modeling_flax_switchtransformers import ( + from transformers.models.switch_transformers.modeling_flax_switch_transformers import ( FlaxSwitchTransformersEncoderModel, FlaxSwitchTransformersForConditionalGeneration, FlaxSwitchTransformersModel, @@ -773,18 +773,18 @@ class FlaxSwitchTransformersModelIntegrationTests(unittest.TestCase): def test_small_integration_test(self): """ For comparision run: - >>> import switchtransformers # pip install switchtransformers==0.7.1 - >>> from switchtransformers.data.sentencepiece_vocabulary import SentencePieceVocabulary + >>> import switch_transformers # pip install switch_transformers==0.7.1 + >>> from switch_transformers.data.sentencepiece_vocabulary import SentencePieceVocabulary - >>> path_to_mtf_small_switchtransformers_checkpoint = '' + >>> path_to_mtf_small_switch_transformers_checkpoint = '' >>> path_to_mtf_small_spm_model_path = '' - >>> switchtransformers_model = switchtransformers.models.MtfModel(model_dir=path_to_mtf_small_switchtransformers_checkpoint, batch_size=1, tpu=None) + >>> switch_transformers_model = switch_transformers.models.MtfModel(model_dir=path_to_mtf_small_switch_transformers_checkpoint, batch_size=1, tpu=None) >>> vocab = SentencePieceVocabulary(path_to_mtf_small_spm_model_path, extra_ids=100) - >>> score = switchtransformers_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) + >>> score = switch_transformers_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) """ - model = FlaxSwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switchtransformers-base") - tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switchtransformers-base") + model = FlaxSwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switch_transformers-base") + tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switch_transformers-base") input_ids = tokenizer("Hello there", return_tensors="np").input_ids labels = tokenizer("Hi I am", return_tensors="np").input_ids @@ -803,18 +803,18 @@ def test_small_integration_test(self): def test_small_v1_1_integration_test(self): """ For comparision run: - >>> import switchtransformers # pip install switchtransformers==0.7.1 - >>> from switchtransformers.data.sentencepiece_vocabulary import SentencePieceVocabulary + >>> import switch_transformers # pip install switch_transformers==0.7.1 + >>> from switch_transformers.data.sentencepiece_vocabulary import SentencePieceVocabulary - >>> path_to_mtf_small_switchtransformers_v1_1_checkpoint = '' + >>> path_to_mtf_small_switch_transformers_v1_1_checkpoint = '' >>> path_to_mtf_small_spm_model_path = '' - >>> switchtransformers_model = switchtransformers.models.MtfModel(model_dir=path_to_mtf_small_switchtransformers_v1_1_checkpoint, batch_size=1, tpu=None) + >>> switch_transformers_model = switch_transformers.models.MtfModel(model_dir=path_to_mtf_small_switch_transformers_v1_1_checkpoint, batch_size=1, tpu=None) >>> vocab = SentencePieceVocabulary(path_to_mtf_small_spm_model_path, extra_ids=100) - >>> score = switchtransformers_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) + >>> score = switch_transformers_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) """ - model = FlaxSwitchTransformersForConditionalGeneration.from_pretrained("google/switchtransformers-v1_1-small") - tokenizer = SwitchTransformersTokenizer.from_pretrained("google/switchtransformers-v1_1-small") + model = FlaxSwitchTransformersForConditionalGeneration.from_pretrained("google/switch_transformers-v1_1-small") + tokenizer = SwitchTransformersTokenizer.from_pretrained("google/switch_transformers-v1_1-small") input_ids = tokenizer("Hello there", return_tensors="np").input_ids labels = tokenizer("Hi I am", return_tensors="np").input_ids @@ -830,21 +830,21 @@ def test_small_v1_1_integration_test(self): self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4) @slow - def test_small_byswitchtransformers_integration_test(self): + def test_small_byswitch_transformers_integration_test(self): """ For comparision run: - >>> import switchtransformers # pip install switchtransformers==0.9.1 + >>> import switch_transformers # pip install switch_transformers==0.9.1 - >>> path_to_byswitchtransformers_small_checkpoint = '' - >>> switchtransformers_model = switchtransformers.models.MtfModel(model_dir=path_to_tf_checkpoint, batch_size=1, tpu=None) - >>> vocab = switchtransformers.data.ByteVocabulary() - >>> score = switchtransformers_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) + >>> path_to_byswitch_transformers_small_checkpoint = '' + >>> switch_transformers_model = switch_transformers.models.MtfModel(model_dir=path_to_tf_checkpoint, batch_size=1, tpu=None) + >>> vocab = switch_transformers.data.ByteVocabulary() + >>> score = switch_transformers_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) """ model = FlaxSwitchTransformersForConditionalGeneration.from_pretrained( - "google/byybelkada/switchtransformers-base" + "google/byybelkada/switch_transformers-base" ) - tokenizer = BySwitchTransformersTokenizer.from_pretrained("google/byybelkada/switchtransformers-base") + tokenizer = BySwitchTransformersTokenizer.from_pretrained("google/byybelkada/switch_transformers-base") input_ids = tokenizer("Hello there", return_tensors="np").input_ids labels = tokenizer("Hi I am", return_tensors="np").input_ids @@ -861,11 +861,11 @@ def test_small_byswitchtransformers_integration_test(self): @slow def test_small_generation(self): - model = FlaxSwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switchtransformers-base") + model = FlaxSwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switch_transformers-base") model.config.max_length = 8 model.config.num_beams = 1 model.config.do_sample = False - tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switchtransformers-base") + tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switch_transformers-base") input_ids = tokenizer("summarize: Hello there", return_tensors="np").input_ids @@ -876,8 +876,8 @@ def test_small_generation(self): @slow def test_summarization(self): - model = FlaxSwitchTransformersForConditionalGeneration.from_pretrained("switchtransformers-base") - tok = SwitchTransformersTokenizer.from_pretrained("switchtransformers-base") + model = FlaxSwitchTransformersForConditionalGeneration.from_pretrained("switch_transformers-base") + tok = SwitchTransformersTokenizer.from_pretrained("switch_transformers-base") FRANCE_ARTICLE = ( # @noqa "Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings" diff --git a/tests/models/switchtransformers/test_modeling_switchtransformers.py b/tests/models/switch_transformers/test_modeling_switch_transformers.py similarity index 97% rename from tests/models/switchtransformers/test_modeling_switchtransformers.py rename to tests/models/switch_transformers/test_modeling_switch_transformers.py index cba49ba4d123c..b6c6f2ceef506 100644 --- a/tests/models/switchtransformers/test_modeling_switchtransformers.py +++ b/tests/models/switch_transformers/test_modeling_switch_transformers.py @@ -34,10 +34,10 @@ SwitchTransformersForConditionalGeneration, SwitchTransformersModel, ) - from transformers.models.switchtransformers.modeling_switchtransformers import ( - SWITCHTRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST, + from transformers.models.switch_transformers.modeling_switch_transformers import ( + SWITCH_TRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST, ) - from transformers.models.switchtransformers.router import ( + from transformers.models.switch_transformers.router import ( ExpertsChooseMaskedRouter, TokensChooseMaskedRouter, load_balancing_loss_func, @@ -95,7 +95,7 @@ def __init__( self.decoder_layers = decoder_layers def get_large_model_config(self): - return SwitchTransformersConfig.from_pretrained("switchtransformers-base") + return SwitchTransformersConfig.from_pretrained("switch_transformers-base") def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size) @@ -124,7 +124,7 @@ def prepare_config_and_inputs(self): def get_pipeline_config(self): return SwitchTransformersConfig( - vocab_size=166, # switchtransformers forces 100 extra tokens + vocab_size=166, # switch_transformers forces 100 extra tokens d_model=self.hidden_size, d_ff=self.d_ff, d_kv=self.hidden_size // self.num_attention_heads, @@ -480,7 +480,7 @@ def create_and_check_encoder_decoder_shared_weights( ) ) - def check_resize_embeddings_switchtransformers_v1_1( + def check_resize_embeddings_switch_transformers_v1_1( self, config, ): @@ -530,7 +530,7 @@ class SwitchTransformersModelTest(ModelTesterMixin, GenerationTesterMixin, unitt test_resize_embeddings = True test_model_parallel = True is_encoder_decoder = True - # The small SWITCHTRANSFORMERS model needs higher percentages for CPU/MP tests + # The small SWITCH_TRANSFORMERS model needs higher percentages for CPU/MP tests model_split_percents = [0.8, 0.9] def setUp(self): @@ -621,11 +621,11 @@ def test_model_fp16_forward(self): def test_v1_1_resize_embeddings(self): config = self.model_tester.prepare_config_and_inputs()[0] - self.model_tester.check_resize_embeddings_switchtransformers_v1_1(config) + self.model_tester.check_resize_embeddings_switch_transformers_v1_1(config) @slow def test_model_from_pretrained(self): - for model_name in SWITCHTRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + for model_name in SWITCH_TRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: model = SwitchTransformersModel.from_pretrained(model_name) self.assertIsNotNone(model) @@ -637,7 +637,7 @@ def test_export_to_onnx(self): torch.onnx.export( model, (config_and_inputs[1], config_and_inputs[3], config_and_inputs[2]), - f"{tmpdirname}/switchtransformers_test.onnx", + f"{tmpdirname}/switch_transformers_test.onnx", export_params=True, opset_version=9, input_names=["input_ids", "decoder_input_ids"], @@ -659,7 +659,7 @@ def test_generate_with_head_masking(self): for attn_name, (name, mask) in zip(attention_names, head_masking.items()): head_masks = {name: mask} - # Explicitly pass decoder_head_mask as it is required from SWITCHTRANSFORMERS model when head_mask specified + # Explicitly pass decoder_head_mask as it is required from SWITCH_TRANSFORMERS model when head_mask specified if name == "head_mask": head_masks["decoder_head_mask"] = torch.ones( config.num_decoder_layers, config.num_heads, device=torch_device @@ -726,7 +726,7 @@ def __init__( self.is_training = is_training def get_large_model_config(self): - return SwitchTransformersConfig.from_pretrained("switchtransformers-base") + return SwitchTransformersConfig.from_pretrained("switch_transformers-base") def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size) diff --git a/tests/models/switchtransformers/test_tokenization_switchtransformers.py b/tests/models/switch_transformers/test_tokenization_switch_transformers.py similarity index 94% rename from tests/models/switchtransformers/test_tokenization_switchtransformers.py rename to tests/models/switch_transformers/test_tokenization_switch_transformers.py index 7ba607e27b5e9..3a9fd41a9f708 100644 --- a/tests/models/switchtransformers/test_tokenization_switchtransformers.py +++ b/tests/models/switch_transformers/test_tokenization_switch_transformers.py @@ -142,12 +142,12 @@ def test_full_tokenizer(self): ) @cached_property - def switchtransformers_base_tokenizer(self): - return SwitchTransformersTokenizer.from_pretrained("switchtransformers-base") + def switch_transformers_base_tokenizer(self): + return SwitchTransformersTokenizer.from_pretrained("switch_transformers-base") @cached_property - def switchtransformers_base_tokenizer_fast(self): - return SwitchTransformersTokenizerFast.from_pretrained("switchtransformers-base") + def switch_transformers_base_tokenizer_fast(self): + return SwitchTransformersTokenizerFast.from_pretrained("switch_transformers-base") def get_tokenizer(self, **kwargs) -> SwitchTransformersTokenizer: return self.tokenizer_class.from_pretrained(self.tmpdirname, pad_token=None, **kwargs) @@ -178,13 +178,13 @@ def test_rust_and_python_full_tokenizers(self): self.assertListEqual(ids, rust_ids) def test_eos_treatment(self): - tokenizer = self.switchtransformers_base_tokenizer + tokenizer = self.switch_transformers_base_tokenizer batch_with_eos_added = tokenizer(["hi", "I went to the gym", ""]) batch_without_eos_added = tokenizer(["hi", "I went to the gym", ""]) self.assertListEqual(batch_with_eos_added["input_ids"], batch_without_eos_added["input_ids"]) def test_prepare_batch(self): - tokenizer = self.switchtransformers_base_tokenizer + tokenizer = self.switch_transformers_base_tokenizer src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."] expected_src_tokens = [71, 307, 8986, 21, 4505, 1635, 1707, 5, tokenizer.eos_token_id] batch = tokenizer(src_text, padding=True, return_tensors=FRAMEWORK) @@ -201,7 +201,7 @@ def test_prepare_batch(self): self.assertEqual((2, 9), batch.attention_mask.shape) def test_empty_target_text(self): - tokenizer = self.switchtransformers_base_tokenizer + tokenizer = self.switch_transformers_base_tokenizer src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."] batch = tokenizer(src_text, padding=True, return_tensors=FRAMEWORK) # check if input_ids are returned and no decoder_input_ids @@ -211,7 +211,7 @@ def test_empty_target_text(self): self.assertNotIn("decoder_attention_mask", batch) def test_max_length(self): - tokenizer = self.switchtransformers_base_tokenizer + tokenizer = self.switch_transformers_base_tokenizer tgt_text = [ "Summary of the text.", "Another summary.", @@ -222,7 +222,7 @@ def test_max_length(self): self.assertEqual(32, targets["input_ids"].shape[1]) def test_outputs_not_longer_than_maxlen(self): - tokenizer = self.switchtransformers_base_tokenizer + tokenizer = self.switch_transformers_base_tokenizer batch = tokenizer( ["I am a small frog" * 1000, "I am a small frog"], padding=True, truncation=True, return_tensors=FRAMEWORK @@ -234,7 +234,7 @@ def test_outputs_not_longer_than_maxlen(self): self.assertEqual(batch.input_ids.shape, (2, 512)) def test_eos_in_input(self): - tokenizer = self.switchtransformers_base_tokenizer + tokenizer = self.switch_transformers_base_tokenizer src_text = ["A long paragraph for summarization. "] tgt_text = ["Summary of the text. "] expected_src_tokens = [71, 307, 8986, 21, 4505, 1635, 1707, 5, 1] @@ -249,10 +249,10 @@ def test_token_type_ids(self): src_text_1 = ["A first paragraph for summarization."] src_text_2 = ["A second paragraph for summarization."] - fast_token_type_ids = self.switchtransformers_base_tokenizer_fast( + fast_token_type_ids = self.switch_transformers_base_tokenizer_fast( src_text_1, src_text_2, add_special_tokens=True, return_token_type_ids=True ).token_type_ids - slow_token_type_ids = self.switchtransformers_base_tokenizer( + slow_token_type_ids = self.switch_transformers_base_tokenizer( src_text_1, src_text_2, add_special_tokens=True, return_token_type_ids=True ).token_type_ids @@ -264,13 +264,13 @@ def test_fast_and_slow_same_result(self): tgt_ids = [0, 1960, 19, 2, 1245, 239, 1] tgt_text = " Today is nice day" - fast_ids = self.switchtransformers_base_tokenizer_fast(src_text, add_special_tokens=False).input_ids - slow_ids = self.switchtransformers_base_tokenizer(src_text, add_special_tokens=False).input_ids + fast_ids = self.switch_transformers_base_tokenizer_fast(src_text, add_special_tokens=False).input_ids + slow_ids = self.switch_transformers_base_tokenizer(src_text, add_special_tokens=False).input_ids self.assertEqual(tgt_ids, fast_ids) self.assertEqual(tgt_ids, slow_ids) - fast_text = self.switchtransformers_base_tokenizer_fast.decode(fast_ids) - slow_text = self.switchtransformers_base_tokenizer.decode(fast_ids) + fast_text = self.switch_transformers_base_tokenizer_fast.decode(fast_ids) + slow_text = self.switch_transformers_base_tokenizer.decode(fast_ids) self.assertEqual(tgt_text, fast_text) self.assertEqual(tgt_text, slow_text) @@ -382,6 +382,6 @@ def test_tokenizer_integration(self): self.tokenizer_integration_test_util( expected_encoding=expected_encoding, - model_name="switchtransformers-base", + model_name="switch_transformers-base", revision="5a7ff2d8f5117c194c7e32ec1ccbf04642cca99b", ) From eeb28771f5de3fb0f8617a3b94bde186c3c15654 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 13 Oct 2022 17:57:04 +0100 Subject: [PATCH 013/102] forward pass run --- docs/source/en/index.mdx | 2 +- .../configuration_switch_transformers.py | 17 +++-- ...switch_transformersx_checkpoint_to_flax.py | 26 +++++--- .../modeling_flax_switch_transformers.py | 12 ++-- .../modeling_switch_transformers.py | 65 ++++++++++++++----- .../models/switch_transformers/router.py | 12 +++- utils/check_repo.py | 1 + 7 files changed, 98 insertions(+), 37 deletions(-) diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index c12d63addb12d..2a74d4a5c1f57 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -162,7 +162,7 @@ The documentation is organized into five sections: 1. **[SqueezeBERT](model_doc/squeezebert)** (from Berkeley) released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer. 1. **[Swin Transformer](model_doc/swin)** (from Microsoft) released with the paper [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030) by Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, Baining Guo. 1. **[Swin Transformer V2](model_doc/swinv2)** (from Microsoft) released with the paper [Swin Transformer V2: Scaling Up Capacity and Resolution](https://arxiv.org/abs/2111.09883) by Ze Liu, Han Hu, Yutong Lin, Zhuliang Yao, Zhenda Xie, Yixuan Wei, Jia Ning, Yue Cao, Zheng Zhang, Li Dong, Furu Wei, Baining Guo. -1. **[Switch Transformers](model_doc/switch_transformers)** (from ) released with the paper []() by . +1. **[SwitchTransformers](model_doc/switch_transformers)** (from ) released with the paper []() by . 1. **[T5](model_doc/t5)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[T5v1.1](model_doc/t5v1.1)** (from Google AI) released in the repository [google-research/text-to-text-transfer-transformer](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[TAPAS](model_doc/tapas)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos. diff --git a/src/transformers/models/switch_transformers/configuration_switch_transformers.py b/src/transformers/models/switch_transformers/configuration_switch_transformers.py index f209cdd38edb2..94c46f5fae325 100644 --- a/src/transformers/models/switch_transformers/configuration_switch_transformers.py +++ b/src/transformers/models/switch_transformers/configuration_switch_transformers.py @@ -52,6 +52,9 @@ class SwitchTransformersConfig(PretrainedConfig): num_heads`. d_ff (`int`, *optional*, defaults to 2048): Size of the intermediate feed forward layer in each `SwitchTransformersBlock`. + expert_capacity (`int`, *optional*, defaults to 1): + Number of tokens that can be stored in each expert. If set to 1, the model will behave like a regular + Transformer. num_encoder_layers (`int`, *optional*, defaults to 12): Number of dense hidden layers in the Transformer encoder layer. num_sparse_encoder_layers (`int`, *optional*, defaults to 6): @@ -64,6 +67,8 @@ class SwitchTransformersConfig(PretrainedConfig): Number of attention heads for each attention layer in the Transformer encoder. num_experts (`int`, *optional*, defaults to 8): Number of experts for each SwitchTransformer layer. + router_type (`str`, *optional*, defaults to `tokens_masked`): + Router type - choice between `tokens_masked` and `tokens_scatter`, `experts_masked`. router_bias (`bool`, *optional*, defaults to `True`): Whether to add a bias to the router. router_jitter_noise (`float`, *optional*, defaults to 0.1): @@ -110,6 +115,8 @@ def __init__( num_sparse_decoder_layers=6, num_heads=8, num_experts=8, + expert_capacity=1, + router_type="tokens_masked", router_bias=False, router_jitter_noise=0.01, router_dtype="float32", @@ -142,18 +149,20 @@ def __init__( # This tells us, each how many encoder layer we'll have to set a sparse layer. if self.num_sparse_encoder_layers > 0: - self.encoder_sparse_step = self.num_encoder_layer % self.num_sparse_encoder_layers + self.encoder_sparse_step = self.num_encoder_layers // self.num_sparse_encoder_layers else: - self.encoder_sparse_step = self.num_encoder_layer # HACK: this will create 0 sparse layers + self.encoder_sparse_step = self.num_encoder_layers # HACK: this will create 0 sparse layers # This tells us, each how many encoder layer we'll have to set a sparse layer. if self.num_sparse_decoder_layers > 0: - self.decoder_sparse_step = self.num_decoder_layer % self.num_sparse_decoder_layers + self.decoder_sparse_step = self.num_decoder_layers // self.num_sparse_decoder_layers else: - self.decoder_sparse_step = self.num_decoder_layer # HACK: this will create 0 sparse layers + self.decoder_sparse_step = self.num_decoder_layers # HACK: this will create 0 sparse layers self.num_heads = num_heads + self.router_type = router_type self.num_experts = num_experts + self.expert_capacity = expert_capacity self.router_bias = router_bias self.router_jitter_noise = router_jitter_noise self.router_dtype = router_dtype diff --git a/src/transformers/models/switch_transformers/convert_switch_transformersx_checkpoint_to_flax.py b/src/transformers/models/switch_transformers/convert_switch_transformersx_checkpoint_to_flax.py index d9e59f3ee91cc..51011f4a19091 100644 --- a/src/transformers/models/switch_transformers/convert_switch_transformersx_checkpoint_to_flax.py +++ b/src/transformers/models/switch_transformers/convert_switch_transformersx_checkpoint_to_flax.py @@ -41,12 +41,12 @@ def convert_switch_transformersx_checkpoint_to_flax( switch_transformersx_attention_out = switch_transformersx_model["target"]["encoder"][layer_name]["attention"][ "out" ]["kernel"] - switch_transformersx_attention_query = switch_transformersx_model["target"]["encoder"][layer_name]["attention"][ - "query" - ]["kernel"] - switch_transformersx_attention_value = switch_transformersx_model["target"]["encoder"][layer_name]["attention"][ - "value" - ]["kernel"] + switch_transformersx_attention_query = switch_transformersx_model["target"]["encoder"][layer_name][ + "attention" + ]["query"]["kernel"] + switch_transformersx_attention_value = switch_transformersx_model["target"]["encoder"][layer_name][ + "attention" + ]["value"]["kernel"] # Layer Normalization switch_transformersx_attention_layer_norm = switch_transformersx_model["target"]["encoder"][layer_name][ @@ -65,7 +65,9 @@ def convert_switch_transformersx_checkpoint_to_flax( "kernel" ] - switch_transformersx_mlp_wo = switch_transformersx_model["target"]["encoder"][layer_name]["mlp"]["wo"]["kernel"] + switch_transformersx_mlp_wo = switch_transformersx_model["target"]["encoder"][layer_name]["mlp"]["wo"][ + "kernel" + ] # Layer Normalization switch_transformersx_mlp_layer_norm = switch_transformersx_model["target"]["encoder"][layer_name][ @@ -176,7 +178,9 @@ def convert_switch_transformersx_checkpoint_to_flax( "kernel" ] - switch_transformersx_mlp_wo = switch_transformersx_model["target"]["decoder"][layer_name]["mlp"]["wo"]["kernel"] + switch_transformersx_mlp_wo = switch_transformersx_model["target"]["decoder"][layer_name]["mlp"]["wo"][ + "kernel" + ] # Layer Normalization tx5_mlp_layer_norm = switch_transformersx_model["target"]["decoder"][layer_name]["pre_mlp_layer_norm"]["scale"] @@ -266,7 +270,11 @@ def convert_switch_transformersx_checkpoint_to_flax( parser = argparse.ArgumentParser() # Required parameters parser.add_argument( - "--switch_transformersx_checkpoint_path", default=None, type=str, required=True, help="Path the TX5 checkpoint." + "--switch_transformersx_checkpoint_path", + default=None, + type=str, + required=True, + help="Path the TX5 checkpoint.", ) parser.add_argument( "--config_name", default=None, type=str, required=True, help="Config name of SwitchTransformers model." diff --git a/src/transformers/models/switch_transformers/modeling_flax_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_flax_switch_transformers.py index 130352c287b87..c43086021bfb6 100644 --- a/src/transformers/models/switch_transformers/modeling_flax_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_flax_switch_transformers.py @@ -1257,10 +1257,10 @@ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs SWITCH_TRANSFORMERS_START_DOCSTRING = r""" - The SWITCH_TRANSFORMERS model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text - Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan - Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a - text-to-text denoising generative setting. + The SWITCH_TRANSFORMERS model was proposed in [Exploring the Limits of Transfer Learning with a Unified + Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine + Lee, Sharan Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer + pre-trained in a text-to-text denoising generative setting. This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads @@ -1434,8 +1434,8 @@ class FlaxSwitchTransformersModel(FlaxSwitchTransformersPreTrainedModel): @add_start_docstrings( - "The bare SWITCH_TRANSFORMERS Model transformer outputting encoder's raw hidden-states without any specific head on" - " top.", + "The bare SWITCH_TRANSFORMERS Model transformer outputting encoder's raw hidden-states without any specific head" + " on top.", SWITCH_TRANSFORMERS_START_DOCSTRING, ) # Copied from transformers.models.t5.modeling_flax_t5.FlaxT5EncoderModule with T5->SwitchTransformers diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 6b4aa78be6b85..36396f5daf368 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -44,7 +44,7 @@ replace_return_docstrings, ) from .configuration_switch_transformers import SwitchTransformersConfig -from .router import ExpertsChooseMaskedRouter, TokensChooseMaskedRouter, TokensChooseScatterRouter +from .router import ExpertsChooseMaskedRouter, RouterMask, TokensChooseMaskedRouter logger = logging.get_logger(__name__) @@ -153,6 +153,30 @@ def forward(self, hidden_states): return hidden_states +class SwitchTransformersExpert(nn.Module): + def __init__(self, config: SwitchTransformersConfig): + super().__init__() + self.weights = nn.ModuleDict( + { + "expert_{}".format(i): nn.Linear(config.d_model, config.d_model, bias=False) + for i in range(config.num_experts) + } + ) + + def forward(self, hidden_states, indices): + r""" + Args: + hidden_states (`torch.FloatTensor`, **required**): + Input to the layer of shape :obj:`(batch_size, sequence_length, hidden_size)`. + indices (:obj:`torch.LongTensor`, **required**): + Indices of the experts of shape :obj:`(batch_size, )` to use for each input in the batch. + """ + for i in range(len(self.weights)): + expert_indices = (indices[:, :, i, :] == 1).squeeze(-1) + hidden_states[expert_indices] = self.weights["expert_{}".format(i)](hidden_states[expert_indices]) + return hidden_states + + class SwitchTransformersSparseMLP(nn.Module): r""" Implementation of the Switch Transformers Sparse MLP module @@ -166,7 +190,9 @@ def __init__(self, config: SwitchTransformersConfig): self.router = self._get_router(config) # Step 2: Get the experts - self.experts = None # TODO: figure out how this is done in t5x... + self.experts = SwitchTransformersExpert(config) + + self.expert_capacity = config.expert_capacity def _get_router(self, config): r""" @@ -178,8 +204,9 @@ def _get_router(self, config): """ if config.router_type.lower() == "tokens_masked": return TokensChooseMaskedRouter(config) - elif config.router_type.lower() == "tokens_scatter": - return TokensChooseScatterRouter(config) + # TODO: completely remove the scatter implementations + # elif config.router_type.lower() == "tokens_scatter": + # return TokensChooseScatterRouter(config) elif config.router_type.lower() == "experts_masked": return ExpertsChooseMaskedRouter(config) else: @@ -189,7 +216,12 @@ def _get_router(self, config): ) def forward(self, hidden_states): - pass + expert_indices = self.router(hidden_states, expert_capacity=self.expert_capacity) + if not isinstance(expert_indices, RouterMask): + raise NotImplementedError("Only MaskedRouter is supported for now") + masked_indices = expert_indices.dispatch_mask + hidden_states = self.experts(hidden_states, masked_indices) + return hidden_states # Copied from transformers.models.t5.modeling_t5.T5Attention with T5->SwitchTransformers @@ -709,13 +741,15 @@ def _shift_right(self, input_ids): return shifted_input_ids -class SwitchTransformersStack(nn.Module): - def __init__(self, config, embed_tokens=None, sparse_step=1): +class SwitchTransformersStack(SwitchTransformersPreTrainedModel): + def __init__(self, config, embed_tokens=None): super().__init__(config) self.embed_tokens = embed_tokens self.is_decoder = config.is_decoder + sparse_step = config.decoder_sparse_step if self.is_decoder else config.encoder_sparse_step + # TODO: change this, actually you can have a block full of sparse layers... self.block = nn.ModuleList( [ @@ -954,10 +988,10 @@ def custom_forward(*inputs): SWITCH_TRANSFORMERS_START_DOCSTRING = r""" - The SWITCH_TRANSFORMERS model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text - Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan - Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a - text-to-text denoising generative setting. + The SWITCH_TRANSFORMERS model was proposed in [Exploring the Limits of Transfer Learning with a Unified + Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine + Lee, Sharan Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer + pre-trained in a text-to-text denoising generative setting. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads @@ -1136,13 +1170,13 @@ def __init__(self, config: SwitchTransformersConfig): encoder_config.is_decoder = False encoder_config.use_cache = False encoder_config.is_encoder_decoder = False - self.encoder = SwitchTransformersStack(encoder_config, self.shared, encoder_config.encoder_sparse_step) + self.encoder = SwitchTransformersStack(encoder_config, self.shared) decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True decoder_config.is_encoder_decoder = False decoder_config.num_layers = config.num_decoder_layers - self.decoder = SwitchTransformersStack(decoder_config, self.shared, decoder_config.decoder_sparse_step) + self.decoder = SwitchTransformersStack(decoder_config, self.shared) # Initialize weights and apply final processing self.post_init() @@ -1310,6 +1344,7 @@ def __init__(self, config: SwitchTransformersConfig): encoder_config = copy.deepcopy(config) encoder_config.is_decoder = False encoder_config.use_cache = False + encoder_config.num_layers = config.num_encoder_layers encoder_config.is_encoder_decoder = False self.encoder = SwitchTransformersStack(encoder_config, self.shared) @@ -1558,8 +1593,8 @@ def _reorder_cache(self, past, beam_idx): @add_start_docstrings( - "The bare SWITCH_TRANSFORMERS Model transformer outputting encoder's raw hidden-states without any specific head on" - " top.", + "The bare SWITCH_TRANSFORMERS Model transformer outputting encoder's raw hidden-states without any specific head" + " on top.", SWITCH_TRANSFORMERS_START_DOCSTRING, ) class SwitchTransformersEncoderModel(SwitchTransformersPreTrainedModel): diff --git a/src/transformers/models/switch_transformers/router.py b/src/transformers/models/switch_transformers/router.py index ee1b8d3c000ce..f97a36065a973 100644 --- a/src/transformers/models/switch_transformers/router.py +++ b/src/transformers/models/switch_transformers/router.py @@ -80,8 +80,8 @@ class RouterIndices: A tensor of size [`num_groups`, `tokens_per_group`, `num_selected_experts`, 2] dispatch indices indicating, for each token, its preferred expert and its priority in that expert's buffer. combine_weights (`torch.Tensor`): - A tensor of size [num_groups, tokens_per_group, num_selected_experts] combine weights used for scaling - expert outputs with the router's dispatch probability/confidence. + A tensor of size [`num_groups`, `tokens_per_group`, `num_selected_experts`] combine weights used for + scaling expert outputs with the router's dispatch probability/confidence. auxiliary_loss (`float`): Load balancing loss for router. router_z_loss (`float`): @@ -92,6 +92,11 @@ class RouterIndices: auxiliary_loss: float router_z_loss: float = 0.0 + def to(self, device): + return replace( + self, dispatch_mask=self.dispatch_indices.to(device), combine_array=self.combine_weights.to(device) + ) + @dataclass class RouterMask: @@ -115,6 +120,9 @@ class RouterMask: auxiliary_loss: float router_z_loss: float = 0.0 + def to(self, device): + return replace(self, dispatch_mask=self.dispatch_mask.to(device), combine_array=self.combine_array.to(device)) + # Router loss diff --git a/utils/check_repo.py b/utils/check_repo.py index 988eb499aa302..0509d12f7fed3 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -39,6 +39,7 @@ "LongT5Stack", "RealmBertModel", "T5Stack", + "SwitchTransformersStack", "TFDPRSpanPredictor", ] From 5f42b6b4553203eb54f5e4c202a956d946df8a8d Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 14 Oct 2022 12:54:50 +0000 Subject: [PATCH 014/102] update MOE layer --- .../modeling_switch_transformers.py | 50 +++++++++---------- 1 file changed, 23 insertions(+), 27 deletions(-) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 36396f5daf368..568054c1de76d 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -93,7 +93,6 @@ def forward(self, hidden_states): ALL_LAYERNORM_LAYERS.append(SwitchTransformersLayerNorm) -# TODO: this has to be changed with the experts # Copied from transformers.models.t5.modeling_t5.T5DenseActDense with T5->SwitchTransformers class SwitchTransformersDenseActDense(nn.Module): def __init__(self, config: SwitchTransformersConfig): @@ -139,7 +138,7 @@ def __init__(self, config: SwitchTransformersConfig, is_sparse=False): super().__init__() # TODO: check the comments above self.is_sparse = is_sparse - if self.is_sparse: + if not self.is_sparse: self.mlp = SwitchTransformersDenseActDense(config) else: self.mlp = SwitchTransformersSparseMLP(config) @@ -153,15 +152,13 @@ def forward(self, hidden_states): return hidden_states -class SwitchTransformersExpert(nn.Module): - def __init__(self, config: SwitchTransformersConfig): +class SwitchTransformersMOExpertLayer(nn.Module): + def __init__(self, config: SwitchTransformersConfig, expert_class:nn.Module = SwitchTransformersDenseActDense): super().__init__() - self.weights = nn.ModuleDict( - { - "expert_{}".format(i): nn.Linear(config.d_model, config.d_model, bias=False) - for i in range(config.num_experts) - } - ) + self.experts = nn.ModuleDict() + + for idx in range(config.num_experts): + self.experts[f"expert_{idx}"] = expert_class(config) def forward(self, hidden_states, indices): r""" @@ -171,16 +168,19 @@ def forward(self, hidden_states, indices): indices (:obj:`torch.LongTensor`, **required**): Indices of the experts of shape :obj:`(batch_size, )` to use for each input in the batch. """ - for i in range(len(self.weights)): - expert_indices = (indices[:, :, i, :] == 1).squeeze(-1) - hidden_states[expert_indices] = self.weights["expert_{}".format(i)](hidden_states[expert_indices]) + for idx, expert in enumerate(self.experts.values()): + # 1. Get the index of the tokens that are routed to the current expert + expert_indices = torch.eq(indices[:, :, idx, :], 1).squeeze(-1) + # 2. Update hidden states + hidden_states[expert_indices] = expert(hidden_states[expert_indices]) return hidden_states class SwitchTransformersSparseMLP(nn.Module): r""" Implementation of the Switch Transformers Sparse MLP module - + We purposely create a `SwitchTransformersMOExpertLayer` in order to give freedom to people if they want to + change this layer (by changing the agregation for example). TODO: Add a LOT of details here """ @@ -190,7 +190,7 @@ def __init__(self, config: SwitchTransformersConfig): self.router = self._get_router(config) # Step 2: Get the experts - self.experts = SwitchTransformersExpert(config) + self.experts = SwitchTransformersMOExpertLayer(config) self.expert_capacity = config.expert_capacity @@ -202,23 +202,19 @@ def _get_router(self, config): In total the list of supported Routers are the following: """ + # TODO, use a ALL_ROUTER_TYPE map instead of havind all the ifs? then just if None raise error. if config.router_type.lower() == "tokens_masked": return TokensChooseMaskedRouter(config) - # TODO: completely remove the scatter implementations - # elif config.router_type.lower() == "tokens_scatter": - # return TokensChooseScatterRouter(config) elif config.router_type.lower() == "experts_masked": return ExpertsChooseMaskedRouter(config) else: raise NotImplementedError( - f"{config.router_type.lower()} not implemented ! Please chose a router among [tokens_masked," - " tokens_scatter, experts_masked]" + f"{config.router_type.lower()} not implemented ! Please chose a router in `{"tokens_masked", + " "experts_masked"}" ) def forward(self, hidden_states): expert_indices = self.router(hidden_states, expert_capacity=self.expert_capacity) - if not isinstance(expert_indices, RouterMask): - raise NotImplementedError("Only MaskedRouter is supported for now") masked_indices = expert_indices.dispatch_mask hidden_states = self.experts(hidden_states, masked_indices) return hidden_states @@ -751,14 +747,14 @@ def __init__(self, config, embed_tokens=None): sparse_step = config.decoder_sparse_step if self.is_decoder else config.encoder_sparse_step # TODO: change this, actually you can have a block full of sparse layers... - self.block = nn.ModuleList( - [ + self.block = nn.ModuleList() + for i in range(config.num_layers) : + self.block.append( SwitchTransformersBlock( config, has_relative_attention_bias=bool(i == 0), is_sparse=(i % sparse_step == 0) ) - for i in range(config.num_layers) - ] - ) + ) + self.final_layer_norm = SwitchTransformersLayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) From 60ab566d9284b18108cb579c53c594e191a843f5 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 14 Oct 2022 12:55:22 +0000 Subject: [PATCH 015/102] small router update --- src/transformers/models/switch_transformers/router.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/transformers/models/switch_transformers/router.py b/src/transformers/models/switch_transformers/router.py index f97a36065a973..20976f0e3f71a 100644 --- a/src/transformers/models/switch_transformers/router.py +++ b/src/transformers/models/switch_transformers/router.py @@ -186,7 +186,6 @@ def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.T # Router classes - class Router(nn.Module): """ Abstract base router class, defining router API and inner workings. @@ -539,7 +538,6 @@ def _compute_routing_instructions( return RouterMask(dispatch_mask, combine_array, auxiliary_loss) - class ScatterRouter(Router): """ Abstract base router class for scatter dispatch routers. From ae2fbc4d154fcb5e831d5c6aac9640d059a3bfce Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 14 Oct 2022 12:57:05 +0000 Subject: [PATCH 016/102] fixup --- .../modeling_switch_transformers.py | 16 ++++++++-------- .../models/switch_transformers/router.py | 2 ++ 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 568054c1de76d..833562d7c3b59 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -153,10 +153,10 @@ def forward(self, hidden_states): class SwitchTransformersMOExpertLayer(nn.Module): - def __init__(self, config: SwitchTransformersConfig, expert_class:nn.Module = SwitchTransformersDenseActDense): + def __init__(self, config: SwitchTransformersConfig, expert_class: nn.Module = SwitchTransformersDenseActDense): super().__init__() self.experts = nn.ModuleDict() - + for idx in range(config.num_experts): self.experts[f"expert_{idx}"] = expert_class(config) @@ -171,7 +171,7 @@ def forward(self, hidden_states, indices): for idx, expert in enumerate(self.experts.values()): # 1. Get the index of the tokens that are routed to the current expert expert_indices = torch.eq(indices[:, :, idx, :], 1).squeeze(-1) - # 2. Update hidden states + # 2. Update hidden states hidden_states[expert_indices] = expert(hidden_states[expert_indices]) return hidden_states @@ -179,7 +179,7 @@ def forward(self, hidden_states, indices): class SwitchTransformersSparseMLP(nn.Module): r""" Implementation of the Switch Transformers Sparse MLP module - We purposely create a `SwitchTransformersMOExpertLayer` in order to give freedom to people if they want to + We purposely create a `SwitchTransformersMOExpertLayer` in order to give freedom to people if they want to change this layer (by changing the agregation for example). TODO: Add a LOT of details here """ @@ -202,15 +202,15 @@ def _get_router(self, config): In total the list of supported Routers are the following: """ - # TODO, use a ALL_ROUTER_TYPE map instead of havind all the ifs? then just if None raise error. + # TODO, use a ALL_ROUTER_TYPE map instead of havind all the ifs? then just if None raise error. if config.router_type.lower() == "tokens_masked": return TokensChooseMaskedRouter(config) elif config.router_type.lower() == "experts_masked": return ExpertsChooseMaskedRouter(config) else: raise NotImplementedError( - f"{config.router_type.lower()} not implemented ! Please chose a router in `{"tokens_masked", - " "experts_masked"}" + f"{config.router_type.lower()} not implemented ! Please chose a router in " + "`{'tokens_masked','experts_masked'}`" ) def forward(self, hidden_states): @@ -748,7 +748,7 @@ def __init__(self, config, embed_tokens=None): # TODO: change this, actually you can have a block full of sparse layers... self.block = nn.ModuleList() - for i in range(config.num_layers) : + for i in range(config.num_layers): self.block.append( SwitchTransformersBlock( config, has_relative_attention_bias=bool(i == 0), is_sparse=(i % sparse_step == 0) diff --git a/src/transformers/models/switch_transformers/router.py b/src/transformers/models/switch_transformers/router.py index 20976f0e3f71a..f97a36065a973 100644 --- a/src/transformers/models/switch_transformers/router.py +++ b/src/transformers/models/switch_transformers/router.py @@ -186,6 +186,7 @@ def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.T # Router classes + class Router(nn.Module): """ Abstract base router class, defining router API and inner workings. @@ -538,6 +539,7 @@ def _compute_routing_instructions( return RouterMask(dispatch_mask, combine_array, auxiliary_loss) + class ScatterRouter(Router): """ Abstract base router class for scatter dispatch routers. From d7ba596d46ccee5de64fde449c9f56015965c52c Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 14 Oct 2022 13:06:11 +0000 Subject: [PATCH 017/102] consistency --- .../modeling_switch_transformers.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 833562d7c3b59..6f6aa51aa9877 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -44,7 +44,7 @@ replace_return_docstrings, ) from .configuration_switch_transformers import SwitchTransformersConfig -from .router import ExpertsChooseMaskedRouter, RouterMask, TokensChooseMaskedRouter +from .router import ExpertsChooseMaskedRouter, TokensChooseMaskedRouter logger = logging.get_logger(__name__) @@ -164,9 +164,9 @@ def forward(self, hidden_states, indices): r""" Args: hidden_states (`torch.FloatTensor`, **required**): - Input to the layer of shape :obj:`(batch_size, sequence_length, hidden_size)`. - indices (:obj:`torch.LongTensor`, **required**): - Indices of the experts of shape :obj:`(batch_size, )` to use for each input in the batch. + Input to the layer of shape `(batch_size, sequence_length, hidden_size)`. + indices (`torch.LongTensor`, **required**): + Indices of the experts of shape `(batch_size, )` to use for each input in the batch. """ for idx, expert in enumerate(self.experts.values()): # 1. Get the index of the tokens that are routed to the current expert @@ -178,9 +178,8 @@ def forward(self, hidden_states, indices): class SwitchTransformersSparseMLP(nn.Module): r""" - Implementation of the Switch Transformers Sparse MLP module - We purposely create a `SwitchTransformersMOExpertLayer` in order to give freedom to people if they want to - change this layer (by changing the agregation for example). + Implementation of the Switch Transformers Sparse MLP module We purposely create a `SwitchTransformersMOExpertLayer` + in order to give freedom to people if they want to change this layer (by changing the agregation for example). TODO: Add a LOT of details here """ From 1ae556368bbc4fcc0f8fdf5ed6140ca58097e555 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 14 Oct 2022 13:19:21 +0000 Subject: [PATCH 018/102] remove scatter router --- .../models/switch_transformers/router.py | 209 ++---------------- 1 file changed, 20 insertions(+), 189 deletions(-) diff --git a/src/transformers/models/switch_transformers/router.py b/src/transformers/models/switch_transformers/router.py index f97a36065a973..1925641ed4e6e 100644 --- a/src/transformers/models/switch_transformers/router.py +++ b/src/transformers/models/switch_transformers/router.py @@ -185,7 +185,7 @@ def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.T # Router classes - +# TODO not a big fan of 3 level of class Router(nn.Module): """ @@ -306,46 +306,26 @@ def forward( def _compute_routing_instructions(self, router_probs, padding_mask, expert_capacity): raise NotImplementedError( """ - The forward function cannot be called from the `Router` super-class. Please call an appropriate Router - class that inherits from the `Router` class (for example `ExpertsChooseMaskedRouter`) + Computes masks for the top-k experts per token. This has to be implemented for each subclass of MaskedRouter + routers. + + Args: + router_probs (`torch.Tensor`): + Input router probabilities of shape [num_groups, tokens_per_group, num_experts] this corresponds to the + probabilities used to determine the routing of tokens to the experts. + padding_mask (`torch.Tensor`): + Padding mask tensor of shape [num_groups, tokens_per_group] a mask used to identify padding tokens that + should be ignored by the router. + expert_capacity (`int`): + Each group will send this many tokens to each expert. + + Returns: + Router mask arrays. """ ) -class MaskedRouter(Router): - """ - Abstract base router class for masked matmul dispatch routers. - - MaskedRouter(s) return RouterMask(s) containing a dispatch mask and combine array for sending and receiving (via - masked matmuls) inputs and outputs to and from experts. - - Routing using masked matmuls is generally faster than scatter-based routing on TPUs. - """ - - def _compute_routing_instructions( - self, router_probs: torch.Tensor, padding_mask: Optional[torch.Tensor], expert_capacity: int - ) -> RouterMask: - """ - Computes masks for the top-k experts per token. This has to be implemented for each subclass of MaskedRouter - routers. - - Args: - router_probs (`torch.Tensor`): - Input router probabilities of shape [num_groups, tokens_per_group, num_experts] this corresponds to the - probabilities used to determine the routing of tokens to the experts. - padding_mask (`torch.Tensor`): - Padding mask tensor of shape [num_groups, tokens_per_group] a mask used to identify padding tokens that - should be ignored by the router. - expert_capacity (`int`): - Each group will send this many tokens to each expert. - - Returns: - Router mask arrays. - """ - raise NotImplementedError("MaskedRouter is an abstract class that should be subclassed.") - - -class ExpertsChooseMaskedRouter(MaskedRouter): +class ExpertsChooseMaskedRouter(Router): """ Masked matmul router using experts choose tokens assignment. @@ -412,7 +392,7 @@ def _compute_routing_instructions( return RouterMask(dispatch_mask, combine_array, auxiliary_loss) -class TokensChooseMaskedRouter(MaskedRouter): +class TokensChooseMaskedRouter(Router): """ Masked matmul router using tokens choose top-k experts assignment. @@ -525,12 +505,14 @@ def _compute_routing_instructions( # Shape: [num_groups, tokens_per_group, num_experts, expert_capacity]. # token_priority = token_priority * (token_priority > 0) + # TODO can we improve the function name or use torch's? # dispatch_mask = torch.nn.functional.one_hot(token_priority.long(), expert_capacity + 1)[..., 1:] dispatch_mask = _jax_one_hot(token_priority.long(), expert_capacity, axis=-1) # The combine array will be used for combining expert outputs, scaled by the # router probabilities. Shape: [num_groups, tokens_per_group, num_experts, # expert_capacity]. + # TODO can we use more understandable code here? combine_array = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask) # combine_array = torch.einsum("...te,...te->...te", router_probs, dispatch_mask) @@ -539,157 +521,6 @@ def _compute_routing_instructions( return RouterMask(dispatch_mask, combine_array, auxiliary_loss) - -class ScatterRouter(Router): - """ - Abstract base router class for scatter dispatch routers. - - ScatterRouter(s) return RouterIndices containing dispatch indices and combine weights for sending token inputs (via - scatter) and receiving outputs (via gather) to and from experts. - - Scatter-based routing is generally faster than masked matmul routing on CPUs and GPUs. - """ - - def _compute_routing_instructions( - self, router_probs: torch.Tensor, padding_mask: Optional[torch.Tensor], expert_capacity: int - ) -> RouterIndices: - """Computes instructions for routing inputs to experts. - - Args: - router_probs (`torch.Tensor`): - Router raw probabilities tensor of shape [num_groups, tokens_per_group, num_experts] used to determine - the routing of tokens to the experts. - padding_mask (`torch.Tensor`): - Padding mask of shape [num_groups, tokens_per_group] used to identify padding tokens that should be - ignored by the router. - expert_capacity (`int`): - Each group will send this many tokens to each expert. - - Returns: - Router indices containing dispatch indices and combine weights. - """ - raise NotImplementedError("ScatterRouter is an abstract class that should be subclassed.") - - -class TokensChooseScatterRouter(ScatterRouter): - """ - Scatter router using tokens choose top-k experts assignment. - - This router uses the same mechanism as in Switch Transformer (https://arxiv.org/abs/2101.03961) and V-MoE - (https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are sorted by router_probs and then - routed to their choice of expert until the expert's expert_capacity is reached. There is no guarantee that each - token is processed by an expert, or that each expert receives at least one token. - - Attributes: - num_selected_experts (`int`): - Maximum number of experts to which each token is routed. Tokens may be routed to fewer experts if - particular experts are oversubscribed / reach capacity. - batch_prioritized_routing (`bool`): - Whether or not to use Batch Prioritized Routing BPR), originally introduced in V-MoE - (https://arxiv.org/abs/2106.05974). With BPR, we prioritize routing those top-k tokens with the highest - router probability, rather than simply using each tokens left-to-right ordering in the batch. This - prioritization is important because the expert's have limited capacity. - """ - - def __init__(self, config, **kwargs): - super().__init__(config, **kwargs) - self.num_selected_experts = config.num_selected_experts - self.batch_prioritized_routing = config.batch_prioritized_routing - - def _compute_routing_instructions( - self, router_probs: torch.Tensor, padding_mask: Optional[torch.Tensor], expert_capacity: int - ) -> RouterIndices: - """Computes dispatch indices and combine weights for the top-k experts. - - Args: - router_probs (`torch.Tensor`): - Router raw probabilities tensor of shape [num_groups, tokens_per_group, num_experts] used to determine - the routing of tokens to the experts. - padding_mask (`torch.Tensor`): - Padding mask of shape [num_groups, tokens_per_group] used to identify padding tokens that should be - ignored by the router. - expert_capacity (`int`): - Each group will send this many tokens to each expert. - - Returns: - Dispatch indices and combine weights for scatter/gather-based routing. - """ - num_groups, tokens_per_group, num_experts = router_probs.shape - - if padding_mask is not None: - # Because experts choose tokens, we mask probabilities corresponding to - # tokens before the top-k operation. Note that, unlike for masked-based - # tokens-choose routing, the experts here may still choose to select the - # (down-weighted) padding tokens. - router_probs *= padding_mask.unsqueeze(-1) - - # Top-k router probability and corresponding expert indices for each token. - # Shape: [num_groups, tokens_per_group, num_selected_experts]. - combine_weights, expert_indices = torch.topk(router_probs, k=self.num_selected_experts) - - auxiliary_loss = load_balancing_loss_func(router_probs, expert_indices) - - if self.batch_prioritized_routing: - # Sort tokens according to their routing probability per group, so that - # the highest probability tokens are routed first. - token_ordering = torch.argsort(-combine_weights[..., 0], dim=-1) - # Shape: [num_groups, tokens_per_group, num_selected_experts] - expert_indices = torch.take_along_dim(expert_indices, token_ordering.unsqueeze(-1), dim=-2) - - # Identify each token's preferred expert. - # Make num_selected_experts the leading axis to ensure that top-1 choices - # have priority over top-2 choices, which have priority over top-3 - # choices... - preferred_experts = expert_indices.permute(0, 2, 1) - # Shape: [num_groups, num_selected_experts * tokens_per_group] - preferred_experts = preferred_experts.reshape(num_groups, -1) - - # Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts]. - expert_mask = _jax_one_hot(preferred_experts, num_experts, dtype=torch.int32) - - # Experts have a fixed capacity that we cannot exceed. A token's priority - # within the expert's buffer is given by the masked, cumulative capacity of - # its target expert. - # Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts]. - token_priority = torch.cumsum(expert_mask, axis=1) * expert_mask - 1.0 - # Shape: [num_groups, num_selected_experts, tokens_per_group, num_experts]. - token_priority = token_priority.reshape((num_groups, self.num_selected_experts, -1, num_experts)) - # Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts]. - token_priority = token_priority.permute((0, 2, 1, 3)) - # For each token, across all selected experts, select the only non-negative - # (unmasked) priority. Now, for group G routing to expert E, token T has - # non-negative priority (i.e. token_priority[G,T,E] >= 0) if and only if E - # is its targeted expert. - # Shape: [num_groups, tokens_per_group, num_experts]. - token_priority = torch.max(token_priority, axis=-1).values - - # Return to original index shape. - preferred_experts = preferred_experts.reshape(num_groups, self.num_selected_experts, tokens_per_group) - # Shape: [num_groups, tokens_per_group, num_selected_experts] - preferred_experts = preferred_experts.permute(0, 2, 1) - - if self.batch_prioritized_routing: - # Place token priorities in original ordering of tokens. - inv_permutation = torch.argsort(token_ordering, dim=-1) - preferred_experts = torch.take_along_dim( - preferred_experts.unsqueeze(-1), inv_permutation.unsqueeze(-1), dim=-2 - ) - token_priority = torch.take_along_dim(token_priority.unsqueeze(-1), inv_permutation.unsqueeze(-1), dim=-2) - - # Mask out tokens that overflow the maximum expert capacities. - # Shape: [num_groups, tokens_per_group, num_selected_experts]. - combine_weights *= token_priority < expert_capacity - - # Expert index and priority within the expert capacity buffer. - # Shape: [num_groups, tokens_per_group, num_selected_experts, 2]. - dispatch_indices = torch.stack([preferred_experts, token_priority], dim=-1) - - # Return to default dtype now that router computation is complete. - dispatch_indices = dispatch_indices.to(torch.int32) - - return RouterIndices(dispatch_indices, combine_weights, auxiliary_loss) - - # num_groups = 2 # tokens_per_group = 4 # hidden_dim = 3 From 60ec29910a65836336a7f12944be640b5eaba053 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 14 Oct 2022 13:35:23 +0000 Subject: [PATCH 019/102] remove abstract layer --- .../modeling_switch_transformers.py | 41 +++++-------------- .../models/switch_transformers/router.py | 20 +++++---- 2 files changed, 22 insertions(+), 39 deletions(-) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 6f6aa51aa9877..407156f17d876 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -152,44 +152,20 @@ def forward(self, hidden_states): return hidden_states -class SwitchTransformersMOExpertLayer(nn.Module): - def __init__(self, config: SwitchTransformersConfig, expert_class: nn.Module = SwitchTransformersDenseActDense): - super().__init__() - self.experts = nn.ModuleDict() - - for idx in range(config.num_experts): - self.experts[f"expert_{idx}"] = expert_class(config) - - def forward(self, hidden_states, indices): - r""" - Args: - hidden_states (`torch.FloatTensor`, **required**): - Input to the layer of shape `(batch_size, sequence_length, hidden_size)`. - indices (`torch.LongTensor`, **required**): - Indices of the experts of shape `(batch_size, )` to use for each input in the batch. - """ - for idx, expert in enumerate(self.experts.values()): - # 1. Get the index of the tokens that are routed to the current expert - expert_indices = torch.eq(indices[:, :, idx, :], 1).squeeze(-1) - # 2. Update hidden states - hidden_states[expert_indices] = expert(hidden_states[expert_indices]) - return hidden_states - - class SwitchTransformersSparseMLP(nn.Module): r""" - Implementation of the Switch Transformers Sparse MLP module We purposely create a `SwitchTransformersMOExpertLayer` - in order to give freedom to people if they want to change this layer (by changing the agregation for example). - TODO: Add a LOT of details here + Implementation of the Switch Transformers Sparse MLP module. TODO: Add a LOT of details here """ - def __init__(self, config: SwitchTransformersConfig): + def __init__(self, config: SwitchTransformersConfig, expert_class: nn.Module = SwitchTransformersDenseActDense): super().__init__() # Step 1: Get the correct router self.router = self._get_router(config) # Step 2: Get the experts - self.experts = SwitchTransformersMOExpertLayer(config) + self.experts = nn.ModuleDict() + for idx in range(config.num_experts): + self.experts[f"expert_{idx}"] = expert_class(config) self.expert_capacity = config.expert_capacity @@ -215,7 +191,12 @@ def _get_router(self, config): def forward(self, hidden_states): expert_indices = self.router(hidden_states, expert_capacity=self.expert_capacity) masked_indices = expert_indices.dispatch_mask - hidden_states = self.experts(hidden_states, masked_indices) + + for idx, expert in enumerate(self.experts.values()): + # 1. Get the index of the tokens that are routed to the current expert + expert_indices = torch.eq(masked_indices[:, :, idx, :], 1).squeeze(-1) + # 2. Update hidden states + hidden_states[expert_indices] = expert(hidden_states[expert_indices]) return hidden_states diff --git a/src/transformers/models/switch_transformers/router.py b/src/transformers/models/switch_transformers/router.py index 1925641ed4e6e..bbff43d04636a 100644 --- a/src/transformers/models/switch_transformers/router.py +++ b/src/transformers/models/switch_transformers/router.py @@ -185,7 +185,8 @@ def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.T # Router classes -# TODO not a big fan of 3 level of +# TODO not a big fan of 3 level of + class Router(nn.Module): """ @@ -306,16 +307,16 @@ def forward( def _compute_routing_instructions(self, router_probs, padding_mask, expert_capacity): raise NotImplementedError( """ - Computes masks for the top-k experts per token. This has to be implemented for each subclass of MaskedRouter - routers. + Computes masks for the top-k experts per token. This has to be implemented for each subclass of + MaskedRouter routers. Args: router_probs (`torch.Tensor`): - Input router probabilities of shape [num_groups, tokens_per_group, num_experts] this corresponds to the - probabilities used to determine the routing of tokens to the experts. + Input router probabilities of shape [num_groups, tokens_per_group, num_experts] this + corresponds to the probabilities used to determine the routing of tokens to the experts. padding_mask (`torch.Tensor`): - Padding mask tensor of shape [num_groups, tokens_per_group] a mask used to identify padding tokens that - should be ignored by the router. + Padding mask tensor of shape [num_groups, tokens_per_group] a mask used to identify padding + tokens that should be ignored by the router. expert_capacity (`int`): Each group will send this many tokens to each expert. @@ -505,14 +506,14 @@ def _compute_routing_instructions( # Shape: [num_groups, tokens_per_group, num_experts, expert_capacity]. # token_priority = token_priority * (token_priority > 0) - # TODO can we improve the function name or use torch's? + # TODO can we improve the function name or use torch's? # dispatch_mask = torch.nn.functional.one_hot(token_priority.long(), expert_capacity + 1)[..., 1:] dispatch_mask = _jax_one_hot(token_priority.long(), expert_capacity, axis=-1) # The combine array will be used for combining expert outputs, scaled by the # router probabilities. Shape: [num_groups, tokens_per_group, num_experts, # expert_capacity]. - # TODO can we use more understandable code here? + # TODO can we use more understandable code here? combine_array = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask) # combine_array = torch.einsum("...te,...te->...te", router_probs, dispatch_mask) @@ -521,6 +522,7 @@ def _compute_routing_instructions( return RouterMask(dispatch_mask, combine_array, auxiliary_loss) + # num_groups = 2 # tokens_per_group = 4 # hidden_dim = 3 From 6181bfa17d77fb95c6f39184a424353307dd4d83 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 14 Oct 2022 14:57:49 +0000 Subject: [PATCH 020/102] update test and model for integration testing --- .../configuration_switch_transformers.py | 13 +++---- .../modeling_switch_transformers.py | 11 +++--- .../models/switch_transformers/router.py | 5 +-- .../test_modeling_switch_transformers.py | 36 +++++++++++++++++-- 4 files changed, 47 insertions(+), 18 deletions(-) diff --git a/src/transformers/models/switch_transformers/configuration_switch_transformers.py b/src/transformers/models/switch_transformers/configuration_switch_transformers.py index 94c46f5fae325..11253c3580666 100644 --- a/src/transformers/models/switch_transformers/configuration_switch_transformers.py +++ b/src/transformers/models/switch_transformers/configuration_switch_transformers.py @@ -55,7 +55,7 @@ class SwitchTransformersConfig(PretrainedConfig): expert_capacity (`int`, *optional*, defaults to 1): Number of tokens that can be stored in each expert. If set to 1, the model will behave like a regular Transformer. - num_encoder_layers (`int`, *optional*, defaults to 12): + num_layers (`int`, *optional*, defaults to 12): Number of dense hidden layers in the Transformer encoder layer. num_sparse_encoder_layers (`int`, *optional*, defaults to 6): Number of sparse (MoE) dense hidden layers in the Transformer encoder layer. @@ -109,7 +109,7 @@ def __init__( d_model=512, d_kv=64, d_ff=2048, - num_encoder_layers=12, + num_layers=12, num_sparse_encoder_layers=6, num_decoder_layers=12, num_sparse_decoder_layers=6, @@ -139,19 +139,20 @@ def __init__( self.d_model = d_model self.d_kv = d_kv self.d_ff = d_ff - self.num_encoder_layers = num_encoder_layers + self.num_sparse_encoder_layers = num_sparse_encoder_layers + self.num_layers = num_layers self.num_decoder_layers = ( - num_decoder_layers if num_decoder_layers is not None else self.num_encoder_layers + num_decoder_layers if num_decoder_layers is not None else self.num_layers ) # default = symmetry self.num_sparse_decoder_layers = num_sparse_decoder_layers # This tells us, each how many encoder layer we'll have to set a sparse layer. if self.num_sparse_encoder_layers > 0: - self.encoder_sparse_step = self.num_encoder_layers // self.num_sparse_encoder_layers + self.encoder_sparse_step = self.num_layers // self.num_sparse_encoder_layers else: - self.encoder_sparse_step = self.num_encoder_layers # HACK: this will create 0 sparse layers + self.encoder_sparse_step = self.num_layers # HACK: this will create 0 sparse layers # This tells us, each how many encoder layer we'll have to set a sparse layer. if self.num_sparse_decoder_layers > 0: diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 407156f17d876..6636da65ceaac 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -725,14 +725,13 @@ def __init__(self, config, embed_tokens=None): self.is_decoder = config.is_decoder sparse_step = config.decoder_sparse_step if self.is_decoder else config.encoder_sparse_step - - # TODO: change this, actually you can have a block full of sparse layers... + config.num_layers = config.num_decoder_layers if self.is_decoder else config.num_layers self.block = nn.ModuleList() for i in range(config.num_layers): + + is_sparse = (i % sparse_step == 0) if sparse_step > 0 else False self.block.append( - SwitchTransformersBlock( - config, has_relative_attention_bias=bool(i == 0), is_sparse=(i % sparse_step == 0) - ) + SwitchTransformersBlock(config, has_relative_attention_bias=bool(i == 0), is_sparse=is_sparse) ) self.final_layer_norm = SwitchTransformersLayerNorm(config.d_model, eps=config.layer_norm_epsilon) @@ -1151,7 +1150,6 @@ def __init__(self, config: SwitchTransformersConfig): decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True decoder_config.is_encoder_decoder = False - decoder_config.num_layers = config.num_decoder_layers self.decoder = SwitchTransformersStack(decoder_config, self.shared) # Initialize weights and apply final processing @@ -1320,7 +1318,6 @@ def __init__(self, config: SwitchTransformersConfig): encoder_config = copy.deepcopy(config) encoder_config.is_decoder = False encoder_config.use_cache = False - encoder_config.num_layers = config.num_encoder_layers encoder_config.is_encoder_decoder = False self.encoder = SwitchTransformersStack(encoder_config, self.shared) diff --git a/src/transformers/models/switch_transformers/router.py b/src/transformers/models/switch_transformers/router.py index bbff43d04636a..6c4a96b41f074 100644 --- a/src/transformers/models/switch_transformers/router.py +++ b/src/transformers/models/switch_transformers/router.py @@ -67,7 +67,7 @@ def _jax_one_hot(tensor, num_classes, axis=-1, dtype=torch.bool): mask = (tensor >= 0) & (tensor < num_classes) out[mask, tensor[mask]] = 1 - return out + return out.to(tensor.device) @dataclass @@ -250,7 +250,8 @@ def _compute_router_probabilities( distrib_upper_bound = 1.0 + self.jitter_noise uniform_distrib = ( - torch.rand(token_inputs.shape) * (distrib_lower_bound - distrib_upper_bound) + torch.rand(token_inputs.shape, device=token_inputs.device) + * (distrib_lower_bound - distrib_upper_bound) ) + distrib_upper_bound # Multiply the token inputs by the uniform distribution - adding some noise diff --git a/tests/models/switch_transformers/test_modeling_switch_transformers.py b/tests/models/switch_transformers/test_modeling_switch_transformers.py index b6c6f2ceef506..4989b919bce00 100644 --- a/tests/models/switch_transformers/test_modeling_switch_transformers.py +++ b/tests/models/switch_transformers/test_modeling_switch_transformers.py @@ -19,7 +19,7 @@ import unittest from transformers import SwitchTransformersConfig, is_torch_available -from transformers.testing_utils import require_torch, slow, torch_device +from transformers.testing_utils import require_tokenizers, require_torch, slow, torch_device from ...generation.test_generation_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -69,6 +69,7 @@ def __init__( decoder_start_token_id=0, scope=None, decoder_layers=None, + sparse_step=1, ): self.parent = parent @@ -93,6 +94,7 @@ def __init__( self.decoder_start_token_id = decoder_start_token_id self.scope = None self.decoder_layers = decoder_layers + self.sparse_step = sparse_step def get_large_model_config(self): return SwitchTransformersConfig.from_pretrained("switch_transformers-base") @@ -156,6 +158,7 @@ def get_config(self): bos_token_id=self.pad_token_id, pad_token_id=self.pad_token_id, decoder_start_token_id=self.decoder_start_token_id, + sparse_step=self.sparse_step, ) def check_prepare_lm_labels_via_shift_left( @@ -990,7 +993,7 @@ def test_equivalency_token_chose_masked_router(self): self.assertAlmostEqual(output.router_z_loss.item(), 0.4789799, places=5) self.assertTrue(torch.allclose(output.dispatch_mask, expected_dispatch_mask)) - self.assertTrue(torch.allclose(output.combine_array, expected_combine_array)) + self.assertTrue(torch.allclose(output.combine_array, expected_combine_array, atol=1e-4)) def test_equivalency_experts_chose_masked_router(self): r""" @@ -1064,4 +1067,31 @@ def test_equivalency_experts_chose_masked_router(self): ] ) - self.assertTrue(torch.allclose(output.combine_array, expected_combined_array)) + self.assertTrue(torch.allclose(output.combine_array, expected_combined_array, atol=1e-4)) + + +@require_torch +@require_tokenizers +class SwitchTransformerModelIntegrationTests(unittest.TestCase): + def test_small_logits(self): + pass + + def test_large_logits(self): + pass + + def test_small_logits_bf16(self): + pass + + def test_small_batch_generate(self): + pass + + def test_large_batch_generate(self): + pass + + @slow + def test_summarization(self): + pass + + @slow + def test_translation_en_to_de(self): + pass From a6a7d5741a37cd792b094ba9f6e06a47d3ee4d5f Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 20 Oct 2022 17:26:16 +0200 Subject: [PATCH 021/102] v1 conversion --- .../configuration_switch_transformers.py | 4 +- ...rmers_original_tf_checkpoint_to_pytorch.py | 104 ++++++++++++++++-- .../modeling_switch_transformers.py | 11 +- .../models/switch_transformers/router.py | 4 +- 4 files changed, 108 insertions(+), 15 deletions(-) diff --git a/src/transformers/models/switch_transformers/configuration_switch_transformers.py b/src/transformers/models/switch_transformers/configuration_switch_transformers.py index 11253c3580666..83b7f905e75cc 100644 --- a/src/transformers/models/switch_transformers/configuration_switch_transformers.py +++ b/src/transformers/models/switch_transformers/configuration_switch_transformers.py @@ -106,12 +106,12 @@ class SwitchTransformersConfig(PretrainedConfig): def __init__( self, vocab_size=32128, - d_model=512, + d_model=768, d_kv=64, d_ff=2048, num_layers=12, num_sparse_encoder_layers=6, - num_decoder_layers=12, + num_decoder_layers=6, num_sparse_decoder_layers=6, num_heads=8, num_experts=8, diff --git a/src/transformers/models/switch_transformers/convert_switch_transformers_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/switch_transformers/convert_switch_transformers_original_tf_checkpoint_to_pytorch.py index 509d622daea58..39aeaf9e41ed0 100644 --- a/src/transformers/models/switch_transformers/convert_switch_transformers_original_tf_checkpoint_to_pytorch.py +++ b/src/transformers/models/switch_transformers/convert_switch_transformers_original_tf_checkpoint_to_pytorch.py @@ -16,33 +16,119 @@ import argparse +import os +import regex as re +from flax.serialization import msgpack_restore from transformers import SwitchTransformersConfig, SwitchTransformersForConditionalGeneration +from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model from transformers.utils import logging +from transformers.utils.hub import get_file_from_repo logging.set_verbosity_info() +from flax.traverse_util import flatten_dict, unflatten_dict -def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path): + +# should not include what is already done by the `from_pt` argument +MOE_LAYER_NAME_MAPPING = { + "/attention/": "/0/SelfAttention/", + "/self_attention/": "/0/SelfAttention/", + "/encoder_decoder_attention/": "/1/EncDecAttention/", + "value": "v", + "query": "q", + "key": "k", + "out": "o", + "pre_self_attention_layer_norm": "0/layer_norm", + "pre_cross_attention_layer_norm": "1/layer_norm", + "pre_attention_layer_norm": "1/layer_norm", + "token_embedder": "shared", + "encoder_norm": "final_layer_norm", + "decoder_norm": "final_layer_norm", + "relpos_bias/rel_embedding": "TO_RENAME", + "router/router_weights/w/": "router/classifier/", + "roer/roer_weights/w/": "router/classifier/", +} + +FLAX_MODELS = { + "base-8": "https://huggingface.co/ybelkada/switch-c-base-8", + "base-16": "https://huggingface.co/ybelkada/switch-c-base-16", + "base-32": "https://huggingface.co/ybelkada/switch-c-base-32", + "base-64": "https://huggingface.co/ybelkada/switch-c-base-64", + "base-128": "https://huggingface.co/ybelkada/switch-c-base-128", + "base-256": "https://huggingface.co/ybelkada/switch-c-base-256", + "large-128": "https://huggingface.co/ybelkada/switch-c-large-128", + "xxl-128": "https://huggingface.co/ybelkada/switch-c-xxl-128", + "beast-2048": "https://huggingface.co/ybelkada/switch-c-2048", +} + + +def rename_keys(s_dict): + # 1. in HF T5, we have block.{x}.layer.{y}. which corresponds to layer.{x} in + # the original model + keys = list(s_dict.keys()) + for key in keys: + layer_to_block_of_layer = r".*/layers_(\d+)" + + if re.match(layer_to_block_of_layer, key): + new_key = re.sub(r"layers_(\d+)", r"block/\1/layer", key) + s_dict[new_key] = s_dict.pop(key) + key = new_key + layer_to_block_of_layer = r"(encoder|decoder)\/" + + if re.match(layer_to_block_of_layer, key): + groups = re.match(layer_to_block_of_layer, key).groups() + if groups[0] == "encoder": + new_key = re.sub(r"/mlp/", r"/1/mlp/", key) + new_key = re.sub(r"/pre_mlp_layer_norm/", r"/0/layer_norm/", new_key) + + elif groups[0] == "decoder": + new_key = re.sub(r"/mlp/", r"/2/mlp/", key) + new_key = re.sub(r"/pre_mlp_layer_norm/", r"/1/layer_norm/", new_key) + + # 2. Convert other classic mappings + for old_key, temp_key in MOE_LAYER_NAME_MAPPING.items(): + if old_key in new_key: + new_key = new_key.replace(old_key, temp_key) + + print(f"{key} -> {new_key}") + s_dict[new_key] = s_dict.pop(key) + # 3. Take extra care of the EXPERTS layer + + return s_dict + + +def convert_flax_checkpoint_to_pytorch(flax_checkpoint_path, config_file, pytorch_dump_path): # Initialise PyTorch model - config = SwitchTransformersConfig.from_json_file(config_file) - print(f"Building PyTorch model from configuration: {config}") - model = SwitchTransformersForConditionalGeneration(config) - # Load weights from tf checkpoint - # load_tf_weights_in_switch_transformers(model, config, tf_checkpoint_path) + print(f"Loading flax weights from : {flax_checkpoint_path}") + # get_file_from_repo(flax_checkpoint_path, "flax_params.flax") + with open(os.path.join(flax_checkpoint_path, "flax_params.flax"), "rb") as f: + params = msgpack_restore(f.read()) + + config = SwitchTransformersConfig.from_pretrained(config_file) + pt_model = SwitchTransformersForConditionalGeneration(config) + + params = flatten_dict(params, sep="/") + params = rename_keys(params) + breakpoint() + params = unflatten_dict(params, sep="/") + + load_flax_weights_in_pytorch_model(pt_model, params) + + # Post process the experts # Save pytorch-model print(f"Save PyTorch model to {pytorch_dump_path}") - model.save_pretrained(pytorch_dump_path) + pt_model.save_pretrained(pytorch_dump_path) if __name__ == "__main__": parser = argparse.ArgumentParser() # Required parameters parser.add_argument( - "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." + "--flax_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." ) parser.add_argument( "--config_file", @@ -58,4 +144,4 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." ) args = parser.parse_args() - convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path) + convert_flax_checkpoint_to_pytorch(args.flax_checkpoint_path, args.config_file, args.pytorch_dump_path) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 6636da65ceaac..1dc1087d2da76 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -130,7 +130,7 @@ def forward(self, hidden_states): # TODO: Change it here to adapt it from the paper, the FF layer contains experts -# an expert is a FF layer with multiple sub-FF layers inside. +# an expert is a FF layer with multiple sub-FF layers inside.s # This class should also contain a router class # check flaxformer/architecture/moe/router.py : https://github.com/google/flaxformer/blob/main/flaxformer/architectures/moe/routing.py class SwitchTransformersLayerFF(nn.Module): @@ -729,7 +729,14 @@ def __init__(self, config, embed_tokens=None): self.block = nn.ModuleList() for i in range(config.num_layers): - is_sparse = (i % sparse_step == 0) if sparse_step > 0 else False + # is_sparse = (i % sparse_step == 1) if sparse_step > 0 else False + if self.is_decoder: + even = 1 + else: + even = 0 + + is_sparse = (i % sparse_step == even) if sparse_step > 0 else False + self.block.append( SwitchTransformersBlock(config, has_relative_attention_bias=bool(i == 0), is_sparse=is_sparse) ) diff --git a/src/transformers/models/switch_transformers/router.py b/src/transformers/models/switch_transformers/router.py index 6c4a96b41f074..24578e59b5d86 100644 --- a/src/transformers/models/switch_transformers/router.py +++ b/src/transformers/models/switch_transformers/router.py @@ -209,7 +209,7 @@ class Router(nn.Module): def __init__(self, config, **kwargs): super().__init__() self.num_experts = config.num_experts - self.router_weights = nn.Linear(config.hidden_size, self.num_experts, bias=config.router_bias) + self.classifier = nn.Linear(config.hidden_size, self.num_experts, bias=config.router_bias) self.jitter_noise = config.router_jitter_noise self.ignore_padding_tokens = config.router_ignore_padding_tokens self.dtype = getattr(torch, config.router_dtype) @@ -258,7 +258,7 @@ def _compute_router_probabilities( token_inputs *= uniform_distrib # Shape: [num_groups, tokens_per_group, num_experts] - router_logits = self.router_weights(token_inputs) + router_logits = self.classifier(token_inputs) router_probabilities = torch.nn.Softmax(dim=-1)(router_logits) From 2e2be49c6b19e137c5a4742808de2b5ec7aa0566 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 20 Oct 2022 17:04:27 +0000 Subject: [PATCH 022/102] update --- ...rs_original_flax_checkpoint_to_pytorch.py} | 43 +++++++++++++------ .../modeling_switch_transformers.py | 7 +-- 2 files changed, 30 insertions(+), 20 deletions(-) rename src/transformers/models/switch_transformers/{convert_switch_transformers_original_tf_checkpoint_to_pytorch.py => convert_switch_transformers_original_flax_checkpoint_to_pytorch.py} (79%) diff --git a/src/transformers/models/switch_transformers/convert_switch_transformers_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py similarity index 79% rename from src/transformers/models/switch_transformers/convert_switch_transformers_original_tf_checkpoint_to_pytorch.py rename to src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py index 39aeaf9e41ed0..f28ba6ead4c8d 100644 --- a/src/transformers/models/switch_transformers/convert_switch_transformers_original_tf_checkpoint_to_pytorch.py +++ b/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py @@ -46,9 +46,12 @@ "token_embedder": "shared", "encoder_norm": "final_layer_norm", "decoder_norm": "final_layer_norm", - "relpos_bias/rel_embedding": "TO_RENAME", + "relpos_bias/rel_embedding": "block/0/layer/0/SelfAttention/relative_attention_bias/weight", "router/router_weights/w/": "router/classifier/", "roer/roer_weights/w/": "router/classifier/", + + + } FLAX_MODELS = { @@ -70,32 +73,43 @@ def rename_keys(s_dict): keys = list(s_dict.keys()) for key in keys: layer_to_block_of_layer = r".*/layers_(\d+)" - + new_key = key if re.match(layer_to_block_of_layer, key): - new_key = re.sub(r"layers_(\d+)", r"block/\1/layer", key) - s_dict[new_key] = s_dict.pop(key) - key = new_key + new_key = re.sub(r"layers_(\d+)", r"block/\1/layer", new_key) + # s_dict[new_key] = s_dict.pop(key) + layer_to_block_of_layer = r"(encoder|decoder)\/" if re.match(layer_to_block_of_layer, key): - groups = re.match(layer_to_block_of_layer, key).groups() + groups = re.match(layer_to_block_of_layer, new_key).groups() if groups[0] == "encoder": - new_key = re.sub(r"/mlp/", r"/1/mlp/", key) + new_key = re.sub(r"/mlp/", r"/1/mlp/", new_key) new_key = re.sub(r"/pre_mlp_layer_norm/", r"/0/layer_norm/", new_key) elif groups[0] == "decoder": - new_key = re.sub(r"/mlp/", r"/2/mlp/", key) + new_key = re.sub(r"/mlp/", r"/2/mlp/", new_key) new_key = re.sub(r"/pre_mlp_layer_norm/", r"/1/layer_norm/", new_key) # 2. Convert other classic mappings for old_key, temp_key in MOE_LAYER_NAME_MAPPING.items(): if old_key in new_key: new_key = new_key.replace(old_key, temp_key) + print(f"{key} -> {new_key}") - s_dict[new_key] = s_dict.pop(key) - # 3. Take extra care of the EXPERTS layer + s_dict[new_key] = s_dict.pop(key) + + # 3. Take extra care of the EXPERTS layer + for key in list(s_dict.keys()): + if "expert" in key: + + num_experts = s_dict[key].shape[0] + expert_weihts = s_dict[key] + for idx in range(num_experts): + s_dict[key.replace("expert/", f"experts/expert_{idx}/")] = expert_weihts[idx] + s_dict.pop(key) + return s_dict @@ -103,16 +117,17 @@ def convert_flax_checkpoint_to_pytorch(flax_checkpoint_path, config_file, pytorc # Initialise PyTorch model print(f"Loading flax weights from : {flax_checkpoint_path}") - # get_file_from_repo(flax_checkpoint_path, "flax_params.flax") - with open(os.path.join(flax_checkpoint_path, "flax_params.flax"), "rb") as f: + path = get_file_from_repo(flax_checkpoint_path, "flax_params.flax", use_auth_token = "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") # get_file_from_repo(config_file, "flax_params.flax") # get_file_from_repo(config_file, "flax_params.flax", use_auth_token = "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") + + config_file = get_file_from_repo(flax_checkpoint_path, "config.json", use_auth_token = "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") + with open(os.path.join(path), "rb") as f: params = msgpack_restore(f.read()) - config = SwitchTransformersConfig.from_pretrained(config_file) + config = SwitchTransformersConfig.from_pretrained(config_file, relative_attention_num_buckets=12) pt_model = SwitchTransformersForConditionalGeneration(config) params = flatten_dict(params, sep="/") params = rename_keys(params) - breakpoint() params = unflatten_dict(params, sep="/") load_flax_weights_in_pytorch_model(pt_model, params) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 1dc1087d2da76..0772b57c76276 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -729,13 +729,8 @@ def __init__(self, config, embed_tokens=None): self.block = nn.ModuleList() for i in range(config.num_layers): - # is_sparse = (i % sparse_step == 1) if sparse_step > 0 else False - if self.is_decoder: - even = 1 - else: - even = 0 - is_sparse = (i % sparse_step == even) if sparse_step > 0 else False + is_sparse = (i % sparse_step == 1) if sparse_step > 0 else False self.block.append( SwitchTransformersBlock(config, has_relative_attention_bias=bool(i == 0), is_sparse=is_sparse) From 6ede6080d30ba63e27a99a2755ee3c80fb244290 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 20 Oct 2022 17:09:36 +0000 Subject: [PATCH 023/102] hardcode hack --- .../models/switch_transformers/modeling_switch_transformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 0772b57c76276..c4861f6bb5d9f 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -221,7 +221,7 @@ def __init__(self, config: SwitchTransformersConfig, has_relative_attention_bias self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) if self.has_relative_attention_bias: - self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) + self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, 32 ) #self.n_heads) self.pruned_heads = set() self.gradient_checkpointing = False From b9cac05df1d415e06c7aa1e75d5f480395a22b34 Mon Sep 17 00:00:00 2001 From: ybelkada Date: Mon, 24 Oct 2022 10:22:51 +0000 Subject: [PATCH 024/102] all keys match --- ...ers_original_flax_checkpoint_to_pytorch.py | 94 +++++++------------ .../modeling_switch_transformers.py | 6 +- 2 files changed, 37 insertions(+), 63 deletions(-) diff --git a/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py b/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py index f28ba6ead4c8d..a85a52ee9b76a 100644 --- a/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py +++ b/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2022 The SwitchTransformers authors and HuggingFace Inc. team. +# Copyright 2022 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,26 +12,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Convert SwitchTransformers checkpoint.""" +"""Convert SwitchTransformersX checkpoints from the original repository to JAX/FLAX model.""" import argparse -import os +import re -import regex as re -from flax.serialization import msgpack_restore -from transformers import SwitchTransformersConfig, SwitchTransformersForConditionalGeneration +from t5x import checkpoints +from transformers import SwitchTransformersForConditionalGeneration, SwitchTransformersConfig from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model -from transformers.utils import logging -from transformers.utils.hub import get_file_from_repo - - -logging.set_verbosity_info() from flax.traverse_util import flatten_dict, unflatten_dict - -# should not include what is already done by the `from_pt` argument MOE_LAYER_NAME_MAPPING = { "/attention/": "/0/SelfAttention/", "/self_attention/": "/0/SelfAttention/", @@ -49,21 +41,6 @@ "relpos_bias/rel_embedding": "block/0/layer/0/SelfAttention/relative_attention_bias/weight", "router/router_weights/w/": "router/classifier/", "roer/roer_weights/w/": "router/classifier/", - - - -} - -FLAX_MODELS = { - "base-8": "https://huggingface.co/ybelkada/switch-c-base-8", - "base-16": "https://huggingface.co/ybelkada/switch-c-base-16", - "base-32": "https://huggingface.co/ybelkada/switch-c-base-32", - "base-64": "https://huggingface.co/ybelkada/switch-c-base-64", - "base-128": "https://huggingface.co/ybelkada/switch-c-base-128", - "base-256": "https://huggingface.co/ybelkada/switch-c-base-256", - "large-128": "https://huggingface.co/ybelkada/switch-c-large-128", - "xxl-128": "https://huggingface.co/ybelkada/switch-c-xxl-128", - "beast-2048": "https://huggingface.co/ybelkada/switch-c-2048", } @@ -97,66 +74,65 @@ def rename_keys(s_dict): print(f"{key} -> {new_key}") - s_dict[new_key] = s_dict.pop(key) + s_dict[new_key] = s_dict.pop(key) + + s_dict["encoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight"] = s_dict["encoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight"].T + s_dict["decoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight"] = s_dict["decoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight"].T # 3. Take extra care of the EXPERTS layer for key in list(s_dict.keys()): - if "expert" in key: - + if "expert" in key: num_experts = s_dict[key].shape[0] expert_weihts = s_dict[key] for idx in range(num_experts): s_dict[key.replace("expert/", f"experts/expert_{idx}/")] = expert_weihts[idx] s_dict.pop(key) - return s_dict - -def convert_flax_checkpoint_to_pytorch(flax_checkpoint_path, config_file, pytorch_dump_path): - # Initialise PyTorch model - - print(f"Loading flax weights from : {flax_checkpoint_path}") - path = get_file_from_repo(flax_checkpoint_path, "flax_params.flax", use_auth_token = "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") # get_file_from_repo(config_file, "flax_params.flax") # get_file_from_repo(config_file, "flax_params.flax", use_auth_token = "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") - - config_file = get_file_from_repo(flax_checkpoint_path, "config.json", use_auth_token = "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") - with open(os.path.join(path), "rb") as f: - params = msgpack_restore(f.read()) - - config = SwitchTransformersConfig.from_pretrained(config_file, relative_attention_num_buckets=12) - pt_model = SwitchTransformersForConditionalGeneration(config) - - params = flatten_dict(params, sep="/") +def convert_flax_checkpoint_to_pytorch(flax_params, pt_model): + # Flatten Flax param dict, rename it and unflatten it + params = flatten_dict(flax_params, sep="/") params = rename_keys(params) params = unflatten_dict(params, sep="/") + # Load the flax params in the PT model load_flax_weights_in_pytorch_model(pt_model, params) + return pt_model + + - # Post process the experts +def convert_switch_transformersx_checkpoint_to_flax( + switch_transformersx_checkpoint_path, config_name, pytorch_dump_path +): + config = SwitchTransformersConfig.from_pretrained(config_name) + pt_model = SwitchTransformersForConditionalGeneration(config=config) + flax_params = checkpoints.load_t5x_checkpoint(switch_transformersx_checkpoint_path) + + pt_model = convert_flax_checkpoint_to_pytorch(flax_params['target'], pt_model) - # Save pytorch-model print(f"Save PyTorch model to {pytorch_dump_path}") pt_model.save_pretrained(pytorch_dump_path) + if __name__ == "__main__": parser = argparse.ArgumentParser() # Required parameters parser.add_argument( - "--flax_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." - ) - parser.add_argument( - "--config_file", + "--switch_t5x_checkpoint_path", default=None, type=str, required=True, - help=( - "The config json file corresponding to the pre-trained SwitchTransformers model. \nThis specifies the" - " model architecture." - ), + help="Path the TX5 checkpoint.", + ) + parser.add_argument( + "--config_name", default=None, type=str, required=True, help="Config name of SwitchTransformers model." ) parser.add_argument( - "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output FLAX model." ) args = parser.parse_args() - convert_flax_checkpoint_to_pytorch(args.flax_checkpoint_path, args.config_file, args.pytorch_dump_path) + convert_switch_transformersx_checkpoint_to_flax( + args.switch_t5x_checkpoint_path, args.config_name, args.pytorch_dump_folder_path + ) \ No newline at end of file diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index c4861f6bb5d9f..05f1ded6b2458 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -142,12 +142,10 @@ def __init__(self, config: SwitchTransformersConfig, is_sparse=False): self.mlp = SwitchTransformersDenseActDense(config) else: self.mlp = SwitchTransformersSparseMLP(config) - self.layer_norm = SwitchTransformersLayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) def forward(self, hidden_states): - forwarded_states = self.layer_norm(hidden_states) - forwarded_states = self.mlp(forwarded_states) + forwarded_states = self.mlp(hidden_states) hidden_states = hidden_states + self.dropout(forwarded_states) return hidden_states @@ -221,7 +219,7 @@ def __init__(self, config: SwitchTransformersConfig, has_relative_attention_bias self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) if self.has_relative_attention_bias: - self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, 32 ) #self.n_heads) + self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) self.pruned_heads = set() self.gradient_checkpointing = False From 6276ce7158aa596d62a905dae45fda6f06dbdb0a Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 24 Oct 2022 10:28:14 +0000 Subject: [PATCH 025/102] add gin conversion, without additional libraries --- ...ers_original_flax_checkpoint_to_pytorch.py | 93 +++++++++++++------ 1 file changed, 63 insertions(+), 30 deletions(-) diff --git a/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py b/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py index f28ba6ead4c8d..a41f6068caea5 100644 --- a/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py +++ b/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py @@ -25,12 +25,26 @@ from transformers.utils import logging from transformers.utils.hub import get_file_from_repo +from t5x import checkpoints logging.set_verbosity_info() from flax.traverse_util import flatten_dict, unflatten_dict +MODEL_MAPPING = { + "switch_base_8":["https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin"], + "switch_base_8":["https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin"], + "switch_base_8":["https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin"], + "switch_base_8":["https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin"], + "switch_base_8":["https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin"], + "switch_base_8":["https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin"], + "switch_base_8":["https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin"], + "switch_base_8":["https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin"], + "switch_base_8":["https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin"], + "switch_base_8":["https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin"], +} + # should not include what is already done by the `from_pt` argument MOE_LAYER_NAME_MAPPING = { "/attention/": "/0/SelfAttention/", @@ -49,24 +63,10 @@ "relpos_bias/rel_embedding": "block/0/layer/0/SelfAttention/relative_attention_bias/weight", "router/router_weights/w/": "router/classifier/", "roer/roer_weights/w/": "router/classifier/", - - -} -FLAX_MODELS = { - "base-8": "https://huggingface.co/ybelkada/switch-c-base-8", - "base-16": "https://huggingface.co/ybelkada/switch-c-base-16", - "base-32": "https://huggingface.co/ybelkada/switch-c-base-32", - "base-64": "https://huggingface.co/ybelkada/switch-c-base-64", - "base-128": "https://huggingface.co/ybelkada/switch-c-base-128", - "base-256": "https://huggingface.co/ybelkada/switch-c-base-256", - "large-128": "https://huggingface.co/ybelkada/switch-c-large-128", - "xxl-128": "https://huggingface.co/ybelkada/switch-c-xxl-128", - "beast-2048": "https://huggingface.co/ybelkada/switch-c-2048", } - def rename_keys(s_dict): # 1. in HF T5, we have block.{x}.layer.{y}. which corresponds to layer.{x} in # the original model @@ -94,36 +94,68 @@ def rename_keys(s_dict): for old_key, temp_key in MOE_LAYER_NAME_MAPPING.items(): if old_key in new_key: new_key = new_key.replace(old_key, temp_key) - + print(f"{key} -> {new_key}") - s_dict[new_key] = s_dict.pop(key) + s_dict[new_key] = s_dict.pop(key) + - # 3. Take extra care of the EXPERTS layer for key in list(s_dict.keys()): - if "expert" in key: - + if "expert" in key: + num_experts = s_dict[key].shape[0] expert_weihts = s_dict[key] for idx in range(num_experts): s_dict[key.replace("expert/", f"experts/expert_{idx}/")] = expert_weihts[idx] s_dict.pop(key) - + return s_dict +GIN_TO_CONFIG_MAPPING = { + "NUM_ENCODER_LAYERS":"num_layers", + "NUM_DECODER_LAYERS":"num_decoder_layers", + "NUM_HEADS":"num_heads", + "HEAD_DIM":"d_kv", + "EMBED_DIM":"d_model", + "MLP_DIM":"d_ff", + "NUM_EXPERTS":"num_experts", + "NUM_SELECTED_EXPERTS":"num_selected_experts", + "NUM_ENCODER_SPARSE_LAYERS":"num_sparse_encoder_layers", + "NUM_DECODER_SPARSE_LAYERS":"num_sparse_decoder_layers", + "EVAL_EXPERT_CAPACITY_FACTOR":"expert_capacity", + "dense.MlpBlock.activations":"feed_forward_proj", -def convert_flax_checkpoint_to_pytorch(flax_checkpoint_path, config_file, pytorch_dump_path): +} + +def convert_gin_to_config(gin_file): + # Convert a google style config to the hugging face fromat + import regex as re + with open(gin_file, "r") as f: + raw_gin = f.read() + + regex_match = re.findall(r"(.*) = ([0-9.]*)", raw_gin) + args = {} + for param, value in regex_match: + if param in GIN_TO_CONFIG_MAPPING and value != "": + args[GIN_TO_CONFIG_MAPPING[param]] = float(value) + + activation = re.findall(r"activations = \(\'(.*)\',\)", raw_gin) + args[GIN_TO_CONFIG_MAPPING[activation]] = str(activation) + config = SwitchTransformersConfig(**args) + return config + +def convert_flax_checkpoint_to_pytorch(flax_checkpoint_path, config_file, gin_file = None, pytorch_dump_path = "./"): # Initialise PyTorch model print(f"Loading flax weights from : {flax_checkpoint_path}") - path = get_file_from_repo(flax_checkpoint_path, "flax_params.flax", use_auth_token = "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") # get_file_from_repo(config_file, "flax_params.flax") # get_file_from_repo(config_file, "flax_params.flax", use_auth_token = "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") - - config_file = get_file_from_repo(flax_checkpoint_path, "config.json", use_auth_token = "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") - with open(os.path.join(path), "rb") as f: - params = msgpack_restore(f.read()) + t5x_model = checkpoints.load_t5x_checkpoint(flax_checkpoint_path) + + if gin_file is not None: + config = convert_gin_to_config(gin_file) + else : + config = SwitchTransformersConfig.from_pretrained(config_file, relative_attention_num_buckets=12) - config = SwitchTransformersConfig.from_pretrained(config_file, relative_attention_num_buckets=12) pt_model = SwitchTransformersForConditionalGeneration(config) params = flatten_dict(params, sep="/") @@ -132,8 +164,6 @@ def convert_flax_checkpoint_to_pytorch(flax_checkpoint_path, config_file, pytorc load_flax_weights_in_pytorch_model(pt_model, params) - # Post process the experts - # Save pytorch-model print(f"Save PyTorch model to {pytorch_dump_path}") pt_model.save_pretrained(pytorch_dump_path) @@ -152,9 +182,12 @@ def convert_flax_checkpoint_to_pytorch(flax_checkpoint_path, config_file, pytorc required=True, help=( "The config json file corresponding to the pre-trained SwitchTransformers model. \nThis specifies the" - " model architecture." + " model architecture. If not provided, a `gin_file` has to be provided." ), ) + parser.add_argument( + "--gin_file", default=None, type=str, required=True, help="Path to the gin config file. If not provided, a `config_file` has to be passed " + ) parser.add_argument( "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." ) From 0c4e54a5752657cb7d0b50d013be9d5e422b93ec Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 24 Oct 2022 14:55:10 +0000 Subject: [PATCH 026/102] update conversion sctipy --- .../configuration_switch_transformers.py | 10 ++-- ...ers_original_flax_checkpoint_to_pytorch.py | 46 +++++++++++-------- .../modeling_switch_transformers.py | 6 ++- 3 files changed, 36 insertions(+), 26 deletions(-) diff --git a/src/transformers/models/switch_transformers/configuration_switch_transformers.py b/src/transformers/models/switch_transformers/configuration_switch_transformers.py index 83b7f905e75cc..8850d22136065 100644 --- a/src/transformers/models/switch_transformers/configuration_switch_transformers.py +++ b/src/transformers/models/switch_transformers/configuration_switch_transformers.py @@ -110,11 +110,11 @@ def __init__( d_kv=64, d_ff=2048, num_layers=12, - num_sparse_encoder_layers=6, - num_decoder_layers=6, - num_sparse_decoder_layers=6, - num_heads=8, - num_experts=8, + num_sparse_encoder_layers=3, + num_decoder_layers=12, + num_sparse_decoder_layers=3, + num_heads=12, + num_experts=64, expert_capacity=1, router_type="tokens_masked", router_bias=False, diff --git a/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py b/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py index 60bf1cb26ed8a..c2e7ff07de410 100644 --- a/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py +++ b/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py @@ -55,7 +55,7 @@ "out": "o", "pre_self_attention_layer_norm": "0/layer_norm", "pre_cross_attention_layer_norm": "1/layer_norm", - "pre_attention_layer_norm": "1/layer_norm", + "pre_attention_layer_norm": "0/layer_norm", # previously 1, but seems wrong "token_embedder": "shared", "encoder_norm": "final_layer_norm", "decoder_norm": "final_layer_norm", @@ -73,7 +73,7 @@ def rename_keys(s_dict): new_key = key if re.match(layer_to_block_of_layer, key): new_key = re.sub(r"layers_(\d+)", r"block/\1/layer", new_key) - # s_dict[new_key] = s_dict.pop(key) + layer_to_block_of_layer = r"(encoder|decoder)\/" @@ -81,11 +81,11 @@ def rename_keys(s_dict): groups = re.match(layer_to_block_of_layer, new_key).groups() if groups[0] == "encoder": new_key = re.sub(r"/mlp/", r"/1/mlp/", new_key) - new_key = re.sub(r"/pre_mlp_layer_norm/", r"/0/layer_norm/", new_key) + new_key = re.sub(r"/pre_mlp_layer_norm/", r"/1/layer_norm/", new_key) elif groups[0] == "decoder": new_key = re.sub(r"/mlp/", r"/2/mlp/", new_key) - new_key = re.sub(r"/pre_mlp_layer_norm/", r"/1/layer_norm/", new_key) + new_key = re.sub(r"/pre_mlp_layer_norm/", r"/2/layer_norm/", new_key) # 2. Convert other classic mappings for old_key, temp_key in MOE_LAYER_NAME_MAPPING.items(): @@ -107,6 +107,8 @@ def rename_keys(s_dict): expert_weihts = s_dict[key] for idx in range(num_experts): s_dict[key.replace("expert/", f"experts/expert_{idx}/")] = expert_weihts[idx] + print(f"{key} -> {key.replace('expert/', f'experts/expert_{idx}/')}") + s_dict.pop(key) return s_dict @@ -118,16 +120,14 @@ def rename_keys(s_dict): "HEAD_DIM":"d_kv", "EMBED_DIM":"d_model", "MLP_DIM":"d_ff", - "NUM_EXPERTS":"num_experts", "NUM_SELECTED_EXPERTS":"num_selected_experts", "NUM_ENCODER_SPARSE_LAYERS":"num_sparse_encoder_layers", "NUM_DECODER_SPARSE_LAYERS":"num_sparse_decoder_layers", - "EVAL_EXPERT_CAPACITY_FACTOR":"expert_capacity", "dense.MlpBlock.activations":"feed_forward_proj", } -def convert_gin_to_config(gin_file): +def convert_gin_to_config(gin_file, num_experts): # Convert a google style config to the hugging face fromat import regex as re with open(gin_file, "r") as f: @@ -137,32 +137,35 @@ def convert_gin_to_config(gin_file): args = {} for param, value in regex_match: if param in GIN_TO_CONFIG_MAPPING and value != "": - args[GIN_TO_CONFIG_MAPPING[param]] = float(value) + args[GIN_TO_CONFIG_MAPPING[param]] = float(value) if '.' in value else int(value) - activation = re.findall(r"activations = \(\'(.*)\',\)", raw_gin) - args[GIN_TO_CONFIG_MAPPING[activation]] = str(activation) + activation = re.findall(r"(.*activations) = \(\'(.*)\',\)", raw_gin)[0] + args[GIN_TO_CONFIG_MAPPING[activation[0]]] = str(activation[1]) + + args["num_experts"] = num_experts config = SwitchTransformersConfig(**args) return config -def convert_flax_checkpoint_to_pytorch(flax_checkpoint_path, config_file, gin_file = None, pytorch_dump_path = "./"): +def convert_flax_checkpoint_to_pytorch(flax_checkpoint_path, config_file, gin_file = None, pytorch_dump_path = "./", num_experts = 8): # Initialise PyTorch model print(f"Loading flax weights from : {flax_checkpoint_path}") flax_params = checkpoints.load_t5x_checkpoint(flax_checkpoint_path) if gin_file is not None: - config = convert_gin_to_config(gin_file) + config = convert_gin_to_config(gin_file, num_experts) else : - config = SwitchTransformersConfig.from_pretrained(config_file, relative_attention_num_buckets=12) + config = SwitchTransformersConfig.from_pretrained(config_file) pt_model = SwitchTransformersForConditionalGeneration(config) - params = flatten_dict(params, sep="/") - params = rename_keys(params) - params = unflatten_dict(params, sep="/") + flax_params = flax_params['target'] + flax_params = flatten_dict(flax_params, sep="/") + flax_params = rename_keys(flax_params) + flax_params = unflatten_dict(flax_params, sep="/") # Load the flax params in the PT model - load_flax_weights_in_pytorch_model(pt_model, params['target']) + load_flax_weights_in_pytorch_model(pt_model, flax_params) print(f"Save PyTorch model to {pytorch_dump_path}") pt_model.save_pretrained(pytorch_dump_path) @@ -183,15 +186,18 @@ def convert_flax_checkpoint_to_pytorch(flax_checkpoint_path, config_file, gin_fi ), ) parser.add_argument( - "--gin_file", default=None, type=str, required=True, help="Path to the gin config file. If not provided, a `config_file` has to be passed " + "--gin_file", default=None, type=str, required=False, help="Path to the gin config file. If not provided, a `config_file` has to be passed " ) parser.add_argument( - "--config_name", default=None, type=str, required=True, help="Config name of SwitchTransformers model." + "--config_name", default=None, type=str, required=False, help="Config name of SwitchTransformers model." ) parser.add_argument( "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output pytorch model." ) + parser.add_argument( + "--num_experts", default=8, type=str, required=False, help="Number of experts" + ) args = parser.parse_args() convert_flax_checkpoint_to_pytorch( - args.switch_t5x_checkpoint_path, args.config_name, args.gin_file,args.pytorch_dump_folder_path + args.switch_t5x_checkpoint_path, args.config_name, args.gin_file,args.pytorch_dump_folder_path, args.num_experts ) \ No newline at end of file diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 05f1ded6b2458..4152ae2210dc3 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -142,10 +142,13 @@ def __init__(self, config: SwitchTransformersConfig, is_sparse=False): self.mlp = SwitchTransformersDenseActDense(config) else: self.mlp = SwitchTransformersSparseMLP(config) + + self.layer_norm = SwitchTransformersLayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) def forward(self, hidden_states): - forwarded_states = self.mlp(hidden_states) + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.mlp(forwarded_states) hidden_states = hidden_states + self.dropout(forwarded_states) return hidden_states @@ -187,6 +190,7 @@ def _get_router(self, config): ) def forward(self, hidden_states): + # TODO the expert capacity is poorly computed expert_indices = self.router(hidden_states, expert_capacity=self.expert_capacity) masked_indices = expert_indices.dispatch_mask From 3b0ee25ec7281188f64e306f1b9960da6bd7b1b1 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 24 Oct 2022 14:59:54 +0000 Subject: [PATCH 027/102] delete router file --- .../modeling_switch_transformers.py | 504 +++++++++++++- .../models/switch_transformers/router.py | 648 ------------------ 2 files changed, 503 insertions(+), 649 deletions(-) delete mode 100644 src/transformers/models/switch_transformers/router.py diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 4152ae2210dc3..b0516517c24c1 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -44,8 +44,12 @@ replace_return_docstrings, ) from .configuration_switch_transformers import SwitchTransformersConfig -from .router import ExpertsChooseMaskedRouter, TokensChooseMaskedRouter +from dataclasses import dataclass, replace +from typing import Any, Optional, Tuple + +import torch +import torch.nn as nn logger = logging.get_logger(__name__) @@ -63,6 +67,504 @@ ] +RouterOutput = Any + +def _jax_one_hot(tensor, num_classes, axis=-1, dtype=torch.bool): + r""" + This function mimics the behavior of jax.nn.functional.one_hot in PyTorch. It takes a tensor of indices, the number + of desired classes and the axis to apply the one-hot encoding. If any value is outside the range [0, num_classes), + it will be set to zeros. + + Args: + tensor (`torch.Tensor`): + Input tensor + num_classes (`int`): + Number of classes to process for one hot encoding + axis (`int`, *optional*): + The lookup axis to check for one-hot encoding + dtype (`torch.dtype`, *optional*): + Output `dtype`. The one hot encoded vector will be casted to this dtype + """ + if tensor.is_floating_point(): + raise "Input tensor for one hot encoding must be an `int32` or `int64`" + + if axis >= len(tensor.shape): + raise "Axis is out of bounds" + + if axis == -1: + axis = len(tensor.shape) + elif axis < -1: + raise "Axis must be greater than -1" + else: + axis = axis + 1 + + # Get the final output shape + output_shape = list(tensor.shape) + output_shape.insert(axis, num_classes) + + # Create an empty output of zeros + out = torch.zeros(tuple(output_shape), dtype=dtype) + + # Mask out the places where it is outside the range [0, num_classes) + # kudos to GitHub copilot for this line + mask = (tensor >= 0) & (tensor < num_classes) + out[mask, tensor[mask]] = 1 + + return out.to(tensor.device) + + +@dataclass +class RouterIndices: + r""" + A dataclass wrapper to store the dispatch indices and combine weights for scatter/gather-based routing. + + Attributes: + dispatch_indices (`torch.Tensor`): + A tensor of size [`num_groups`, `tokens_per_group`, `num_selected_experts`, 2] dispatch indices indicating, + for each token, its preferred expert and its priority in that expert's buffer. + combine_weights (`torch.Tensor`): + A tensor of size [`num_groups`, `tokens_per_group`, `num_selected_experts`] combine weights used for + scaling expert outputs with the router's dispatch probability/confidence. + auxiliary_loss (`float`): + Load balancing loss for router. + router_z_loss (`float`): + Router z-loss. Encourages router logits to remain small in an effort to improve stability. + """ + dispatch_indices: torch.Tensor + combine_weights: torch.Tensor + auxiliary_loss: float + router_z_loss: float = 0.0 + + def to(self, device): + return replace( + self, dispatch_mask=self.dispatch_indices.to(device), combine_array=self.combine_weights.to(device) + ) + + +@dataclass +class RouterMask: + r""" + Dispatch and combine torch.Tensors for expert routing with masked matmuls. + + Attributes: + dispatch_mask (`torch.Tensor`): + A mask tensor of shape [num_groups, tokens_per_group, num_experts, expert_capacity] that is 1 if the token + gets routed to the corresponding expert, and 0 otherwise. + combine_array (`torch.Tensor`): + A tensor of shape [num_groups, tokens_per_group, num_experts, expert_capacity] combine torch.Tensor used + for combining expert outputs and scaling with router probability. + auxiliary_loss (`float`): + Load balancing loss for router. + router_z_loss (`float`): + Router z-loss. Encourages router logits to remain small in an effort to improve stability. + """ + dispatch_mask: torch.Tensor + combine_array: torch.Tensor + auxiliary_loss: float + router_z_loss: float = 0.0 + + def to(self, device): + return replace(self, dispatch_mask=self.dispatch_mask.to(device), combine_array=self.combine_array.to(device)) + + +# Router loss + + +def router_z_loss_func(router_logits: torch.Tensor) -> float: + r""" + Compute router z-loss implemented in PyTorch. + + The router z-loss was introduced in Designing Effective Sparse Expert Models (https://arxiv.org/abs/2202.08906). It + encourages router logits to remain small in an effort to improve stability. + + Args: + router_logits (`float`): + Input logits of shape [num_groups, tokens_per_group, num_experts] + + Returns: + Scalar router z-loss. + """ + num_groups, tokens_per_group, _ = router_logits.shape + log_z = torch.logsumexp(router_logits, dim=-1) + z_loss = log_z**2 + return torch.sum(z_loss) / (num_groups * tokens_per_group) + + +def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.Tensor) -> float: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://arxiv.org/abs/2101.03961). This function implements the loss function presented in + equations (4) - (6). It aims to penalize those cases where the routing between experts is unbalanced. + + Args: + router_probs (`torch.Tensor`): + Probability assigned to each expert per token. Shape: [num_groups, tokens_per_group, num_experts]. + expert_indices (`torch.Tensor`): + Indices tensor of shape [num_groups, tokens_per_group, num_selected_experts] identifying the top + num_selected_experts for a given token. + + Returns: + The auxiliary loss. + """ + num_experts = router_probs.shape[-1] + + # Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts]. + # cast the expert indices to int64, otherwise one-hot encoding will fail + if expert_indices.dtype != torch.int64: + expert_indices = expert_indices.to(torch.int64) + expert_mask = torch.nn.functional.one_hot(expert_indices, num_experts) + + # For a given token, determine if it was routed to a given expert. + # Shape: [num_groups, tokens_per_group, num_experts] + expert_mask = torch.max(expert_mask, axis=-2).values + + # cast to float32 otherwise mean will fail + expert_mask = expert_mask.to(torch.float32) + tokens_per_group_and_expert = torch.mean(expert_mask, axis=-2) + + router_prob_per_group_and_expert = torch.mean(router_probs, axis=-2) + return torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert) * (num_experts**2) + + +class Router(nn.Module): + """ + Abstract base router class, defining router API and inner workings. + + Attributes: + router_weights (`torch.nn.Module`): + Configurable module used to compute router logits from token inputs. + jitter_noise (`float`): + Amplitude of jitter noise applied to router logits. + dtype (`torch.dtype`): + Numeric float type for returned combine torch.Tensor. All actual computations are performed in float32 of + the input for stability. + ignore_padding_tokens (`bool`): + Whether to ignore padding tokens during routing. Note that some routers (e.g. `TokensChooseMaskedRouter`) + will completely ignore padding tokens, while others (e.g. `TokensChooseScatterRouter` and + `ExpertsChooseMaskedRouter`) will simply down-weight the probability of selecting padding tokens. + """ + + def __init__(self, config, **kwargs): + super().__init__() + self.num_experts = config.num_experts + self.classifier = nn.Linear(config.hidden_size, self.num_experts, bias=config.router_bias) + self.jitter_noise = config.router_jitter_noise + self.ignore_padding_tokens = config.router_ignore_padding_tokens + self.dtype = getattr(torch, config.router_dtype) + + def _compute_router_probabilities( + self, token_inputs: torch.Tensor, num_experts: int, apply_jitter: bool + ) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Computes router probabilities from input tokens. + + Args: + token_inputs (`torch.Tensor`): + [num_groups, tokens_per_group, hidden_dim] from which router probabilities are computed. + num_experts (`int`): + Number of experts. + apply_jitter (`bool`): + If true, apply jitter noise. + + Returns: + router_probabilities (`torch.Tensor`): + Tensor of shape [num_groups, tokens_per_group, num_experts] corresponding to the probabilities for each + token and expert. Used for routing tokens to experts. + router_logits (`torch.Tensor`): + Logits tensor of shape [num_groups, tokens_per_group, num_experts] corresponding to raw router logits. + This is used later for computing router z-loss. + """ + # For remainder of routing computation we use float32 to ensure stability. + # See the discussion of "selective precision" in + # https://arxiv.org/abs/2101.03961. + # We also store the previous dtype to cast back the output to the previous dtype + self.input_tokens_dtype = token_inputs.dtype + token_inputs = token_inputs.to(self.dtype) + + if apply_jitter and self.jitter_noise > 0: + # Get the lower and upper bound of the uniform distribution + # Adapted from: https://stackoverflow.com/questions/44328530/how-to-get-a-uniform-distribution-in-a-range-r1-r2-in-pytorch + distrib_lower_bound = 1.0 - self.jitter_noise + distrib_upper_bound = 1.0 + self.jitter_noise + + uniform_distrib = ( + torch.rand(token_inputs.shape, device=token_inputs.device) + * (distrib_lower_bound - distrib_upper_bound) + ) + distrib_upper_bound + + # Multiply the token inputs by the uniform distribution - adding some noise + token_inputs *= uniform_distrib + + # Shape: [num_groups, tokens_per_group, num_experts] + router_logits = self.classifier(token_inputs.to(self.classifier.weight.dtype)) + + # computations in the router have to be done in float16 + router_probabilities = torch.nn.Softmax(dim=-1)(router_logits.to(self.dtype)) + + return router_probabilities, router_logits + + def forward( + self, token_inputs: torch.Tensor, expert_capacity: int, apply_jitter: bool = True, **kwargs + ) -> RouterOutput: + r""" + Generic forward function for every Router class. Each Router expects to have the same input hidden states + (`token_inputs`) corresponding to the hidden states for each token, the `expert_capacity` corresponding to the + number of tokens the Router will send to each expert, some Routers can send up to few tokens to each expert. + + Each Router works as the following: it expects the hidden states for each token, gets the `router_probs` and + `router_logits` from the `router_weights`. This will assign for each token, the raw probability to be assigned + to an expert. Then each Router class will have to define its own `_compute_routing_instructions`. + + Args: + Computes dispatch and combine torch.Tensors for routing to experts. + token_inputs: [num_groups, tokens_per_group, hidden_dim] inputs to send to experts. num_experts: + Number of experts. expert_capacity: Each group will send this many tokens to each expert. apply_jitter: If + true, apply jitter noise during routing. + Returns: + Router indices or mask torch.Tensors (depending on router type). + """ + router_probs, router_logits = self._compute_router_probabilities(token_inputs, self.num_experts, apply_jitter) + + # Flax code for reference + if self.ignore_padding_tokens: + # To identify non-padding tokens, we rely on the fact that padding tokens + # in the inputs have already been masked in the default T5 architecture. + # See + # https://github.com/google/flaxformer/blob/9712a16/flaxformer/architectures/t5/t5_architecture.py#L315 + # and + # https://github.com/google/flaxformer/blob/9712a16/flaxformer/architectures/t5/t5_architecture.py#L603. + padding_mask = torch.Tensor((torch.sum(torch.abs(token_inputs), axis=-1) > 0)).to(token_inputs.dtype) + router_logits *= padding_mask.unsqueeze(-1) + else: + padding_mask = None + + instructions = self._compute_routing_instructions(router_probs, padding_mask, expert_capacity, **kwargs) + # We cast back the output to the previous dtype + instructions = instructions.to(self.input_tokens_dtype) + + return replace(instructions, router_z_loss=router_z_loss_func(router_logits)) + + def _compute_routing_instructions(self, router_probs, padding_mask, expert_capacity): + raise NotImplementedError( + """ + Computes masks for the top-k experts per token. This has to be implemented for each subclass of + MaskedRouter routers. + + Args: + router_probs (`torch.Tensor`): + Input router probabilities of shape [num_groups, tokens_per_group, num_experts] this + corresponds to the probabilities used to determine the routing of tokens to the experts. + padding_mask (`torch.Tensor`): + Padding mask tensor of shape [num_groups, tokens_per_group] a mask used to identify padding + tokens that should be ignored by the router. + expert_capacity (`int`): + Each group will send this many tokens to each expert. + + Returns: + Router mask arrays. + """ + ) + + +class ExpertsChooseMaskedRouter(Router): + """ + Masked matmul router using experts choose tokens assignment. + + This router uses the same mechanism as in Mixture-of-Experts with Expert Choice (https://arxiv.org/abs/2202.09368): + each expert selects its top expert_capacity tokens. An individual token may be processed by multiple experts or + none at all. + + Note: "experts choose routing" should not be used in decoder blocks because it breaks the autoregressive behavior + -- the model will learn to cheat by using future token information to improve current token predictions. + """ + + def _compute_routing_instructions( + self, router_probs: torch.Tensor, padding_mask: Optional[torch.Tensor], expert_capacity: int + ) -> RouterMask: + """Computes masks for the highest probability token per expert. + + Args: + router_probs (`torch.Tensor`): + Raw router probabilities of shape [num_groups, tokens_per_group, num_experts] used to determine the + routing of tokens to the experts. + padding_mask (`torch.Tensor`): + padding mask tensor of shape [num_groups, tokens_per_group] used to identify padding tokens that should + be down-weighted by the router. + expert_capacity (`int`): + Each group will send this many tokens to each expert. + + Returns: + Dispatch and combine arrays for routing with masked matmuls. + """ + tokens_per_group = router_probs.shape[1] + + if padding_mask is not None: + # Because experts choose tokens, we mask probabilities corresponding to + # tokens before the top-k operation. Note that, unlike for masked-based + # tokens-choose routing, the experts here may still choose to select the + # (down-weighted) padding tokens. + router_probs *= padding_mask.unsqueeze(-1) + + # vmap over group dimension. + # router_probs_t = router_probs.t() + router_probs_t = router_probs.permute(0, 2, 1) + + # Top expert_capacity router probability and corresponding token indices for + # each expert. Shapes: [num_groups, num_experts, expert_capacity]. + expert_gate, expert_index = torch.topk(router_probs_t, k=expert_capacity) + + # Convert to one-hot mask of expert indices for each token in each group. + # Shape: [num_groups, num_experts, expert_capacity, tokens_per_group]. + dispatch_mask = _jax_one_hot(expert_index, tokens_per_group, dtype=torch.int32) + + # Move axes to conform with shape expected by MoeLayer API. + # Shape: [num_groups, tokens_per_group, num_experts, expert_capacity] + dispatch_mask = torch.moveaxis(dispatch_mask, 3, 1) + + # The combine array will be used for combining expert outputs, scaled by the + # router probabilities. Shape: [num_groups, num_experts, tokens_per_group, + # expert_capacity]. + combine_array = torch.einsum("...ec,...tec->...tec", expert_gate, dispatch_mask) + + # Each expert is choosing tokens until it reaches full capacity, so we don't + # need an auxiliary loading balancing loss for expert choice routing. + auxiliary_loss = 0.0 + + return RouterMask(dispatch_mask, combine_array, auxiliary_loss) + + +class TokensChooseMaskedRouter(Router): + """ + Masked matmul router using tokens choose top-k experts assignment. + + This router uses the same mechanism as in Switch Transformer (https://arxiv.org/abs/2101.03961) and V-MoE + (https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are sorted by router_probs and then + routed to their choice of expert until the expert's expert_capacity is reached. There is no guarantee that each + token is processed by an expert, or that each expert receives at least one token. + + Attributes: + num_selected_experts (`int`): + Maximum number of experts to which each token is routed. Tokens may be routed to fewer experts if particular + experts are oversubscribed / reach capacity. + batch_prioritized_routing (`bool`): + Whether or not to use Batch Prioritized Routing (BPR), originally introduced in V-MoE + (https://arxiv.org/abs/2106.05974). With BPR, we prioritize routing those top-k tokens with the highest router + probability, rather than simply using each tokens left-to-right ordering in the batch. This prioritization is + important because the experts have limited capacity. + """ + + def __init__(self, config, **kwargs): + super().__init__(config, **kwargs) + self.num_selected_experts = config.num_selected_experts + self.batch_prioritized_routing = config.batch_prioritized_routing + + def _compute_routing_instructions( + self, router_probs: torch.Tensor, padding_mask: Optional[torch.Tensor], expert_capacity: int + ) -> RouterMask: + """ + Computes masks for the top-k experts per token. + + Args: + router_probs (`torch.Tensor`): + Router raw probabilities tensor of shape [num_groups, tokens_per_group, num_experts] used to determine + the routing of tokens to the experts. + padding_mask (`torch.Tensor`): + Padding mask of shape [num_groups, tokens_per_group] used to identify padding tokens that should be + ignored by the router. + expert_capacity (`int`): + Each group will send this many tokens to each expert. + + Returns: + Dispatch and combine arrays for routing with masked matmuls. + """ + num_groups, _, num_experts = router_probs.shape + + # Top-k router probability and corresponding expert indices for each token. + # Shape: [num_groups, tokens_per_group, num_selected_experts]. + expert_gate, expert_index = torch.topk(router_probs, k=self.num_selected_experts) + + if padding_mask is not None: + # Mask applied to gate. Exclude choices corresponding to padding tokens. + gate_mask = padding_mask.unsqueeze(-1).to(expert_index.dtype) + expert_gate *= gate_mask + + # Set `expert_index` elements corresponding to padding to negative + # numbers. Negative `expert_index` elements will ultimately be dropped in + # the one_hot conversion to the `expert_mask`. + # First convert nonzero padding elements to negative values. + expert_index *= (2 * gate_mask) - 1 + # Handle zero padding elements by negatively shifting all padding. + expert_index += (gate_mask - 1).repeat(1, 1, self.num_selected_experts) + + # To correctly compute load balancing loss, we also mask out probs. + router_probs *= gate_mask + + auxiliary_loss = load_balancing_loss_func(router_probs, expert_index) + + if self.batch_prioritized_routing: + # Sort tokens according to their routing probability per group, so that + # the highest probability tokens are routed first. + permutation = torch.argsort(-expert_gate[..., 0], dim=-1) + # Shape: [num_groups, tokens_per_group, num_selected_experts] + expert_index = torch.take_along_dim(expert_index, permutation.unsqueeze(-1), dim=-2) + + # Make num_selected_experts the leading axis to ensure that top-1 choices + # have priority over top-2 choices, which have priority over top-3 choices, + # etc. + expert_index = expert_index.permute((0, 2, 1)) + # Shape: [num_groups, num_selected_experts * tokens_per_group] + expert_index = expert_index.reshape(num_groups, -1) + + # Create mask out of indices. + # Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts]. + expert_mask = torch.nn.functional.one_hot(expert_index, num_experts) + + # Experts have a fixed capacity that we cannot exceed. A token's priority + # within the expert's buffer is given by the masked, cumulative capacity of + # its target expert. + # Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts]. + token_priority = torch.cumsum(expert_mask, axis=1) * expert_mask - 1.0 + # Shape: [num_groups, num_selected_experts, tokens_per_group, num_experts]. + token_priority = token_priority.reshape((num_groups, self.num_selected_experts, -1, num_experts)) + # Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts]. + token_priority = token_priority.permute((0, 2, 1, 3)) + # For each token, across all selected experts, select the only non-negative + # (unmasked) priority. Now, for group G routing to expert E, token T has + # non-negative priority (i.e. token_priority[G,T,E] >= 0) if and only if E + # is its targeted expert. + # Shape: [num_groups, tokens_per_group, num_experts]. + token_priority = torch.max(token_priority, axis=2).values + + if self.batch_prioritized_routing: + # Place token priorities in original ordering of tokens. + inv_permutation = torch.argsort(permutation, dim=-1) + token_priority = torch.take_along_dim(token_priority, inv_permutation.unsqueeze(-1), dim=-2) + + # Token T can only be routed to expert E if its priority is positive and + # less than the expert capacity. One-hot matrix will ignore indices outside + # the range [0, expert_capacity). + # Shape: [num_groups, tokens_per_group, num_experts, expert_capacity]. + # token_priority = token_priority * (token_priority > 0) + + # TODO can we improve the function name or use torch's? + # dispatch_mask = torch.nn.functional.one_hot(token_priority.long(), expert_capacity + 1)[..., 1:] + dispatch_mask = _jax_one_hot(token_priority.long(), expert_capacity, axis=-1) + + # The combine array will be used for combining expert outputs, scaled by the + # router probabilities. Shape: [num_groups, tokens_per_group, num_experts, + # expert_capacity]. + # TODO can we use more understandable code here? + combine_array = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask) + # combine_array = torch.einsum("...te,...te->...te", router_probs, dispatch_mask) + + # Return to default dtype now that router computation is complete. + combine_array = combine_array.to(torch.float32) + + return RouterMask(dispatch_mask, combine_array, auxiliary_loss) + + + # Copied from transformers.models.t5.modeling_t5.T5LayerNorm with T5->SwitchTransformers class SwitchTransformersLayerNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): diff --git a/src/transformers/models/switch_transformers/router.py b/src/transformers/models/switch_transformers/router.py deleted file mode 100644 index 24578e59b5d86..0000000000000 --- a/src/transformers/models/switch_transformers/router.py +++ /dev/null @@ -1,648 +0,0 @@ -# coding=utf-8 -# Copyright 2022 Mesh TensorFlow authors, SwitchTransformers Authors and HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from dataclasses import dataclass, replace -from typing import Any, Optional, Tuple - -import torch -import torch.nn as nn - - -# from transformers.models.switch_transformers.configuration_switch_transformers import SwitchTransformersConfig - - -# Output classes -RouterOutput = Any - - -def _jax_one_hot(tensor, num_classes, axis=-1, dtype=torch.bool): - r""" - This function mimics the behavior of jax.nn.functional.one_hot in PyTorch. It takes a tensor of indices, the number - of desired classes and the axis to apply the one-hot encoding. If any value is outside the range [0, num_classes), - it will be set to zeros. - - Args: - tensor (`torch.Tensor`): - Input tensor - num_classes (`int`): - Number of classes to process for one hot encoding - axis (`int`, *optional*): - The lookup axis to check for one-hot encoding - dtype (`torch.dtype`, *optional*): - Output `dtype`. The one hot encoded vector will be casted to this dtype - """ - if tensor.is_floating_point(): - raise "Input tensor for one hot encoding must be an `int32` or `int64`" - - if axis >= len(tensor.shape): - raise "Axis is out of bounds" - - if axis == -1: - axis = len(tensor.shape) - elif axis < -1: - raise "Axis must be greater than -1" - else: - axis = axis + 1 - - # Get the final output shape - output_shape = list(tensor.shape) - output_shape.insert(axis, num_classes) - - # Create an empty output of zeros - out = torch.zeros(tuple(output_shape), dtype=dtype) - - # Mask out the places where it is outside the range [0, num_classes) - # kudos to GitHub copilot for this line - mask = (tensor >= 0) & (tensor < num_classes) - out[mask, tensor[mask]] = 1 - - return out.to(tensor.device) - - -@dataclass -class RouterIndices: - r""" - A dataclass wrapper to store the dispatch indices and combine weights for scatter/gather-based routing. - - Attributes: - dispatch_indices (`torch.Tensor`): - A tensor of size [`num_groups`, `tokens_per_group`, `num_selected_experts`, 2] dispatch indices indicating, - for each token, its preferred expert and its priority in that expert's buffer. - combine_weights (`torch.Tensor`): - A tensor of size [`num_groups`, `tokens_per_group`, `num_selected_experts`] combine weights used for - scaling expert outputs with the router's dispatch probability/confidence. - auxiliary_loss (`float`): - Load balancing loss for router. - router_z_loss (`float`): - Router z-loss. Encourages router logits to remain small in an effort to improve stability. - """ - dispatch_indices: torch.Tensor - combine_weights: torch.Tensor - auxiliary_loss: float - router_z_loss: float = 0.0 - - def to(self, device): - return replace( - self, dispatch_mask=self.dispatch_indices.to(device), combine_array=self.combine_weights.to(device) - ) - - -@dataclass -class RouterMask: - r""" - Dispatch and combine torch.Tensors for expert routing with masked matmuls. - - Attributes: - dispatch_mask (`torch.Tensor`): - A mask tensor of shape [num_groups, tokens_per_group, num_experts, expert_capacity] that is 1 if the token - gets routed to the corresponding expert, and 0 otherwise. - combine_array (`torch.Tensor`): - A tensor of shape [num_groups, tokens_per_group, num_experts, expert_capacity] combine torch.Tensor used - for combining expert outputs and scaling with router probability. - auxiliary_loss (`float`): - Load balancing loss for router. - router_z_loss (`float`): - Router z-loss. Encourages router logits to remain small in an effort to improve stability. - """ - dispatch_mask: torch.Tensor - combine_array: torch.Tensor - auxiliary_loss: float - router_z_loss: float = 0.0 - - def to(self, device): - return replace(self, dispatch_mask=self.dispatch_mask.to(device), combine_array=self.combine_array.to(device)) - - -# Router loss - - -def router_z_loss_func(router_logits: torch.Tensor) -> float: - r""" - Compute router z-loss implemented in PyTorch. - - The router z-loss was introduced in Designing Effective Sparse Expert Models (https://arxiv.org/abs/2202.08906). It - encourages router logits to remain small in an effort to improve stability. - - Args: - router_logits (`float`): - Input logits of shape [num_groups, tokens_per_group, num_experts] - - Returns: - Scalar router z-loss. - """ - num_groups, tokens_per_group, _ = router_logits.shape - log_z = torch.logsumexp(router_logits, dim=-1) - z_loss = log_z**2 - return torch.sum(z_loss) / (num_groups * tokens_per_group) - - -def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.Tensor) -> float: - r""" - Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. - - See Switch Transformer (https://arxiv.org/abs/2101.03961). This function implements the loss function presented in - equations (4) - (6). It aims to penalize those cases where the routing between experts is unbalanced. - - Args: - router_probs (`torch.Tensor`): - Probability assigned to each expert per token. Shape: [num_groups, tokens_per_group, num_experts]. - expert_indices (`torch.Tensor`): - Indices tensor of shape [num_groups, tokens_per_group, num_selected_experts] identifying the top - num_selected_experts for a given token. - - Returns: - The auxiliary loss. - """ - num_experts = router_probs.shape[-1] - - # Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts]. - # cast the expert indices to int64, otherwise one-hot encoding will fail - if expert_indices.dtype != torch.int64: - expert_indices = expert_indices.to(torch.int64) - expert_mask = torch.nn.functional.one_hot(expert_indices, num_experts) - - # For a given token, determine if it was routed to a given expert. - # Shape: [num_groups, tokens_per_group, num_experts] - expert_mask = torch.max(expert_mask, axis=-2).values - - # cast to float32 otherwise mean will fail - expert_mask = expert_mask.to(torch.float32) - tokens_per_group_and_expert = torch.mean(expert_mask, axis=-2) - - router_prob_per_group_and_expert = torch.mean(router_probs, axis=-2) - return torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert) * (num_experts**2) - - -# Router classes -# TODO not a big fan of 3 level of - - -class Router(nn.Module): - """ - Abstract base router class, defining router API and inner workings. - - Attributes: - router_weights (`torch.nn.Module`): - Configurable module used to compute router logits from token inputs. - jitter_noise (`float`): - Amplitude of jitter noise applied to router logits. - dtype (`torch.dtype`): - Numeric float type for returned combine torch.Tensor. All actual computations are performed in float32 of - the input for stability. - ignore_padding_tokens (`bool`): - Whether to ignore padding tokens during routing. Note that some routers (e.g. `TokensChooseMaskedRouter`) - will completely ignore padding tokens, while others (e.g. `TokensChooseScatterRouter` and - `ExpertsChooseMaskedRouter`) will simply down-weight the probability of selecting padding tokens. - """ - - def __init__(self, config, **kwargs): - super().__init__() - self.num_experts = config.num_experts - self.classifier = nn.Linear(config.hidden_size, self.num_experts, bias=config.router_bias) - self.jitter_noise = config.router_jitter_noise - self.ignore_padding_tokens = config.router_ignore_padding_tokens - self.dtype = getattr(torch, config.router_dtype) - - def _compute_router_probabilities( - self, token_inputs: torch.Tensor, num_experts: int, apply_jitter: bool - ) -> Tuple[torch.Tensor, torch.Tensor]: - r""" - Computes router probabilities from input tokens. - - Args: - token_inputs (`torch.Tensor`): - [num_groups, tokens_per_group, hidden_dim] from which router probabilities are computed. - num_experts (`int`): - Number of experts. - apply_jitter (`bool`): - If true, apply jitter noise. - - Returns: - router_probabilities (`torch.Tensor`): - Tensor of shape [num_groups, tokens_per_group, num_experts] corresponding to the probabilities for each - token and expert. Used for routing tokens to experts. - router_logits (`torch.Tensor`): - Logits tensor of shape [num_groups, tokens_per_group, num_experts] corresponding to raw router logits. - This is used later for computing router z-loss. - """ - # For remainder of routing computation we use float32 to ensure stability. - # See the discussion of "selective precision" in - # https://arxiv.org/abs/2101.03961. - # We also store the previous dtype to cast back the output to the previous dtype - self.input_tokens_dtype = token_inputs.dtype - token_inputs = token_inputs.to(self.dtype) - - if apply_jitter and self.jitter_noise > 0: - # Get the lower and upper bound of the uniform distribution - # Adapted from: https://stackoverflow.com/questions/44328530/how-to-get-a-uniform-distribution-in-a-range-r1-r2-in-pytorch - distrib_lower_bound = 1.0 - self.jitter_noise - distrib_upper_bound = 1.0 + self.jitter_noise - - uniform_distrib = ( - torch.rand(token_inputs.shape, device=token_inputs.device) - * (distrib_lower_bound - distrib_upper_bound) - ) + distrib_upper_bound - - # Multiply the token inputs by the uniform distribution - adding some noise - token_inputs *= uniform_distrib - - # Shape: [num_groups, tokens_per_group, num_experts] - router_logits = self.classifier(token_inputs) - - router_probabilities = torch.nn.Softmax(dim=-1)(router_logits) - - return router_probabilities, router_logits - - def forward( - self, token_inputs: torch.Tensor, expert_capacity: int, apply_jitter: bool = True, **kwargs - ) -> RouterOutput: - r""" - Generic forward function for every Router class. Each Router expects to have the same input hidden states - (`token_inputs`) corresponding to the hidden states for each token, the `expert_capacity` corresponding to the - number of tokens the Router will send to each expert, some Routers can send up to few tokens to each expert. - - Each Router works as the following: it expects the hidden states for each token, gets the `router_probs` and - `router_logits` from the `router_weights`. This will assign for each token, the raw probability to be assigned - to an expert. Then each Router class will have to define its own `_compute_routing_instructions`. - - Args: - Computes dispatch and combine torch.Tensors for routing to experts. - token_inputs: [num_groups, tokens_per_group, hidden_dim] inputs to send to experts. num_experts: - Number of experts. expert_capacity: Each group will send this many tokens to each expert. apply_jitter: If - true, apply jitter noise during routing. - Returns: - Router indices or mask torch.Tensors (depending on router type). - """ - router_probs, router_logits = self._compute_router_probabilities(token_inputs, self.num_experts, apply_jitter) - - # Flax code for reference - if self.ignore_padding_tokens: - # To identify non-padding tokens, we rely on the fact that padding tokens - # in the inputs have already been masked in the default T5 architecture. - # See - # https://github.com/google/flaxformer/blob/9712a16/flaxformer/architectures/t5/t5_architecture.py#L315 - # and - # https://github.com/google/flaxformer/blob/9712a16/flaxformer/architectures/t5/t5_architecture.py#L603. - padding_mask = torch.Tensor((torch.sum(torch.abs(token_inputs), axis=-1) > 0)).to(token_inputs.dtype) - router_logits *= padding_mask.unsqueeze(-1) - else: - padding_mask = None - - instructions = self._compute_routing_instructions(router_probs, padding_mask, expert_capacity, **kwargs) - # We cast back the output to the previous dtype - instructions = instructions.to(self.input_tokens_dtype) - - return replace(instructions, router_z_loss=router_z_loss_func(router_logits)) - - def _compute_routing_instructions(self, router_probs, padding_mask, expert_capacity): - raise NotImplementedError( - """ - Computes masks for the top-k experts per token. This has to be implemented for each subclass of - MaskedRouter routers. - - Args: - router_probs (`torch.Tensor`): - Input router probabilities of shape [num_groups, tokens_per_group, num_experts] this - corresponds to the probabilities used to determine the routing of tokens to the experts. - padding_mask (`torch.Tensor`): - Padding mask tensor of shape [num_groups, tokens_per_group] a mask used to identify padding - tokens that should be ignored by the router. - expert_capacity (`int`): - Each group will send this many tokens to each expert. - - Returns: - Router mask arrays. - """ - ) - - -class ExpertsChooseMaskedRouter(Router): - """ - Masked matmul router using experts choose tokens assignment. - - This router uses the same mechanism as in Mixture-of-Experts with Expert Choice (https://arxiv.org/abs/2202.09368): - each expert selects its top expert_capacity tokens. An individual token may be processed by multiple experts or - none at all. - - Note: "experts choose routing" should not be used in decoder blocks because it breaks the autoregressive behavior - -- the model will learn to cheat by using future token information to improve current token predictions. - """ - - def _compute_routing_instructions( - self, router_probs: torch.Tensor, padding_mask: Optional[torch.Tensor], expert_capacity: int - ) -> RouterMask: - """Computes masks for the highest probability token per expert. - - Args: - router_probs (`torch.Tensor`): - Raw router probabilities of shape [num_groups, tokens_per_group, num_experts] used to determine the - routing of tokens to the experts. - padding_mask (`torch.Tensor`): - padding mask tensor of shape [num_groups, tokens_per_group] used to identify padding tokens that should - be down-weighted by the router. - expert_capacity (`int`): - Each group will send this many tokens to each expert. - - Returns: - Dispatch and combine arrays for routing with masked matmuls. - """ - tokens_per_group = router_probs.shape[1] - - if padding_mask is not None: - # Because experts choose tokens, we mask probabilities corresponding to - # tokens before the top-k operation. Note that, unlike for masked-based - # tokens-choose routing, the experts here may still choose to select the - # (down-weighted) padding tokens. - router_probs *= padding_mask.unsqueeze(-1) - - # vmap over group dimension. - # router_probs_t = router_probs.t() - router_probs_t = router_probs.permute(0, 2, 1) - - # Top expert_capacity router probability and corresponding token indices for - # each expert. Shapes: [num_groups, num_experts, expert_capacity]. - expert_gate, expert_index = torch.topk(router_probs_t, k=expert_capacity) - - # Convert to one-hot mask of expert indices for each token in each group. - # Shape: [num_groups, num_experts, expert_capacity, tokens_per_group]. - dispatch_mask = _jax_one_hot(expert_index, tokens_per_group, dtype=torch.int32) - - # Move axes to conform with shape expected by MoeLayer API. - # Shape: [num_groups, tokens_per_group, num_experts, expert_capacity] - dispatch_mask = torch.moveaxis(dispatch_mask, 3, 1) - - # The combine array will be used for combining expert outputs, scaled by the - # router probabilities. Shape: [num_groups, num_experts, tokens_per_group, - # expert_capacity]. - combine_array = torch.einsum("...ec,...tec->...tec", expert_gate, dispatch_mask) - - # Each expert is choosing tokens until it reaches full capacity, so we don't - # need an auxiliary loading balancing loss for expert choice routing. - auxiliary_loss = 0.0 - - return RouterMask(dispatch_mask, combine_array, auxiliary_loss) - - -class TokensChooseMaskedRouter(Router): - """ - Masked matmul router using tokens choose top-k experts assignment. - - This router uses the same mechanism as in Switch Transformer (https://arxiv.org/abs/2101.03961) and V-MoE - (https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are sorted by router_probs and then - routed to their choice of expert until the expert's expert_capacity is reached. There is no guarantee that each - token is processed by an expert, or that each expert receives at least one token. - - Attributes: - num_selected_experts (`int`): - Maximum number of experts to which each token is routed. Tokens may be routed to fewer experts if particular - experts are oversubscribed / reach capacity. - batch_prioritized_routing (`bool`): - Whether or not to use Batch Prioritized Routing (BPR), originally introduced in V-MoE - (https://arxiv.org/abs/2106.05974). With BPR, we prioritize routing those top-k tokens with the highest router - probability, rather than simply using each tokens left-to-right ordering in the batch. This prioritization is - important because the experts have limited capacity. - """ - - def __init__(self, config, **kwargs): - super().__init__(config, **kwargs) - self.num_selected_experts = config.num_selected_experts - self.batch_prioritized_routing = config.batch_prioritized_routing - - def _compute_routing_instructions( - self, router_probs: torch.Tensor, padding_mask: Optional[torch.Tensor], expert_capacity: int - ) -> RouterMask: - """ - Computes masks for the top-k experts per token. - - Args: - router_probs (`torch.Tensor`): - Router raw probabilities tensor of shape [num_groups, tokens_per_group, num_experts] used to determine - the routing of tokens to the experts. - padding_mask (`torch.Tensor`): - Padding mask of shape [num_groups, tokens_per_group] used to identify padding tokens that should be - ignored by the router. - expert_capacity (`int`): - Each group will send this many tokens to each expert. - - Returns: - Dispatch and combine arrays for routing with masked matmuls. - """ - num_groups, _, num_experts = router_probs.shape - - # Top-k router probability and corresponding expert indices for each token. - # Shape: [num_groups, tokens_per_group, num_selected_experts]. - expert_gate, expert_index = torch.topk(router_probs, k=self.num_selected_experts) - - if padding_mask is not None: - # Mask applied to gate. Exclude choices corresponding to padding tokens. - gate_mask = padding_mask.unsqueeze(-1).to(expert_index.dtype) - expert_gate *= gate_mask - - # Set `expert_index` elements corresponding to padding to negative - # numbers. Negative `expert_index` elements will ultimately be dropped in - # the one_hot conversion to the `expert_mask`. - # First convert nonzero padding elements to negative values. - expert_index *= (2 * gate_mask) - 1 - # Handle zero padding elements by negatively shifting all padding. - expert_index += (gate_mask - 1).repeat(1, 1, self.num_selected_experts) - - # To correctly compute load balancing loss, we also mask out probs. - router_probs *= gate_mask - - auxiliary_loss = load_balancing_loss_func(router_probs, expert_index) - - if self.batch_prioritized_routing: - # Sort tokens according to their routing probability per group, so that - # the highest probability tokens are routed first. - permutation = torch.argsort(-expert_gate[..., 0], dim=-1) - # Shape: [num_groups, tokens_per_group, num_selected_experts] - expert_index = torch.take_along_dim(expert_index, permutation.unsqueeze(-1), dim=-2) - - # Make num_selected_experts the leading axis to ensure that top-1 choices - # have priority over top-2 choices, which have priority over top-3 choices, - # etc. - expert_index = expert_index.permute((0, 2, 1)) - # Shape: [num_groups, num_selected_experts * tokens_per_group] - expert_index = expert_index.reshape(num_groups, -1) - - # Create mask out of indices. - # Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts]. - expert_mask = torch.nn.functional.one_hot(expert_index, num_experts) - - # Experts have a fixed capacity that we cannot exceed. A token's priority - # within the expert's buffer is given by the masked, cumulative capacity of - # its target expert. - # Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts]. - token_priority = torch.cumsum(expert_mask, axis=1) * expert_mask - 1.0 - # Shape: [num_groups, num_selected_experts, tokens_per_group, num_experts]. - token_priority = token_priority.reshape((num_groups, self.num_selected_experts, -1, num_experts)) - # Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts]. - token_priority = token_priority.permute((0, 2, 1, 3)) - # For each token, across all selected experts, select the only non-negative - # (unmasked) priority. Now, for group G routing to expert E, token T has - # non-negative priority (i.e. token_priority[G,T,E] >= 0) if and only if E - # is its targeted expert. - # Shape: [num_groups, tokens_per_group, num_experts]. - token_priority = torch.max(token_priority, axis=2).values - - if self.batch_prioritized_routing: - # Place token priorities in original ordering of tokens. - inv_permutation = torch.argsort(permutation, dim=-1) - token_priority = torch.take_along_dim(token_priority, inv_permutation.unsqueeze(-1), dim=-2) - - # Token T can only be routed to expert E if its priority is positive and - # less than the expert capacity. One-hot matrix will ignore indices outside - # the range [0, expert_capacity). - # Shape: [num_groups, tokens_per_group, num_experts, expert_capacity]. - # token_priority = token_priority * (token_priority > 0) - - # TODO can we improve the function name or use torch's? - # dispatch_mask = torch.nn.functional.one_hot(token_priority.long(), expert_capacity + 1)[..., 1:] - dispatch_mask = _jax_one_hot(token_priority.long(), expert_capacity, axis=-1) - - # The combine array will be used for combining expert outputs, scaled by the - # router probabilities. Shape: [num_groups, tokens_per_group, num_experts, - # expert_capacity]. - # TODO can we use more understandable code here? - combine_array = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask) - # combine_array = torch.einsum("...te,...te->...te", router_probs, dispatch_mask) - - # Return to default dtype now that router computation is complete. - combine_array = combine_array.to(torch.float32) - - return RouterMask(dispatch_mask, combine_array, auxiliary_loss) - - -# num_groups = 2 -# tokens_per_group = 4 -# hidden_dim = 3 -# num_experts = 2 -# expert_capacity = 2 # Total capacity = 2*2*1 = 4 < num_tokens -# jitter_noise = 0.0 - -# input_tokens = torch.Tensor( -# [[[0.6433916 , 0.18188512, 0.02240455], -# [0.563781 , 0.5526401 , 0.0958724 ], -# [0.34253013, 0.03644359, 0.08744538], -# [0.7909105 , 0.35205448, 0.53364205]], - -# [[0.02900076, 0.4168595 , 0.5802449 ], -# [0.91486526, 0.27414513, 0.14991808], -# [0.9383501 , 0.5209162 , 0.51207185], -# [0.90618336, 0.7309413 , 0.95533276]]] -# ) - -# config = SwitchTransformersConfig( -# num_experts=num_experts, -# hidden_size=hidden_dim, -# router_jitter_noise=jitter_noise, -# expert_capacity=expert_capacity, -# batch_prioritized_routing=False, -# ) -# # model = TokensChooseMaskedRouter(config) -# model = ExpertsChooseMaskedRouter(config) - -# model.router_weights.weight = torch.nn.Parameter( -# torch.Tensor([[-0.00107201, 0.01544739], -# [-0.0087319 , 0.01314363], -# [ 0.03530733, 0.03709853]]).t() -# ) - -# model(input_tokens, expert_capacity=expert_capacity) - - -# hidden_dim = 4 -# num_experts = 2 -# num_selected_experts = 1 # Switch routing case -# expert_capacity = 1 # Total capacity = 2*2*1 = 4 < num_tokens -# jitter_noise = 0.0 - -# input_tokens = torch.Tensor( -# [ -# [ -# [0.6433916, 0.18188512, 0.02240455, 0.563781], -# [0.5526401, 0.0958724, 0.34253013, 0.03644359], -# [0.08744538, 0.7909105, 0.35205448, 0.53364205], -# ], -# [ -# [0.02900076, 0.4168595, 0.5802449, 0.91486526], -# [0.27414513, 0.14991808, 0.9383501, 0.5209162], -# [0.51207185, 0.90618336, 0.7309413, 0.95533276], -# ], -# ] -# ) - -# config = SwitchTransformersConfig( -# num_experts=num_experts, -# hidden_size=hidden_dim, -# num_selected_experts=num_selected_experts, -# router_jitter_noise=jitter_noise, -# expert_capacity=expert_capacity, -# batch_prioritized_routing=False, -# ) -# model = TokensChooseMaskedRouter(config) - -# model.router_weights.weight = torch.nn.Parameter( -# torch.Tensor( -# [ -# [0.02008116, 0.00620062], -# [-0.00811031, -0.00031623], -# [-0.03542127, 0.02703803], -# [0.02335377, -0.02971946], -# ], -# ).t() -# ) - -# output = model(input_tokens, expert_capacity=expert_capacity) - - -# num_groups = 2 -# tokens_per_group = 4 -# hidden_dim = 3 -# num_experts = 3 -# num_selected_experts = 1 -# expert_capacity = 2 -# jitter_noise = 0.0 - -# input_tokens = torch.Tensor( -# [[[0.6433916 , 0.18188512, 0.02240455], -# [0.563781 , 0.5526401 , 0.0958724 ], -# [0.34253013, 0.03644359, 0.08744538], -# [0.7909105 , 0.35205448, 0.53364205]], - -# [[0.02900076, 0.4168595 , 0.5802449 ], -# [0.91486526, 0.27414513, 0.14991808], -# [0.9383501 , 0.5209162 , 0.51207185], -# [0.90618336, 0.7309413 , 0.95533276]]] -# ) - -# config = SwitchTransformersConfig( -# num_experts=num_experts, -# hidden_size=hidden_dim, -# num_selected_experts=num_selected_experts, -# router_jitter_noise=jitter_noise, -# expert_capacity=expert_capacity, -# batch_prioritized_routing=False, -# ) -# model = TokensChooseScatterRouter(config) - -# model.router_weights.weight = torch.nn.Parameter( -# torch.Tensor( -# [[ 0.02736656, -0.00253537, 0.04682618], -# [ 0.00928149, 0.04933621, -0.00275501], -# [ 0.00751786, 0.04295348, -0.00503795]], -# ).t() -# ) - -# output = model(input_tokens, expert_capacity=expert_capacity) From 5bd7a6237c985ddd89e8c83faf962e2fe85fb441 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 24 Oct 2022 15:10:26 +0000 Subject: [PATCH 028/102] update tests wrt router deletion --- .../test_modeling_switch_transformers.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/models/switch_transformers/test_modeling_switch_transformers.py b/tests/models/switch_transformers/test_modeling_switch_transformers.py index 4989b919bce00..ed721c042b417 100644 --- a/tests/models/switch_transformers/test_modeling_switch_transformers.py +++ b/tests/models/switch_transformers/test_modeling_switch_transformers.py @@ -36,8 +36,6 @@ ) from transformers.models.switch_transformers.modeling_switch_transformers import ( SWITCH_TRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST, - ) - from transformers.models.switch_transformers.router import ( ExpertsChooseMaskedRouter, TokensChooseMaskedRouter, load_balancing_loss_func, @@ -962,7 +960,7 @@ def test_equivalency_token_chose_masked_router(self): ) model = TokensChooseMaskedRouter(config) - model.router_weights.weight = torch.nn.Parameter( + model.classifier.weight = torch.nn.Parameter( torch.Tensor( [ [0.02008116, 0.00620062], @@ -1032,7 +1030,7 @@ def test_equivalency_experts_chose_masked_router(self): model = ExpertsChooseMaskedRouter(config) - model.router_weights.weight = torch.nn.Parameter( + model.classifier.weight = torch.nn.Parameter( torch.Tensor([[-0.00107201, 0.01544739], [-0.0087319, 0.01314363], [0.03530733, 0.03709853]]).t() ) From 751cfdc6e42caae595937a4a2ed4cd3735cd6703 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 24 Oct 2022 16:31:17 +0000 Subject: [PATCH 029/102] fix router issues --- ...ers_original_flax_checkpoint_to_pytorch.py | 110 +++++++++++------- .../modeling_switch_transformers.py | 23 ++-- 2 files changed, 82 insertions(+), 51 deletions(-) diff --git a/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py b/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py index c2e7ff07de410..a9a89eeee1ea0 100644 --- a/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py +++ b/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py @@ -19,12 +19,11 @@ import re from t5x import checkpoints -from transformers import SwitchTransformersForConditionalGeneration, SwitchTransformersConfig +from transformers import SwitchTransformersConfig, SwitchTransformersForConditionalGeneration from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model from transformers.utils import logging from transformers.utils.hub import get_file_from_repo -from t5x import checkpoints logging.set_verbosity_info() @@ -32,16 +31,36 @@ MODEL_MAPPING = { - "switch_base_8":["https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin"], - "switch_base_8":["https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin"], - "switch_base_8":["https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin"], - "switch_base_8":["https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin"], - "switch_base_8":["https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin"], - "switch_base_8":["https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin"], - "switch_base_8":["https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin"], - "switch_base_8":["https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin"], - "switch_base_8":["https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin"], - "switch_base_8":["https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin"], + "switch_base_8": [ + "https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin" + ], + "switch_base_16": [ + "https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin" + ], + "switch_base_32": [ + "https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin" + ], + "switch_base_64": [ + "https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin" + ], + "switch_base_128": [ + "https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin" + ], + "switch_base_256": [ + "https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin" + ], + "switch_base_512": [ + "https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin" + ], + "switch_base_1024": [ + "https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin" + ], + "switch_base_2048": [ + "https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin" + ], + "switch_base_8": [ + "https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin" + ], } # should not include what is already done by the `from_pt` argument @@ -55,7 +74,7 @@ "out": "o", "pre_self_attention_layer_norm": "0/layer_norm", "pre_cross_attention_layer_norm": "1/layer_norm", - "pre_attention_layer_norm": "0/layer_norm", # previously 1, but seems wrong + "pre_attention_layer_norm": "0/layer_norm", # previously 1, but seems wrong "token_embedder": "shared", "encoder_norm": "final_layer_norm", "decoder_norm": "final_layer_norm", @@ -64,6 +83,7 @@ "roer/roer_weights/w/": "router/classifier/", } + def rename_keys(s_dict): # 1. in HF T5, we have block.{x}.layer.{y}. which corresponds to layer.{x} in # the original model @@ -74,7 +94,6 @@ def rename_keys(s_dict): if re.match(layer_to_block_of_layer, key): new_key = re.sub(r"layers_(\d+)", r"block/\1/layer", new_key) - layer_to_block_of_layer = r"(encoder|decoder)\/" if re.match(layer_to_block_of_layer, key): @@ -92,12 +111,15 @@ def rename_keys(s_dict): if old_key in new_key: new_key = new_key.replace(old_key, temp_key) - print(f"{key} -> {new_key}") s_dict[new_key] = s_dict.pop(key) - s_dict["encoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight"] = s_dict["encoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight"].T - s_dict["decoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight"] = s_dict["decoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight"].T + s_dict["encoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight"] = s_dict[ + "encoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight" + ].T + s_dict["decoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight"] = s_dict[ + "decoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight" + ].T # 3. Take extra care of the EXPERTS layer for key in list(s_dict.keys()): @@ -113,23 +135,25 @@ def rename_keys(s_dict): return s_dict -GIN_TO_CONFIG_MAPPING = { - "NUM_ENCODER_LAYERS":"num_layers", - "NUM_DECODER_LAYERS":"num_decoder_layers", - "NUM_HEADS":"num_heads", - "HEAD_DIM":"d_kv", - "EMBED_DIM":"d_model", - "MLP_DIM":"d_ff", - "NUM_SELECTED_EXPERTS":"num_selected_experts", - "NUM_ENCODER_SPARSE_LAYERS":"num_sparse_encoder_layers", - "NUM_DECODER_SPARSE_LAYERS":"num_sparse_decoder_layers", - "dense.MlpBlock.activations":"feed_forward_proj", +GIN_TO_CONFIG_MAPPING = { + "NUM_ENCODER_LAYERS": "num_layers", + "NUM_DECODER_LAYERS": "num_decoder_layers", + "NUM_HEADS": "num_heads", + "HEAD_DIM": "d_kv", + "EMBED_DIM": "d_model", + "MLP_DIM": "d_ff", + "NUM_SELECTED_EXPERTS": "num_selected_experts", + "NUM_ENCODER_SPARSE_LAYERS": "num_sparse_encoder_layers", + "NUM_DECODER_SPARSE_LAYERS": "num_sparse_decoder_layers", + "dense.MlpBlock.activations": "feed_forward_proj", } + def convert_gin_to_config(gin_file, num_experts): # Convert a google style config to the hugging face fromat import regex as re + with open(gin_file, "r") as f: raw_gin = f.read() @@ -137,7 +161,7 @@ def convert_gin_to_config(gin_file, num_experts): args = {} for param, value in regex_match: if param in GIN_TO_CONFIG_MAPPING and value != "": - args[GIN_TO_CONFIG_MAPPING[param]] = float(value) if '.' in value else int(value) + args[GIN_TO_CONFIG_MAPPING[param]] = float(value) if "." in value else int(value) activation = re.findall(r"(.*activations) = \(\'(.*)\',\)", raw_gin)[0] args[GIN_TO_CONFIG_MAPPING[activation[0]]] = str(activation[1]) @@ -146,7 +170,10 @@ def convert_gin_to_config(gin_file, num_experts): config = SwitchTransformersConfig(**args) return config -def convert_flax_checkpoint_to_pytorch(flax_checkpoint_path, config_file, gin_file = None, pytorch_dump_path = "./", num_experts = 8): + +def convert_flax_checkpoint_to_pytorch( + flax_checkpoint_path, config_file, gin_file=None, pytorch_dump_path="./", num_experts=8 +): # Initialise PyTorch model print(f"Loading flax weights from : {flax_checkpoint_path}") @@ -154,12 +181,12 @@ def convert_flax_checkpoint_to_pytorch(flax_checkpoint_path, config_file, gin_fi if gin_file is not None: config = convert_gin_to_config(gin_file, num_experts) - else : + else: config = SwitchTransformersConfig.from_pretrained(config_file) pt_model = SwitchTransformersForConditionalGeneration(config) - flax_params = flax_params['target'] + flax_params = flax_params["target"] flax_params = flatten_dict(flax_params, sep="/") flax_params = rename_keys(flax_params) flax_params = unflatten_dict(flax_params, sep="/") @@ -171,7 +198,6 @@ def convert_flax_checkpoint_to_pytorch(flax_checkpoint_path, config_file, gin_fi pt_model.save_pretrained(pytorch_dump_path) - if __name__ == "__main__": parser = argparse.ArgumentParser() # Required parameters @@ -186,7 +212,11 @@ def convert_flax_checkpoint_to_pytorch(flax_checkpoint_path, config_file, gin_fi ), ) parser.add_argument( - "--gin_file", default=None, type=str, required=False, help="Path to the gin config file. If not provided, a `config_file` has to be passed " + "--gin_file", + default=None, + type=str, + required=False, + help="Path to the gin config file. If not provided, a `config_file` has to be passed ", ) parser.add_argument( "--config_name", default=None, type=str, required=False, help="Config name of SwitchTransformers model." @@ -194,10 +224,12 @@ def convert_flax_checkpoint_to_pytorch(flax_checkpoint_path, config_file, gin_fi parser.add_argument( "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output pytorch model." ) - parser.add_argument( - "--num_experts", default=8, type=str, required=False, help="Number of experts" - ) + parser.add_argument("--num_experts", default=8, type=str, required=False, help="Number of experts") args = parser.parse_args() convert_flax_checkpoint_to_pytorch( - args.switch_t5x_checkpoint_path, args.config_name, args.gin_file,args.pytorch_dump_folder_path, args.num_experts - ) \ No newline at end of file + args.switch_t5x_checkpoint_path, + args.config_name, + args.gin_file, + args.pytorch_dump_folder_path, + args.num_experts, + ) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index b0516517c24c1..5e94c6db2d2c4 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -18,9 +18,11 @@ import copy import math import warnings -from typing import Optional, Tuple, Union +from dataclasses import dataclass, replace +from typing import Any, Optional, Tuple, Union import torch +import torch.nn as nn from torch import nn from torch.nn import CrossEntropyLoss from torch.utils.checkpoint import checkpoint @@ -45,11 +47,6 @@ ) from .configuration_switch_transformers import SwitchTransformersConfig -from dataclasses import dataclass, replace -from typing import Any, Optional, Tuple - -import torch -import torch.nn as nn logger = logging.get_logger(__name__) @@ -69,6 +66,7 @@ RouterOutput = Any + def _jax_one_hot(tensor, num_classes, axis=-1, dtype=torch.bool): r""" This function mimics the behavior of jax.nn.functional.one_hot in PyTorch. It takes a tensor of indices, the number @@ -289,7 +287,7 @@ def _compute_router_probabilities( distrib_upper_bound = 1.0 + self.jitter_noise uniform_distrib = ( - torch.rand(token_inputs.shape, device=token_inputs.device) + torch.rand(token_inputs.shape, device=token_inputs.device, dtype = self.dtype) * (distrib_lower_bound - distrib_upper_bound) ) + distrib_upper_bound @@ -297,10 +295,11 @@ def _compute_router_probabilities( token_inputs *= uniform_distrib # Shape: [num_groups, tokens_per_group, num_experts] - router_logits = self.classifier(token_inputs.to(self.classifier.weight.dtype)) + self.classifier.to(self.dtype) + router_logits = self.classifier(token_inputs) # computations in the router have to be done in float16 - router_probabilities = torch.nn.Softmax(dim=-1)(router_logits.to(self.dtype)) + router_probabilities = torch.nn.Softmax(dim=-1)(router_logits).to(self.input_tokens_dtype) return router_probabilities, router_logits @@ -564,7 +563,6 @@ def _compute_routing_instructions( return RouterMask(dispatch_mask, combine_array, auxiliary_loss) - # Copied from transformers.models.t5.modeling_t5.T5LayerNorm with T5->SwitchTransformers class SwitchTransformersLayerNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): @@ -696,10 +694,12 @@ def forward(self, hidden_states): expert_indices = self.router(hidden_states, expert_capacity=self.expert_capacity) masked_indices = expert_indices.dispatch_mask + dispatched_tokens = int(torch.sum(masked_indices)) for idx, expert in enumerate(self.experts.values()): # 1. Get the index of the tokens that are routed to the current expert - expert_indices = torch.eq(masked_indices[:, :, idx, :], 1).squeeze(-1) + expert_indices = torch.argmax(masked_indices[:, :, idx, :], -1).bool() # 2. Update hidden states + print(f"{(masked_indices[:, :, idx, :]).sum()}/{dispatched_tokens} tokens will be dispatched to expert {idx}") hidden_states[expert_indices] = expert(hidden_states[expert_indices]) return hidden_states @@ -1233,7 +1233,6 @@ def __init__(self, config, embed_tokens=None): self.block = nn.ModuleList() for i in range(config.num_layers): - is_sparse = (i % sparse_step == 1) if sparse_step > 0 else False self.block.append( From 76d0199e9430bafe98de20b995a5bafe42be7ab2 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 25 Oct 2022 08:03:48 +0000 Subject: [PATCH 030/102] update expert code --- .../switch_transformers/modeling_switch_transformers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 5e94c6db2d2c4..70c8c27c2b3c7 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -697,9 +697,9 @@ def forward(self, hidden_states): dispatched_tokens = int(torch.sum(masked_indices)) for idx, expert in enumerate(self.experts.values()): # 1. Get the index of the tokens that are routed to the current expert - expert_indices = torch.argmax(masked_indices[:, :, idx, :], -1).bool() + expert_indices = masked_indices[:, :, idx, :].sum(dim = -1).bool() # 2. Update hidden states - print(f"{(masked_indices[:, :, idx, :]).sum()}/{dispatched_tokens} tokens will be dispatched to expert {idx}") + print(f"{expert_indices.sum()}/{dispatched_tokens} tokens will be dispatched to expert {idx}") hidden_states[expert_indices] = expert(hidden_states[expert_indices]) return hidden_states From 4fde6495118392110dd883e570b38963080d616f Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 25 Oct 2022 14:08:09 +0000 Subject: [PATCH 031/102] update, logits match, code needsREFACTORING --- .../modeling_switch_transformers.py | 24 ++++++++++++------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 70c8c27c2b3c7..ddbcf674dafe4 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -459,7 +459,7 @@ def __init__(self, config, **kwargs): self.batch_prioritized_routing = config.batch_prioritized_routing def _compute_routing_instructions( - self, router_probs: torch.Tensor, padding_mask: Optional[torch.Tensor], expert_capacity: int + self, router_probs: torch.Tensor, padding_mask: Optional[torch.Tensor], expert_capacity: int = None ) -> RouterMask: """ Computes masks for the top-k experts per token. @@ -546,14 +546,14 @@ def _compute_routing_instructions( # Shape: [num_groups, tokens_per_group, num_experts, expert_capacity]. # token_priority = token_priority * (token_priority > 0) - # TODO can we improve the function name or use torch's? + # TODO remove this # dispatch_mask = torch.nn.functional.one_hot(token_priority.long(), expert_capacity + 1)[..., 1:] dispatch_mask = _jax_one_hot(token_priority.long(), expert_capacity, axis=-1) # The combine array will be used for combining expert outputs, scaled by the # router probabilities. Shape: [num_groups, tokens_per_group, num_experts, # expert_capacity]. - # TODO can we use more understandable code here? + # TODO remove this combine_array = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask) # combine_array = torch.einsum("...te,...te->...te", router_probs, dispatch_mask) @@ -690,17 +690,23 @@ def _get_router(self, config): ) def forward(self, hidden_states): - # TODO the expert capacity is poorly computed - expert_indices = self.router(hidden_states, expert_capacity=self.expert_capacity) - masked_indices = expert_indices.dispatch_mask + router_mask = self.router(hidden_states, expert_capacity = 64) + masked_indices = router_mask.dispatch_mask + + probs, _ = self.router._compute_router_probabilities(hidden_states,num_experts=8,apply_jitter=False) + + # computations in the router have to be done in float16 + dispatched_tokens = int(torch.sum(masked_indices)) for idx, expert in enumerate(self.experts.values()): # 1. Get the index of the tokens that are routed to the current expert - expert_indices = masked_indices[:, :, idx, :].sum(dim = -1).bool() + token_indices = masked_indices[:, :, idx].sum(-1).bool() # 2. Update hidden states - print(f"{expert_indices.sum()}/{dispatched_tokens} tokens will be dispatched to expert {idx}") - hidden_states[expert_indices] = expert(hidden_states[expert_indices]) + print(f"{token_indices.sum()}/{dispatched_tokens} tokens will be dispatched to expert {idx}") + hidden_states[token_indices] = expert(hidden_states[token_indices]) + + hidden_states = torch.max(probs,dim=-1).values.unsqueeze(-1) * hidden_states return hidden_states From c5263769d5ced43185c5cdd15a2cf685759f8d4b Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 25 Oct 2022 16:20:52 +0000 Subject: [PATCH 032/102] Refactor code Co-authored-by: Younes Belkada --- ...ers_original_flax_checkpoint_to_pytorch.py | 4 +- .../modeling_switch_transformers.py | 345 ++++-------------- .../test_modeling_switch_transformers.py | 118 +++--- 3 files changed, 114 insertions(+), 353 deletions(-) diff --git a/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py b/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py index a9a89eeee1ea0..5934e0f5c7a56 100644 --- a/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py +++ b/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py @@ -81,9 +81,9 @@ "relpos_bias/rel_embedding": "block/0/layer/0/SelfAttention/relative_attention_bias/weight", "router/router_weights/w/": "router/classifier/", "roer/roer_weights/w/": "router/classifier/", + "logits_dense":"lm_head" } - def rename_keys(s_dict): # 1. in HF T5, we have block.{x}.layer.{y}. which corresponds to layer.{x} in # the original model @@ -224,7 +224,7 @@ def convert_flax_checkpoint_to_pytorch( parser.add_argument( "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output pytorch model." ) - parser.add_argument("--num_experts", default=8, type=str, required=False, help="Number of experts") + parser.add_argument("--num_experts", default=8, type=int, required=False, help="Number of experts") args = parser.parse_args() convert_flax_checkpoint_to_pytorch( args.switch_t5x_checkpoint_path, diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index ddbcf674dafe4..d2a91a06a8788 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -51,7 +51,7 @@ logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "SwitchTransformersConfig" -_TOKENIZER_FOR_DOC = "SwitchTransformersTokenizer" +_TOKENIZER_FOR_DOC = "T5Tokenizer" _CHECKPOINT_FOR_DOC = "ybelkada/switch_transformers-base" #################################################### @@ -63,11 +63,7 @@ # See all SwitchTransformers models at https://huggingface.co/models?filter=switch_transformers ] - -RouterOutput = Any - - -def _jax_one_hot(tensor, num_classes, axis=-1, dtype=torch.bool): +def _one_hot(tensor, num_classes, axis=-1, dtype=torch.bool): r""" This function mimics the behavior of jax.nn.functional.one_hot in PyTorch. It takes a tensor of indices, the number of desired classes and the axis to apply the one-hot encoding. If any value is outside the range [0, num_classes), @@ -110,61 +106,6 @@ def _jax_one_hot(tensor, num_classes, axis=-1, dtype=torch.bool): return out.to(tensor.device) - -@dataclass -class RouterIndices: - r""" - A dataclass wrapper to store the dispatch indices and combine weights for scatter/gather-based routing. - - Attributes: - dispatch_indices (`torch.Tensor`): - A tensor of size [`num_groups`, `tokens_per_group`, `num_selected_experts`, 2] dispatch indices indicating, - for each token, its preferred expert and its priority in that expert's buffer. - combine_weights (`torch.Tensor`): - A tensor of size [`num_groups`, `tokens_per_group`, `num_selected_experts`] combine weights used for - scaling expert outputs with the router's dispatch probability/confidence. - auxiliary_loss (`float`): - Load balancing loss for router. - router_z_loss (`float`): - Router z-loss. Encourages router logits to remain small in an effort to improve stability. - """ - dispatch_indices: torch.Tensor - combine_weights: torch.Tensor - auxiliary_loss: float - router_z_loss: float = 0.0 - - def to(self, device): - return replace( - self, dispatch_mask=self.dispatch_indices.to(device), combine_array=self.combine_weights.to(device) - ) - - -@dataclass -class RouterMask: - r""" - Dispatch and combine torch.Tensors for expert routing with masked matmuls. - - Attributes: - dispatch_mask (`torch.Tensor`): - A mask tensor of shape [num_groups, tokens_per_group, num_experts, expert_capacity] that is 1 if the token - gets routed to the corresponding expert, and 0 otherwise. - combine_array (`torch.Tensor`): - A tensor of shape [num_groups, tokens_per_group, num_experts, expert_capacity] combine torch.Tensor used - for combining expert outputs and scaling with router probability. - auxiliary_loss (`float`): - Load balancing loss for router. - router_z_loss (`float`): - Router z-loss. Encourages router logits to remain small in an effort to improve stability. - """ - dispatch_mask: torch.Tensor - combine_array: torch.Tensor - auxiliary_loss: float - router_z_loss: float = 0.0 - - def to(self, device): - return replace(self, dispatch_mask=self.dispatch_mask.to(device), combine_array=self.combine_array.to(device)) - - # Router loss @@ -225,27 +166,32 @@ def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.T return torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert) * (num_experts**2) -class Router(nn.Module): + +class TokensChooseMaskedRouter(nn.Module): """ - Abstract base router class, defining router API and inner workings. + Masked matmul router using tokens choose top-k experts assignment. + + This router uses the same mechanism as in Switch Transformer (https://arxiv.org/abs/2101.03961) and V-MoE + (https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are sorted by router_probs and then + routed to their choice of expert until the expert's expert_capacity is reached. There is no guarantee that each + token is processed by an expert, or that each expert receives at least one token. Attributes: - router_weights (`torch.nn.Module`): - Configurable module used to compute router logits from token inputs. - jitter_noise (`float`): - Amplitude of jitter noise applied to router logits. - dtype (`torch.dtype`): - Numeric float type for returned combine torch.Tensor. All actual computations are performed in float32 of - the input for stability. - ignore_padding_tokens (`bool`): - Whether to ignore padding tokens during routing. Note that some routers (e.g. `TokensChooseMaskedRouter`) - will completely ignore padding tokens, while others (e.g. `TokensChooseScatterRouter` and - `ExpertsChooseMaskedRouter`) will simply down-weight the probability of selecting padding tokens. + num_selected_experts (`int`): + Maximum number of experts to which each token is routed. Tokens may be routed to fewer experts if particular + experts are oversubscribed / reach capacity. + batch_prioritized_routing (`bool`): + Whether or not to use Batch Prioritized Routing (BPR), originally introduced in V-MoE + (https://arxiv.org/abs/2106.05974). With BPR, we prioritize routing those top-k tokens with the highest router + probability, rather than simply using each tokens left-to-right ordering in the batch. This prioritization is + important because the experts have limited capacity. """ def __init__(self, config, **kwargs): super().__init__() self.num_experts = config.num_experts + self.batch_prioritized_routing = config.batch_prioritized_routing + self.num_experts = config.num_experts self.classifier = nn.Linear(config.hidden_size, self.num_experts, bias=config.router_bias) self.jitter_noise = config.router_jitter_noise self.ignore_padding_tokens = config.router_ignore_padding_tokens @@ -298,14 +244,13 @@ def _compute_router_probabilities( self.classifier.to(self.dtype) router_logits = self.classifier(token_inputs) - # computations in the router have to be done in float16 + # Apply Softmax and cast back to the original `dtype` router_probabilities = torch.nn.Softmax(dim=-1)(router_logits).to(self.input_tokens_dtype) - return router_probabilities, router_logits def forward( - self, token_inputs: torch.Tensor, expert_capacity: int, apply_jitter: bool = True, **kwargs - ) -> RouterOutput: + self, token_inputs: torch.Tensor, apply_jitter: bool = True, **kwargs + ) -> Tuple: r""" Generic forward function for every Router class. Each Router expects to have the same input hidden states (`token_inputs`) corresponding to the hidden states for each token, the `expert_capacity` corresponding to the @@ -338,129 +283,17 @@ def forward( else: padding_mask = None - instructions = self._compute_routing_instructions(router_probs, padding_mask, expert_capacity, **kwargs) + expert_index, auxiliary_loss = self._compute_routing_instructions(router_probs, padding_mask, **kwargs) # We cast back the output to the previous dtype - instructions = instructions.to(self.input_tokens_dtype) - - return replace(instructions, router_z_loss=router_z_loss_func(router_logits)) - - def _compute_routing_instructions(self, router_probs, padding_mask, expert_capacity): - raise NotImplementedError( - """ - Computes masks for the top-k experts per token. This has to be implemented for each subclass of - MaskedRouter routers. - - Args: - router_probs (`torch.Tensor`): - Input router probabilities of shape [num_groups, tokens_per_group, num_experts] this - corresponds to the probabilities used to determine the routing of tokens to the experts. - padding_mask (`torch.Tensor`): - Padding mask tensor of shape [num_groups, tokens_per_group] a mask used to identify padding - tokens that should be ignored by the router. - expert_capacity (`int`): - Each group will send this many tokens to each expert. - - Returns: - Router mask arrays. - """ - ) - - -class ExpertsChooseMaskedRouter(Router): - """ - Masked matmul router using experts choose tokens assignment. - - This router uses the same mechanism as in Mixture-of-Experts with Expert Choice (https://arxiv.org/abs/2202.09368): - each expert selects its top expert_capacity tokens. An individual token may be processed by multiple experts or - none at all. + expert_index = expert_index.to(self.input_tokens_dtype) + router_probs = torch.max(router_probs,dim=-1).values.unsqueeze(-1) + router_z_loss = router_z_loss_func(router_logits) + return expert_index, auxiliary_loss, router_z_loss, router_probs - Note: "experts choose routing" should not be used in decoder blocks because it breaks the autoregressive behavior - -- the model will learn to cheat by using future token information to improve current token predictions. - """ def _compute_routing_instructions( - self, router_probs: torch.Tensor, padding_mask: Optional[torch.Tensor], expert_capacity: int - ) -> RouterMask: - """Computes masks for the highest probability token per expert. - - Args: - router_probs (`torch.Tensor`): - Raw router probabilities of shape [num_groups, tokens_per_group, num_experts] used to determine the - routing of tokens to the experts. - padding_mask (`torch.Tensor`): - padding mask tensor of shape [num_groups, tokens_per_group] used to identify padding tokens that should - be down-weighted by the router. - expert_capacity (`int`): - Each group will send this many tokens to each expert. - - Returns: - Dispatch and combine arrays for routing with masked matmuls. - """ - tokens_per_group = router_probs.shape[1] - - if padding_mask is not None: - # Because experts choose tokens, we mask probabilities corresponding to - # tokens before the top-k operation. Note that, unlike for masked-based - # tokens-choose routing, the experts here may still choose to select the - # (down-weighted) padding tokens. - router_probs *= padding_mask.unsqueeze(-1) - - # vmap over group dimension. - # router_probs_t = router_probs.t() - router_probs_t = router_probs.permute(0, 2, 1) - - # Top expert_capacity router probability and corresponding token indices for - # each expert. Shapes: [num_groups, num_experts, expert_capacity]. - expert_gate, expert_index = torch.topk(router_probs_t, k=expert_capacity) - - # Convert to one-hot mask of expert indices for each token in each group. - # Shape: [num_groups, num_experts, expert_capacity, tokens_per_group]. - dispatch_mask = _jax_one_hot(expert_index, tokens_per_group, dtype=torch.int32) - - # Move axes to conform with shape expected by MoeLayer API. - # Shape: [num_groups, tokens_per_group, num_experts, expert_capacity] - dispatch_mask = torch.moveaxis(dispatch_mask, 3, 1) - - # The combine array will be used for combining expert outputs, scaled by the - # router probabilities. Shape: [num_groups, num_experts, tokens_per_group, - # expert_capacity]. - combine_array = torch.einsum("...ec,...tec->...tec", expert_gate, dispatch_mask) - - # Each expert is choosing tokens until it reaches full capacity, so we don't - # need an auxiliary loading balancing loss for expert choice routing. - auxiliary_loss = 0.0 - - return RouterMask(dispatch_mask, combine_array, auxiliary_loss) - - -class TokensChooseMaskedRouter(Router): - """ - Masked matmul router using tokens choose top-k experts assignment. - - This router uses the same mechanism as in Switch Transformer (https://arxiv.org/abs/2101.03961) and V-MoE - (https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are sorted by router_probs and then - routed to their choice of expert until the expert's expert_capacity is reached. There is no guarantee that each - token is processed by an expert, or that each expert receives at least one token. - - Attributes: - num_selected_experts (`int`): - Maximum number of experts to which each token is routed. Tokens may be routed to fewer experts if particular - experts are oversubscribed / reach capacity. - batch_prioritized_routing (`bool`): - Whether or not to use Batch Prioritized Routing (BPR), originally introduced in V-MoE - (https://arxiv.org/abs/2106.05974). With BPR, we prioritize routing those top-k tokens with the highest router - probability, rather than simply using each tokens left-to-right ordering in the batch. This prioritization is - important because the experts have limited capacity. - """ - - def __init__(self, config, **kwargs): - super().__init__(config, **kwargs) - self.num_selected_experts = config.num_selected_experts - self.batch_prioritized_routing = config.batch_prioritized_routing - - def _compute_routing_instructions( - self, router_probs: torch.Tensor, padding_mask: Optional[torch.Tensor], expert_capacity: int = None - ) -> RouterMask: + self, router_probs: torch.Tensor, padding_mask: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor,torch.Tensor]: """ Computes masks for the top-k experts per token. @@ -477,13 +310,12 @@ def _compute_routing_instructions( Returns: Dispatch and combine arrays for routing with masked matmuls. """ - num_groups, _, num_experts = router_probs.shape # Top-k router probability and corresponding expert indices for each token. # Shape: [num_groups, tokens_per_group, num_selected_experts]. - expert_gate, expert_index = torch.topk(router_probs, k=self.num_selected_experts) + expert_index = torch.argmax(router_probs, dim = -1) - if padding_mask is not None: + if padding_mask is not None: # TODO test or delete # Mask applied to gate. Exclude choices corresponding to padding tokens. gate_mask = padding_mask.unsqueeze(-1).to(expert_index.dtype) expert_gate *= gate_mask @@ -500,67 +332,8 @@ def _compute_routing_instructions( router_probs *= gate_mask auxiliary_loss = load_balancing_loss_func(router_probs, expert_index) - - if self.batch_prioritized_routing: - # Sort tokens according to their routing probability per group, so that - # the highest probability tokens are routed first. - permutation = torch.argsort(-expert_gate[..., 0], dim=-1) - # Shape: [num_groups, tokens_per_group, num_selected_experts] - expert_index = torch.take_along_dim(expert_index, permutation.unsqueeze(-1), dim=-2) - - # Make num_selected_experts the leading axis to ensure that top-1 choices - # have priority over top-2 choices, which have priority over top-3 choices, - # etc. - expert_index = expert_index.permute((0, 2, 1)) - # Shape: [num_groups, num_selected_experts * tokens_per_group] - expert_index = expert_index.reshape(num_groups, -1) - - # Create mask out of indices. - # Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts]. - expert_mask = torch.nn.functional.one_hot(expert_index, num_experts) - - # Experts have a fixed capacity that we cannot exceed. A token's priority - # within the expert's buffer is given by the masked, cumulative capacity of - # its target expert. - # Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts]. - token_priority = torch.cumsum(expert_mask, axis=1) * expert_mask - 1.0 - # Shape: [num_groups, num_selected_experts, tokens_per_group, num_experts]. - token_priority = token_priority.reshape((num_groups, self.num_selected_experts, -1, num_experts)) - # Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts]. - token_priority = token_priority.permute((0, 2, 1, 3)) - # For each token, across all selected experts, select the only non-negative - # (unmasked) priority. Now, for group G routing to expert E, token T has - # non-negative priority (i.e. token_priority[G,T,E] >= 0) if and only if E - # is its targeted expert. - # Shape: [num_groups, tokens_per_group, num_experts]. - token_priority = torch.max(token_priority, axis=2).values - - if self.batch_prioritized_routing: - # Place token priorities in original ordering of tokens. - inv_permutation = torch.argsort(permutation, dim=-1) - token_priority = torch.take_along_dim(token_priority, inv_permutation.unsqueeze(-1), dim=-2) - - # Token T can only be routed to expert E if its priority is positive and - # less than the expert capacity. One-hot matrix will ignore indices outside - # the range [0, expert_capacity). - # Shape: [num_groups, tokens_per_group, num_experts, expert_capacity]. - # token_priority = token_priority * (token_priority > 0) - - # TODO remove this - # dispatch_mask = torch.nn.functional.one_hot(token_priority.long(), expert_capacity + 1)[..., 1:] - dispatch_mask = _jax_one_hot(token_priority.long(), expert_capacity, axis=-1) - - # The combine array will be used for combining expert outputs, scaled by the - # router probabilities. Shape: [num_groups, tokens_per_group, num_experts, - # expert_capacity]. - # TODO remove this - combine_array = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask) - # combine_array = torch.einsum("...te,...te->...te", router_probs, dispatch_mask) - - # Return to default dtype now that router computation is complete. - combine_array = combine_array.to(torch.float32) - - return RouterMask(dispatch_mask, combine_array, auxiliary_loss) + expert_index = torch.nn.functional.one_hot(expert_index, self.num_experts) + return expert_index, auxiliary_loss # Copied from transformers.models.t5.modeling_t5.T5LayerNorm with T5->SwitchTransformers @@ -590,6 +363,7 @@ def forward(self, hidden_states): return self.weight * hidden_states +# TODO: do we need this???? ALL_LAYERNORM_LAYERS.append(SwitchTransformersLayerNorm) @@ -634,9 +408,22 @@ def forward(self, hidden_states): # This class should also contain a router class # check flaxformer/architecture/moe/router.py : https://github.com/google/flaxformer/blob/main/flaxformer/architectures/moe/routing.py class SwitchTransformersLayerFF(nn.Module): + r""" + Switch Transformers Feed Forward layer module. This is a wrapper around the Mixture of Experts + module. + + Attributes: + is_sparse (`bool`): + Whether the MLP layer is a `Sparse` layer (contains a Mixture of Experts) or not + mlp (`torch.nn.Module`): + The MLP layer of the Feed Forward layer + layer_norm (`torch.nn.Module`): + The pre-MLP layer norm. This module is equivalent to the `pre_mlp_layer_norm` in `t5x` + dropout (`torch.nn.Module`): + Post-MLP dropout layer. + """ def __init__(self, config: SwitchTransformersConfig, is_sparse=False): super().__init__() - # TODO: check the comments above self.is_sparse = is_sparse if not self.is_sparse: self.mlp = SwitchTransformersDenseActDense(config) @@ -655,12 +442,13 @@ def forward(self, hidden_states): class SwitchTransformersSparseMLP(nn.Module): r""" - Implementation of the Switch Transformers Sparse MLP module. TODO: Add a LOT of details here + Implementation of the Switch Transformers Sparse MLP module. + TODO: Add a LOT of details here """ def __init__(self, config: SwitchTransformersConfig, expert_class: nn.Module = SwitchTransformersDenseActDense): super().__init__() - # Step 1: Get the correct router + # Step 1: Get the correct router according to its class self.router = self._get_router(config) # Step 2: Get the experts @@ -681,8 +469,6 @@ def _get_router(self, config): # TODO, use a ALL_ROUTER_TYPE map instead of havind all the ifs? then just if None raise error. if config.router_type.lower() == "tokens_masked": return TokensChooseMaskedRouter(config) - elif config.router_type.lower() == "experts_masked": - return ExpertsChooseMaskedRouter(config) else: raise NotImplementedError( f"{config.router_type.lower()} not implemented ! Please chose a router in " @@ -690,23 +476,32 @@ def _get_router(self, config): ) def forward(self, hidden_states): - router_mask = self.router(hidden_states, expert_capacity = 64) - masked_indices = router_mask.dispatch_mask + r""" + Hold on, this will be slightly tricky to understand + In the correct order, a MoE layer does the following: + + 1- Gets the `router_mask` from the router. This mask will contain the indices of the + routed tokens. Also retrieve the probabilities (max prob) for each token. The probabilities are + needed in the computation of the hidden states since the probabilities will be broadcasted + to the hidden states values (they can be interpreted as a scaling factor). - probs, _ = self.router._compute_router_probabilities(hidden_states,num_experts=8,apply_jitter=False) + 2- TODO: explain @ArthurZucker - # computations in the router have to be done in float16 + """ + # Step 1: Get the router_mask from the router as wel as the probabilities + router_mask, auxiliary_loss, router_z_loss, router_probs = self.router(hidden_states) - dispatched_tokens = int(torch.sum(masked_indices)) for idx, expert in enumerate(self.experts.values()): + # 1. Get the index of the tokens that are routed to the current expert - token_indices = masked_indices[:, :, idx].sum(-1).bool() - # 2. Update hidden states - print(f"{token_indices.sum()}/{dispatched_tokens} tokens will be dispatched to expert {idx}") + # masked_indices has a shape of `batch_size`, `seq_len`, `num_experts` + token_indices = router_mask[:, :, idx].bool() + + # 2. Update only the hidden states affected by the routing hidden_states[token_indices] = expert(hidden_states[token_indices]) - hidden_states = torch.max(probs,dim=-1).values.unsqueeze(-1) * hidden_states + hidden_states = router_probs * hidden_states return hidden_states diff --git a/tests/models/switch_transformers/test_modeling_switch_transformers.py b/tests/models/switch_transformers/test_modeling_switch_transformers.py index ed721c042b417..5d55e696921b7 100644 --- a/tests/models/switch_transformers/test_modeling_switch_transformers.py +++ b/tests/models/switch_transformers/test_modeling_switch_transformers.py @@ -36,7 +36,6 @@ ) from transformers.models.switch_transformers.modeling_switch_transformers import ( SWITCH_TRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST, - ExpertsChooseMaskedRouter, TokensChooseMaskedRouter, load_balancing_loss_func, router_z_loss_func, @@ -993,86 +992,53 @@ def test_equivalency_token_chose_masked_router(self): self.assertTrue(torch.allclose(output.dispatch_mask, expected_dispatch_mask)) self.assertTrue(torch.allclose(output.combine_array, expected_combine_array, atol=1e-4)) - def test_equivalency_experts_chose_masked_router(self): +@require_torch +@require_tokenizers +class SwitchTransformerModelIntegrationTests(unittest.TestCase): + def test_small_logits(self): r""" - This test tests the equivalency between the `ExpertsChooseMaskedRouter` - originally implemented from here: TODO: provide link + Logits testing to check implementation consistency between `t5x` implementation + and `transformers` implementation of Switch-C transformers. We only check the logits + of the first batch. """ - hidden_dim = 3 - num_experts = 2 - expert_capacity = 2 # Total capacity = 2*2*1 = 4 < num_tokens - jitter_noise = 0.0 - - input_tokens = torch.Tensor( - [ - [ - [0.6433916, 0.18188512, 0.02240455], - [0.563781, 0.5526401, 0.0958724], - [0.34253013, 0.03644359, 0.08744538], - [0.7909105, 0.35205448, 0.53364205], - ], - [ - [0.02900076, 0.4168595, 0.5802449], - [0.91486526, 0.27414513, 0.14991808], - [0.9383501, 0.5209162, 0.51207185], - [0.90618336, 0.7309413, 0.95533276], - ], - ] - ) - - config = SwitchTransformersConfig( - num_experts=num_experts, - hidden_size=hidden_dim, - router_jitter_noise=jitter_noise, - expert_capacity=expert_capacity, - batch_prioritized_routing=False, - ) - - model = ExpertsChooseMaskedRouter(config) - - model.classifier.weight = torch.nn.Parameter( - torch.Tensor([[-0.00107201, 0.01544739], [-0.0087319, 0.01314363], [0.03530733, 0.03709853]]).t() - ) - - output = model(input_tokens, expert_capacity=expert_capacity) - - self.assertEqual(output.auxiliary_loss, 0.0) - self.assertAlmostEqual(output.router_z_loss.item(), 0.507016, places=5) - - expected_dispatch_mask = torch.Tensor( - [ - [[[0, 1], [0, 0]], [[0, 0], [0, 1]], [[1, 0], [0, 0]], [[0, 0], [1, 0]]], - [[[1, 0], [0, 0]], [[0, 1], [0, 0]], [[0, 0], [0, 1]], [[0, 0], [1, 0]]], - ] - ) - - self.assertTrue(torch.allclose(output.dispatch_mask, expected_dispatch_mask)) - - expected_combined_array = torch.Tensor( - [ - [ - [[0.0000, 0.4963], [0.0000, 0.0000]], - [[0.0000, 0.0000], [0.0000, 0.5054]], - [[0.4983, 0.0000], [0.0000, 0.0000]], - [[0.0000, 0.0000], [0.5054, 0.0000]], - ], - [ - [[0.4973, 0.0000], [0.0000, 0.0000]], - [[0.0000, 0.4947], [0.0000, 0.0000]], - [[0.0000, 0.0000], [0.0000, 0.5070]], - [[0.0000, 0.0000], [0.5082, 0.0000]], - ], - ] - ) + model = SwitchTransformersForConditionalGeneration.from_pretrained("HFLAY/switch_base_8", torch_dtype=torch.bfloat16).eval() + input_ids = torch.ones((32,64), dtype = torch.long) + decoder_input_ids = torch.ones((32,64), dtype = torch.long) + + EXPECTED_MEAN_LOGITS = torch.Tensor( + [-29.330458, -29.332455, -29.333147, -29.341417, -29.472025, + -29.335613, -29.47691 , -29.328053, -29.328312, -29.329872, + -29.336075, -29.331112, -29.30393 , -29.328972, -29.33514 , + -29.335201, -29.317245, -29.48052 , -29.328382, -29.4837 , + -29.489216, -29.338572, -29.331537, -29.337881, -29.497675, + -29.483559, -29.497217, -29.343832, -29.483425, -29.333313, + -29.49259 , -29.318579, -29.478128, -29.328222, -29.339464, + -29.329647, -29.339725, -29.648586, -29.312738, -29.314232, + -29.330048, -29.314402, -29.329876, -29.33895 , -29.337482, + -29.477829, -29.482548, -29.337194, -29.487375, -29.33446 , + -29.340445, -29.479067, -29.333689, -29.338657, -29.339827, + -29.33101 , -29.482433, -29.498121, -29.477905, -29.33606 , + -29.333132, -29.335573, -29.482475, -29.330212],) + + hf_logits = model(input_ids, decoder_input_ids=decoder_input_ids).logits + hf_logits = hf_logits.mean(dim=-1)[0] + + self.assertTrue(torch.testing.assert_allclose(hf_logits, EXPECTED_MEAN_LOGITS, rtol=1e-3, atol=1e-3)) + + def test_small_generate(self): + r""" + Generate test using the smalled switch-C model. + """ + model = SwitchTransformersForConditionalGeneration.from_pretrained("HFLAY/switch_base_8", torch_dtype=torch.bfloat16).eval() + tokenizer = SwitchTransformersForConditionalGeneration.from_pretrained("t5-small") - self.assertTrue(torch.allclose(output.combine_array, expected_combined_array, atol=1e-4)) + input_ids = tokenizer("summarize: Hello world", return_tensors="pt").input_ids.to(torch_device) + sequences = model.generate(input_ids) + output_str = tokenizer.batch_decode(sequences, skip_special_tokens=True)[0] + EXPECTED_OUTPUT = " . The best way to do it is to use a smartphone. Hello there" + self.assertisEqual(output_str, EXPECTED_OUTPUT) -@require_torch -@require_tokenizers -class SwitchTransformerModelIntegrationTests(unittest.TestCase): - def test_small_logits(self): - pass def test_large_logits(self): pass From 5673476e3fe843fd76cd41b18c07713d9cc86497 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 26 Oct 2022 09:27:51 +0000 Subject: [PATCH 033/102] add generate tests Co-authored-by: younesbelkada --- .../modeling_switch_transformers.py | 5 +++-- .../test_modeling_switch_transformers.py | 20 ++++++++++++++----- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index d2a91a06a8788..94fee28b3fbb5 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -59,7 +59,8 @@ # for the pretrained weights provided with the models #################################################### SWITCH_TRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "ybelkada/switch_transformers-base", + "ybelkada/switch-base-8", + "ybelkada/switch-base-16", # See all SwitchTransformers models at https://huggingface.co/models?filter=switch_transformers ] @@ -363,7 +364,7 @@ def forward(self, hidden_states): return self.weight * hidden_states -# TODO: do we need this???? +# TODO: do we need this? No let's just import ALL_LAYERNORM_LAYERS. ALL_LAYERNORM_LAYERS.append(SwitchTransformersLayerNorm) diff --git a/tests/models/switch_transformers/test_modeling_switch_transformers.py b/tests/models/switch_transformers/test_modeling_switch_transformers.py index 5d55e696921b7..94e3b59f1e822 100644 --- a/tests/models/switch_transformers/test_modeling_switch_transformers.py +++ b/tests/models/switch_transformers/test_modeling_switch_transformers.py @@ -1019,16 +1019,15 @@ def test_small_logits(self): -29.340445, -29.479067, -29.333689, -29.338657, -29.339827, -29.33101 , -29.482433, -29.498121, -29.477905, -29.33606 , -29.333132, -29.335573, -29.482475, -29.330212],) - + hf_logits = model(input_ids, decoder_input_ids=decoder_input_ids).logits hf_logits = hf_logits.mean(dim=-1)[0] self.assertTrue(torch.testing.assert_allclose(hf_logits, EXPECTED_MEAN_LOGITS, rtol=1e-3, atol=1e-3)) - + def test_small_generate(self): - r""" - Generate test using the smalled switch-C model. - """ + #Generate test using the smalled switch-C model. + model = SwitchTransformersForConditionalGeneration.from_pretrained("HFLAY/switch_base_8", torch_dtype=torch.bfloat16).eval() tokenizer = SwitchTransformersForConditionalGeneration.from_pretrained("t5-small") @@ -1039,6 +1038,17 @@ def test_small_generate(self): EXPECTED_OUTPUT = " . The best way to do it is to use a smartphone. Hello there" self.assertisEqual(output_str, EXPECTED_OUTPUT) + input_ids = tokenizer("The human walks into a bar and orders a ", return_tensors="pt").input_ids.to(torch_device) + sequences = model.generate(input_ids) + output_str = tokenizer.batch_decode(sequences, skip_special_tokens=True)[0] + self.assertisEqual(output_str, "drink.") + + input_ids = tokenizer("A walks into a bar a orders a with pinch of .", return_tensors="pt").input_ids.to(torch_device) + sequences = model.generate(input_ids) + output_str = tokenizer.batch_decode(sequences, skip_special_tokens=False)[0] + + EXPECTED_OUTPUT = " man beer a salt." + self.assertisEqual(output_str, EXPECTED_OUTPUT) def test_large_logits(self): pass From 25ec9b6287cff079fbc345d438e71d5feb5d328e Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 26 Oct 2022 11:01:11 +0000 Subject: [PATCH 034/102] add support for router loss Co-authored-by: Younes Belkada --- src/transformers/modeling_outputs.py | 196 ++++++++++++++++++ .../configuration_switch_transformers.py | 4 + .../modeling_switch_transformers.py | 79 +++++-- 3 files changed, 261 insertions(+), 18 deletions(-) diff --git a/src/transformers/modeling_outputs.py b/src/transformers/modeling_outputs.py index 1ffc019d8492c..bebd748df53f7 100644 --- a/src/transformers/modeling_outputs.py +++ b/src/transformers/modeling_outputs.py @@ -285,6 +285,58 @@ class BaseModelOutputWithPastAndCrossAttentions(ModelOutput): attentions: Optional[Tuple[torch.FloatTensor]] = None cross_attentions: Optional[Tuple[torch.FloatTensor]] = None +@dataclass +class MoEModelOutputWithPastAndCrossAttentions(ModelOutput): + """ + Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding) as well as + Mixture of Expert's router hidden states terms, to train a MoE model. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + router_probs (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, + num_experts)`. + + Raw router probabilities that are computed by MoE routers, these terms are used to compute the auxiliary loss for Mixture of Experts models. + """ + + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + router_probs: Optional[Tuple[torch.FloatTensor]] = None + @dataclass class Seq2SeqModelOutput(ModelOutput): @@ -346,6 +398,78 @@ class Seq2SeqModelOutput(ModelOutput): encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None +@dataclass +class Seq2SeqMoEModelOutput(ModelOutput): + """ + Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential + decoding. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the optional initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + decoder_router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, + num_experts)`. + + Router logits of the decoder model, useful to compute the auxiliary loss for Mixture of Experts models. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the optional initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + encoder_router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, + num_experts)`. + + Router logits of the encoder model, useful to compute the auxiliary loss for Mixture of Experts models. + """ + + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + decoder_router_logits: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + encoder_router_logits: Optional[Tuple[torch.FloatTensor]] = None + @dataclass class CausalLMOutput(ModelOutput): @@ -580,6 +704,78 @@ class Seq2SeqLMOutput(ModelOutput): encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None +@dataclass +class Seq2SeqMoEOutput(ModelOutput): + """ + Base class for sequence-to-sequence language models outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + decoder_router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, + num_experts)`. + + Router logits of the decoder model, useful to compute the auxiliary loss for Mixture of Experts models. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + encoder_router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, + num_experts)`. + + Router logits of the encoder model, useful to compute the auxiliary loss for Mixture of Experts models. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + decoder_router_logits: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + encoder_router_logits: Optional[Tuple[torch.FloatTensor]] = None + + @dataclass class NextSentencePredictorOutput(ModelOutput): diff --git a/src/transformers/models/switch_transformers/configuration_switch_transformers.py b/src/transformers/models/switch_transformers/configuration_switch_transformers.py index 8850d22136065..a2a9f5038df61 100644 --- a/src/transformers/models/switch_transformers/configuration_switch_transformers.py +++ b/src/transformers/models/switch_transformers/configuration_switch_transformers.py @@ -80,6 +80,8 @@ class SwitchTransformersConfig(PretrainedConfig): "selective precision" discussion in https://arxiv.org/abs/2101.03961. batch_prioritized_routing (`bool`, *optional*, defaults to `False`): Whether to use batch prioritized routing. + add_router_probs (`bool`, *optional*, defaults to `False`): + Whether to output router probabilities to compute router auxiliary loss. num_selected_experts (`int`, *optional*, defaults to 2): Number of experts to select for each token. relative_attention_num_buckets (`int`, *optional*, defaults to 32): @@ -130,6 +132,7 @@ def __init__( initializer_factor=1.0, feed_forward_proj="relu", is_encoder_decoder=True, + add_router_probs=False, use_cache=True, pad_token_id=0, eos_token_id=1, @@ -183,6 +186,7 @@ def __init__( self.initializer_factor = initializer_factor self.feed_forward_proj = feed_forward_proj self.use_cache = use_cache + self.add_router_probs = add_router_probs act_info = self.feed_forward_proj.split("-") self.dense_act_fn = act_info[-1] diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 94fee28b3fbb5..d79a464c7294a 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -30,9 +30,9 @@ from ...activations import ACT2FN from ...modeling_outputs import ( BaseModelOutput, - BaseModelOutputWithPastAndCrossAttentions, - Seq2SeqLMOutput, - Seq2SeqModelOutput, + MoEModelOutputWithPastAndCrossAttentions, + Seq2SeqMoEOutput, + Seq2SeqMoEModelOutput, ) from ...modeling_utils import PreTrainedModel from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer @@ -108,8 +108,6 @@ def _one_hot(tensor, num_classes, axis=-1, dtype=torch.bool): return out.to(tensor.device) # Router loss - - def router_z_loss_func(router_logits: torch.Tensor) -> float: r""" Compute router z-loss implemented in PyTorch. @@ -434,16 +432,26 @@ def __init__(self, config: SwitchTransformersConfig, is_sparse=False): self.layer_norm = SwitchTransformersLayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) - def forward(self, hidden_states): + def forward(self, hidden_states, output_router_logits): forwarded_states = self.layer_norm(hidden_states) forwarded_states = self.mlp(forwarded_states) - hidden_states = hidden_states + self.dropout(forwarded_states) - return hidden_states + + if isinstance(forwarded_states, tuple): + forwarded_states, router_logits = forwarded_states + else: + router_logits = None + + output = hidden_states + self.dropout(forwarded_states) + + if output_router_logits and router_logits is not None: + output = (output, router_logits) + + return output class SwitchTransformersSparseMLP(nn.Module): r""" - Implementation of the Switch Transformers Sparse MLP module. + Implementation of the Switch Transformers Sparse MLP module. TODO: Add a LOT of details here """ @@ -503,7 +511,7 @@ def forward(self, hidden_states): hidden_states[token_indices] = expert(hidden_states[token_indices]) hidden_states = router_probs * hidden_states - return hidden_states + return hidden_states, router_probs # Copied from transformers.models.t5.modeling_t5.T5Attention with T5->SwitchTransformers @@ -834,6 +842,7 @@ def forward( past_key_value=None, use_cache=False, output_attentions=False, + output_router_logits=False, return_dict=True, ): @@ -906,7 +915,12 @@ def forward( attention_outputs = attention_outputs + cross_attention_outputs[2:] # Apply Feed Forward layer - hidden_states = self.layer[-1](hidden_states) + hidden_states = self.layer[-1](hidden_states, output_router_logits) + + if isinstance(hidden_states, tuple) : + hidden_states, router_probs = hidden_states + else: + router_probs = (None,) # clamp inf values to enable fp16 training if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): @@ -920,6 +934,8 @@ def forward( else: outputs = outputs + attention_outputs + outputs = outputs + (router_probs,) + return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) @@ -1070,6 +1086,7 @@ def forward( use_cache=None, output_attentions=None, output_hidden_states=None, + output_router_logits=None, return_dict=None, ): # Model parallel @@ -1142,6 +1159,7 @@ def forward( present_key_value_states = () if use_cache else None all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None + all_router_probs = () if output_router_logits else None all_cross_attentions = () if (output_attentions and self.is_decoder) else None position_bias = None encoder_decoder_position_bias = None @@ -1210,6 +1228,7 @@ def custom_forward(*inputs): past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, + output_router_logits=output_router_logits, ) # layer_outputs is a tuple with: @@ -1234,6 +1253,9 @@ def custom_forward(*inputs): if self.is_decoder: all_cross_attentions = all_cross_attentions + (layer_outputs[5],) + if output_router_logits: + all_router_probs = all_router_probs + (layer_outputs[-1],) + # Model Parallel: If it's the last layer for that device, put things on the next device if self.model_parallel: for k, v in self.device_map.items(): @@ -1259,12 +1281,13 @@ def custom_forward(*inputs): ] if v is not None ) - return BaseModelOutputWithPastAndCrossAttentions( + return MoEModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=present_key_value_states, hidden_states=all_hidden_states, attentions=all_attentions, cross_attentions=all_cross_attentions, + router_probs=all_router_probs, ) @@ -1380,6 +1403,9 @@ def custom_forward(*inputs): output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, + and should not be returned during inference. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -1489,7 +1515,7 @@ class PreTrainedModel self.encoder.layer[layer].attention.prune_heads(heads) @add_start_docstrings_to_model_forward(SWITCH_TRANSFORMERS_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) + @replace_return_docstrings(output_type=Seq2SeqMoEModelOutput, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -1506,8 +1532,9 @@ def forward( use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, - ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqMoEModelOutput]: r""" Returns: @@ -1550,6 +1577,7 @@ def forward( head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, return_dict=return_dict, ) elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): @@ -1585,13 +1613,14 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, return_dict=return_dict, ) if not return_dict: return decoder_outputs + encoder_outputs - return Seq2SeqModelOutput( + return Seq2SeqMoEModelOutput( last_hidden_state=decoder_outputs.last_hidden_state, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, @@ -1600,6 +1629,8 @@ def forward( encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, + encoder_router_logits=encoder_outputs.router_probs, + decoder_router_logits=decoder_outputs.router_probs ) @@ -1664,7 +1695,7 @@ def get_decoder(self): return self.decoder @add_start_docstrings_to_model_forward(SWITCH_TRANSFORMERS_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + @replace_return_docstrings(output_type=Seq2SeqMoEOutput, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -1682,8 +1713,9 @@ def forward( use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, - ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqMoEOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ..., @@ -1734,6 +1766,7 @@ def forward( head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, return_dict=return_dict, ) elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): @@ -1776,6 +1809,7 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, return_dict=return_dict, ) @@ -1797,6 +1831,11 @@ def forward( loss = None if labels is not None: loss_fct = CrossEntropyLoss(ignore_index=-100) + # todo check in the config if router loss enables + if output_router_logits : + router_z_loss = router_z_loss_func(encoder_outputs.router_probs) + decoder_outputs.router_probs + loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 @@ -1804,7 +1843,7 @@ def forward( output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs return ((loss,) + output) if loss is not None else output - return Seq2SeqLMOutput( + return Seq2SeqMoEOutput( loss=loss, logits=lm_logits, past_key_values=decoder_outputs.past_key_values, @@ -1814,6 +1853,8 @@ def forward( encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, + encoder_router_logits=encoder_outputs.router_probs, + decoder_router_logits=decoder_outputs.router_probs, ) def prepare_inputs_for_generation( @@ -1926,6 +1967,7 @@ def forward( inputs_embeds: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]: r""" @@ -1953,6 +1995,7 @@ def forward( head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, return_dict=return_dict, ) From 55ab162e5901140883483092c8977f88ac48c971 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 26 Oct 2022 14:12:15 +0200 Subject: [PATCH 035/102] fix forward error --- src/transformers/modeling_outputs.py | 36 +++++- ...ers_original_flax_checkpoint_to_pytorch.py | 3 +- .../modeling_switch_transformers.py | 102 ++++++++-------- .../test_modeling_switch_transformers.py | 110 ++++++++++++++---- 4 files changed, 175 insertions(+), 76 deletions(-) diff --git a/src/transformers/modeling_outputs.py b/src/transformers/modeling_outputs.py index bebd748df53f7..5d3218838464d 100644 --- a/src/transformers/modeling_outputs.py +++ b/src/transformers/modeling_outputs.py @@ -285,6 +285,39 @@ class BaseModelOutputWithPastAndCrossAttentions(ModelOutput): attentions: Optional[Tuple[torch.FloatTensor]] = None cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + +@dataclass +class MoEModelOutput(ModelOutput): + """ + Base class for model's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + router_probs (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, + num_experts)`. + + Raw router probabilities that are computed by MoE routers, these terms are used to compute the auxiliary loss for Mixture of Experts models. + """ + + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + router_probs: Optional[Tuple[torch.FloatTensor]] = None + + @dataclass class MoEModelOutputWithPastAndCrossAttentions(ModelOutput): """ @@ -398,6 +431,7 @@ class Seq2SeqModelOutput(ModelOutput): encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + @dataclass class Seq2SeqMoEModelOutput(ModelOutput): """ @@ -704,6 +738,7 @@ class Seq2SeqLMOutput(ModelOutput): encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + @dataclass class Seq2SeqMoEOutput(ModelOutput): """ @@ -776,7 +811,6 @@ class Seq2SeqMoEOutput(ModelOutput): encoder_router_logits: Optional[Tuple[torch.FloatTensor]] = None - @dataclass class NextSentencePredictorOutput(ModelOutput): """ diff --git a/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py b/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py index 5934e0f5c7a56..437818a43f301 100644 --- a/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py +++ b/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py @@ -81,9 +81,10 @@ "relpos_bias/rel_embedding": "block/0/layer/0/SelfAttention/relative_attention_bias/weight", "router/router_weights/w/": "router/classifier/", "roer/roer_weights/w/": "router/classifier/", - "logits_dense":"lm_head" + "logits_dense": "lm_head", } + def rename_keys(s_dict): # 1. in HF T5, we have block.{x}.layer.{y}. which corresponds to layer.{x} in # the original model diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index d79a464c7294a..feab42150a89d 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -29,10 +29,10 @@ from ...activations import ACT2FN from ...modeling_outputs import ( - BaseModelOutput, + MoEModelOutput, MoEModelOutputWithPastAndCrossAttentions, - Seq2SeqMoEOutput, Seq2SeqMoEModelOutput, + Seq2SeqMoEOutput, ) from ...modeling_utils import PreTrainedModel from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer @@ -52,18 +52,19 @@ _CONFIG_FOR_DOC = "SwitchTransformersConfig" _TOKENIZER_FOR_DOC = "T5Tokenizer" -_CHECKPOINT_FOR_DOC = "ybelkada/switch_transformers-base" +_CHECKPOINT_FOR_DOC = "google/switch-base-8" #################################################### # This dict contains ids and associated url # for the pretrained weights provided with the models #################################################### SWITCH_TRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "ybelkada/switch-base-8", - "ybelkada/switch-base-16", + "google/switch-base-8", + "google/switch-base-16", # See all SwitchTransformers models at https://huggingface.co/models?filter=switch_transformers ] + def _one_hot(tensor, num_classes, axis=-1, dtype=torch.bool): r""" This function mimics the behavior of jax.nn.functional.one_hot in PyTorch. It takes a tensor of indices, the number @@ -107,6 +108,7 @@ def _one_hot(tensor, num_classes, axis=-1, dtype=torch.bool): return out.to(tensor.device) + # Router loss def router_z_loss_func(router_logits: torch.Tensor) -> float: r""" @@ -165,7 +167,6 @@ def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.T return torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert) * (num_experts**2) - class TokensChooseMaskedRouter(nn.Module): """ Masked matmul router using tokens choose top-k experts assignment. @@ -232,7 +233,7 @@ def _compute_router_probabilities( distrib_upper_bound = 1.0 + self.jitter_noise uniform_distrib = ( - torch.rand(token_inputs.shape, device=token_inputs.device, dtype = self.dtype) + torch.rand(token_inputs.shape, device=token_inputs.device, dtype=self.dtype) * (distrib_lower_bound - distrib_upper_bound) ) + distrib_upper_bound @@ -247,9 +248,7 @@ def _compute_router_probabilities( router_probabilities = torch.nn.Softmax(dim=-1)(router_logits).to(self.input_tokens_dtype) return router_probabilities, router_logits - def forward( - self, token_inputs: torch.Tensor, apply_jitter: bool = True, **kwargs - ) -> Tuple: + def forward(self, token_inputs: torch.Tensor, apply_jitter: bool = True, **kwargs) -> Tuple: r""" Generic forward function for every Router class. Each Router expects to have the same input hidden states (`token_inputs`) corresponding to the hidden states for each token, the `expert_capacity` corresponding to the @@ -285,14 +284,15 @@ def forward( expert_index, auxiliary_loss = self._compute_routing_instructions(router_probs, padding_mask, **kwargs) # We cast back the output to the previous dtype expert_index = expert_index.to(self.input_tokens_dtype) - router_probs = torch.max(router_probs,dim=-1).values.unsqueeze(-1) + router_probs = torch.max(router_probs, dim=-1).values.unsqueeze(-1) router_z_loss = router_z_loss_func(router_logits) return expert_index, auxiliary_loss, router_z_loss, router_probs - def _compute_routing_instructions( - self, router_probs: torch.Tensor, padding_mask: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor,torch.Tensor]: + self, + router_probs: torch.Tensor, + padding_mask: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: """ Computes masks for the top-k experts per token. @@ -312,9 +312,9 @@ def _compute_routing_instructions( # Top-k router probability and corresponding expert indices for each token. # Shape: [num_groups, tokens_per_group, num_selected_experts]. - expert_index = torch.argmax(router_probs, dim = -1) + expert_index = torch.argmax(router_probs, dim=-1) - if padding_mask is not None: # TODO test or delete + if padding_mask is not None: # TODO test or delete # Mask applied to gate. Exclude choices corresponding to padding tokens. gate_mask = padding_mask.unsqueeze(-1).to(expert_index.dtype) expert_gate *= gate_mask @@ -362,7 +362,7 @@ def forward(self, hidden_states): return self.weight * hidden_states -# TODO: do we need this? No let's just import ALL_LAYERNORM_LAYERS. +# TODO: do we need this? No let's just import ALL_LAYERNORM_LAYERS. ALL_LAYERNORM_LAYERS.append(SwitchTransformersLayerNorm) @@ -402,28 +402,27 @@ def forward(self, hidden_states): return hidden_states -# TODO: Change it here to adapt it from the paper, the FF layer contains experts -# an expert is a FF layer with multiple sub-FF layers inside.s -# This class should also contain a router class -# check flaxformer/architecture/moe/router.py : https://github.com/google/flaxformer/blob/main/flaxformer/architectures/moe/routing.py class SwitchTransformersLayerFF(nn.Module): r""" - Switch Transformers Feed Forward layer module. This is a wrapper around the Mixture of Experts - module. - - Attributes: - is_sparse (`bool`): - Whether the MLP layer is a `Sparse` layer (contains a Mixture of Experts) or not - mlp (`torch.nn.Module`): - The MLP layer of the Feed Forward layer - layer_norm (`torch.nn.Module`): - The pre-MLP layer norm. This module is equivalent to the `pre_mlp_layer_norm` in `t5x` - dropout (`torch.nn.Module`): - Post-MLP dropout layer. + Switch Transformers Feed Forward layer module. This is a wrapper around the Mixture of Experts + module. + + Attributes: + is_sparse (`bool`): + Whether the MLP layer is a `Sparse` layer (contains a Mixture of Experts) or not + mlp (`torch.nn.Module`): + The MLP layer of the Feed Forward layer + layer_norm (`torch.nn.Module`): + The pre-MLP layer norm. This module is equivalent to the `pre_mlp_layer_norm` in `t5x` + dropout (`torch.nn.Module`): + Post-MLP dropout layer. """ + def __init__(self, config: SwitchTransformersConfig, is_sparse=False): super().__init__() self.is_sparse = is_sparse + + # Check if it is a sparse layer, if not then it is a dense layer if not self.is_sparse: self.mlp = SwitchTransformersDenseActDense(config) else: @@ -444,7 +443,7 @@ def forward(self, hidden_states, output_router_logits): output = hidden_states + self.dropout(forwarded_states) if output_router_logits and router_logits is not None: - output = (output, router_logits) + output = (output, router_logits) return output @@ -486,20 +485,20 @@ def _get_router(self, config): def forward(self, hidden_states): r""" - Hold on, this will be slightly tricky to understand - In the correct order, a MoE layer does the following: + Hold on, this will be slightly tricky to understand + In the correct order, a MoE layer does the following: - 1- Gets the `router_mask` from the router. This mask will contain the indices of the - routed tokens. Also retrieve the probabilities (max prob) for each token. The probabilities are - needed in the computation of the hidden states since the probabilities will be broadcasted - to the hidden states values (they can be interpreted as a scaling factor). + 1- Gets the `router_mask` from the router. This mask will contain the indices of the + routed tokens. Also retrieve the probabilities (max prob) for each token. The probabilities are + needed in the computation of the hidden states since the probabilities will be broadcasted + to the hidden states values (they can be interpreted as a scaling factor). - 2- TODO: explain @ArthurZucker + 2- TODO: explain @ArthurZucker """ # Step 1: Get the router_mask from the router as wel as the probabilities - router_mask, auxiliary_loss, router_z_loss, router_probs = self.router(hidden_states) + router_mask, auxiliary_loss, router_z_loss, router_probs = self.router(hidden_states) for idx, expert in enumerate(self.experts.values()): @@ -917,7 +916,7 @@ def forward( # Apply Feed Forward layer hidden_states = self.layer[-1](hidden_states, output_router_logits) - if isinstance(hidden_states, tuple) : + if isinstance(hidden_states, tuple): hidden_states, router_probs = hidden_states else: router_probs = (None,) @@ -1278,6 +1277,7 @@ def custom_forward(*inputs): all_hidden_states, all_attentions, all_cross_attentions, + all_router_probs, ] if v is not None ) @@ -1580,8 +1580,8 @@ def forward( output_router_logits=output_router_logits, return_dict=return_dict, ) - elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): - encoder_outputs = BaseModelOutput( + elif return_dict and not isinstance(encoder_outputs, MoEModelOutput): + encoder_outputs = MoEModelOutput( last_hidden_state=encoder_outputs[0], hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, @@ -1630,7 +1630,7 @@ def forward( encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, encoder_router_logits=encoder_outputs.router_probs, - decoder_router_logits=decoder_outputs.router_probs + decoder_router_logits=decoder_outputs.router_probs, ) @@ -1769,8 +1769,8 @@ def forward( output_router_logits=output_router_logits, return_dict=return_dict, ) - elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): - encoder_outputs = BaseModelOutput( + elif return_dict and not isinstance(encoder_outputs, MoEModelOutput): + encoder_outputs = MoEModelOutput( last_hidden_state=encoder_outputs[0], hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, @@ -1832,7 +1832,7 @@ def forward( if labels is not None: loss_fct = CrossEntropyLoss(ignore_index=-100) # todo check in the config if router loss enables - if output_router_logits : + if output_router_logits: router_z_loss = router_z_loss_func(encoder_outputs.router_probs) decoder_outputs.router_probs @@ -1958,7 +1958,7 @@ class PreTrainedModel self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads) @add_start_docstrings_to_model_forward(SWITCH_TRANSFORMERS_ENCODER_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) + @replace_return_docstrings(output_type=MoEModelOutput, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -1969,7 +1969,7 @@ def forward( output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, - ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]: + ) -> Union[Tuple[torch.FloatTensor], MoEModelOutput]: r""" Returns: diff --git a/tests/models/switch_transformers/test_modeling_switch_transformers.py b/tests/models/switch_transformers/test_modeling_switch_transformers.py index 94e3b59f1e822..2d8b114d82a50 100644 --- a/tests/models/switch_transformers/test_modeling_switch_transformers.py +++ b/tests/models/switch_transformers/test_modeling_switch_transformers.py @@ -992,33 +992,90 @@ def test_equivalency_token_chose_masked_router(self): self.assertTrue(torch.allclose(output.dispatch_mask, expected_dispatch_mask)) self.assertTrue(torch.allclose(output.combine_array, expected_combine_array, atol=1e-4)) + @require_torch @require_tokenizers class SwitchTransformerModelIntegrationTests(unittest.TestCase): def test_small_logits(self): r""" - Logits testing to check implementation consistency between `t5x` implementation - and `transformers` implementation of Switch-C transformers. We only check the logits - of the first batch. + Logits testing to check implementation consistency between `t5x` implementation + and `transformers` implementation of Switch-C transformers. We only check the logits + of the first batch. """ - model = SwitchTransformersForConditionalGeneration.from_pretrained("HFLAY/switch_base_8", torch_dtype=torch.bfloat16).eval() - input_ids = torch.ones((32,64), dtype = torch.long) - decoder_input_ids = torch.ones((32,64), dtype = torch.long) + model = SwitchTransformersForConditionalGeneration.from_pretrained( + "HFLAY/switch_base_8", torch_dtype=torch.bfloat16 + ).eval() + input_ids = torch.ones((32, 64), dtype=torch.long) + decoder_input_ids = torch.ones((32, 64), dtype=torch.long) EXPECTED_MEAN_LOGITS = torch.Tensor( - [-29.330458, -29.332455, -29.333147, -29.341417, -29.472025, - -29.335613, -29.47691 , -29.328053, -29.328312, -29.329872, - -29.336075, -29.331112, -29.30393 , -29.328972, -29.33514 , - -29.335201, -29.317245, -29.48052 , -29.328382, -29.4837 , - -29.489216, -29.338572, -29.331537, -29.337881, -29.497675, - -29.483559, -29.497217, -29.343832, -29.483425, -29.333313, - -29.49259 , -29.318579, -29.478128, -29.328222, -29.339464, - -29.329647, -29.339725, -29.648586, -29.312738, -29.314232, - -29.330048, -29.314402, -29.329876, -29.33895 , -29.337482, - -29.477829, -29.482548, -29.337194, -29.487375, -29.33446 , - -29.340445, -29.479067, -29.333689, -29.338657, -29.339827, - -29.33101 , -29.482433, -29.498121, -29.477905, -29.33606 , - -29.333132, -29.335573, -29.482475, -29.330212],) + [ + -29.330458, + -29.332455, + -29.333147, + -29.341417, + -29.472025, + -29.335613, + -29.47691, + -29.328053, + -29.328312, + -29.329872, + -29.336075, + -29.331112, + -29.30393, + -29.328972, + -29.33514, + -29.335201, + -29.317245, + -29.48052, + -29.328382, + -29.4837, + -29.489216, + -29.338572, + -29.331537, + -29.337881, + -29.497675, + -29.483559, + -29.497217, + -29.343832, + -29.483425, + -29.333313, + -29.49259, + -29.318579, + -29.478128, + -29.328222, + -29.339464, + -29.329647, + -29.339725, + -29.648586, + -29.312738, + -29.314232, + -29.330048, + -29.314402, + -29.329876, + -29.33895, + -29.337482, + -29.477829, + -29.482548, + -29.337194, + -29.487375, + -29.33446, + -29.340445, + -29.479067, + -29.333689, + -29.338657, + -29.339827, + -29.33101, + -29.482433, + -29.498121, + -29.477905, + -29.33606, + -29.333132, + -29.335573, + -29.482475, + -29.330212, + ], + ) hf_logits = model(input_ids, decoder_input_ids=decoder_input_ids).logits hf_logits = hf_logits.mean(dim=-1)[0] @@ -1026,9 +1083,11 @@ def test_small_logits(self): self.assertTrue(torch.testing.assert_allclose(hf_logits, EXPECTED_MEAN_LOGITS, rtol=1e-3, atol=1e-3)) def test_small_generate(self): - #Generate test using the smalled switch-C model. + # Generate test using the smalled switch-C model. - model = SwitchTransformersForConditionalGeneration.from_pretrained("HFLAY/switch_base_8", torch_dtype=torch.bfloat16).eval() + model = SwitchTransformersForConditionalGeneration.from_pretrained( + "HFLAY/switch_base_8", torch_dtype=torch.bfloat16 + ).eval() tokenizer = SwitchTransformersForConditionalGeneration.from_pretrained("t5-small") input_ids = tokenizer("summarize: Hello world", return_tensors="pt").input_ids.to(torch_device) @@ -1038,12 +1097,17 @@ def test_small_generate(self): EXPECTED_OUTPUT = " . The best way to do it is to use a smartphone. Hello there" self.assertisEqual(output_str, EXPECTED_OUTPUT) - input_ids = tokenizer("The human walks into a bar and orders a ", return_tensors="pt").input_ids.to(torch_device) + input_ids = tokenizer( + "The human walks into a bar and orders a ", return_tensors="pt" + ).input_ids.to(torch_device) sequences = model.generate(input_ids) output_str = tokenizer.batch_decode(sequences, skip_special_tokens=True)[0] self.assertisEqual(output_str, "drink.") - input_ids = tokenizer("A walks into a bar a orders a with pinch of .", return_tensors="pt").input_ids.to(torch_device) + input_ids = tokenizer( + "A walks into a bar a orders a with pinch of .", + return_tensors="pt", + ).input_ids.to(torch_device) sequences = model.generate(input_ids) output_str = tokenizer.batch_decode(sequences, skip_special_tokens=False)[0] From fa47eeffabb7132e2897ef93103382cf74626bdc Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 26 Oct 2022 15:10:48 +0200 Subject: [PATCH 036/102] refactor a bit --- ...ansformers.mdx => switch_transformers.mdx} | 0 example_switch.py | 15 + src/transformers/modeling_outputs.py | 28 +- ...ers_original_flax_checkpoint_to_pytorch.py | 7 +- ...switch_transformersx_checkpoint_to_flax.py | 288 ------------------ .../modeling_switch_transformers.py | 135 ++++---- 6 files changed, 105 insertions(+), 368 deletions(-) rename docs/source/en/model_doc/{switchtransformers.mdx => switch_transformers.mdx} (100%) create mode 100644 example_switch.py delete mode 100644 src/transformers/models/switch_transformers/convert_switch_transformersx_checkpoint_to_flax.py diff --git a/docs/source/en/model_doc/switchtransformers.mdx b/docs/source/en/model_doc/switch_transformers.mdx similarity index 100% rename from docs/source/en/model_doc/switchtransformers.mdx rename to docs/source/en/model_doc/switch_transformers.mdx diff --git a/example_switch.py b/example_switch.py new file mode 100644 index 0000000000000..f5a9256c2448c --- /dev/null +++ b/example_switch.py @@ -0,0 +1,15 @@ +from transformers import AutoTokenizer, SwitchTransformersForConditionalGeneration + +tokenizer = AutoTokenizer.from_pretrained("t5-small") +text = "A walks into a bar a orders a with pinch of ." +model = SwitchTransformersForConditionalGeneration.from_pretrained("HFLAY/switch_base_8") + +input_ids = tokenizer(text, return_tensors="pt").input_ids +out = model.generate(input_ids, decoder_start_token_id=0, output_router_logits=True) +print(tokenizer.decode(out[0])) + +# Loss + +input_ids = tokenizer("The walks in park", return_tensors="pt").input_ids +labels = tokenizer(" cute dog the ", return_tensors="pt").input_ids +outputs = model(input_ids=input_ids, labels=labels, output_router_logits=True) \ No newline at end of file diff --git a/src/transformers/modeling_outputs.py b/src/transformers/modeling_outputs.py index 5d3218838464d..e0d21f2b00662 100644 --- a/src/transformers/modeling_outputs.py +++ b/src/transformers/modeling_outputs.py @@ -306,10 +306,10 @@ class MoEModelOutput(ModelOutput): Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. router_probs (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, - num_experts)`. + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. - Raw router probabilities that are computed by MoE routers, these terms are used to compute the auxiliary loss for Mixture of Experts models. + Raw router probabilities that are computed by MoE routers, these terms are used to compute the auxiliary + loss for Mixture of Experts models. """ last_hidden_state: torch.FloatTensor = None @@ -357,10 +357,10 @@ class MoEModelOutputWithPastAndCrossAttentions(ModelOutput): Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the weighted average in the cross-attention heads. router_probs (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, - num_experts)`. + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. - Raw router probabilities that are computed by MoE routers, these terms are used to compute the auxiliary loss for Mixture of Experts models. + Raw router probabilities that are computed by MoE routers, these terms are used to compute the auxiliary + loss for Mixture of Experts models. """ last_hidden_state: torch.FloatTensor = None @@ -463,8 +463,7 @@ class Seq2SeqMoEModelOutput(ModelOutput): Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the self-attention heads. decoder_router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, - num_experts)`. + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. Router logits of the decoder model, useful to compute the auxiliary loss for Mixture of Experts models. cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): @@ -487,8 +486,7 @@ class Seq2SeqMoEModelOutput(ModelOutput): Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the self-attention heads. encoder_router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, - num_experts)`. + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. Router logits of the encoder model, useful to compute the auxiliary loss for Mixture of Experts models. """ @@ -768,8 +766,7 @@ class Seq2SeqMoEOutput(ModelOutput): Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the self-attention heads. decoder_router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, - num_experts)`. + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. Router logits of the decoder model, useful to compute the auxiliary loss for Mixture of Experts models. cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): @@ -792,14 +789,17 @@ class Seq2SeqMoEOutput(ModelOutput): Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the self-attention heads. encoder_router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, - num_experts)`. + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. Router logits of the encoder model, useful to compute the auxiliary loss for Mixture of Experts models. """ loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None + encoder_total_z_loss: torch.FloatTensor = None + decoder_total_z_loss: torch.FloatTensor = None + encoder_total_aux_loss: torch.FloatTensor = None + decoder_total_aux_loss: torch.FloatTensor = None past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None diff --git a/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py b/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py index 437818a43f301..a6d8cac3e371f 100644 --- a/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py +++ b/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py @@ -18,17 +18,15 @@ import argparse import re +from flax.traverse_util import flatten_dict, unflatten_dict from t5x import checkpoints from transformers import SwitchTransformersConfig, SwitchTransformersForConditionalGeneration from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model from transformers.utils import logging -from transformers.utils.hub import get_file_from_repo logging.set_verbosity_info() -from flax.traverse_util import flatten_dict, unflatten_dict - MODEL_MAPPING = { "switch_base_8": [ @@ -58,9 +56,6 @@ "switch_base_2048": [ "https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin" ], - "switch_base_8": [ - "https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin" - ], } # should not include what is already done by the `from_pt` argument diff --git a/src/transformers/models/switch_transformers/convert_switch_transformersx_checkpoint_to_flax.py b/src/transformers/models/switch_transformers/convert_switch_transformersx_checkpoint_to_flax.py deleted file mode 100644 index 51011f4a19091..0000000000000 --- a/src/transformers/models/switch_transformers/convert_switch_transformersx_checkpoint_to_flax.py +++ /dev/null @@ -1,288 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Convert SwitchTransformersX checkpoints from the original repository to JAX/FLAX model.""" - -import argparse - -from switch_transformersx import checkpoints -from transformers import FlaxSwitchTransformersForConditionalGeneration, SwitchTransformersConfig - - -def convert_switch_transformersx_checkpoint_to_flax( - switch_transformersx_checkpoint_path, config_name, flax_dump_folder_path -): - config = SwitchTransformersConfig.from_pretrained(config_name) - flax_model = FlaxSwitchTransformersForConditionalGeneration(config=config) - switch_transformersx_model = checkpoints.load_switch_transformersx_checkpoint(switch_transformersx_checkpoint_path) - - split_mlp_wi = "wi_0" in switch_transformersx_model["target"]["encoder"]["layers_0"]["mlp"] - - # Encoder - for layer_index in range(config.num_layers): - layer_name = f"layers_{str(layer_index)}" - - # Self-Attention - switch_transformersx_attention_key = switch_transformersx_model["target"]["encoder"][layer_name]["attention"][ - "key" - ]["kernel"] - switch_transformersx_attention_out = switch_transformersx_model["target"]["encoder"][layer_name]["attention"][ - "out" - ]["kernel"] - switch_transformersx_attention_query = switch_transformersx_model["target"]["encoder"][layer_name][ - "attention" - ]["query"]["kernel"] - switch_transformersx_attention_value = switch_transformersx_model["target"]["encoder"][layer_name][ - "attention" - ]["value"]["kernel"] - - # Layer Normalization - switch_transformersx_attention_layer_norm = switch_transformersx_model["target"]["encoder"][layer_name][ - "pre_attention_layer_norm" - ]["scale"] - - if split_mlp_wi: - switch_transformersx_mlp_wi_0 = switch_transformersx_model["target"]["encoder"][layer_name]["mlp"]["wi_0"][ - "kernel" - ] - switch_transformersx_mlp_wi_1 = switch_transformersx_model["target"]["encoder"][layer_name]["mlp"]["wi_1"][ - "kernel" - ] - else: - switch_transformersx_mlp_wi = switch_transformersx_model["target"]["encoder"][layer_name]["mlp"]["wi"][ - "kernel" - ] - - switch_transformersx_mlp_wo = switch_transformersx_model["target"]["encoder"][layer_name]["mlp"]["wo"][ - "kernel" - ] - - # Layer Normalization - switch_transformersx_mlp_layer_norm = switch_transformersx_model["target"]["encoder"][layer_name][ - "pre_mlp_layer_norm" - ]["scale"] - - # Assigning - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["k"][ - "kernel" - ] = switch_transformersx_attention_key - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["o"][ - "kernel" - ] = switch_transformersx_attention_out - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["q"][ - "kernel" - ] = switch_transformersx_attention_query - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["v"][ - "kernel" - ] = switch_transformersx_attention_value - - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["layer_norm"][ - "weight" - ] = switch_transformersx_attention_layer_norm - - if split_mlp_wi: - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi_0"][ - "kernel" - ] = switch_transformersx_mlp_wi_0 - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi_1"][ - "kernel" - ] = switch_transformersx_mlp_wi_1 - else: - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi"][ - "kernel" - ] = switch_transformersx_mlp_wi - - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wo"][ - "kernel" - ] = switch_transformersx_mlp_wo - flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["layer_norm"][ - "weight" - ] = switch_transformersx_mlp_layer_norm - - # Only for layer 0: - switch_transformersx_encoder_rel_embedding = switch_transformersx_model["target"]["encoder"]["relpos_bias"][ - "rel_embedding" - ].T - flax_model.params["encoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"][ - "embedding" - ] = switch_transformersx_encoder_rel_embedding - - # Assigning - switch_transformersx_encoder_norm = switch_transformersx_model["target"]["encoder"]["encoder_norm"]["scale"] - flax_model.params["encoder"]["final_layer_norm"]["weight"] = switch_transformersx_encoder_norm - - # Decoder - for layer_index in range(config.num_decoder_layers): - layer_name = f"layers_{str(layer_index)}" - - # Self-Attention - switch_transformersx_attention_key = switch_transformersx_model["target"]["decoder"][layer_name][ - "self_attention" - ]["key"]["kernel"] - switch_transformersx_attention_out = switch_transformersx_model["target"]["decoder"][layer_name][ - "self_attention" - ]["out"]["kernel"] - switch_transformersx_attention_query = switch_transformersx_model["target"]["decoder"][layer_name][ - "self_attention" - ]["query"]["kernel"] - switch_transformersx_attention_value = switch_transformersx_model["target"]["decoder"][layer_name][ - "self_attention" - ]["value"]["kernel"] - - # Layer Normalization - switch_transformersx_pre_attention_layer_norm = switch_transformersx_model["target"]["decoder"][layer_name][ - "pre_self_attention_layer_norm" - ]["scale"] - - # Encoder-Decoder-Attention - switch_transformersx_enc_dec_attention_key = switch_transformersx_model["target"]["decoder"][layer_name][ - "encoder_decoder_attention" - ]["key"]["kernel"] - switch_transformersx_enc_dec_attention_out = switch_transformersx_model["target"]["decoder"][layer_name][ - "encoder_decoder_attention" - ]["out"]["kernel"] - switch_transformersx_enc_dec_attention_query = switch_transformersx_model["target"]["decoder"][layer_name][ - "encoder_decoder_attention" - ]["query"]["kernel"] - switch_transformersx_enc_dec_attention_value = switch_transformersx_model["target"]["decoder"][layer_name][ - "encoder_decoder_attention" - ]["value"]["kernel"] - - # Layer Normalization - switch_transformersx_cross_layer_norm = switch_transformersx_model["target"]["decoder"][layer_name][ - "pre_cross_attention_layer_norm" - ]["scale"] - - # MLP - if split_mlp_wi: - switch_transformersx_mlp_wi_0 = switch_transformersx_model["target"]["decoder"][layer_name]["mlp"]["wi_0"][ - "kernel" - ] - switch_transformersx_mlp_wi_1 = switch_transformersx_model["target"]["decoder"][layer_name]["mlp"]["wi_1"][ - "kernel" - ] - else: - switch_transformersx_mlp_wi = switch_transformersx_model["target"]["decoder"][layer_name]["mlp"]["wi"][ - "kernel" - ] - - switch_transformersx_mlp_wo = switch_transformersx_model["target"]["decoder"][layer_name]["mlp"]["wo"][ - "kernel" - ] - - # Layer Normalization - tx5_mlp_layer_norm = switch_transformersx_model["target"]["decoder"][layer_name]["pre_mlp_layer_norm"]["scale"] - - # Assigning - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["k"][ - "kernel" - ] = switch_transformersx_attention_key - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["o"][ - "kernel" - ] = switch_transformersx_attention_out - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["q"][ - "kernel" - ] = switch_transformersx_attention_query - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["v"][ - "kernel" - ] = switch_transformersx_attention_value - - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["layer_norm"][ - "weight" - ] = switch_transformersx_pre_attention_layer_norm - - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["k"][ - "kernel" - ] = switch_transformersx_enc_dec_attention_key - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["o"][ - "kernel" - ] = switch_transformersx_enc_dec_attention_out - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["q"][ - "kernel" - ] = switch_transformersx_enc_dec_attention_query - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["v"][ - "kernel" - ] = switch_transformersx_enc_dec_attention_value - - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["layer_norm"][ - "weight" - ] = switch_transformersx_cross_layer_norm - - if split_mlp_wi: - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi_0"][ - "kernel" - ] = switch_transformersx_mlp_wi_0 - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi_1"][ - "kernel" - ] = switch_transformersx_mlp_wi_1 - else: - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi"][ - "kernel" - ] = switch_transformersx_mlp_wi - - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wo"][ - "kernel" - ] = switch_transformersx_mlp_wo - - flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["layer_norm"][ - "weight" - ] = tx5_mlp_layer_norm - - # Decoder Normalization - tx5_decoder_norm = switch_transformersx_model["target"]["decoder"]["decoder_norm"]["scale"] - flax_model.params["decoder"]["final_layer_norm"]["weight"] = tx5_decoder_norm - - # Only for layer 0: - switch_transformersx_decoder_rel_embedding = switch_transformersx_model["target"]["decoder"]["relpos_bias"][ - "rel_embedding" - ].T - flax_model.params["decoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"][ - "embedding" - ] = switch_transformersx_decoder_rel_embedding - - # Token Embeddings - tx5_token_embeddings = switch_transformersx_model["target"]["token_embedder"]["embedding"] - flax_model.params["shared"]["embedding"] = tx5_token_embeddings - - # LM Head (only in v1.1 checkpoints) - if "logits_dense" in switch_transformersx_model["target"]["decoder"]: - flax_model.params["lm_head"]["kernel"] = switch_transformersx_model["target"]["decoder"]["logits_dense"][ - "kernel" - ] - - flax_model.save_pretrained(flax_dump_folder_path) - print("SwitchTransformersX Model was sucessfully converted!") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # Required parameters - parser.add_argument( - "--switch_transformersx_checkpoint_path", - default=None, - type=str, - required=True, - help="Path the TX5 checkpoint.", - ) - parser.add_argument( - "--config_name", default=None, type=str, required=True, help="Config name of SwitchTransformers model." - ) - parser.add_argument( - "--flax_dump_folder_path", default=None, type=str, required=True, help="Path to the output FLAX model." - ) - args = parser.parse_args() - convert_switch_transformersx_checkpoint_to_flax( - args.switch_transformersx_checkpoint_path, args.config_name, args.flax_dump_folder_path - ) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index feab42150a89d..3a580e703f227 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -18,12 +18,11 @@ import copy import math import warnings -from dataclasses import dataclass, replace -from typing import Any, Optional, Tuple, Union +from dataclasses import dataclass +from typing import Optional, Tuple, Union import torch import torch.nn as nn -from torch import nn from torch.nn import CrossEntropyLoss from torch.utils.checkpoint import checkpoint @@ -167,6 +166,25 @@ def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.T return torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert) * (num_experts**2) +@dataclass +class RouterOutput: + """ + Base class for MoE Routers outputs, with expert indices, together with router probabilities. + + Args: + expert_indices (`torch.LongTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + router_probs (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Raw router probabilities that are computed by MoE routers, these terms are used to compute the auxiliary + loss for Mixture of Experts models. + """ + + expert_indices: torch.LongTensor = None + router_probs: torch.FloatTensor = None + + class TokensChooseMaskedRouter(nn.Module): """ Masked matmul router using tokens choose top-k experts assignment. @@ -281,17 +299,16 @@ def forward(self, token_inputs: torch.Tensor, apply_jitter: bool = True, **kwarg else: padding_mask = None - expert_index, auxiliary_loss = self._compute_routing_instructions(router_probs, padding_mask, **kwargs) - # We cast back the output to the previous dtype - expert_index = expert_index.to(self.input_tokens_dtype) + expert_index = self._compute_routing_instructions(router_probs, padding_mask, **kwargs) + expert_index = torch.nn.functional.one_hot(expert_index, num_classes=self.num_experts) + router_probs = torch.max(router_probs, dim=-1).values.unsqueeze(-1) - router_z_loss = router_z_loss_func(router_logits) - return expert_index, auxiliary_loss, router_z_loss, router_probs + # router_z_loss = router_z_loss_func(router_logits) + return expert_index, router_probs, router_logits def _compute_routing_instructions( self, router_probs: torch.Tensor, - padding_mask: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: """ Computes masks for the top-k experts per token. @@ -300,39 +317,13 @@ def _compute_routing_instructions( router_probs (`torch.Tensor`): Router raw probabilities tensor of shape [num_groups, tokens_per_group, num_experts] used to determine the routing of tokens to the experts. - padding_mask (`torch.Tensor`): - Padding mask of shape [num_groups, tokens_per_group] used to identify padding tokens that should be - ignored by the router. - expert_capacity (`int`): - Each group will send this many tokens to each expert. - Returns: Dispatch and combine arrays for routing with masked matmuls. """ - # Top-k router probability and corresponding expert indices for each token. # Shape: [num_groups, tokens_per_group, num_selected_experts]. expert_index = torch.argmax(router_probs, dim=-1) - - if padding_mask is not None: # TODO test or delete - # Mask applied to gate. Exclude choices corresponding to padding tokens. - gate_mask = padding_mask.unsqueeze(-1).to(expert_index.dtype) - expert_gate *= gate_mask - - # Set `expert_index` elements corresponding to padding to negative - # numbers. Negative `expert_index` elements will ultimately be dropped in - # the one_hot conversion to the `expert_mask`. - # First convert nonzero padding elements to negative values. - expert_index *= (2 * gate_mask) - 1 - # Handle zero padding elements by negatively shifting all padding. - expert_index += (gate_mask - 1).repeat(1, 1, self.num_selected_experts) - - # To correctly compute load balancing loss, we also mask out probs. - router_probs *= gate_mask - - auxiliary_loss = load_balancing_loss_func(router_probs, expert_index) - expert_index = torch.nn.functional.one_hot(expert_index, self.num_experts) - return expert_index, auxiliary_loss + return expert_index # Copied from transformers.models.t5.modeling_t5.T5LayerNorm with T5->SwitchTransformers @@ -404,8 +395,7 @@ def forward(self, hidden_states): class SwitchTransformersLayerFF(nn.Module): r""" - Switch Transformers Feed Forward layer module. This is a wrapper around the Mixture of Experts - module. + Switch Transformers Feed Forward layer module. This is a wrapper around the Mixture of Experts module. Attributes: is_sparse (`bool`): @@ -436,22 +426,21 @@ def forward(self, hidden_states, output_router_logits): forwarded_states = self.mlp(forwarded_states) if isinstance(forwarded_states, tuple): - forwarded_states, router_logits = forwarded_states + forwarded_states, router_tuple = forwarded_states else: - router_logits = None + router_tuple = None output = hidden_states + self.dropout(forwarded_states) - if output_router_logits and router_logits is not None: - output = (output, router_logits) + if output_router_logits and router_tuple is not None: + output = (output, router_tuple) return output class SwitchTransformersSparseMLP(nn.Module): r""" - Implementation of the Switch Transformers Sparse MLP module. - TODO: Add a LOT of details here + Implementation of the Switch Transformers Sparse MLP module. TODO: Add a LOT of details here """ def __init__(self, config: SwitchTransformersConfig, expert_class: nn.Module = SwitchTransformersDenseActDense): @@ -485,20 +474,20 @@ def _get_router(self, config): def forward(self, hidden_states): r""" - Hold on, this will be slightly tricky to understand - In the correct order, a MoE layer does the following: + Hold on, this will be slightly tricky to understand In the correct order, a MoE layer does the following: - 1- Gets the `router_mask` from the router. This mask will contain the indices of the - routed tokens. Also retrieve the probabilities (max prob) for each token. The probabilities are - needed in the computation of the hidden states since the probabilities will be broadcasted - to the hidden states values (they can be interpreted as a scaling factor). + 1- Gets the `router_mask` from the router. This mask will contain the indices of the routed tokens. Also + retrieve the probabilities (max prob) for each token. The probabilities are needed in the computation of the + hidden states since the probabilities will be broadcasted to the hidden states values (they can be interpreted + as a scaling factor). 2- TODO: explain @ArthurZucker """ # Step 1: Get the router_mask from the router as wel as the probabilities - router_mask, auxiliary_loss, router_z_loss, router_probs = self.router(hidden_states) + router_mask, router_probs, router_logits = self.router(hidden_states) + expert_index = torch.argmax(router_mask, dim=-1) for idx, expert in enumerate(self.experts.values()): @@ -510,7 +499,7 @@ def forward(self, hidden_states): hidden_states[token_indices] = expert(hidden_states[token_indices]) hidden_states = router_probs * hidden_states - return hidden_states, router_probs + return hidden_states, (router_logits, expert_index) # Copied from transformers.models.t5.modeling_t5.T5Attention with T5->SwitchTransformers @@ -917,9 +906,9 @@ def forward( hidden_states = self.layer[-1](hidden_states, output_router_logits) if isinstance(hidden_states, tuple): - hidden_states, router_probs = hidden_states + hidden_states, router_tuple = hidden_states else: - router_probs = (None,) + router_tuple = (None,) # clamp inf values to enable fp16 training if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): @@ -933,7 +922,7 @@ def forward( else: outputs = outputs + attention_outputs - outputs = outputs + (router_probs,) + outputs = outputs + (router_tuple,) return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) @@ -1404,8 +1393,8 @@ def custom_forward(*inputs): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. output_router_logits (`bool`, *optional*): - Whether or not to return the logits of all the routers. They are useful for computing the router loss, - and should not be returned during inference. + Whether or not to return the logits of all the routers. They are useful for computing the router loss, and + should not be returned during inference. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -1774,6 +1763,7 @@ def forward( last_hidden_state=encoder_outputs[0], hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + router_probs=encoder_outputs[3] if len(encoder_outputs) > 3 else None, ) hidden_states = encoder_outputs[0] @@ -1829,23 +1819,48 @@ def forward( lm_logits = self.lm_head(sequence_output) loss = None + total_encoder_z_loss = None + total_encoder_aux_loss = None + total_decoder_z_loss = None + total_decoder_aux_loss = None + if labels is not None: loss_fct = CrossEntropyLoss(ignore_index=-100) # todo check in the config if router loss enables if output_router_logits: - router_z_loss = router_z_loss_func(encoder_outputs.router_probs) - decoder_outputs.router_probs + # Calculate the router loss (z_loss + auxiliary loss) for each router in the encoder + total_encoder_z_loss = 0 + total_encoder_aux_loss = 0 + for router_tuple in encoder_outputs.router_probs: + if router_tuple[0] is not None: + total_encoder_z_loss += router_z_loss_func(router_tuple[0]) + total_encoder_aux_loss += load_balancing_loss_func(router_tuple[0], router_tuple[1]) + + total_decoder_z_loss = 0 + total_decoder_aux_loss = 0 + for router_tuple in decoder_outputs.router_probs: + if router_tuple[0] is not None: + total_decoder_z_loss += router_z_loss_func(router_tuple[0]) + total_decoder_aux_loss += load_balancing_loss_func(router_tuple[0], router_tuple[1]) loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 if not return_dict: - output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs + output = ( + (lm_logits, total_encoder_z_loss, total_encoder_aux_loss, total_decoder_z_loss, total_decoder_aux_loss) + + decoder_outputs[1:] + + encoder_outputs + ) return ((loss,) + output) if loss is not None else output return Seq2SeqMoEOutput( loss=loss, logits=lm_logits, + encoder_total_z_loss=total_encoder_z_loss, + encoder_total_aux_loss=total_encoder_aux_loss, + decoder_total_z_loss=total_decoder_z_loss, + decoder_total_aux_loss=total_decoder_aux_loss, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, From afb3d376074615ba9de6aadf3f6d21443e6718b1 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 26 Oct 2022 15:38:32 +0200 Subject: [PATCH 037/102] remove `FlaxSwitchTransformers` modules --- docs/source/en/index.mdx | 2 +- .../en/model_doc/switch_transformers.mdx | 20 - src/transformers/__init__.py | 14 - .../models/auto/modeling_flax_auto.py | 3 - .../models/switch_transformers/__init__.py | 27 - .../configuration_switch_transformers.py | 10 +- .../modeling_flax_switch_transformers.py | 1838 ----------------- src/transformers/utils/dummy_flax_objects.py | 28 - .../test_modeling_flax_switch_transformers.py | 1108 ---------- 9 files changed, 6 insertions(+), 3044 deletions(-) delete mode 100644 src/transformers/models/switch_transformers/modeling_flax_switch_transformers.py delete mode 100644 tests/models/switch_transformers/test_modeling_flax_switch_transformers.py diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index 2a74d4a5c1f57..1712dd9f11b6f 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -313,7 +313,7 @@ Flax), PyTorch, and/or TensorFlow. | SqueezeBERT | ✅ | ✅ | ✅ | ❌ | ❌ | | Swin Transformer | ❌ | ❌ | ✅ | ✅ | ❌ | | Swin Transformer V2 | ❌ | ❌ | ✅ | ❌ | ❌ | -| SwitchTransformers | ✅ | ✅ | ✅ | ❌ | ✅ | +| SwitchTransformers | ✅ | ✅ | ✅ | ❌ | ❌ | | T5 | ✅ | ✅ | ✅ | ✅ | ✅ | | TAPAS | ✅ | ❌ | ✅ | ✅ | ❌ | | Time Series Transformer | ❌ | ❌ | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/switch_transformers.mdx b/docs/source/en/model_doc/switch_transformers.mdx index e602041cb52eb..64979a94ba86e 100644 --- a/docs/source/en/model_doc/switch_transformers.mdx +++ b/docs/source/en/model_doc/switch_transformers.mdx @@ -65,23 +65,3 @@ The original code can be found [here](). - forward - parallelize - deparallelize - - -## FlaxSwitchTransformersModel - -[[autodoc]] FlaxSwitchTransformersModel - - __call__ - - encode - - decode - -## FlaxSwitchTransformersForConditionalGeneration - -[[autodoc]] FlaxSwitchTransformersForConditionalGeneration - - __call__ - - encode - - decode - -## FlaxSwitchTransformersEncoderModel - -[[autodoc]] FlaxSwitchTransformersEncoderModel - - __call__ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 38ceb39cb348c..8f438933c2cfa 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -3074,14 +3074,6 @@ _import_structure["models.t5"].extend( ["FlaxT5EncoderModel", "FlaxT5ForConditionalGeneration", "FlaxT5Model", "FlaxT5PreTrainedModel"] ) - _import_structure["models.switch_transformers"].extend( - [ - "FlaxSwitchTransformersEncoderModel", - "FlaxSwitchTransformersForConditionalGeneration", - "FlaxSwitchTransformersModel", - "FlaxSwitchTransformersPreTrainedModel", - ] - ) _import_structure["models.vision_encoder_decoder"].append("FlaxVisionEncoderDecoderModel") _import_structure["models.vision_text_dual_encoder"].extend(["FlaxVisionTextDualEncoderModel"]) _import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel", "FlaxViTPreTrainedModel"]) @@ -5605,12 +5597,6 @@ FlaxRoFormerPreTrainedModel, ) from .models.speech_encoder_decoder import FlaxSpeechEncoderDecoderModel - from .models.switch_transformers import ( - FlaxSwitchTransformersEncoderModel, - FlaxSwitchTransformersForConditionalGeneration, - FlaxSwitchTransformersModel, - FlaxSwitchTransformersPreTrainedModel, - ) from .models.t5 import FlaxT5EncoderModel, FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel from .models.vision_encoder_decoder import FlaxVisionEncoderDecoderModel from .models.vision_text_dual_encoder import FlaxVisionTextDualEncoderModel diff --git a/src/transformers/models/auto/modeling_flax_auto.py b/src/transformers/models/auto/modeling_flax_auto.py index ea720ca6b0732..98c5d6fb5a104 100644 --- a/src/transformers/models/auto/modeling_flax_auto.py +++ b/src/transformers/models/auto/modeling_flax_auto.py @@ -49,7 +49,6 @@ ("pegasus", "FlaxPegasusModel"), ("roberta", "FlaxRobertaModel"), ("roformer", "FlaxRoFormerModel"), - ("switch_transformers", "FlaxSwitchTransformersModel"), ("t5", "FlaxT5Model"), ("vision-text-dual-encoder", "FlaxVisionTextDualEncoderModel"), ("vit", "FlaxViTModel"), @@ -72,7 +71,6 @@ ("mt5", "FlaxMT5ForConditionalGeneration"), ("roberta", "FlaxRobertaForMaskedLM"), ("roformer", "FlaxRoFormerForMaskedLM"), - ("switch_transformers", "FlaxSwitchTransformersForConditionalGeneration"), ("t5", "FlaxT5ForConditionalGeneration"), ("wav2vec2", "FlaxWav2Vec2ForPreTraining"), ("xlm-roberta", "FlaxXLMRobertaForMaskedLM"), @@ -107,7 +105,6 @@ ("mbart", "FlaxMBartForConditionalGeneration"), ("mt5", "FlaxMT5ForConditionalGeneration"), ("pegasus", "FlaxPegasusForConditionalGeneration"), - ("switch_transformers", "FlaxSwitchTransformersForConditionalGeneration"), ("t5", "FlaxT5ForConditionalGeneration"), ] ) diff --git a/src/transformers/models/switch_transformers/__init__.py b/src/transformers/models/switch_transformers/__init__.py index 5f5ee32c89f9e..e6fc32117cad9 100644 --- a/src/transformers/models/switch_transformers/__init__.py +++ b/src/transformers/models/switch_transformers/__init__.py @@ -68,20 +68,6 @@ ] -try: - if not is_flax_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - pass -else: - _import_structure["modeling_flax_switch_transformers"] = [ - "FlaxSwitchTransformersEncoderModel", - "FlaxSwitchTransformersForConditionalGeneration", - "FlaxSwitchTransformersModel", - "FlaxSwitchTransformersPreTrainedModel", - ] - - if TYPE_CHECKING: from .configuration_switch_transformers import ( SWITCH_TRANSFORMERS_PRETRAINED_CONFIG_ARCHIVE_MAP, @@ -119,19 +105,6 @@ SwitchTransformersPreTrainedModel, ) - try: - if not is_flax_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - pass - else: - from .modeling_flax_switch_transformers import ( - FlaxSwitchTransformersEncoderModel, - FlaxSwitchTransformersForConditionalGeneration, - FlaxSwitchTransformersModel, - FlaxSwitchTransformersPreTrainedModel, - ) - else: import sys diff --git a/src/transformers/models/switch_transformers/configuration_switch_transformers.py b/src/transformers/models/switch_transformers/configuration_switch_transformers.py index a2a9f5038df61..2848154da3f75 100644 --- a/src/transformers/models/switch_transformers/configuration_switch_transformers.py +++ b/src/transformers/models/switch_transformers/configuration_switch_transformers.py @@ -31,11 +31,11 @@ class SwitchTransformersConfig(PretrainedConfig): r""" - This is the configuration class to store the configuration of a [`SwitchTransformersModel`] or a - [`FlaxSwitchTransformersModel`]. It is used to instantiate a SwitchTransformers model according to the specified - arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar - configuration to that of the SwitchTransformers - [ybelkada/switch_transformers-base](https://huggingface.co/ybelkada/switch_transformers-base) architecture. + This is the configuration class to store the configuration of a [`SwitchTransformersModel`]. It is used to + instantiate a SwitchTransformers model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar configuration to that of the + SwitchTransformers [ybelkada/switch_transformers-base](https://huggingface.co/ybelkada/switch_transformers-base) + architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. diff --git a/src/transformers/models/switch_transformers/modeling_flax_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_flax_switch_transformers.py deleted file mode 100644 index c43086021bfb6..0000000000000 --- a/src/transformers/models/switch_transformers/modeling_flax_switch_transformers.py +++ /dev/null @@ -1,1838 +0,0 @@ -# coding=utf-8 -# Copyright 2022 SwitchTransformers Authors and HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" Flax SwitchTransformers model.""" - - -import copy -from typing import Callable, Optional, Tuple - -import numpy as np - -import flax.linen as nn -import jax -import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict, freeze, unfreeze -from flax.linen import combine_masks, make_causal_mask -from flax.linen import partitioning as nn_partitioning -from flax.linen.attention import dot_product_attention_weights -from flax.traverse_util import flatten_dict, unflatten_dict -from jax.random import PRNGKey - -from ...modeling_flax_outputs import ( - FlaxBaseModelOutput, - FlaxBaseModelOutputWithPastAndCrossAttentions, - FlaxCausalLMOutputWithCrossAttentions, - FlaxSeq2SeqLMOutput, - FlaxSeq2SeqModelOutput, -) -from ...modeling_flax_utils import ( - ACT2FN, - FlaxPreTrainedModel, - append_call_sample_docstring, - append_replace_return_docstrings, - overwrite_call_docstring, -) -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings -from .configuration_switch_transformers import SwitchTransformersConfig - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "ybelkada/switch_transformers-base" -_CONFIG_FOR_DOC = "SwitchTransformersConfig" -_TOKENIZER_FOR_DOC = "SwitchTransformersTokenizer" - -remat = nn_partitioning.remat - - -# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right -def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray: - """ - Shift input ids one token to the right. - """ - shifted_input_ids = np.zeros_like(input_ids) - shifted_input_ids[:, 1:] = input_ids[:, :-1] - shifted_input_ids[:, 0] = decoder_start_token_id - - shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) - return shifted_input_ids - - -# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5LayerNorm with T5->SwitchTransformers -class FlaxSwitchTransformersLayerNorm(nn.Module): - hidden_size: int - dtype: jnp.dtype = jnp.float32 - eps: float = 1e-6 - weight_init: Callable[..., np.ndarray] = jax.nn.initializers.ones - - def setup(self): - self.weight = self.param("weight", self.weight_init, (self.hidden_size,)) - - def __call__(self, hidden_states): - """ - Construct a layernorm module in the SwitchTransformers style; No bias and no subtraction of mean. - """ - # layer norm should always be calculated in float32 - variance = jnp.power(hidden_states.astype("f4"), 2).mean(axis=-1, keepdims=True) - hidden_states = hidden_states / jnp.sqrt(variance + self.eps) - - return self.weight * hidden_states - - -# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5DenseActDense with T5->SwitchTransformers -class FlaxSwitchTransformersDenseActDense(nn.Module): - config: SwitchTransformersConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5) - wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5) - - self.wi = nn.Dense( - self.config.d_ff, - use_bias=False, - kernel_init=jax.nn.initializers.normal(wi_init_std), - dtype=self.dtype, - ) - self.wo = nn.Dense( - self.config.d_model, - use_bias=False, - kernel_init=jax.nn.initializers.normal(wo_init_std), - dtype=self.dtype, - ) - self.dropout = nn.Dropout(self.config.dropout_rate) - self.act = ACT2FN[self.config.dense_act_fn] - - def __call__(self, hidden_states, deterministic=True): - hidden_states = self.wi(hidden_states) - hidden_states = self.act(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - hidden_states = self.wo(hidden_states) - return hidden_states - - -# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5DenseGatedActDense with T5->SwitchTransformers -class FlaxSwitchTransformersDenseGatedActDense(nn.Module): - config: SwitchTransformersConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5) - wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5) - - self.wi_0 = nn.Dense( - self.config.d_ff, - use_bias=False, - kernel_init=jax.nn.initializers.normal(wi_init_std), - dtype=self.dtype, - ) - self.wi_1 = nn.Dense( - self.config.d_ff, - use_bias=False, - kernel_init=jax.nn.initializers.normal(wi_init_std), - dtype=self.dtype, - ) - self.wo = nn.Dense( - self.config.d_model, - use_bias=False, - kernel_init=jax.nn.initializers.normal(wo_init_std), - dtype=self.dtype, - ) - self.dropout = nn.Dropout(self.config.dropout_rate) - self.act = ACT2FN[self.config.dense_act_fn] - - def __call__(self, hidden_states, deterministic): - hidden_gelu = self.act(self.wi_0(hidden_states)) - hidden_linear = self.wi_1(hidden_states) - hidden_states = hidden_gelu * hidden_linear - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - hidden_states = self.wo(hidden_states) - return hidden_states - - -# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5LayerFF with T5->SwitchTransformers -class FlaxSwitchTransformersLayerFF(nn.Module): - config: SwitchTransformersConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - if self.config.is_gated_act: - self.DenseReluDense = FlaxSwitchTransformersDenseGatedActDense(self.config, dtype=self.dtype) - else: - self.DenseReluDense = FlaxSwitchTransformersDenseActDense(self.config, dtype=self.dtype) - - self.layer_norm = FlaxSwitchTransformersLayerNorm( - self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype - ) - self.dropout = nn.Dropout(self.config.dropout_rate) - - def __call__(self, hidden_states, deterministic=True): - forwarded_states = self.layer_norm(hidden_states) - forwarded_states = self.DenseReluDense(forwarded_states, deterministic=deterministic) - hidden_states = hidden_states + self.dropout(forwarded_states, deterministic=deterministic) - return hidden_states - - -# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5Attention with T5->SwitchTransformers -class FlaxSwitchTransformersAttention(nn.Module): - config: SwitchTransformersConfig - has_relative_attention_bias: bool = False - causal: bool = False - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.relative_attention_num_buckets = self.config.relative_attention_num_buckets - self.relative_attention_max_distance = self.config.relative_attention_max_distance - self.d_model = self.config.d_model - self.key_value_proj_dim = self.config.d_kv - self.n_heads = self.config.num_heads - self.dropout = self.config.dropout_rate - self.inner_dim = self.n_heads * self.key_value_proj_dim - - q_init_std = self.config.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5) - kv_init_std = self.config.initializer_factor * (self.inner_dim**-0.5) - o_init_std = self.config.initializer_factor * (self.inner_dim**-0.5) - - self.q = nn.Dense( - self.inner_dim, - use_bias=False, - kernel_init=jax.nn.initializers.normal(q_init_std), - dtype=self.dtype, - ) - self.k = nn.Dense( - self.inner_dim, - use_bias=False, - kernel_init=jax.nn.initializers.normal(kv_init_std), - dtype=self.dtype, - ) - self.v = nn.Dense( - self.inner_dim, - use_bias=False, - kernel_init=jax.nn.initializers.normal(kv_init_std), - dtype=self.dtype, - ) - self.o = nn.Dense( - self.d_model, - use_bias=False, - kernel_init=jax.nn.initializers.normal(o_init_std), - dtype=self.dtype, - ) - - if self.has_relative_attention_bias: - self.relative_attention_bias = nn.Embed( - self.relative_attention_num_buckets, - self.n_heads, - embedding_init=jax.nn.initializers.normal(kv_init_std), - ) - - @staticmethod - def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): - """ - Adapted from Mesh Tensorflow: - https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 - - Translate relative position to a bucket number for relative attention. The relative position is defined as - memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to - position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for - small absolute relative_position and larger buckets for larger absolute relative_positions. All relative - positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. - This should allow for more graceful generalization to longer sequences than the model has been trained on - """ - relative_buckets = 0 - if bidirectional: - num_buckets //= 2 - relative_buckets += (relative_position > 0) * num_buckets - relative_position = jnp.abs(relative_position) - else: - relative_position = -jnp.clip(relative_position, a_max=0) - # now relative_position is in the range [0, inf) - - # half of the buckets are for exact increments in positions - max_exact = num_buckets // 2 - is_small = relative_position < max_exact - - # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance - relative_position_if_large = max_exact + ( - jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact) - ) - relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1) - - relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large) - - return relative_buckets.astype("i4") - - def compute_bias(self, query_length, key_length): - """Compute binned relative position bias""" - context_position = jnp.arange(query_length, dtype="i4")[:, None] - memory_position = jnp.arange(key_length, dtype="i4")[None, :] - - relative_position = memory_position - context_position - relative_position_bucket = self._relative_position_bucket( - relative_position, - bidirectional=(not self.causal), - num_buckets=self.relative_attention_num_buckets, - max_distance=self.relative_attention_max_distance, - ) - - values = self.relative_attention_bias(relative_position_bucket) - values = values.transpose((2, 0, 1))[None, :, :, :] - return values - - def _split_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.n_heads, self.key_value_proj_dim)) - - def _merge_heads(self, hidden_states): - return hidden_states.reshape(hidden_states.shape[:2] + (self.inner_dim,)) - - @nn.compact - def _concatenate_to_cache(self, key, value, query, attention_mask): - """ - This function takes projected key, value states from a single input token and concatenates the states to cached - states from previous steps. This function is slighly adapted from the official Flax repository: - https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 - """ - # detect if we're initializing by absence of existing cache data. - is_initialized = self.has_variable("cache", "cached_key") - cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) - cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) - cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) - - if is_initialized: - *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape - # update key, value caches with our new 1d spatial slices - cur_index = cache_index.value - indices = (0,) * len(batch_dims) + (cur_index, 0, 0) - key = jax.lax.dynamic_update_slice(cached_key.value, key, indices) - value = jax.lax.dynamic_update_slice(cached_value.value, value, indices) - cached_key.value = key - cached_value.value = value - num_updated_cache_vectors = query.shape[1] - cache_index.value = cache_index.value + num_updated_cache_vectors - # causal mask for cached decoder self-attention: our single query position should only attend to those key positions - # that have already been generated and cached, not the remaining zero elements. - pad_mask = jnp.broadcast_to( - jnp.arange(max_length) < cur_index + num_updated_cache_vectors, - tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), - ) - attention_mask = combine_masks(pad_mask, attention_mask) - return key, value, attention_mask - - def _create_position_bias( - self, key_states, query_states, attention_mask, init_cache, seq_length, causal_attention_mask_shift - ): - cache_is_filled = self.causal and self.has_variable("cache", "cached_key") and (not init_cache) - key_length = key_states.shape[1] - query_length = key_length if cache_is_filled else query_states.shape[1] - - if self.has_relative_attention_bias: - position_bias = self.compute_bias(query_length, key_length) - elif attention_mask is not None: - position_bias = jnp.zeros_like(attention_mask) - else: - position_bias = jnp.zeros((1, self.n_heads, query_length, key_length), dtype=self.dtype) - - # if key and values are already calculated, only the last query position bias should be taken - if cache_is_filled: - max_decoder_length = self.variables["cache"]["cached_key"].shape[1] - position_bias = jax.lax.dynamic_slice( - position_bias, - (0, 0, causal_attention_mask_shift, 0), - (1, self.n_heads, seq_length, max_decoder_length), - ) - return position_bias - - def __call__( - self, - hidden_states, - attention_mask=None, - key_value_states=None, - position_bias=None, - use_cache=False, - output_attentions=False, - deterministic=True, - init_cache=False, - ): - """ - Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). - """ - batch_size, seq_length = hidden_states.shape[:2] - - # q, k, v projections - query_states = self.q(hidden_states) # (batch_size, n_heads, seq_length, dim_per_head) - key_states = self.k(hidden_states) if key_value_states is None else self.k(key_value_states) - value_states = self.v(hidden_states) if key_value_states is None else self.v(key_value_states) - - # reshape to (batch_size, seq_length, n_heads, head_dim) - query_states = self._split_heads(query_states) - key_states = self._split_heads(key_states) - value_states = self._split_heads(value_states) - - # counter-act scaling in dot_product_attention_weights function - query_states *= jnp.sqrt(query_states.shape[-1]) - - # for fast decoding causal attention mask should be shifted - causal_attention_mask_shift = ( - self.variables["cache"]["cache_index"] if (self.has_variable("cache", "cached_key") and self.causal) else 0 - ) - # create causal attention_mask; attention_mask has to be defined when model is causal - if self.causal: - causal_attention_mask = make_causal_mask(attention_mask, dtype="bool") - - # fast decoding for generate requires special attention_mask - if self.has_variable("cache", "cached_key"): - max_decoder_length = self.variables["cache"]["cached_key"].shape[1] - causal_attention_mask = jax.lax.dynamic_slice( - causal_attention_mask, - (0, 0, causal_attention_mask_shift, 0), - (1, 1, seq_length, max_decoder_length), - ) - - # broadcast causal attention mask & attention mask to fit for merge - causal_attention_mask = jnp.broadcast_to( - causal_attention_mask, (batch_size,) + causal_attention_mask.shape[1:] - ) - attention_mask = jnp.broadcast_to( - jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_attention_mask.shape - ) - attention_mask = combine_masks(attention_mask, causal_attention_mask) - elif attention_mask is not None: - attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) - - # During fast autoregressive decoding, we feed one position at a time, - # and cache the keys and values step by step. - if self.causal and (self.has_variable("cache", "cached_key") or init_cache): - key_states, value_states, attention_attention_mask = self._concatenate_to_cache( - key_states, value_states, query_states, attention_mask - ) - - # replace masked positions with -10_000 - if attention_mask is not None: - mask_value = jnp.finfo(self.dtype).min - attention_mask = jax.lax.select( - attention_mask > 0, - jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, mask_value).astype(self.dtype), - ) - - if position_bias is None: - # compute position bias (only for first layer) - position_bias = self._create_position_bias( - key_states, query_states, attention_mask, init_cache, seq_length, causal_attention_mask_shift - ) - - if attention_mask is not None: - position_bias = position_bias + attention_mask - - # create dropout rng - dropout_rng = None - if not deterministic and self.dropout > 0.0: - dropout_rng = self.make_rng("dropout") - - # Softmax(QK^T) - attn_weights = dot_product_attention_weights( - query_states, - key_states, - bias=position_bias, - dropout_rng=dropout_rng, - dropout_rate=self.dropout, - broadcast_dropout=True, - deterministic=deterministic, - dtype=self.dtype, - ) - - # multiply with value states - attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) - - # bring back to (batch_size, seq_length, d_model) - attn_output = self._merge_heads(attn_output) - - # apply output matrix - attn_output = self.o(attn_output) - - outputs = (attn_output, position_bias) - - if output_attentions: - outputs = outputs + (attn_weights,) - - return outputs - - -# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5LayerSelfAttention with T5->SwitchTransformers -class FlaxSwitchTransformersLayerSelfAttention(nn.Module): - config: SwitchTransformersConfig - has_relative_attention_bias: bool = False - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.SelfAttention = FlaxSwitchTransformersAttention( - self.config, - has_relative_attention_bias=self.has_relative_attention_bias, - causal=self.config.causal, - dtype=self.dtype, - ) - self.layer_norm = FlaxSwitchTransformersLayerNorm( - self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype - ) - self.dropout = nn.Dropout(self.config.dropout_rate) - - def __call__( - self, - hidden_states, - attention_mask=None, - position_bias=None, - output_attentions=False, - deterministic=True, - init_cache=False, - ): - normed_hidden_states = self.layer_norm(hidden_states) - attention_output = self.SelfAttention( - normed_hidden_states, - attention_mask=attention_mask, - position_bias=position_bias, - output_attentions=output_attentions, - deterministic=deterministic, - init_cache=init_cache, - ) - hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic) - outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them - return outputs - - -# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5LayerCrossAttention with T5->SwitchTransformers -class FlaxSwitchTransformersLayerCrossAttention(nn.Module): - config: SwitchTransformersConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.EncDecAttention = FlaxSwitchTransformersAttention( - self.config, has_relative_attention_bias=False, causal=False, dtype=self.dtype - ) - self.layer_norm = FlaxSwitchTransformersLayerNorm( - self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype - ) - self.dropout = nn.Dropout(self.config.dropout_rate) - - def __call__( - self, - hidden_states, - key_value_states, - attention_mask=None, - position_bias=None, - output_attentions=False, - deterministic=True, - ): - normed_hidden_states = self.layer_norm(hidden_states) - attention_output = self.EncDecAttention( - normed_hidden_states, - attention_mask=attention_mask, - key_value_states=key_value_states, - position_bias=position_bias, - output_attentions=output_attentions, - ) - hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic) - outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them - return outputs - - -# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5Block with T5->SwitchTransformers -class FlaxSwitchTransformersBlock(nn.Module): - config: SwitchTransformersConfig - has_relative_attention_bias: bool = False - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.causal = self.config.causal - self.layer = ( - FlaxSwitchTransformersLayerSelfAttention( - self.config, - has_relative_attention_bias=self.has_relative_attention_bias, - name=str(0), - dtype=self.dtype, - ), - ) - feed_forward_index = 1 - if self.causal: - self.layer += (FlaxSwitchTransformersLayerCrossAttention(self.config, name=str(1), dtype=self.dtype),) - feed_forward_index += 1 - - self.layer += (FlaxSwitchTransformersLayerFF(self.config, name=str(feed_forward_index), dtype=self.dtype),) - - def __call__( - self, - hidden_states, - attention_mask=None, - position_bias=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - encoder_decoder_position_bias=None, - output_attentions=False, - return_dict=True, - deterministic=True, - init_cache=False, - ): - self_attention_outputs = self.layer[0]( - hidden_states, - attention_mask=attention_mask, - position_bias=position_bias, - output_attentions=output_attentions, - deterministic=deterministic, - init_cache=init_cache, - ) - hidden_states = self_attention_outputs[0] - attention_outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights - - do_cross_attention = self.causal and encoder_hidden_states is not None - if do_cross_attention: - cross_attention_outputs = self.layer[1]( - hidden_states, - key_value_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - position_bias=encoder_decoder_position_bias, - output_attentions=output_attentions, - deterministic=deterministic, - ) - hidden_states = cross_attention_outputs[0] - - # Keep cross-attention outputs and relative position weights - attention_outputs = attention_outputs + cross_attention_outputs[1:] - - # Apply Feed Forward layer - hidden_states = self.layer[-1](hidden_states, deterministic=deterministic) - - outputs = (hidden_states,) - - outputs = outputs + attention_outputs - - # returns hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), - # (cross-attention position bias), (cross-attention weights) - return outputs - - -# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5LayerCollection with T5->SwitchTransformers -class FlaxSwitchTransformersLayerCollection(nn.Module): - config: SwitchTransformersConfig - has_relative_attention_bias: bool - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - self.layer = FlaxSwitchTransformersBlock( - self.config, has_relative_attention_bias=self.has_relative_attention_bias, dtype=self.dtype - ) - - def __call__( - self, - hidden_states, - attention_mask=None, - position_bias=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - encoder_decoder_position_bias=None, - output_attentions=False, - deterministic=True, - init_cache=False, - ): - return self.layer( - hidden_states, - attention_mask=attention_mask, - position_bias=position_bias, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - encoder_decoder_position_bias=encoder_decoder_position_bias, - output_attentions=output_attentions, - deterministic=deterministic, - init_cache=init_cache, - ) - - -# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5BlockCollection with T5->SwitchTransformers -class FlaxSwitchTransformersBlockCollection(nn.Module): - config: SwitchTransformersConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - gradient_checkpointing: bool = False - - def setup(self): - self.causal = self.config.causal - if self.gradient_checkpointing: - FlaxSwitchTransformersCheckpointLayer = remat( - FlaxSwitchTransformersLayerCollection, static_argnums=(6, 7, 8) - ) - self.blocks = [ - FlaxSwitchTransformersCheckpointLayer( - self.config, - has_relative_attention_bias=(i == 0), - dtype=self.dtype, - name=str(i), - ) - for i in range(self.config.num_layers) - ] - else: - self.blocks = [ - FlaxSwitchTransformersLayerCollection( - self.config, - has_relative_attention_bias=(i == 0), - dtype=self.dtype, - name=str(i), - ) - for i in range(self.config.num_layers) - ] - - def __call__( - self, - hidden_states=None, - attention_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - output_attentions: bool = False, - output_hidden_states: bool = False, - deterministic: bool = True, - init_cache: bool = False, - ): - # Prepare head mask if needed - all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - all_cross_attentions = () if (output_attentions and self.causal) else None - position_bias = None - encoder_decoder_position_bias = None - - for i, layer_module in enumerate(self.blocks): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer_outputs = layer_module( - hidden_states, - attention_mask, - position_bias, - encoder_hidden_states, - encoder_attention_mask, - encoder_decoder_position_bias, - output_attentions, - deterministic, - init_cache, - ) - - hidden_states = layer_outputs[0] - - # We share the position biases between the layers - the first layer store them - # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), - # (cross-attention position bias), (cross-attention weights) - position_bias = layer_outputs[1] - - if self.causal and encoder_hidden_states is not None: - encoder_decoder_position_bias = layer_outputs[3 if output_attentions else 2] - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[2],) - if self.causal: - all_cross_attentions = all_cross_attentions + (layer_outputs[4],) - - return FlaxBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_attentions, - cross_attentions=all_cross_attentions, - ) - - -# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5Stack with T5->SwitchTransformers -class FlaxSwitchTransformersStack(nn.Module): - config: SwitchTransformersConfig - embed_tokens: nn.Embed - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - gradient_checkpointing: bool = False - - def setup(self): - self.causal = self.config.causal - - self.block = FlaxSwitchTransformersBlockCollection( - self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing - ) - self.final_layer_norm = FlaxSwitchTransformersLayerNorm( - self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype - ) - self.dropout = nn.Dropout(self.config.dropout_rate) - - def __call__( - self, - input_ids=None, - attention_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - deterministic: bool = True, - init_cache: bool = False, - ): - hidden_states = self.embed_tokens(input_ids) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - - outputs = self.block( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - deterministic=deterministic, - init_cache=init_cache, - ) - - hidden_states = outputs[0] - - hidden_states = self.final_layer_norm(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic=deterministic) - - # Add last layer - all_hidden_states = None - - if output_hidden_states: - all_hidden_states = outputs.hidden_states - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - if output_hidden_states: - return ( - hidden_states, - all_hidden_states, - ) + outputs[2:] - return (hidden_states,) + outputs[1:] - - return FlaxBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - - -SWITCH_TRANSFORMERS_ENCODE_INPUTS_DOCSTRING = r""" - Args: - input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. SWITCH_TRANSFORMERS is a model with relative position - embeddings so you should be able to pad the inputs on both the right and the left. - - Indices can be obtained using [`SwitchTransformersTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for detail. - - To know more on how to prepare `input_ids` for pretraining take a look a [SWITCH_TRANSFORMERS - Training](./switch_transformers#training). - attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - -SWITCH_TRANSFORMERS_DECODE_INPUTS_DOCSTRING = r""" - Args: - decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`): - Indices of decoder input sequence tokens in the vocabulary. - - Indices can be obtained using [`SwitchTransformersTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are decoder input IDs?](../glossary#decoder-input-ids) - - For training, `decoder_input_ids` should be provided. - encoder_outputs (`tuple(tuple(jnp.ndarray)`): - Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) - `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of - hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. - encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): - Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also - be used by default. - - If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the - paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. - past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): - Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast - auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -SWITCH_TRANSFORMERS_INPUTS_DOCSTRING = r""" - Args: - input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. SWITCH_TRANSFORMERS is a model with relative position - embeddings so you should be able to pad the inputs on both the right and the left. - - Indices can be obtained using [`SwitchTransformersTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for detail. - - [What are input IDs?](../glossary#input-ids) - - To know more on how to prepare `input_ids` for pretraining take a look a [SWITCH_TRANSFORMERS - Training](./switch_transformers#training). - attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): - Indices of decoder input sequence tokens in the vocabulary. - - Indices can be obtained using [`SwitchTransformersTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are decoder input IDs?](../glossary#decoder-input-ids) - - SWITCH_TRANSFORMERS uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If - `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see - `past_key_values`). - - To know more on how to prepare `decoder_input_ids` for pretraining take a look at [SWITCH_TRANSFORMERS - Training](./switch_transformers#training). - decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): - Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also - be used by default. - encoder_outputs (`tuple(tuple(jnp.ndarray)`, *optional*): - Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*) - `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at - the output of the last layer of the encoder. Used in the cross-attention of the decoder. - past_key_values (`tuple(tuple(jnp.ndarray))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - - - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -class FlaxSwitchTransformersPreTrainedModel(FlaxPreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = SwitchTransformersConfig - base_model_prefix = "transformer" - module_class: nn.Module = None - - def __init__( - self, - config: SwitchTransformersConfig, - input_shape: Tuple[int] = (1, 1), - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - _do_init: bool = True, - gradient_checkpointing: bool = False, - **kwargs - ): - module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - - def enable_gradient_checkpointing(self): - self._module = self.module_class( - config=self.config, - dtype=self.dtype, - gradient_checkpointing=True, - ) - - def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: - # init input tensors - input_ids = jnp.zeros(input_shape, dtype="i4") - - attention_mask = jnp.ones_like(input_ids) - args = [input_ids, attention_mask] - if self.module_class not in [FlaxSwitchTransformersEncoderModule]: - decoder_input_ids = jnp.ones_like(input_ids) - decoder_attention_mask = jnp.ones_like(input_ids) - args.extend([decoder_input_ids, decoder_attention_mask]) - - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} - - random_params = self.module.init( - rngs, - *args, - )["params"] - - if params is not None: - random_params = flatten_dict(unfreeze(random_params)) - params = flatten_dict(unfreeze(params)) - for missing_key in self._missing_keys: - params[missing_key] = random_params[missing_key] - self._missing_keys = set() - return freeze(unflatten_dict(params)) - else: - return random_params - - @add_start_docstrings_to_model_forward(SWITCH_TRANSFORMERS_INPUTS_DOCSTRING) - def __call__( - self, - input_ids: jnp.ndarray, - attention_mask: Optional[jnp.ndarray] = None, - decoder_input_ids: jnp.ndarray = None, - decoder_attention_mask: Optional[jnp.ndarray] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: dict = None, - dropout_rng: PRNGKey = None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - if decoder_input_ids is None: - raise ValueError( - "Make sure to provide both `input_ids` and `decoder_input_ids`. `decoder_input_ids` is not passed" - " here." - ) - - # prepare encoder inputs - if attention_mask is None: - attention_mask = jnp.ones_like(input_ids) - - # prepare decoder inputs - if decoder_attention_mask is None: - decoder_attention_mask = jnp.ones_like(decoder_input_ids) - - # Handle any PRNG if needed - rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} - - return self.module.apply( - {"params": params or self.params}, - input_ids=jnp.array(input_ids, dtype="i4"), - attention_mask=jnp.array(attention_mask, dtype="i4"), - decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), - decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - ) - - def init_cache(self, batch_size, max_length, encoder_outputs): - r""" - Args: - batch_size (`int`): - batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. - max_length (`int`): - maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized - cache. - encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`): - `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: - `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) - is a sequence of hidden-states at the output of the last layer of the encoder. Used in the - cross-attention of the decoder. - """ - # init input variables to retrieve cache - decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") - decoder_attention_mask = jnp.ones_like(decoder_input_ids) - - def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs): - decoder_module = module._get_decoder_module() - return decoder_module( - decoder_input_ids, - decoder_attention_mask, - **kwargs, - ) - - init_variables = self.module.init( - jax.random.PRNGKey(0), - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - encoder_hidden_states=encoder_outputs[0], - init_cache=True, - method=_decoder_forward, # we only need to call the decoder to init the cache - ) - return unfreeze(init_variables["cache"]) - - @add_start_docstrings(SWITCH_TRANSFORMERS_ENCODE_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=SwitchTransformersConfig) - def encode( - self, - input_ids: jnp.ndarray, - attention_mask: Optional[jnp.ndarray] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: dict = None, - dropout_rng: PRNGKey = None, - ): - r""" - Returns: - - Example: - - ```python - >>> from transformers import SwitchTransformersTokenizer, FlaxSwitchTransformersForConditionalGeneration - - >>> tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switch_transformers-base") - >>> model = FlaxSwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switch_transformers-base") - - >>> text = "My friends are cool but they eat too many carbs." - >>> inputs = tokenizer(text, return_tensors="np") - >>> encoder_outputs = model.encode(**inputs) - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - if attention_mask is None: - attention_mask = jnp.ones_like(input_ids) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - def _encoder_forward(module, input_ids, attention_mask, **kwargs): - encode_module = module._get_encoder_module() - return encode_module(input_ids, attention_mask, **kwargs) - - return self.module.apply( - {"params": params or self.params}, - input_ids=jnp.array(input_ids, dtype="i4"), - attention_mask=jnp.array(attention_mask, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - method=_encoder_forward, - ) - - @add_start_docstrings(SWITCH_TRANSFORMERS_DECODE_INPUTS_DOCSTRING) - @replace_return_docstrings( - output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=SwitchTransformersConfig - ) - def decode( - self, - decoder_input_ids, - encoder_outputs, - encoder_attention_mask: Optional[jnp.ndarray] = None, - decoder_attention_mask: Optional[jnp.ndarray] = None, - past_key_values: dict = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: dict = None, - dropout_rng: PRNGKey = None, - ): - r""" - Returns: - - Example: - - ```python - >>> from transformers import SwitchTransformersTokenizer, FlaxSwitchTransformersForConditionalGeneration - >>> import jax.numpy as jnp - - >>> tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switch_transformers-base") - >>> model = FlaxSwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switch_transformers-base") - - >>> text = "My friends are cool but they eat too many carbs." - >>> inputs = tokenizer(text, return_tensors="np") - >>> encoder_outputs = model.encode(**inputs) - - >>> decoder_start_token_id = model.config.decoder_start_token_id - >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id - - >>> outputs = model.decode(decoder_input_ids, encoder_outputs) - >>> logits = outputs.logits - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - encoder_hidden_states = encoder_outputs[0] - if encoder_attention_mask is None: - batch_size, sequence_length = encoder_hidden_states.shape[:2] - encoder_attention_mask = jnp.ones((batch_size, sequence_length)) - - batch_size, sequence_length = decoder_input_ids.shape - if decoder_attention_mask is None: - decoder_attention_mask = jnp.ones((batch_size, sequence_length)) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - inputs = {"params": params or self.params} - - # if past_key_values are passed then cache is already initialized a private flag init_cache has to be - # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that - # it can be changed by FlaxSwitchTransformersAttention module - if past_key_values: - inputs["cache"] = past_key_values - mutable = ["cache"] - else: - mutable = False - - def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs): - decoder_module = module._get_decoder_module() - return decoder_module( - decoder_input_ids, - decoder_attention_mask, - **kwargs, - ) - - outputs = self.module.apply( - inputs, - decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), - decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - mutable=mutable, - method=_decoder_forward, - ) - - # add updated cache to model output - if past_key_values is not None and return_dict: - outputs, past = outputs - outputs["past_key_values"] = unfreeze(past["cache"]) - return outputs - elif past_key_values is not None and not return_dict: - outputs, past = outputs - outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] - - return outputs - - -SWITCH_TRANSFORMERS_START_DOCSTRING = r""" - The SWITCH_TRANSFORMERS model was proposed in [Exploring the Limits of Transfer Learning with a Unified - Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine - Lee, Sharan Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer - pre-trained in a text-to-text denoising generative setting. - - This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a Flax Linen - [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a - regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. - - Finally, this model supports inherent JAX features such as: - - - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) - - Parameters: - config ([`SwitchTransformersConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. - dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): - The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and - `jax.numpy.bfloat16` (on TPUs). - - This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If - specified all the computation will be performed with the given `dtype`. - - **Note that this only specifies the dtype of the computation and does not influence the dtype of model - parameters.** - - If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and - [`~FlaxPreTrainedModel.to_bf16`]. -""" - - -@add_start_docstrings( - "The bare SWITCH_TRANSFORMERS Model transformer outputting raw hidden-stateswithout any specific head on top.", - SWITCH_TRANSFORMERS_START_DOCSTRING, -) -# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5Module with T5->SwitchTransformers -class FlaxSwitchTransformersModule(nn.Module): - config: SwitchTransformersConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - gradient_checkpointing: bool = False - - def _get_encoder_module(self): - return self.encoder - - def _get_decoder_module(self): - return self.decoder - - def setup(self): - self.shared = nn.Embed( - self.config.vocab_size, - self.config.d_model, - embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0), - ) - - encoder_config = copy.deepcopy(self.config) - encoder_config.causal = False - self.encoder = FlaxSwitchTransformersStack( - encoder_config, - embed_tokens=self.shared, - dtype=self.dtype, - gradient_checkpointing=self.gradient_checkpointing, - ) - - decoder_config = copy.deepcopy(self.config) - decoder_config.causal = True - decoder_config.num_layers = self.config.num_decoder_layers - self.decoder = FlaxSwitchTransformersStack( - decoder_config, - embed_tokens=self.shared, - dtype=self.dtype, - gradient_checkpointing=self.gradient_checkpointing, - ) - - def __call__( - self, - input_ids=None, - attention_mask=None, - decoder_input_ids=None, - decoder_attention_mask=None, - encoder_outputs=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - deterministic: bool = True, - ): - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # Encode if needed (training, first prediction pass) - encoder_outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - # Decode - decoder_outputs = self.decoder( - input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - encoder_hidden_states=encoder_outputs[0], - encoder_attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - if not return_dict: - return decoder_outputs + encoder_outputs - - return FlaxSeq2SeqModelOutput( - last_hidden_state=decoder_outputs.last_hidden_state, - past_key_values=decoder_outputs.past_key_values, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ) - - -# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5Model with T5->SwitchTransformers -class FlaxSwitchTransformersModel(FlaxSwitchTransformersPreTrainedModel): - module_class = FlaxSwitchTransformersModule - - -append_call_sample_docstring( - FlaxSwitchTransformersModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC -) - -FLAX_SWITCH_TRANSFORMERS_MODEL_DOCSTRING = """ - Returns: - - Example: - - ```python - >>> from transformers import SwitchTransformersTokenizer, FlaxSwitchTransformersModel - - >>> tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switch_transformers-base") - >>> model = FlaxSwitchTransformersModel.from_pretrained("ybelkada/switch_transformers-base") - - >>> input_ids = tokenizer( - ... "Studies have been shown that owning a dog is good for you", return_tensors="np" - ... ).input_ids - >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="np").input_ids - - >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for SwitchTransformersModel. - >>> # This is not needed for torch's SwitchTransformersForConditionalGeneration as it does this internally using labels arg. - >>> decoder_input_ids = model._shift_right(decoder_input_ids) - - >>> # forward pass - >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) - >>> last_hidden_states = outputs.last_hidden_state - ``` -""" - - -overwrite_call_docstring( - FlaxSwitchTransformersModel, SWITCH_TRANSFORMERS_INPUTS_DOCSTRING + FLAX_SWITCH_TRANSFORMERS_MODEL_DOCSTRING -) -append_replace_return_docstrings( - FlaxSwitchTransformersModel, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC -) - - -@add_start_docstrings( - "The bare SWITCH_TRANSFORMERS Model transformer outputting encoder's raw hidden-states without any specific head" - " on top.", - SWITCH_TRANSFORMERS_START_DOCSTRING, -) -# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5EncoderModule with T5->SwitchTransformers -class FlaxSwitchTransformersEncoderModule(nn.Module): - config: SwitchTransformersConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - gradient_checkpointing: bool = False - - def setup(self): - self.shared = nn.Embed( - self.config.vocab_size, - self.config.d_model, - embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0), - ) - - encoder_config = copy.deepcopy(self.config) - encoder_config.is_decoder = False - encoder_config.is_encoder_decoder = False - encoder_config.causal = False - self.encoder = FlaxSwitchTransformersStack( - encoder_config, - embed_tokens=self.shared, - dtype=self.dtype, - gradient_checkpointing=self.gradient_checkpointing, - ) - - def __call__( - self, - input_ids=None, - attention_mask=None, - output_attentions=False, - output_hidden_states=False, - return_dict: bool = True, - deterministic: bool = True, - ): - - # Encode if needed (training, first prediction pass) - encoder_outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - return encoder_outputs - - -class FlaxSwitchTransformersEncoderModel(FlaxSwitchTransformersPreTrainedModel): - module_class = FlaxSwitchTransformersEncoderModule - - @add_start_docstrings_to_model_forward(SWITCH_TRANSFORMERS_ENCODE_INPUTS_DOCSTRING) - def __call__( - self, - input_ids: jnp.ndarray, - attention_mask: Optional[jnp.ndarray] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: dict = None, - dropout_rng: PRNGKey = None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - # prepare encoder inputs - if attention_mask is None: - attention_mask = jnp.ones_like(input_ids) - - # Handle any PRNG if needed - rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} - - return self.module.apply( - {"params": params or self.params}, - input_ids=jnp.array(input_ids, dtype="i4"), - attention_mask=jnp.array(attention_mask, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - ) - - -@add_start_docstrings( - """SWITCH_TRANSFORMERS Model with a `language modeling` head on top.""", SWITCH_TRANSFORMERS_START_DOCSTRING -) -# Copied from transformers.models.t5.modeling_flax_t5.FlaxT5ForConditionalGenerationModule with T5->SwitchTransformers -class FlaxSwitchTransformersForConditionalGenerationModule(nn.Module): - config: SwitchTransformersConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - gradient_checkpointing: bool = False - - def _get_encoder_module(self): - return self.encoder - - def _get_decoder_module(self): - return self.decoder - - def setup(self): - self.model_dim = self.config.d_model - - self.shared = nn.Embed( - self.config.vocab_size, - self.config.d_model, - embedding_init=jax.nn.initializers.normal(self.config.initializer_factor), - ) - - encoder_config = copy.deepcopy(self.config) - encoder_config.causal = False - encoder_config.use_cache = False - encoder_config.is_encoder_decoder = False - self.encoder = FlaxSwitchTransformersStack( - encoder_config, self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing - ) - - decoder_config = copy.deepcopy(self.config) - decoder_config.causal = True - decoder_config.is_encoder_decoder = False - decoder_config.num_layers = self.config.num_decoder_layers - self.decoder = FlaxSwitchTransformersStack( - decoder_config, self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing - ) - - self.lm_head = nn.Dense( - self.config.vocab_size, - use_bias=False, - kernel_init=jax.nn.initializers.normal(self.config.initializer_factor), - dtype=self.dtype, - ) - - def __call__( - self, - input_ids=None, - attention_mask=None, - decoder_input_ids=None, - decoder_attention_mask=None, - encoder_outputs=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - deterministic: bool = True, - ): - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # Encode - encoder_outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - hidden_states = encoder_outputs[0] - - # Decode - decoder_outputs = self.decoder( - input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - encoder_hidden_states=hidden_states, - encoder_attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=deterministic, - ) - - sequence_output = decoder_outputs[0] - - if self.config.tie_word_embeddings: - # Rescale output before projecting on vocab - # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 - sequence_output = sequence_output * (self.model_dim**-0.5) - - if self.config.tie_word_embeddings: - shared_embedding = self.shared.variables["params"]["embedding"] - lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output) - else: - lm_logits = self.lm_head(sequence_output) - - if not return_dict: - return (lm_logits,) + decoder_outputs[1:] + encoder_outputs - - return FlaxSeq2SeqLMOutput( - logits=lm_logits, - past_key_values=decoder_outputs.past_key_values, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ) - - -class FlaxSwitchTransformersForConditionalGeneration(FlaxSwitchTransformersPreTrainedModel): - module_class = FlaxSwitchTransformersForConditionalGenerationModule - - @add_start_docstrings(SWITCH_TRANSFORMERS_DECODE_INPUTS_DOCSTRING) - @replace_return_docstrings( - output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=SwitchTransformersConfig - ) - def decode( - self, - decoder_input_ids, - encoder_outputs, - encoder_attention_mask: Optional[jnp.ndarray] = None, - decoder_attention_mask: Optional[jnp.ndarray] = None, - past_key_values: dict = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - train: bool = False, - params: dict = None, - dropout_rng: PRNGKey = None, - ): - r""" - Returns: - - Example: - - ```python - >>> from transformers import SwitchTransformersTokenizer, FlaxSwitchTransformersForConditionalGeneration - >>> import jax.numpy as jnp - - >>> tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switch_transformers-base") - >>> model = FlaxSwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switch_transformers-base") - - >>> text = "summarize: My friends are cool but they eat too many carbs." - >>> inputs = tokenizer(text, return_tensors="np") - >>> encoder_outputs = model.encode(**inputs) - - >>> decoder_start_token_id = model.config.decoder_start_token_id - >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id - - >>> outputs = model.decode(decoder_input_ids, encoder_outputs) - >>> logits = outputs.logits - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.return_dict - - encoder_hidden_states = encoder_outputs[0] - if encoder_attention_mask is None: - batch_size, sequence_length = encoder_hidden_states.shape[:2] - encoder_attention_mask = jnp.ones((batch_size, sequence_length)) - - batch_size, sequence_length = decoder_input_ids.shape - if decoder_attention_mask is None: - decoder_attention_mask = jnp.ones((batch_size, sequence_length)) - - # Handle any PRNG if needed - rngs = {} - if dropout_rng is not None: - rngs["dropout"] = dropout_rng - - inputs = {"params": params or self.params} - - # if past_key_values are passed then cache is already initialized a private flag init_cache has to be - # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that - # it can be changed by FlaxSwitchTransformersAttention module - if past_key_values: - inputs["cache"] = past_key_values - mutable = ["cache"] - else: - mutable = False - - def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs): - decoder_module = module._get_decoder_module() - decoder_outputs = decoder_module( - decoder_input_ids, - decoder_attention_mask, - **kwargs, - ) - - sequence_output = decoder_outputs[0] - - if self.config.tie_word_embeddings: - # Rescale output before projecting on vocab - # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 - sequence_output = sequence_output * (self.config.d_model**-0.5) - - if self.config.tie_word_embeddings: - shared_embedding = module.shared.variables["params"]["embedding"] - lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output) - else: - lm_logits = module.lm_head(sequence_output) - - return lm_logits, decoder_outputs - - outputs = self.module.apply( - inputs, - decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), - decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - deterministic=not train, - rngs=rngs, - mutable=mutable, - method=_decoder_forward, - ) - - if past_key_values is None: - lm_logits, decoder_outputs = outputs - else: - (lm_logits, decoder_outputs), past = outputs - - if return_dict: - outputs = FlaxCausalLMOutputWithCrossAttentions( - logits=lm_logits, - hidden_states=decoder_outputs.hidden_states, - attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - ) - else: - outputs = (lm_logits,) + decoder_outputs[1:] - - # add updated cache to model output - if past_key_values is not None and return_dict: - outputs["past_key_values"] = unfreeze(past["cache"]) - return outputs - elif past_key_values is not None and not return_dict: - outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] - - return outputs - - def prepare_inputs_for_generation( - self, - decoder_input_ids, - max_length, - attention_mask: Optional[jnp.DeviceArray] = None, - decoder_attention_mask: Optional[jnp.DeviceArray] = None, - encoder_outputs=None, - **kwargs - ): - # initializing the cache - batch_size, seq_length = decoder_input_ids.shape - - past_key_values = self.init_cache(batch_size, max_length, encoder_outputs) - # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. - # But since the decoder uses a causal mask, those positions are masked anyways. - # Thus we can create a single static attention_mask here, which is more efficient for compilation - extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") - if decoder_attention_mask is not None: - extended_attention_mask = jax.lax.dynamic_update_slice( - extended_attention_mask, decoder_attention_mask, (0, 0) - ) - - return { - "past_key_values": past_key_values, - "encoder_outputs": encoder_outputs, - "encoder_attention_mask": attention_mask, - "decoder_attention_mask": extended_attention_mask, - } - - def update_inputs_for_generation(self, model_outputs, model_kwargs): - model_kwargs["past_key_values"] = model_outputs.past_key_values - return model_kwargs - - -FLAX_SWITCH_TRANSFORMERS_CONDITIONAL_GENERATION_DOCSTRING = """ - Returns: - - Example: - - ```python - >>> from transformers import SwitchTransformersTokenizer, FlaxSwitchTransformersForConditionalGeneration - - >>> tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switch_transformers-base") - >>> model = FlaxSwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switch_transformers-base") - - >>> ARTICLE_TO_SUMMARIZE = "summarize: My friends are cool but they eat too many carbs." - >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], return_tensors="np") - - >>> # Generate Summary - >>> summary_ids = model.generate(inputs["input_ids"]).sequences - >>> print(tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)) - ``` -""" - - -overwrite_call_docstring( - FlaxSwitchTransformersForConditionalGeneration, - SWITCH_TRANSFORMERS_INPUTS_DOCSTRING + FLAX_SWITCH_TRANSFORMERS_CONDITIONAL_GENERATION_DOCSTRING, -) -append_replace_return_docstrings( - FlaxSwitchTransformersForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC -) diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py index f26f9f2625138..953808dab8ad7 100644 --- a/src/transformers/utils/dummy_flax_objects.py +++ b/src/transformers/utils/dummy_flax_objects.py @@ -977,34 +977,6 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) -class FlaxSwitchTransformersEncoderModel(metaclass=DummyObject): - _backends = ["flax"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["flax"]) - - -class FlaxSwitchTransformersForConditionalGeneration(metaclass=DummyObject): - _backends = ["flax"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["flax"]) - - -class FlaxSwitchTransformersModel(metaclass=DummyObject): - _backends = ["flax"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["flax"]) - - -class FlaxSwitchTransformersPreTrainedModel(metaclass=DummyObject): - _backends = ["flax"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["flax"]) - - class FlaxT5EncoderModel(metaclass=DummyObject): _backends = ["flax"] diff --git a/tests/models/switch_transformers/test_modeling_flax_switch_transformers.py b/tests/models/switch_transformers/test_modeling_flax_switch_transformers.py deleted file mode 100644 index e714397770a36..0000000000000 --- a/tests/models/switch_transformers/test_modeling_flax_switch_transformers.py +++ /dev/null @@ -1,1108 +0,0 @@ -# coding=utf-8 -# Copyright 2022 Google SwitchTransformers Authors and HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import tempfile -import unittest - -import numpy as np - -import transformers -from transformers import is_flax_available -from transformers.testing_utils import ( - is_pt_flax_cross_test, - require_flax, - require_sentencepiece, - require_tokenizers, - slow, -) - -from ...generation.test_generation_flax_utils import FlaxGenerationTesterMixin -from ...test_configuration_common import ConfigTester -from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor - - -if is_flax_available(): - import os - - # The slow tests are often failing with OOM error on GPU - # This makes JAX allocate exactly what is needed on demand, and deallocate memory that is no longer needed - # but will be slower as stated here https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html - os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" - - import jax - import jax.numpy as jnp - import optax - from flax.core.frozen_dict import unfreeze - from flax.training.common_utils import onehot - from flax.traverse_util import flatten_dict - from transformers import ( - FLAX_MODEL_MAPPING, - BySwitchTransformersTokenizer, - SwitchTransformersConfig, - SwitchTransformersTokenizer, - ) - from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model - from transformers.models.switch_transformers.modeling_flax_switch_transformers import ( - FlaxSwitchTransformersEncoderModel, - FlaxSwitchTransformersForConditionalGeneration, - FlaxSwitchTransformersModel, - shift_tokens_right, - ) - - -class FlaxSwitchTransformersModelTester: - def __init__( - self, - parent, - vocab_size=99, - batch_size=13, - encoder_seq_length=7, - decoder_seq_length=9, - # For common tests - is_training=True, - use_attention_mask=True, - use_labels=True, - hidden_size=32, - num_hidden_layers=5, - num_attention_heads=4, - d_ff=37, - relative_attention_num_buckets=8, - dropout_rate=0.1, - initializer_factor=0.002, - eos_token_id=1, - pad_token_id=0, - decoder_start_token_id=0, - scope=None, - decoder_layers=None, - ): - - self.parent = parent - self.batch_size = batch_size - self.encoder_seq_length = encoder_seq_length - self.decoder_seq_length = decoder_seq_length - # For common tests - self.seq_length = self.decoder_seq_length - self.is_training = is_training - self.use_attention_mask = use_attention_mask - self.use_labels = use_labels - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.d_ff = d_ff - self.relative_attention_num_buckets = relative_attention_num_buckets - self.dropout_rate = dropout_rate - self.initializer_factor = initializer_factor - self.eos_token_id = eos_token_id - self.pad_token_id = pad_token_id - self.decoder_start_token_id = decoder_start_token_id - self.scope = None - self.decoder_layers = decoder_layers - - def prepare_config_and_inputs(self): - input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size) - decoder_input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) - - attention_mask = None - decoder_attention_mask = None - if self.use_attention_mask: - attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2) - decoder_attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2) - - config = SwitchTransformersConfig( - vocab_size=self.vocab_size, - d_model=self.hidden_size, - d_ff=self.d_ff, - d_kv=self.hidden_size // self.num_attention_heads, - num_layers=self.num_hidden_layers, - num_decoder_layers=self.decoder_layers, - num_heads=self.num_attention_heads, - relative_attention_num_buckets=self.relative_attention_num_buckets, - dropout_rate=self.dropout_rate, - initializer_factor=self.initializer_factor, - eos_token_id=self.eos_token_id, - bos_token_id=self.pad_token_id, - pad_token_id=self.pad_token_id, - decoder_start_token_id=self.decoder_start_token_id, - ) - - return ( - config, - input_ids, - decoder_input_ids, - attention_mask, - decoder_attention_mask, - ) - - def create_and_check_model( - self, - config, - input_ids, - decoder_input_ids, - attention_mask, - decoder_attention_mask, - ): - model = FlaxSwitchTransformersModel(config=config) - result = model( - input_ids=input_ids, - decoder_input_ids=decoder_input_ids, - attention_mask=attention_mask, - decoder_attention_mask=decoder_attention_mask, - ) - result = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) - decoder_output = result.last_hidden_state - encoder_output = result.encoder_last_hidden_state - - self.parent.assertEqual(encoder_output.shape, (self.batch_size, self.encoder_seq_length, self.hidden_size)) - self.parent.assertEqual(decoder_output.shape, (self.batch_size, self.decoder_seq_length, self.hidden_size)) - - def check_use_cache_forward_with_attn_mask( - self, - model_class_name, - config, - input_ids, - decoder_input_ids, - attention_mask, - decoder_attention_mask, - ): - max_decoder_length = 20 - model = model_class_name(config) - - encoder_outputs = model.encode(input_ids) - - # prevent fully zero'd out attention mask - decoder_attention_mask = jnp.ones_like(decoder_attention_mask) - - decoder_attention_mask_cache = jnp.concatenate( - [ - decoder_attention_mask, - jnp.zeros((decoder_attention_mask.shape[0], max_decoder_length - decoder_attention_mask.shape[1])), - ], - axis=-1, - ) - - past_key_values = model.init_cache(decoder_input_ids.shape[0], max_decoder_length, encoder_outputs) - - outputs_cache = model.decode( - decoder_input_ids[:, :-1], - encoder_outputs, - decoder_attention_mask=decoder_attention_mask_cache, - past_key_values=past_key_values, - ) - outputs_cache_next = model.decode( - decoder_input_ids[:, -1:], - encoder_outputs, - past_key_values=outputs_cache.past_key_values, - decoder_attention_mask=decoder_attention_mask_cache, - ) - - outputs = model.decode(decoder_input_ids, encoder_outputs, decoder_attention_mask=decoder_attention_mask) - - diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5]))) - self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}") - - def prepare_config_and_inputs_for_common(self): - config_and_inputs = self.prepare_config_and_inputs() - ( - config, - input_ids, - decoder_input_ids, - attention_mask, - decoder_attention_mask, - ) = config_and_inputs - - inputs_dict = { - "input_ids": input_ids, - "attention_mask": attention_mask, - "decoder_input_ids": decoder_input_ids, - "decoder_attention_mask": decoder_attention_mask, - } - return config, inputs_dict - - -@require_flax -class FlaxSwitchTransformersModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase): - - all_model_classes = ( - (FlaxSwitchTransformersModel, FlaxSwitchTransformersForConditionalGeneration) if is_flax_available() else () - ) - all_generative_model_classes = (FlaxSwitchTransformersForConditionalGeneration,) if is_flax_available() else () - is_encoder_decoder = True - - def setUp(self): - self.model_tester = FlaxSwitchTransformersModelTester(self) - self.config_tester = ConfigTester(self, config_class=SwitchTransformersConfig, d_model=37) - - def test_config(self): - self.config_tester.run_common_tests() - - def test_model(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_model(*config_and_inputs) - - def test_model_v1_1(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - # check that gated gelu feed forward and different word embeddings work - config = config_and_inputs[0] - config.tie_word_embeddings = False - config.feed_forward_proj = "gated-gelu" - self.model_tester.create_and_check_model(config, *config_and_inputs[1:]) - - def test_use_cache_forward_with_attn_mask(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for model_class in self.all_model_classes: - self.model_tester.check_use_cache_forward_with_attn_mask(model_class, *config_and_inputs) - - def test_encode(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - with self.subTest(model_class.__name__): - prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) - model = model_class(config) - - @jax.jit - def encode_jitted(input_ids, attention_mask=None, **kwargs): - return model.encode(input_ids=input_ids, attention_mask=attention_mask) - - with self.subTest("JIT Enabled"): - jitted_outputs = encode_jitted(**prepared_inputs_dict).to_tuple() - - with self.subTest("JIT Disabled"): - with jax.disable_jit(): - outputs = encode_jitted(**prepared_inputs_dict).to_tuple() - - self.assertEqual(len(outputs), len(jitted_outputs)) - for jitted_output, output in zip(jitted_outputs, outputs): - self.assertEqual(jitted_output.shape, output.shape) - - def test_decode(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - with self.subTest(model_class.__name__): - model = model_class(config) - encoder_outputs = model.encode(inputs_dict["input_ids"], inputs_dict["attention_mask"]) - - prepared_inputs_dict = { - "decoder_input_ids": inputs_dict["decoder_input_ids"], - "decoder_attention_mask": inputs_dict["decoder_attention_mask"], - "encoder_outputs": encoder_outputs, - } - - @jax.jit - def decode_jitted(decoder_input_ids, decoder_attention_mask, encoder_outputs): - return model.decode( - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - encoder_outputs=encoder_outputs, - ) - - with self.subTest("JIT Enabled"): - jitted_outputs = decode_jitted(**prepared_inputs_dict).to_tuple() - - with self.subTest("JIT Disabled"): - with jax.disable_jit(): - outputs = decode_jitted(**prepared_inputs_dict).to_tuple() - - self.assertEqual(len(outputs), len(jitted_outputs)) - for jitted_output, output in zip(jitted_outputs, outputs): - self.assertEqual(jitted_output.shape, output.shape) - - def test_shift_right(self): - decoder_start_token_id = 0 - pad_token_id = 1 - labels = np.arange(2, 102).reshape(5, 20) - labels[:2, 15:] = -100 - - decoder_input_ids = shift_tokens_right(labels, pad_token_id, decoder_start_token_id) - np_decoder_input_ids = np.array(decoder_input_ids) - - padded_slice = np_decoder_input_ids[:2, (15 + 1) :] - self.assertTrue((padded_slice == 1).all()) - - not_padded_slice = np_decoder_input_ids[2:, 1:] - rolled_labels = np.roll(labels[2:], 1)[:, 1:] - self.assertTrue((not_padded_slice == rolled_labels).all()) - self.assertTrue((np_decoder_input_ids[:, 0] == 0).all()) - - # overwrite since special base model prefix is used - def test_save_load_from_base(self): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - base_class = FLAX_MODEL_MAPPING[config.__class__] - - for model_class in self.all_model_classes: - if model_class == base_class: - continue - - model = base_class(config) - base_params = flatten_dict(unfreeze(model.params)) - - # check that all base model weights are loaded correctly - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - head_model = model_class.from_pretrained(tmpdirname) - - base_param_from_head = flatten_dict(unfreeze(head_model.params)) - - for key in base_param_from_head.keys(): - max_diff = (base_params[key] - base_param_from_head[key]).sum().item() - self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") - - # overwrite since special base model prefix is used - def test_save_load_to_base(self): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - base_class = FLAX_MODEL_MAPPING[config.__class__] - - for model_class in self.all_model_classes: - if model_class == base_class: - continue - - model = model_class(config) - base_params_from_head = flatten_dict(unfreeze(model.params)) - - # check that all base model weights are loaded correctly - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - base_model = base_class.from_pretrained(tmpdirname) - - base_params = flatten_dict(unfreeze(base_model.params)) - - for key in base_params_from_head.keys(): - max_diff = (base_params[key] - base_params_from_head[key]).sum().item() - self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") - - # overwrite since special base model prefix is used - @is_pt_flax_cross_test - def test_save_load_from_base_pt(self): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - base_class = FLAX_MODEL_MAPPING[config.__class__] - - for model_class in self.all_model_classes: - if model_class == base_class: - continue - - model = base_class(config) - base_params = flatten_dict(unfreeze(model.params)) - - # convert Flax model to PyTorch model - pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning - pt_model = pt_model_class(config).eval() - pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params) - - # check that all base model weights are loaded correctly - with tempfile.TemporaryDirectory() as tmpdirname: - # save pt model - pt_model.save_pretrained(tmpdirname) - head_model = model_class.from_pretrained(tmpdirname, from_pt=True) - - base_param_from_head = flatten_dict(unfreeze(head_model.params)) - - for key in base_param_from_head.keys(): - max_diff = (base_params[key] - base_param_from_head[key]).sum().item() - self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") - - # overwrite since special base model prefix is used - @is_pt_flax_cross_test - def test_save_load_to_base_pt(self): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - base_class = FLAX_MODEL_MAPPING[config.__class__] - - for model_class in self.all_model_classes: - if model_class == base_class: - continue - - model = model_class(config) - base_params_from_head = flatten_dict(unfreeze(model.params)) - - # convert Flax model to PyTorch model - pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning - pt_model = pt_model_class(config).eval() - pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params) - - # check that all base model weights are loaded correctly - with tempfile.TemporaryDirectory() as tmpdirname: - pt_model.save_pretrained(tmpdirname) - base_model = base_class.from_pretrained(tmpdirname, from_pt=True) - - base_params = flatten_dict(unfreeze(base_model.params)) - - for key in base_params_from_head.keys(): - max_diff = (base_params[key] - base_params_from_head[key]).sum().item() - self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") - - # overwrite since special base model prefix is used - @is_pt_flax_cross_test - def test_save_load_bf16_to_base_pt(self): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - base_class = FLAX_MODEL_MAPPING[config.__class__] - - for model_class in self.all_model_classes: - if model_class == base_class: - continue - - model = model_class(config) - model.params = model.to_bf16(model.params) - base_params_from_head = flatten_dict(unfreeze(model.params)) - - # convert Flax model to PyTorch model - pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning - pt_model = pt_model_class(config).eval() - pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params) - - # check that all base model weights are loaded correctly - with tempfile.TemporaryDirectory() as tmpdirname: - pt_model.save_pretrained(tmpdirname) - base_model = base_class.from_pretrained(tmpdirname, from_pt=True) - - base_params = flatten_dict(unfreeze(base_model.params)) - - for key in base_params_from_head.keys(): - max_diff = (base_params[key] - base_params_from_head[key]).sum().item() - self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") - - -class FlaxSwitchTransformersEncoderOnlyModelTester: - def __init__( - self, - parent, - vocab_size=99, - batch_size=13, - encoder_seq_length=7, - # For common tests - is_training=True, - use_attention_mask=True, - use_labels=True, - hidden_size=32, - num_hidden_layers=5, - num_attention_heads=4, - d_ff=37, - relative_attention_num_buckets=8, - dropout_rate=0.1, - initializer_factor=0.002, - eos_token_id=1, - pad_token_id=0, - decoder_start_token_id=0, - scope=None, - ): - - self.parent = parent - self.batch_size = batch_size - self.encoder_seq_length = encoder_seq_length - # For common tests - self.seq_length = self.encoder_seq_length - self.is_training = is_training - self.use_attention_mask = use_attention_mask - self.use_labels = use_labels - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.d_ff = d_ff - self.relative_attention_num_buckets = relative_attention_num_buckets - self.dropout_rate = dropout_rate - self.initializer_factor = initializer_factor - self.eos_token_id = eos_token_id - self.pad_token_id = pad_token_id - self.decoder_start_token_id = decoder_start_token_id - self.scope = None - self.decoder_layers = 0 - - def prepare_config_and_inputs(self): - input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size) - - attention_mask = None - if self.use_attention_mask: - attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2) - - config = SwitchTransformersConfig( - vocab_size=self.vocab_size, - d_model=self.hidden_size, - d_ff=self.d_ff, - d_kv=self.hidden_size // self.num_attention_heads, - num_layers=self.num_hidden_layers, - num_decoder_layers=self.decoder_layers, - num_heads=self.num_attention_heads, - relative_attention_num_buckets=self.relative_attention_num_buckets, - dropout_rate=self.dropout_rate, - initializer_factor=self.initializer_factor, - eos_token_id=self.eos_token_id, - bos_token_id=self.pad_token_id, - pad_token_id=self.pad_token_id, - decoder_start_token_id=self.decoder_start_token_id, - is_encoder_decoder=False, - ) - - return ( - config, - input_ids, - attention_mask, - ) - - def create_and_check_model( - self, - config, - input_ids, - attention_mask, - ): - model = FlaxSwitchTransformersEncoderModel(config=config) - result = model( - input_ids=input_ids, - attention_mask=attention_mask, - ) - result = model(input_ids=input_ids) - encoder_output = result.last_hidden_state - - self.parent.assertEqual(encoder_output.shape, (self.batch_size, self.encoder_seq_length, self.hidden_size)) - - def prepare_config_and_inputs_for_common(self): - config_and_inputs = self.prepare_config_and_inputs() - ( - config, - input_ids, - attention_mask, - ) = config_and_inputs - - inputs_dict = { - "input_ids": input_ids, - "attention_mask": attention_mask, - } - return config, inputs_dict - - -@require_flax -class FlaxSwitchTransformersEncoderOnlyModelTest(FlaxModelTesterMixin, unittest.TestCase): - - all_model_classes = (FlaxSwitchTransformersEncoderModel,) if is_flax_available() else () - is_encoder_decoder = False - - def setUp(self): - self.model_tester = FlaxSwitchTransformersEncoderOnlyModelTester(self) - self.config_tester = ConfigTester(self, config_class=SwitchTransformersConfig, d_model=37) - - def test_config(self): - self.config_tester.run_common_tests() - - def test_model(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_model(*config_and_inputs) - - def test_model_v1_1(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - # check that gated gelu feed forward and different word embeddings work - config = config_and_inputs[0] - config.tie_word_embeddings = False - config.feed_forward_proj = "gated-gelu" - self.model_tester.create_and_check_model(config, *config_and_inputs[1:]) - - def test_encode(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - with self.subTest(model_class.__name__): - prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) - model = model_class(config) - - @jax.jit - def encode_jitted(input_ids, attention_mask=None, **kwargs): - return model(input_ids=input_ids, attention_mask=attention_mask) - - with self.subTest("JIT Enabled"): - jitted_outputs = encode_jitted(**prepared_inputs_dict).to_tuple() - - with self.subTest("JIT Disabled"): - with jax.disable_jit(): - outputs = encode_jitted(**prepared_inputs_dict).to_tuple() - - self.assertEqual(len(outputs), len(jitted_outputs)) - for jitted_output, output in zip(jitted_outputs, outputs): - self.assertEqual(jitted_output.shape, output.shape) - - # overwrite since special base model prefix is used - def test_save_load_from_base(self): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - base_class = FLAX_MODEL_MAPPING[config.__class__] - - for model_class in self.all_model_classes: - if model_class == base_class: - continue - - model = base_class(config) - base_params = flatten_dict(unfreeze(model.params)) - - # check that all base model weights are loaded correctly - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - head_model = model_class.from_pretrained(tmpdirname) - - base_param_from_head = flatten_dict(unfreeze(head_model.params)) - - for key in base_param_from_head.keys(): - max_diff = (base_params[key] - base_param_from_head[key]).sum().item() - self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") - - # overwrite since special base model prefix is used - def test_save_load_to_base(self): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - base_class = FLAX_MODEL_MAPPING[config.__class__] - - for model_class in self.all_model_classes: - if model_class == base_class: - continue - - model = model_class(config) - base_params_from_head = flatten_dict(unfreeze(model.params)) - - # check that all base model weights are loaded correctly - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - base_model = base_class.from_pretrained(tmpdirname) - - base_params = flatten_dict(unfreeze(base_model.params)) - - for key in base_params_from_head.keys(): - max_diff = (base_params[key] - base_params_from_head[key]).sum().item() - self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") - - # overwrite since special base model prefix is used - @is_pt_flax_cross_test - def test_save_load_from_base_pt(self): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - base_class = FLAX_MODEL_MAPPING[config.__class__] - - for model_class in self.all_model_classes: - if model_class == base_class: - continue - - model = base_class(config) - base_params = flatten_dict(unfreeze(model.params)) - - # convert Flax model to PyTorch model - pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning - pt_model = pt_model_class(config).eval() - pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params) - - # check that all base model weights are loaded correctly - with tempfile.TemporaryDirectory() as tmpdirname: - # save pt model - pt_model.save_pretrained(tmpdirname) - head_model = model_class.from_pretrained(tmpdirname, from_pt=True) - - base_param_from_head = flatten_dict(unfreeze(head_model.params)) - - for key in base_param_from_head.keys(): - max_diff = (base_params[key] - base_param_from_head[key]).sum().item() - self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") - - # overwrite since special base model prefix is used - @is_pt_flax_cross_test - def test_save_load_to_base_pt(self): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - base_class = FLAX_MODEL_MAPPING[config.__class__] - - for model_class in self.all_model_classes: - if model_class == base_class: - continue - - model = model_class(config) - base_params_from_head = flatten_dict(unfreeze(model.params)) - - # convert Flax model to PyTorch model - pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning - pt_model = pt_model_class(config).eval() - pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params) - - # check that all base model weights are loaded correctly - with tempfile.TemporaryDirectory() as tmpdirname: - pt_model.save_pretrained(tmpdirname) - base_model = base_class.from_pretrained(tmpdirname, from_pt=True) - - base_params = flatten_dict(unfreeze(base_model.params)) - - for key in base_params_from_head.keys(): - max_diff = (base_params[key] - base_params_from_head[key]).sum().item() - self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") - - # overwrite since special base model prefix is used - @is_pt_flax_cross_test - def test_save_load_bf16_to_base_pt(self): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - base_class = FLAX_MODEL_MAPPING[config.__class__] - - for model_class in self.all_model_classes: - if model_class == base_class: - continue - - model = model_class(config) - model.params = model.to_bf16(model.params) - base_params_from_head = flatten_dict(unfreeze(model.params)) - - # convert Flax model to PyTorch model - pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning - pt_model = pt_model_class(config).eval() - pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params) - - # check that all base model weights are loaded correctly - with tempfile.TemporaryDirectory() as tmpdirname: - pt_model.save_pretrained(tmpdirname) - base_model = base_class.from_pretrained(tmpdirname, from_pt=True) - - base_params = flatten_dict(unfreeze(base_model.params)) - - for key in base_params_from_head.keys(): - max_diff = (base_params[key] - base_params_from_head[key]).sum().item() - self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") - - -@require_sentencepiece -@require_tokenizers -@require_flax -class FlaxSwitchTransformersModelIntegrationTests(unittest.TestCase): - @slow - def test_small_integration_test(self): - """ - For comparision run: - >>> import switch_transformers # pip install switch_transformers==0.7.1 - >>> from switch_transformers.data.sentencepiece_vocabulary import SentencePieceVocabulary - - >>> path_to_mtf_small_switch_transformers_checkpoint = '' - >>> path_to_mtf_small_spm_model_path = '' - >>> switch_transformers_model = switch_transformers.models.MtfModel(model_dir=path_to_mtf_small_switch_transformers_checkpoint, batch_size=1, tpu=None) - >>> vocab = SentencePieceVocabulary(path_to_mtf_small_spm_model_path, extra_ids=100) - >>> score = switch_transformers_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) - """ - - model = FlaxSwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switch_transformers-base") - tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switch_transformers-base") - - input_ids = tokenizer("Hello there", return_tensors="np").input_ids - labels = tokenizer("Hi I am", return_tensors="np").input_ids - - decoder_input_ids = shift_tokens_right(labels, model.config.pad_token_id, model.config.decoder_start_token_id) - - logits = model(input_ids, decoder_input_ids=decoder_input_ids).logits - - loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])).mean() - mtf_score = -(labels.shape[-1] * loss.item()) - - EXPECTED_SCORE = -19.0845 - self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4) - - @slow - def test_small_v1_1_integration_test(self): - """ - For comparision run: - >>> import switch_transformers # pip install switch_transformers==0.7.1 - >>> from switch_transformers.data.sentencepiece_vocabulary import SentencePieceVocabulary - - >>> path_to_mtf_small_switch_transformers_v1_1_checkpoint = '' - >>> path_to_mtf_small_spm_model_path = '' - >>> switch_transformers_model = switch_transformers.models.MtfModel(model_dir=path_to_mtf_small_switch_transformers_v1_1_checkpoint, batch_size=1, tpu=None) - >>> vocab = SentencePieceVocabulary(path_to_mtf_small_spm_model_path, extra_ids=100) - >>> score = switch_transformers_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) - """ - - model = FlaxSwitchTransformersForConditionalGeneration.from_pretrained("google/switch_transformers-v1_1-small") - tokenizer = SwitchTransformersTokenizer.from_pretrained("google/switch_transformers-v1_1-small") - - input_ids = tokenizer("Hello there", return_tensors="np").input_ids - labels = tokenizer("Hi I am", return_tensors="np").input_ids - - decoder_input_ids = shift_tokens_right(labels, model.config.pad_token_id, model.config.decoder_start_token_id) - - logits = model(input_ids, decoder_input_ids=decoder_input_ids).logits - loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])).mean() - - mtf_score = -(labels.shape[-1] * loss.item()) - - EXPECTED_SCORE = -59.0293 - self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4) - - @slow - def test_small_byswitch_transformers_integration_test(self): - """ - For comparision run: - >>> import switch_transformers # pip install switch_transformers==0.9.1 - - >>> path_to_byswitch_transformers_small_checkpoint = '' - >>> switch_transformers_model = switch_transformers.models.MtfModel(model_dir=path_to_tf_checkpoint, batch_size=1, tpu=None) - >>> vocab = switch_transformers.data.ByteVocabulary() - >>> score = switch_transformers_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) - """ - - model = FlaxSwitchTransformersForConditionalGeneration.from_pretrained( - "google/byybelkada/switch_transformers-base" - ) - tokenizer = BySwitchTransformersTokenizer.from_pretrained("google/byybelkada/switch_transformers-base") - - input_ids = tokenizer("Hello there", return_tensors="np").input_ids - labels = tokenizer("Hi I am", return_tensors="np").input_ids - - decoder_input_ids = shift_tokens_right(labels, model.config.pad_token_id, model.config.decoder_start_token_id) - - logits = model(input_ids, decoder_input_ids=decoder_input_ids).logits - loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])).mean() - - mtf_score = -(labels.shape[-1] * loss.item()) - - EXPECTED_SCORE = -60.7397 - self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4) - - @slow - def test_small_generation(self): - model = FlaxSwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switch_transformers-base") - model.config.max_length = 8 - model.config.num_beams = 1 - model.config.do_sample = False - tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switch_transformers-base") - - input_ids = tokenizer("summarize: Hello there", return_tensors="np").input_ids - - sequences = model.generate(input_ids).sequences - - output_str = tokenizer.batch_decode(sequences, skip_special_tokens=True)[0] - self.assertTrue(output_str == "Hello there!") - - @slow - def test_summarization(self): - model = FlaxSwitchTransformersForConditionalGeneration.from_pretrained("switch_transformers-base") - tok = SwitchTransformersTokenizer.from_pretrained("switch_transformers-base") - - FRANCE_ARTICLE = ( # @noqa - "Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings" - " Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane." - ' Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation."' - ' He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s' - " comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video" - " showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French" - " Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a" - " phone at the wreckage site. The two publications described the supposed video, but did not post it on" - " their websites. The publications said that they watched the video, which was found by a source close to" - " the investigation. \"One can hear cries of 'My God' in several languages,\" Paris Match reported." - ' "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the' - " cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the" - ' screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt,' - " editor-in-chief of Bild online. An official with France's accident investigation agency, the BEA, said" - " the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie spokesman" - " in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the" - ' reports were "completely wrong" and "unwarranted." Cell phones have been collected at the site, he said,' - ' but that they "hadn\'t been exploited yet." Menichini said he believed the cell phones would need to be' - " sent to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by" - " specialized technicians working hand-in-hand with investigators. But none of the cell phones found so" - " far have been sent to the institute, Menichini said. Asked whether staff involved in the search could" - ' have leaked a memory card to the media, Menichini answered with a categorical "no." Reichelt told "Erin' - ' Burnett: Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match' - ' are "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered' - ' cell phones from the crash site after Bild and Paris Match published their reports. "That is something' - " we did not know before. ... Overall we can say many things of the investigation weren't revealed by the" - ' investigation at the beginning," he said. What was mental state of Germanwings co-pilot? German airline' - " Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled depression years before he took the" - " controls of Germanwings Flight 9525, which he's accused of deliberately crashing last week in the" - ' French Alps. Lubitz told his Lufthansa flight training school in 2009 that he had a "previous episode of' - ' severe depression," the airline said Tuesday. Email correspondence between Lubitz and the school' - " discovered in an internal investigation, Lufthansa said, included medical documents he submitted in" - " connection with resuming his flight training. The announcement indicates that Lufthansa, the parent" - " company of Germanwings, knew of Lubitz's battle with depression, allowed him to continue training and" - " ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously said Lubitz was 100%" - ' fit to fly, described its statement Tuesday as a "swift and seamless clarification" and said it was' - " sharing the information and documents -- including training and medical records -- with public" - " prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the" - " past week to recover human remains and plane debris scattered across a steep mountainside. He saw the" - " crisis center set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash" - " site, where grieving families have left flowers at a simple stone memorial. Menichini told CNN late" - " Tuesday that no visible human remains were left at the site but recovery teams would keep searching." - " French President Francois Hollande, speaking Tuesday, said that it should be possible to identify all" - " the victims using DNA analysis by the end of the week, sooner than authorities had previously suggested." - " In the meantime, the recovery of the victims' personal belongings will start Wednesday, Menichini said." - " Among those personal belongings could be more cell phones belonging to the 144 passengers and six crew" - " on board. Check out the latest from our correspondents . The details about Lubitz's correspondence with" - " the flight school during his training were among several developments as investigators continued to" - " delve into what caused the crash and Lubitz's possible motive for downing the jet. A Lufthansa" - " spokesperson told CNN on Tuesday that Lubitz had a valid medical certificate, had passed all his" - ' examinations and "held all the licenses required." Earlier, a spokesman for the prosecutor\'s office in' - " Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz suffered from suicidal tendencies at" - " some point before his aviation career and underwent psychotherapy before he got his pilot's license." - " Kumpa emphasized there's no evidence suggesting Lubitz was suicidal or acting aggressively before the" - " crash. Investigators are looking into whether Lubitz feared his medical condition would cause him to" - " lose his pilot's license, a European government official briefed on the investigation told CNN on" - ' Tuesday. While flying was "a big part of his life," the source said, it\'s only one theory being' - " considered. Another source, a law enforcement official briefed on the investigation, also told CNN that" - " authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would" - " not be allowed to fly because of his medical problems. Lubitz's girlfriend told investigators he had" - " seen an eye doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded" - " he had psychological issues, the European government official said. But no matter what details emerge" - " about his previous mental health struggles, there's more to the story, said Brian Russell, a forensic" - ' psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the fact' - " that maybe they weren't going to keep doing their job and they're upset about that and so they're" - ' suicidal," he said. "But there is no mental illness that explains why somebody then feels entitled to' - " also take that rage and turn it outward on 149 other people who had nothing to do with the person's" - ' problems." Germanwings crash compensation: What we know . Who was the captain of Germanwings Flight' - " 9525? CNN's Margot Haddad reported from Marseille and Pamela Brown from Dusseldorf, while Laura" - " Smith-Spark wrote from London. CNN's Frederik Pleitgen, Pamela Boykoff, Antonia Mortensen, Sandrine" - " Amiel and Anna-Maja Rappard contributed to this report." - ) - SHORTER_ARTICLE = ( - "(CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on" - " Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The" - " formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based." - " The Palestinians signed the ICC's founding Rome Statute in January, when they also accepted its" - ' jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East' - ' Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the' - " situation in Palestinian territories, paving the way for possible war crimes investigations against" - " Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and" - " the United States, neither of which is an ICC member, opposed the Palestinians' efforts to join the" - " body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday's ceremony, said it was a" - ' move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the' - ' world is also a step closer to ending a long era of impunity and injustice," he said, according to an' - ' ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge' - " Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the" - ' Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine' - " acquires all the rights as well as responsibilities that come with being a State Party to the Statute." - ' These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights' - ' Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should' - " immediately end their pressure, and countries that support universal acceptance of the court's treaty" - ' should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the' - " group. \"What's objectionable is the attempts to undermine international justice, not Palestine's" - ' decision to join a treaty to which over 100 countries around the world are members." In January, when' - " the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an" - ' outrage, saying the court was overstepping its boundaries. The United States also said it "strongly"' - " disagreed with the court's decision. \"As we have said repeatedly, we do not believe that Palestine is a" - ' state and therefore we do not believe that it is eligible to join the ICC," the State Department said in' - ' a statement. It urged the warring sides to resolve their differences through direct negotiations. "We' - ' will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace,"' - " it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the" - ' territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the' - " court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou" - ' Bensouda said her office would "conduct its analysis in full independence and impartiality." The war' - " between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry" - " will include alleged war crimes committed since June. The International Criminal Court was set up in" - " 2002 to prosecute genocide, crimes against humanity and war crimes. CNN's Vasco Cotovio, Kareem Khadder" - " and Faith Karimi contributed to this report." - ) - IRAN_ARTICLE = ( - "(CNN)The United States and its negotiating partners reached a very strong framework agreement with Iran" - " in Lausanne, Switzerland, on Thursday that limits Iran's nuclear program in such a way as to effectively" - " block it from building a nuclear weapon. Expect pushback anyway, if the recent past is any harbinger." - " Just last month, in an attempt to head off such an agreement, House Speaker John Boehner invited Israeli" - " Prime Minister Benjamin Netanyahu to preemptively blast it before Congress, and 47 senators sent a" - " letter to the Iranian leadership warning them away from a deal. The debate that has already begun since" - " the announcement of the new framework will likely result in more heat than light. It will not be helped" - " by the gathering swirl of dubious assumptions and doubtful assertions. Let us address some of these: ." - " The most misleading assertion, despite universal rejection by experts, is that the negotiations'" - " objective at the outset was the total elimination of any nuclear program in Iran. That is the position" - " of Netanyahu and his acolytes in the U.S. Congress. But that is not and never was the objective. If it" - " had been, there would have been no Iranian team at the negotiating table. Rather, the objective has" - " always been to structure an agreement or series of agreements so that Iran could not covertly develop a" - " nuclear arsenal before the United States and its allies could respond. The new framework has exceeded" - " expectations in achieving that goal. It would reduce Iran's low-enriched uranium stockpile, cut by" - " two-thirds its number of installed centrifuges and implement a rigorous inspection regime. Another" - " dubious assumption of opponents is that the Iranian nuclear program is a covert weapons program. Despite" - " sharp accusations by some in the United States and its allies, Iran denies having such a program, and" - " U.S. intelligence contends that Iran has not yet made the decision to build a nuclear weapon. Iran's" - " continued cooperation with International Atomic Energy Agency inspections is further evidence on this" - " point, and we'll know even more about Iran's program in the coming months and years because of the deal." - " In fact, the inspections provisions that are part of this agreement are designed to protect against any" - " covert action by the Iranians. What's more, the rhetoric of some members of Congress has implied that" - " the negotiations have been between only the United States and Iran (i.e., the 47 senators' letter" - " warning that a deal might be killed by Congress or a future president). This of course is not the case." - " The talks were between Iran and the five permanent members of the U.N. Security Council (United States," - " United Kingdom, France, China and Russia) plus Germany, dubbed the P5+1. While the United States has" - " played a leading role in the effort, it negotiated the terms alongside its partners. If the agreement" - " reached by the P5+1 is rejected by Congress, it could result in an unraveling of the sanctions on Iran" - " and threaten NATO cohesion in other areas. Another questionable assertion is that this agreement" - " contains a sunset clause, after which Iran will be free to do as it pleases. Again, this is not the" - " case. Some of the restrictions on Iran's nuclear activities, such as uranium enrichment, will be eased" - " or eliminated over time, as long as 15 years. But most importantly, the framework agreement includes" - " Iran's ratification of the Additional Protocol, which allows IAEA inspectors expanded access to nuclear" - " sites both declared and nondeclared. This provision will be permanent. It does not sunset. Thus, going" - " forward, if Iran decides to enrich uranium to weapons-grade levels, monitors will be able to detect such" - " a move in a matter of days and alert the U.N. Security Council. Many in Congress have said that the" - ' agreement should be a formal treaty requiring the Senate to "advise and consent." But the issue is not' - " suited for a treaty. Treaties impose equivalent obligations on all signatories. For example, the New" - " START treaty limits Russia and the United States to 1,550 deployed strategic warheads. But any agreement" - " with Iran will not be so balanced. The restrictions and obligations in the final framework agreement" - " will be imposed almost exclusively on Iran. The P5+1 are obligated only to ease and eventually remove" - " most but not all economic sanctions, which were imposed as leverage to gain this final deal. Finally" - " some insist that any agreement must address Iranian missile programs, human rights violations or support" - " for Hamas or Hezbollah. As important as these issues are, and they must indeed be addressed, they are" - " unrelated to the most important aim of a nuclear deal: preventing a nuclear Iran. To include them in" - " the negotiations would be a poison pill. This agreement should be judged on its merits and on how it" - " affects the security of our negotiating partners and allies, including Israel. Those judgments should be" - " fact-based, not based on questionable assertions or dubious assumptions." - ) - ARTICLE_SUBWAY = ( - "New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A" - " year later, she got married again in Westchester County, but to a different man and without divorcing" - " her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos" - ' declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married' - " once more, this time in the Bronx. In an application for a marriage license, she stated it was her" - ' "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false' - ' instrument for filing in the first degree," referring to her false statements on the 2010 marriage' - " license application, according to court documents. Prosecutors said the marriages were part of an" - " immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to" - " her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was" - " arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New" - " York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total," - " Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All" - " occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be" - " married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors" - " said the immigration scam involved some of her husbands, who filed for permanent residence status" - " shortly after the marriages. Any divorces happened only after such filings were approved. It was" - " unclear whether any of the men will be prosecuted. The case was referred to the Bronx District" - " Attorney's Office by Immigration and Customs Enforcement and the Department of Homeland Security's" - ' Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt,' - " Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his" - " native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces" - " up to four years in prison. Her next court appearance is scheduled for May 18." - ) - - expected_summaries = [ - 'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a' - " cell phone video of the final seconds . \"one can hear cries of 'My God' in several languages,\" one" - " magazine says . all 150 on board were killed when germanwings flight 9525 crashed .", - "the formal accession was marked by a ceremony at The Hague, in the Netherlands . the ICC opened a" - " preliminary examination into the situation in the occupied Palestinian territory . as members of the" - " court, Palestinians may be subject to counter-charges as well .", - "the u.s. and its negotiating partners reached a very strong framework agreement with Iran . aaron miller:" - " the debate that has already begun since the announcement of the new framework will likely result in more" - " heat than light . he says the new framework would reduce Iran's low-enriched uranium stockpile and cut" - " centrifuges . miller: if it had been, there would have been no Iranian team at the table .", - "prosecutors say the marriages were part of an immigration scam . if convicted, barrientos faces two" - ' criminal counts of "offering a false instrument for filing in the first degree" she has been married 10' - " times, with nine of her marriages occurring between 1999 and 2002 .", - ] - - dct = tok( - ["summarize: " + x for x in [FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY]], - padding="max_length", - truncation=True, - return_tensors="np", - ) - self.assertEqual(512, dct["input_ids"].shape[1]) - - hypotheses_batch = model.generate( - **dct, - num_beams=4, - length_penalty=2.0, - max_length=142, - min_length=56, - do_sample=False, - early_stopping=True, - ).sequences - - decoded = tok.batch_decode(hypotheses_batch, skip_special_tokens=True, clean_up_tokenization_spaces=False) - self.assertListEqual( - expected_summaries, - decoded, - ) From ccaaf61d70aa52d2c7aa3e82453dea71a9d3657f Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 26 Oct 2022 15:19:55 +0000 Subject: [PATCH 038/102] more tests pass --- .../configuration_switch_transformers.py | 4 +- .../modeling_switch_transformers.py | 70 ++----------------- .../test_modeling_switch_transformers.py | 59 +++++++--------- 3 files changed, 33 insertions(+), 100 deletions(-) diff --git a/src/transformers/models/switch_transformers/configuration_switch_transformers.py b/src/transformers/models/switch_transformers/configuration_switch_transformers.py index 2848154da3f75..35710d01c4e71 100644 --- a/src/transformers/models/switch_transformers/configuration_switch_transformers.py +++ b/src/transformers/models/switch_transformers/configuration_switch_transformers.py @@ -23,7 +23,7 @@ logger = logging.get_logger(__name__) SWITCH_TRANSFORMERS_PRETRAINED_CONFIG_ARCHIVE_MAP = { - "ybelkada/switch_transformers-base": ( + "HFLAY/switch_base_8": ( "https://huggingface.co/ybelkada/switch_transformers-base/resolve/main/config.json" ), } @@ -116,7 +116,7 @@ def __init__( num_decoder_layers=12, num_sparse_decoder_layers=3, num_heads=12, - num_experts=64, + num_experts=8, expert_capacity=1, router_type="tokens_masked", router_bias=False, diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 3a580e703f227..20e770bcd2b06 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -299,7 +299,7 @@ def forward(self, token_inputs: torch.Tensor, apply_jitter: bool = True, **kwarg else: padding_mask = None - expert_index = self._compute_routing_instructions(router_probs, padding_mask, **kwargs) + expert_index = self._compute_routing_instructions(router_probs, **kwargs) expert_index = torch.nn.functional.one_hot(expert_index, num_classes=self.num_experts) router_probs = torch.max(router_probs, dim=-1).values.unsqueeze(-1) @@ -935,7 +935,6 @@ class SwitchTransformersPreTrainedModel(PreTrainedModel): config_class = SwitchTransformersConfig base_model_prefix = "transformer" - is_parallelizable = True supports_gradient_checkpointing = True _no_split_modules = ["SwitchTransformersBlock"] @@ -1050,8 +1049,7 @@ def __init__(self, config, embed_tokens=None): # Initialize weights and apply final processing self.post_init() - # Model parallel - self.model_parallel = False + self.device_map = None self.gradient_checkpointing = False @@ -1077,10 +1075,6 @@ def forward( output_router_logits=None, return_dict=None, ): - # Model parallel - if self.model_parallel: - torch.cuda.set_device(self.first_device) - self.embed_tokens = self.embed_tokens.to(self.first_device) use_cache = use_cache if use_cache is not None else self.config.use_cache output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1157,24 +1151,7 @@ def forward( for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): layer_head_mask = head_mask[i] cross_attn_layer_head_mask = cross_attn_head_mask[i] - # Model parallel - if self.model_parallel: - torch.cuda.set_device(hidden_states.device) - # Ensure that attention_mask is always on the same device as hidden_states - if attention_mask is not None: - attention_mask = attention_mask.to(hidden_states.device) - if position_bias is not None: - position_bias = position_bias.to(hidden_states.device) - if encoder_hidden_states is not None: - encoder_hidden_states = encoder_hidden_states.to(hidden_states.device) - if encoder_extended_attention_mask is not None: - encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) - if encoder_decoder_position_bias is not None: - encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) - if layer_head_mask is not None: - layer_head_mask = layer_head_mask.to(hidden_states.device) - if cross_attn_layer_head_mask is not None: - cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device) + if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -1244,12 +1221,6 @@ def custom_forward(*inputs): if output_router_logits: all_router_probs = all_router_probs + (layer_outputs[-1],) - # Model Parallel: If it's the last layer for that device, put things on the next device - if self.model_parallel: - for k, v in self.device_map.items(): - if i == v[-1] and "cuda:" + str(k) != self.last_device: - hidden_states = hidden_states.to("cuda:" + str(k + 1)) - hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) @@ -1478,7 +1449,6 @@ def __init__(self, config: SwitchTransformersConfig): self.post_init() # Model parallel - self.model_parallel = False self.device_map = None def get_input_embeddings(self): @@ -1574,21 +1544,11 @@ def forward( last_hidden_state=encoder_outputs[0], hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + router_probs=encoder_outputs[3] if len(encoder_outputs) > 3 else None, ) hidden_states = encoder_outputs[0] - # Set device for model parallelism - if self.model_parallel: - torch.cuda.set_device(self.decoder.first_device) - hidden_states = hidden_states.to(self.decoder.first_device) - if decoder_input_ids is not None: - decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) - if attention_mask is not None: - attention_mask = attention_mask.to(self.decoder.first_device) - if decoder_attention_mask is not None: - decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) - # Decode decoder_outputs = self.decoder( input_ids=decoder_input_ids, @@ -1660,7 +1620,6 @@ def __init__(self, config: SwitchTransformersConfig): self.post_init() # Model parallel - self.model_parallel = False self.device_map = None def get_input_embeddings(self): @@ -1768,24 +1727,10 @@ def forward( hidden_states = encoder_outputs[0] - if self.model_parallel: - torch.cuda.set_device(self.decoder.first_device) - if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: # get decoder inputs from shifting lm labels to the right decoder_input_ids = self._shift_right(labels) - # Set device for model parallelism - if self.model_parallel: - torch.cuda.set_device(self.decoder.first_device) - hidden_states = hidden_states.to(self.decoder.first_device) - if decoder_input_ids is not None: - decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) - if attention_mask is not None: - attention_mask = attention_mask.to(self.decoder.first_device) - if decoder_attention_mask is not None: - decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) - # Decode decoder_outputs = self.decoder( input_ids=decoder_input_ids, @@ -1805,12 +1750,6 @@ def forward( sequence_output = decoder_outputs[0] - # Set device for model parallelism - if self.model_parallel: - torch.cuda.set_device(self.encoder.first_device) - self.lm_head = self.lm_head.to(self.encoder.first_device) - sequence_output = sequence_output.to(self.lm_head.weight.device) - if self.config.tie_word_embeddings: # Rescale output before projecting on vocab # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 @@ -1951,7 +1890,6 @@ def __init__(self, config: SwitchTransformersConfig): self.post_init() # Model parallel - self.model_parallel = False self.device_map = None def get_input_embeddings(self): diff --git a/tests/models/switch_transformers/test_modeling_switch_transformers.py b/tests/models/switch_transformers/test_modeling_switch_transformers.py index 2d8b114d82a50..76cc5bd38d998 100644 --- a/tests/models/switch_transformers/test_modeling_switch_transformers.py +++ b/tests/models/switch_transformers/test_modeling_switch_transformers.py @@ -30,6 +30,7 @@ import torch from transformers import ( + AutoTokenizer, SwitchTransformersEncoderModel, SwitchTransformersForConditionalGeneration, SwitchTransformersModel, @@ -94,7 +95,7 @@ def __init__( self.sparse_step = sparse_step def get_large_model_config(self): - return SwitchTransformersConfig.from_pretrained("switch_transformers-base") + return SwitchTransformersConfig.from_pretrained("HFLAY/switch_base_8") def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size) @@ -528,7 +529,7 @@ class SwitchTransformersModelTest(ModelTesterMixin, GenerationTesterMixin, unitt fx_compatible = False test_pruning = False test_resize_embeddings = True - test_model_parallel = True + test_model_parallel = False is_encoder_decoder = True # The small SWITCH_TRANSFORMERS model needs higher percentages for CPU/MP tests model_split_percents = [0.8, 0.9] @@ -970,7 +971,7 @@ def test_equivalency_token_chose_masked_router(self): ).t() ) - output = model(input_tokens, expert_capacity=expert_capacity) + expert_index, _, router_logits = model(input_tokens) expected_dispatch_mask = torch.Tensor( [ @@ -978,19 +979,15 @@ def test_equivalency_token_chose_masked_router(self): [[[True], [False]], [[False], [True]], [[False], [False]]], ] ) + - expected_combine_array = torch.Tensor( - [ - [[[0.5090], [0.0000]], [[0.0000], [0.5031]], [[0.0000], [0.0000]]], - [[[0.5024], [0.0000]], [[0.0000], [0.5071]], [[0.0000], [0.0000]]], - ] - ) + router_z_loss = router_z_loss_func(router_logits) + auxiliary_loss = load_balancing_loss_func(router_logits, torch.argmax(expert_index, dim=-1)) - self.assertAlmostEqual(output.auxiliary_loss.item(), 1.000308, places=5) - self.assertAlmostEqual(output.router_z_loss.item(), 0.4789799, places=5) + self.assertAlmostEqual(auxiliary_loss.item(), 1.000308, places=5) + self.assertAlmostEqual(router_z_loss.item(), 0.4789799, places=5) - self.assertTrue(torch.allclose(output.dispatch_mask, expected_dispatch_mask)) - self.assertTrue(torch.allclose(output.combine_array, expected_combine_array, atol=1e-4)) + self.assertTrue(torch.allclose(expert_index.bool().unsqueeze(-1), expected_dispatch_mask)) @require_torch @@ -1088,21 +1085,22 @@ def test_small_generate(self): model = SwitchTransformersForConditionalGeneration.from_pretrained( "HFLAY/switch_base_8", torch_dtype=torch.bfloat16 ).eval() - tokenizer = SwitchTransformersForConditionalGeneration.from_pretrained("t5-small") + tokenizer = AutoTokenizer.from_pretrained("t5-small") + input_ids = tokenizer("summarize: Hello world", return_tensors="pt").input_ids.to(torch_device) sequences = model.generate(input_ids) output_str = tokenizer.batch_decode(sequences, skip_special_tokens=True)[0] EXPECTED_OUTPUT = " . The best way to do it is to use a smartphone. Hello there" - self.assertisEqual(output_str, EXPECTED_OUTPUT) + self.assertEqual(output_str, EXPECTED_OUTPUT) input_ids = tokenizer( "The human walks into a bar and orders a ", return_tensors="pt" ).input_ids.to(torch_device) sequences = model.generate(input_ids) output_str = tokenizer.batch_decode(sequences, skip_special_tokens=True)[0] - self.assertisEqual(output_str, "drink.") + self.assertEqual(output_str, "drink.") input_ids = tokenizer( "A walks into a bar a orders a with pinch of .", @@ -1112,24 +1110,21 @@ def test_small_generate(self): output_str = tokenizer.batch_decode(sequences, skip_special_tokens=False)[0] EXPECTED_OUTPUT = " man beer a salt." - self.assertisEqual(output_str, EXPECTED_OUTPUT) - - def test_large_logits(self): - pass - - def test_small_logits_bf16(self): - pass + self.assertEqual(output_str, EXPECTED_OUTPUT) def test_small_batch_generate(self): - pass + BATCH_SIZE = 4 + model = SwitchTransformersForConditionalGeneration.from_pretrained( + "HFLAY/switch_base_8", torch_dtype=torch.bfloat16 + ).eval() + tokenizer = AutoTokenizer.from_pretrained("t5-small") - def test_large_batch_generate(self): - pass + inputs = ["A walks into a bar a orders a with pinch of ." ] * BATCH_SIZE + encoded_input = tokenizer.batch_encode_plus(inputs, return_tensors="pt") - @slow - def test_summarization(self): - pass + sequences = model.generate(**encoded_input) + batch_output = tokenizer.batch_decode(sequences, skip_special_tokens=False) + + for i in range(0, BATCH_SIZE, 2): + self.assertEqual(batch_output[i], batch_output[i+1]) - @slow - def test_translation_en_to_de(self): - pass From 805fe69bbd3db2c2341742e11a16b36b9874ff66 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 27 Oct 2022 10:54:47 +0000 Subject: [PATCH 039/102] Update code Co-authored-by: Younes Belkada --- src/transformers/modeling_outputs.py | 8 +- .../modeling_switch_transformers.py | 56 ++++++----- .../test_modeling_switch_transformers.py | 95 ++++--------------- 3 files changed, 47 insertions(+), 112 deletions(-) diff --git a/src/transformers/modeling_outputs.py b/src/transformers/modeling_outputs.py index e0d21f2b00662..5325cc550c781 100644 --- a/src/transformers/modeling_outputs.py +++ b/src/transformers/modeling_outputs.py @@ -796,10 +796,10 @@ class Seq2SeqMoEOutput(ModelOutput): loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None - encoder_total_z_loss: torch.FloatTensor = None - decoder_total_z_loss: torch.FloatTensor = None - encoder_total_aux_loss: torch.FloatTensor = None - decoder_total_aux_loss: torch.FloatTensor = None + encoder_z_loss: torch.FloatTensor = None + decoder_z_loss: torch.FloatTensor = None + encoder_aux_loss: torch.FloatTensor = None + decoder_aux_loss: torch.FloatTensor = None past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 20e770bcd2b06..9c2e7334c45aa 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -259,11 +259,11 @@ def _compute_router_probabilities( token_inputs *= uniform_distrib # Shape: [num_groups, tokens_per_group, num_experts] - self.classifier.to(self.dtype) + self.classifier = self.classifier.to(self.dtype) router_logits = self.classifier(token_inputs) # Apply Softmax and cast back to the original `dtype` - router_probabilities = torch.nn.Softmax(dim=-1)(router_logits).to(self.input_tokens_dtype) + router_probabilities = nn.functional.softmax(router_logits, dim=-1, dtype = self.dtype).to(self.input_tokens_dtype) return router_probabilities, router_logits def forward(self, token_inputs: torch.Tensor, apply_jitter: bool = True, **kwargs) -> Tuple: @@ -489,6 +489,7 @@ def forward(self, hidden_states): router_mask, router_probs, router_logits = self.router(hidden_states) expert_index = torch.argmax(router_mask, dim=-1) + # next_states = torch.zeros_like(hidden_states) for idx, expert in enumerate(self.experts.values()): # 1. Get the index of the tokens that are routed to the current expert @@ -1527,6 +1528,9 @@ def forward( warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) decoder_head_mask = head_mask + if output_router_logits and self.config.num_sparse_encoder_layers == 0 and self.config.num_sparse_encoder_layers == 0: + raise ValueError("You asked to return `output_router_logits` but the transformer in dense, and does\ + not contain any sparse MLP Layers. Set `output_router_logits = False` and restart") # Encode if needed (training, first prediction pass) if encoder_outputs is None: encoder_outputs = self.encoder( @@ -1758,48 +1762,42 @@ def forward( lm_logits = self.lm_head(sequence_output) loss = None - total_encoder_z_loss = None - total_encoder_aux_loss = None - total_decoder_z_loss = None - total_decoder_aux_loss = None + encoder_z_loss = None + encoder_aux_loss = None + decoder_z_loss = None + decoder_aux_loss = None if labels is not None: loss_fct = CrossEntropyLoss(ignore_index=-100) # todo check in the config if router loss enables + if output_router_logits: - # Calculate the router loss (z_loss + auxiliary loss) for each router in the encoder - total_encoder_z_loss = 0 - total_encoder_aux_loss = 0 - for router_tuple in encoder_outputs.router_probs: - if router_tuple[0] is not None: - total_encoder_z_loss += router_z_loss_func(router_tuple[0]) - total_encoder_aux_loss += load_balancing_loss_func(router_tuple[0], router_tuple[1]) - - total_decoder_z_loss = 0 - total_decoder_aux_loss = 0 - for router_tuple in decoder_outputs.router_probs: - if router_tuple[0] is not None: - total_decoder_z_loss += router_z_loss_func(router_tuple[0]) - total_decoder_aux_loss += load_balancing_loss_func(router_tuple[0], router_tuple[1]) + # Compute the router loss (z_loss + auxiliary loss) for each router in the encoder and decoder + encoder_router_logits, encoder_expert_indexes = encoder_outputs.router_probs + encoder_z_loss = router_z_loss_func(encoder_router_logits) + encoder_aux_loss = load_balancing_loss_func(encoder_router_logits,encoder_expert_indexes) + + decoder_router_logits, decoder_expert_indexes = decoder_outputs.router_probs + decoder_z_loss = router_z_loss_func(decoder_router_logits) + decoder_aux_loss = load_balancing_loss_func(decoder_router_logits,decoder_expert_indexes) loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 + if not return_dict: - output = ( - (lm_logits, total_encoder_z_loss, total_encoder_aux_loss, total_decoder_z_loss, total_decoder_aux_loss) - + decoder_outputs[1:] - + encoder_outputs - ) + output = (decoder_outputs[1:], encoder_outputs) + if output_router_logits: # only return the loss if they are not None + output = (lm_logits, encoder_z_loss, encoder_aux_loss, decoder_z_loss, decoder_aux_loss) + output return ((loss,) + output) if loss is not None else output return Seq2SeqMoEOutput( loss=loss, logits=lm_logits, - encoder_total_z_loss=total_encoder_z_loss, - encoder_total_aux_loss=total_encoder_aux_loss, - decoder_total_z_loss=total_decoder_z_loss, - decoder_total_aux_loss=total_decoder_aux_loss, + encoder_z_loss=encoder_z_loss, + encoder_aux_loss=encoder_aux_loss, + decoder_z_loss=decoder_z_loss, + decoder_aux_loss=decoder_aux_loss, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, diff --git a/tests/models/switch_transformers/test_modeling_switch_transformers.py b/tests/models/switch_transformers/test_modeling_switch_transformers.py index 76cc5bd38d998..4aa976f6ab0ac 100644 --- a/tests/models/switch_transformers/test_modeling_switch_transformers.py +++ b/tests/models/switch_transformers/test_modeling_switch_transformers.py @@ -999,85 +999,29 @@ def test_small_logits(self): and `transformers` implementation of Switch-C transformers. We only check the logits of the first batch. """ - model = SwitchTransformersForConditionalGeneration.from_pretrained( + model = SwitchTransformersModel.from_pretrained( "HFLAY/switch_base_8", torch_dtype=torch.bfloat16 ).eval() input_ids = torch.ones((32, 64), dtype=torch.long) decoder_input_ids = torch.ones((32, 64), dtype=torch.long) + # fmt: off EXPECTED_MEAN_LOGITS = torch.Tensor( [ - -29.330458, - -29.332455, - -29.333147, - -29.341417, - -29.472025, - -29.335613, - -29.47691, - -29.328053, - -29.328312, - -29.329872, - -29.336075, - -29.331112, - -29.30393, - -29.328972, - -29.33514, - -29.335201, - -29.317245, - -29.48052, - -29.328382, - -29.4837, - -29.489216, - -29.338572, - -29.331537, - -29.337881, - -29.497675, - -29.483559, - -29.497217, - -29.343832, - -29.483425, - -29.333313, - -29.49259, - -29.318579, - -29.478128, - -29.328222, - -29.339464, - -29.329647, - -29.339725, - -29.648586, - -29.312738, - -29.314232, - -29.330048, - -29.314402, - -29.329876, - -29.33895, - -29.337482, - -29.477829, - -29.482548, - -29.337194, - -29.487375, - -29.33446, - -29.340445, - -29.479067, - -29.333689, - -29.338657, - -29.339827, - -29.33101, - -29.482433, - -29.498121, - -29.477905, - -29.33606, - -29.333132, - -29.335573, - -29.482475, - -29.330212, - ], - ) + -0.204102, -0.193359, 0.523438, -0.296875, 0.108887, + 0.0211182, 0.605469, -0.100586, -0.0551758, 0.296875, + 0.0090332, 0.174805, 0.139648, -0.170898, -0.0981445, + 0.0245361, 0.0373535, 0.050293, -0.212891, 0.129883, + 0.390625, -0.203125, -0.122559, -0.180664, 0.0437012, + -0.349609, -0.0250244, -0.104004, -0.15918, -0.133789 + ] + ).to(torch.bfloat16) + # fmt: on - hf_logits = model(input_ids, decoder_input_ids=decoder_input_ids).logits - hf_logits = hf_logits.mean(dim=-1)[0] + hf_logits = model(input_ids, decoder_input_ids=decoder_input_ids).last_hidden_state + hf_logits = hf_logits[0,0,:30] - self.assertTrue(torch.testing.assert_allclose(hf_logits, EXPECTED_MEAN_LOGITS, rtol=1e-3, atol=1e-3)) + torch.testing.assert_allclose(hf_logits, EXPECTED_MEAN_LOGITS, rtol = 6e-3,atol=9e-3) def test_small_generate(self): # Generate test using the smalled switch-C model. @@ -1086,14 +1030,7 @@ def test_small_generate(self): "HFLAY/switch_base_8", torch_dtype=torch.bfloat16 ).eval() tokenizer = AutoTokenizer.from_pretrained("t5-small") - - - input_ids = tokenizer("summarize: Hello world", return_tensors="pt").input_ids.to(torch_device) - sequences = model.generate(input_ids) - output_str = tokenizer.batch_decode(sequences, skip_special_tokens=True)[0] - - EXPECTED_OUTPUT = " . The best way to do it is to use a smartphone. Hello there" - self.assertEqual(output_str, EXPECTED_OUTPUT) + model = model.to(torch_device) input_ids = tokenizer( "The human walks into a bar and orders a ", return_tensors="pt" @@ -1109,7 +1046,7 @@ def test_small_generate(self): sequences = model.generate(input_ids) output_str = tokenizer.batch_decode(sequences, skip_special_tokens=False)[0] - EXPECTED_OUTPUT = " man beer a salt." + EXPECTED_OUTPUT = " man beer a salt." self.assertEqual(output_str, EXPECTED_OUTPUT) def test_small_batch_generate(self): From 955a811f8a23a5284486ba2f5fb51373d71c8b23 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 27 Oct 2022 10:55:45 +0000 Subject: [PATCH 040/102] fixup --- .../configuration_switch_transformers.py | 4 +-- .../modeling_switch_transformers.py | 23 +++++++++------ .../test_modeling_switch_transformers.py | 28 +++++++++---------- 3 files changed, 29 insertions(+), 26 deletions(-) diff --git a/src/transformers/models/switch_transformers/configuration_switch_transformers.py b/src/transformers/models/switch_transformers/configuration_switch_transformers.py index 35710d01c4e71..a25b9ed4182a0 100644 --- a/src/transformers/models/switch_transformers/configuration_switch_transformers.py +++ b/src/transformers/models/switch_transformers/configuration_switch_transformers.py @@ -23,9 +23,7 @@ logger = logging.get_logger(__name__) SWITCH_TRANSFORMERS_PRETRAINED_CONFIG_ARCHIVE_MAP = { - "HFLAY/switch_base_8": ( - "https://huggingface.co/ybelkada/switch_transformers-base/resolve/main/config.json" - ), + "HFLAY/switch_base_8": "https://huggingface.co/ybelkada/switch_transformers-base/resolve/main/config.json", } diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 9c2e7334c45aa..4ba7dfc0725a0 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -263,7 +263,9 @@ def _compute_router_probabilities( router_logits = self.classifier(token_inputs) # Apply Softmax and cast back to the original `dtype` - router_probabilities = nn.functional.softmax(router_logits, dim=-1, dtype = self.dtype).to(self.input_tokens_dtype) + router_probabilities = nn.functional.softmax(router_logits, dim=-1, dtype=self.dtype).to( + self.input_tokens_dtype + ) return router_probabilities, router_logits def forward(self, token_inputs: torch.Tensor, apply_jitter: bool = True, **kwargs) -> Tuple: @@ -1528,9 +1530,15 @@ def forward( warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) decoder_head_mask = head_mask - if output_router_logits and self.config.num_sparse_encoder_layers == 0 and self.config.num_sparse_encoder_layers == 0: - raise ValueError("You asked to return `output_router_logits` but the transformer in dense, and does\ - not contain any sparse MLP Layers. Set `output_router_logits = False` and restart") + if ( + output_router_logits + and self.config.num_sparse_encoder_layers == 0 + and self.config.num_sparse_encoder_layers == 0 + ): + raise ValueError( + "You asked to return `output_router_logits` but the transformer in dense, and does " + " not contain any sparse MLP Layers. Set `output_router_logits = False` and restart" + ) # Encode if needed (training, first prediction pass) if encoder_outputs is None: encoder_outputs = self.encoder( @@ -1775,19 +1783,18 @@ def forward( # Compute the router loss (z_loss + auxiliary loss) for each router in the encoder and decoder encoder_router_logits, encoder_expert_indexes = encoder_outputs.router_probs encoder_z_loss = router_z_loss_func(encoder_router_logits) - encoder_aux_loss = load_balancing_loss_func(encoder_router_logits,encoder_expert_indexes) + encoder_aux_loss = load_balancing_loss_func(encoder_router_logits, encoder_expert_indexes) decoder_router_logits, decoder_expert_indexes = decoder_outputs.router_probs decoder_z_loss = router_z_loss_func(decoder_router_logits) - decoder_aux_loss = load_balancing_loss_func(decoder_router_logits,decoder_expert_indexes) + decoder_aux_loss = load_balancing_loss_func(decoder_router_logits, decoder_expert_indexes) loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 - if not return_dict: output = (decoder_outputs[1:], encoder_outputs) - if output_router_logits: # only return the loss if they are not None + if output_router_logits: # only return the loss if they are not None output = (lm_logits, encoder_z_loss, encoder_aux_loss, decoder_z_loss, decoder_aux_loss) + output return ((loss,) + output) if loss is not None else output diff --git a/tests/models/switch_transformers/test_modeling_switch_transformers.py b/tests/models/switch_transformers/test_modeling_switch_transformers.py index 4aa976f6ab0ac..930d9cf8127a7 100644 --- a/tests/models/switch_transformers/test_modeling_switch_transformers.py +++ b/tests/models/switch_transformers/test_modeling_switch_transformers.py @@ -979,7 +979,6 @@ def test_equivalency_token_chose_masked_router(self): [[[True], [False]], [[False], [True]], [[False], [False]]], ] ) - router_z_loss = router_z_loss_func(router_logits) auxiliary_loss = load_balancing_loss_func(router_logits, torch.argmax(expert_index, dim=-1)) @@ -999,29 +998,27 @@ def test_small_logits(self): and `transformers` implementation of Switch-C transformers. We only check the logits of the first batch. """ - model = SwitchTransformersModel.from_pretrained( - "HFLAY/switch_base_8", torch_dtype=torch.bfloat16 - ).eval() + model = SwitchTransformersModel.from_pretrained("HFLAY/switch_base_8", torch_dtype=torch.bfloat16).eval() input_ids = torch.ones((32, 64), dtype=torch.long) decoder_input_ids = torch.ones((32, 64), dtype=torch.long) # fmt: off EXPECTED_MEAN_LOGITS = torch.Tensor( [ - -0.204102, -0.193359, 0.523438, -0.296875, 0.108887, - 0.0211182, 0.605469, -0.100586, -0.0551758, 0.296875, - 0.0090332, 0.174805, 0.139648, -0.170898, -0.0981445, - 0.0245361, 0.0373535, 0.050293, -0.212891, 0.129883, - 0.390625, -0.203125, -0.122559, -0.180664, 0.0437012, - -0.349609, -0.0250244, -0.104004, -0.15918, -0.133789 + -0.204102, -0.193359, 0.523438, -0.296875, 0.108887, + 0.0211182, 0.605469, -0.100586, -0.0551758, 0.296875, + 0.0090332, 0.174805, 0.139648, -0.170898, -0.0981445, + 0.0245361, 0.0373535, 0.050293, -0.212891, 0.129883, + 0.390625, -0.203125, -0.122559, -0.180664, 0.0437012, + -0.349609, -0.0250244, -0.104004, -0.15918, -0.133789 ] ).to(torch.bfloat16) # fmt: on hf_logits = model(input_ids, decoder_input_ids=decoder_input_ids).last_hidden_state - hf_logits = hf_logits[0,0,:30] + hf_logits = hf_logits[0, 0, :30] - torch.testing.assert_allclose(hf_logits, EXPECTED_MEAN_LOGITS, rtol = 6e-3,atol=9e-3) + torch.testing.assert_allclose(hf_logits, EXPECTED_MEAN_LOGITS, rtol=6e-3, atol=9e-3) def test_small_generate(self): # Generate test using the smalled switch-C model. @@ -1056,12 +1053,13 @@ def test_small_batch_generate(self): ).eval() tokenizer = AutoTokenizer.from_pretrained("t5-small") - inputs = ["A walks into a bar a orders a with pinch of ." ] * BATCH_SIZE + inputs = [ + "A walks into a bar a orders a with pinch of ." + ] * BATCH_SIZE encoded_input = tokenizer.batch_encode_plus(inputs, return_tensors="pt") sequences = model.generate(**encoded_input) batch_output = tokenizer.batch_decode(sequences, skip_special_tokens=False) for i in range(0, BATCH_SIZE, 2): - self.assertEqual(batch_output[i], batch_output[i+1]) - + self.assertEqual(batch_output[i], batch_output[i + 1]) From 01527226246b8f18d46729c27e87d6f851706c50 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 27 Oct 2022 15:31:10 +0200 Subject: [PATCH 041/102] fix tests --- .../modeling_switch_transformers.py | 22 +++++++++++++++---- .../test_modeling_switch_transformers.py | 13 ++++------- 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 4ba7dfc0725a0..46b05e377311a 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -129,6 +129,7 @@ def router_z_loss_func(router_logits: torch.Tensor) -> float: return torch.sum(z_loss) / (num_groups * tokens_per_group) +# aux loss function def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.Tensor) -> float: r""" Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. @@ -152,6 +153,10 @@ def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.T # cast the expert indices to int64, otherwise one-hot encoding will fail if expert_indices.dtype != torch.int64: expert_indices = expert_indices.to(torch.int64) + + if len(expert_indices.shape) == 2: + expert_indices = expert_indices.unsqueeze(2) + expert_mask = torch.nn.functional.one_hot(expert_indices, num_experts) # For a given token, determine if it was routed to a given expert. @@ -1587,11 +1592,11 @@ def forward( decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, + decoder_router_logits=decoder_outputs.router_probs, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, encoder_router_logits=encoder_outputs.router_probs, - decoder_router_logits=decoder_outputs.router_probs, ) @@ -1793,11 +1798,20 @@ def forward( # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 if not return_dict: - output = (decoder_outputs[1:], encoder_outputs) + output = (lm_logits,) if output_router_logits: # only return the loss if they are not None - output = (lm_logits, encoder_z_loss, encoder_aux_loss, decoder_z_loss, decoder_aux_loss) + output - return ((loss,) + output) if loss is not None else output + output += ( + encoder_z_loss, + encoder_aux_loss, + decoder_z_loss, + decoder_aux_loss, + *decoder_outputs[1:], + *encoder_outputs, + ) + else: + output += (*decoder_outputs[1:], *encoder_outputs) + return ((loss,) + output) if loss is not None else output return Seq2SeqMoEOutput( loss=loss, logits=lm_logits, diff --git a/tests/models/switch_transformers/test_modeling_switch_transformers.py b/tests/models/switch_transformers/test_modeling_switch_transformers.py index 930d9cf8127a7..c2a657284328f 100644 --- a/tests/models/switch_transformers/test_modeling_switch_transformers.py +++ b/tests/models/switch_transformers/test_modeling_switch_transformers.py @@ -972,23 +972,18 @@ def test_equivalency_token_chose_masked_router(self): ) expert_index, _, router_logits = model(input_tokens) - - expected_dispatch_mask = torch.Tensor( - [ - [[[True], [False]], [[False], [True]], [[False], [False]]], - [[[True], [False]], [[False], [True]], [[False], [False]]], - ] - ) + router_probs = torch.softmax(router_logits, dim=-1) router_z_loss = router_z_loss_func(router_logits) - auxiliary_loss = load_balancing_loss_func(router_logits, torch.argmax(expert_index, dim=-1)) + auxiliary_loss = load_balancing_loss_func(router_probs, torch.argmax(expert_index, dim=-1)) self.assertAlmostEqual(auxiliary_loss.item(), 1.000308, places=5) self.assertAlmostEqual(router_z_loss.item(), 0.4789799, places=5) - self.assertTrue(torch.allclose(expert_index.bool().unsqueeze(-1), expected_dispatch_mask)) + # self.assertTrue(torch.allclose(expert_index.bool().unsqueeze(-1), expected_dispatch_mask)) +@slow @require_torch @require_tokenizers class SwitchTransformerModelIntegrationTests(unittest.TestCase): From 5bcf84f9c3cb84936c38f45bc19a54469d074891 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 27 Oct 2022 15:47:39 +0200 Subject: [PATCH 042/102] fix doc --- docs/source/en/model_doc/switch_transformers.mdx | 6 ------ 1 file changed, 6 deletions(-) diff --git a/docs/source/en/model_doc/switch_transformers.mdx b/docs/source/en/model_doc/switch_transformers.mdx index 64979a94ba86e..3603d25b9188e 100644 --- a/docs/source/en/model_doc/switch_transformers.mdx +++ b/docs/source/en/model_doc/switch_transformers.mdx @@ -49,19 +49,13 @@ The original code can be found [here](). [[autodoc]] SwitchTransformersModel - forward - - parallelize - - deparallelize ## SwitchTransformersForConditionalGeneration [[autodoc]] SwitchTransformersForConditionalGeneration - forward - - parallelize - - deparallelize ## SwitchTransformersEncoderModel [[autodoc]] SwitchTransformersEncoderModel - forward - - parallelize - - deparallelize From aeff41c7f3013b40ea79d5d48fa12dd3645e6a4f Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 27 Oct 2022 15:57:31 +0200 Subject: [PATCH 043/102] fix doc + tokenization --- docs/source/en/_toctree.yml | 2 ++ .../tokenization_switch_transformers.py | 30 +++++++++++-------- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index f8866e04d75e0..44b26d73af688 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -343,6 +343,8 @@ title: Splinter - local: model_doc/squeezebert title: SqueezeBERT + - local: model_doc/switch_transformers + title: SwitchTransformers - local: model_doc/t5 title: T5 - local: model_doc/t5v1.1 diff --git a/src/transformers/models/switch_transformers/tokenization_switch_transformers.py b/src/transformers/models/switch_transformers/tokenization_switch_transformers.py index a2652b5008435..547ec81af9aa3 100644 --- a/src/transformers/models/switch_transformers/tokenization_switch_transformers.py +++ b/src/transformers/models/switch_transformers/tokenization_switch_transformers.py @@ -33,24 +33,30 @@ PRETRAINED_VOCAB_FILES_MAP = { "vocab_file": { - "ybelkada/switch_transformers-base": ( - "https://huggingface.co/ybelkada/switch_transformers-base/resolve/main/spiece.model" - ), - "switch_transformers-base": "https://huggingface.co/switch_transformers-base/resolve/main/spiece.model", - "switch_transformers-large": "https://huggingface.co/switch_transformers-large/resolve/main/spiece.model", - "switch_transformers-3b": "https://huggingface.co/switch_transformers-3b/resolve/main/spiece.model", - "switch_transformers-11b": "https://huggingface.co/switch_transformers-11b/resolve/main/spiece.model", + "google/switch-base-8": "https://huggingface.co/google/switch-base-8/resolve/main/spiece.model", + "google/switch-base-16": "https://huggingface.co/switch-base-16/resolve/main/spiece.model", + "google/switch-base-32": "https://huggingface.co/google/switch-base-32/resolve/main/spiece.model", + "google/switch-base-64": "https://huggingface.co/google/switch-base-64/resolve/main/spiece.model", + "google/switch-base-128": "https://huggingface.co/google/switch-base-128/resolve/main/spiece.model", + "google/switch-base-256": "https://huggingface.co/google/switch-base-256/resolve/main/spiece.model", + "google/switch-large-128": "https://huggingface.co/google/switch-large-128/resolve/main/spiece.model", + "google/switch-xxl-128": "https://huggingface.co/google/switch-xxl-128/resolve/main/spiece.model", + "google/switch-c-2048": "https://huggingface.co/google/switch-c-2048/resolve/main/spiece.model", } } # TODO(PVP) - this should be removed in Transformers v5 PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { - "ybelkada/switch_transformers-base": 512, - "switch_transformers-base": 512, - "switch_transformers-large": 512, - "switch_transformers-3b": 512, - "switch_transformers-11b": 512, + "google/switch-base-8": 512, + "google/switch-base-16": 512, + "google/switch-base-32": 512, + "google/switch-base-64": 512, + "google/switch-base-128": 512, + "google/switch-base-256": 512, + "google/switch-large-128": 512, + "google/switch-xxl-128": 512, + "google/switch-c-2048": 512, } From 58b9426d7abb264d8007b05db90dffea78a331f3 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 27 Oct 2022 16:16:53 +0200 Subject: [PATCH 044/102] fix tokenizer test --- src/transformers/convert_slow_tokenizer.py | 1 + .../configuration_switch_transformers.py | 5 +- .../tokenization_switch_transformers.py | 48 ++++++++++------- .../tokenization_switch_transformers_fast.py | 53 ++++++++++++------- .../test_modeling_switch_transformers.py | 2 +- .../test_tokenization_switch_transformers.py | 6 +-- 6 files changed, 71 insertions(+), 44 deletions(-) diff --git a/src/transformers/convert_slow_tokenizer.py b/src/transformers/convert_slow_tokenizer.py index ce52ba3b3beba..6520414e3821e 100644 --- a/src/transformers/convert_slow_tokenizer.py +++ b/src/transformers/convert_slow_tokenizer.py @@ -1127,6 +1127,7 @@ def converted(self) -> Tokenizer: "RoFormerTokenizer": RoFormerConverter, "SqueezeBertTokenizer": BertConverter, "T5Tokenizer": T5Converter, + "SwitchTransformersTokenizer": T5Converter, "XLMRobertaTokenizer": XLMRobertaConverter, "XLNetTokenizer": XLNetConverter, "SplinterTokenizer": SplinterConverter, diff --git a/src/transformers/models/switch_transformers/configuration_switch_transformers.py b/src/transformers/models/switch_transformers/configuration_switch_transformers.py index a25b9ed4182a0..02d99d099322a 100644 --- a/src/transformers/models/switch_transformers/configuration_switch_transformers.py +++ b/src/transformers/models/switch_transformers/configuration_switch_transformers.py @@ -23,7 +23,7 @@ logger = logging.get_logger(__name__) SWITCH_TRANSFORMERS_PRETRAINED_CONFIG_ARCHIVE_MAP = { - "HFLAY/switch_base_8": "https://huggingface.co/ybelkada/switch_transformers-base/resolve/main/config.json", + "HFLAY/switch_base_8": "https://huggingface.co/HFLAY/switch_base_8/blob/main/config.json", } @@ -32,8 +32,7 @@ class SwitchTransformersConfig(PretrainedConfig): This is the configuration class to store the configuration of a [`SwitchTransformersModel`]. It is used to instantiate a SwitchTransformers model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the - SwitchTransformers [ybelkada/switch_transformers-base](https://huggingface.co/ybelkada/switch_transformers-base) - architecture. + SwitchTransformers [google/switch-base-8](https://huggingface.co/google/switch-base-8) architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. diff --git a/src/transformers/models/switch_transformers/tokenization_switch_transformers.py b/src/transformers/models/switch_transformers/tokenization_switch_transformers.py index 547ec81af9aa3..09d9a3c8d57e8 100644 --- a/src/transformers/models/switch_transformers/tokenization_switch_transformers.py +++ b/src/transformers/models/switch_transformers/tokenization_switch_transformers.py @@ -31,32 +31,44 @@ VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"} +# PRETRAINED_VOCAB_FILES_MAP = { +# "vocab_file": { +# "google/switch-base-8": "https://huggingface.co/google/switch-base-8/resolve/main/spiece.model", +# "google/switch-base-16": "https://huggingface.co/switch-base-16/resolve/main/spiece.model", +# "google/switch-base-32": "https://huggingface.co/google/switch-base-32/resolve/main/spiece.model", +# "google/switch-base-64": "https://huggingface.co/google/switch-base-64/resolve/main/spiece.model", +# "google/switch-base-128": "https://huggingface.co/google/switch-base-128/resolve/main/spiece.model", +# "google/switch-base-256": "https://huggingface.co/google/switch-base-256/resolve/main/spiece.model", +# "google/switch-large-128": "https://huggingface.co/google/switch-large-128/resolve/main/spiece.model", +# "google/switch-xxl-128": "https://huggingface.co/google/switch-xxl-128/resolve/main/spiece.model", +# "google/switch-c-2048": "https://huggingface.co/google/switch-c-2048/resolve/main/spiece.model", +# } +# } + + +# # TODO(PVP) - this should be removed in Transformers v5 +# PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { +# "google/switch-base-8": 512, +# "google/switch-base-16": 512, +# "google/switch-base-32": 512, +# "google/switch-base-64": 512, +# "google/switch-base-128": 512, +# "google/switch-base-256": 512, +# "google/switch-large-128": 512, +# "google/switch-xxl-128": 512, +# "google/switch-c-2048": 512, +# } + PRETRAINED_VOCAB_FILES_MAP = { "vocab_file": { - "google/switch-base-8": "https://huggingface.co/google/switch-base-8/resolve/main/spiece.model", - "google/switch-base-16": "https://huggingface.co/switch-base-16/resolve/main/spiece.model", - "google/switch-base-32": "https://huggingface.co/google/switch-base-32/resolve/main/spiece.model", - "google/switch-base-64": "https://huggingface.co/google/switch-base-64/resolve/main/spiece.model", - "google/switch-base-128": "https://huggingface.co/google/switch-base-128/resolve/main/spiece.model", - "google/switch-base-256": "https://huggingface.co/google/switch-base-256/resolve/main/spiece.model", - "google/switch-large-128": "https://huggingface.co/google/switch-large-128/resolve/main/spiece.model", - "google/switch-xxl-128": "https://huggingface.co/google/switch-xxl-128/resolve/main/spiece.model", - "google/switch-c-2048": "https://huggingface.co/google/switch-c-2048/resolve/main/spiece.model", + "HFLAY/switch_base_8": "https://huggingface.co/HFLAY/switch_base_8/blob/main/spiece.model", } } # TODO(PVP) - this should be removed in Transformers v5 PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { - "google/switch-base-8": 512, - "google/switch-base-16": 512, - "google/switch-base-32": 512, - "google/switch-base-64": 512, - "google/switch-base-128": 512, - "google/switch-base-256": 512, - "google/switch-large-128": 512, - "google/switch-xxl-128": 512, - "google/switch-c-2048": 512, + "HFLAY/switch_base_8": 512, } diff --git a/src/transformers/models/switch_transformers/tokenization_switch_transformers_fast.py b/src/transformers/models/switch_transformers/tokenization_switch_transformers_fast.py index 758498ac2f632..c1a9ca9d44070 100644 --- a/src/transformers/models/switch_transformers/tokenization_switch_transformers_fast.py +++ b/src/transformers/models/switch_transformers/tokenization_switch_transformers_fast.py @@ -34,35 +34,50 @@ VOCAB_FILES_NAMES = {"vocab_file": "spiece.model", "tokenizer_file": "tokenizer.json"} +# PRETRAINED_VOCAB_FILES_MAP = { +# "vocab_file": { +# "ybelkada/switch_transformers-base": ( +# "https://huggingface.co/ybelkada/switch_transformers-base/resolve/main/spiece.model" +# ), +# "switch_transformers-base": "https://huggingface.co/switch_transformers-base/resolve/main/spiece.model", +# "switch_transformers-large": "https://huggingface.co/switch_transformers-large/resolve/main/spiece.model", +# "switch_transformers-3b": "https://huggingface.co/switch_transformers-3b/resolve/main/spiece.model", +# "switch_transformers-11b": "https://huggingface.co/switch_transformers-11b/resolve/main/spiece.model", +# }, +# "tokenizer_file": { +# "ybelkada/switch_transformers-base": ( +# "https://huggingface.co/ybelkada/switch_transformers-base/resolve/main/tokenizer.json" +# ), +# "switch_transformers-base": "https://huggingface.co/switch_transformers-base/resolve/main/tokenizer.json", +# "switch_transformers-large": "https://huggingface.co/switch_transformers-large/resolve/main/tokenizer.json", +# "switch_transformers-3b": "https://huggingface.co/switch_transformers-3b/resolve/main/tokenizer.json", +# "switch_transformers-11b": "https://huggingface.co/switch_transformers-11b/resolve/main/tokenizer.json", +# }, +# } + + +# # TODO(PVP) - this should be removed in Transformers v5 +# PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { +# "ybelkada/switch_transformers-base": 512, +# "switch_transformers-base": 512, +# "switch_transformers-large": 512, +# "switch_transformers-3b": 512, +# "switch_transformers-11b": 512, +# } + PRETRAINED_VOCAB_FILES_MAP = { "vocab_file": { - "ybelkada/switch_transformers-base": ( - "https://huggingface.co/ybelkada/switch_transformers-base/resolve/main/spiece.model" - ), - "switch_transformers-base": "https://huggingface.co/switch_transformers-base/resolve/main/spiece.model", - "switch_transformers-large": "https://huggingface.co/switch_transformers-large/resolve/main/spiece.model", - "switch_transformers-3b": "https://huggingface.co/switch_transformers-3b/resolve/main/spiece.model", - "switch_transformers-11b": "https://huggingface.co/switch_transformers-11b/resolve/main/spiece.model", + "HFLAY/switch_base_8": "https://huggingface.co/HFLAY/switch_base_8/blob/main/spiece.model", }, "tokenizer_file": { - "ybelkada/switch_transformers-base": ( - "https://huggingface.co/ybelkada/switch_transformers-base/resolve/main/tokenizer.json" - ), - "switch_transformers-base": "https://huggingface.co/switch_transformers-base/resolve/main/tokenizer.json", - "switch_transformers-large": "https://huggingface.co/switch_transformers-large/resolve/main/tokenizer.json", - "switch_transformers-3b": "https://huggingface.co/switch_transformers-3b/resolve/main/tokenizer.json", - "switch_transformers-11b": "https://huggingface.co/switch_transformers-11b/resolve/main/tokenizer.json", + "HFLAY/switch_base_8": "https://huggingface.co/HFLAY/switch_base_8/blob/main/tokenizer.json", }, } # TODO(PVP) - this should be removed in Transformers v5 PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { - "ybelkada/switch_transformers-base": 512, - "switch_transformers-base": 512, - "switch_transformers-large": 512, - "switch_transformers-3b": 512, - "switch_transformers-11b": 512, + "HFLAY/switch_base_8": 512, } diff --git a/tests/models/switch_transformers/test_modeling_switch_transformers.py b/tests/models/switch_transformers/test_modeling_switch_transformers.py index c2a657284328f..377fdbe3fe7be 100644 --- a/tests/models/switch_transformers/test_modeling_switch_transformers.py +++ b/tests/models/switch_transformers/test_modeling_switch_transformers.py @@ -727,7 +727,7 @@ def __init__( self.is_training = is_training def get_large_model_config(self): - return SwitchTransformersConfig.from_pretrained("switch_transformers-base") + return SwitchTransformersConfig.from_pretrained("switch_base_8") def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size) diff --git a/tests/models/switch_transformers/test_tokenization_switch_transformers.py b/tests/models/switch_transformers/test_tokenization_switch_transformers.py index 3a9fd41a9f708..6ca18416a7d35 100644 --- a/tests/models/switch_transformers/test_tokenization_switch_transformers.py +++ b/tests/models/switch_transformers/test_tokenization_switch_transformers.py @@ -143,11 +143,11 @@ def test_full_tokenizer(self): @cached_property def switch_transformers_base_tokenizer(self): - return SwitchTransformersTokenizer.from_pretrained("switch_transformers-base") + return SwitchTransformersTokenizer.from_pretrained("HFLAY/switch_base_8") @cached_property def switch_transformers_base_tokenizer_fast(self): - return SwitchTransformersTokenizerFast.from_pretrained("switch_transformers-base") + return SwitchTransformersTokenizerFast.from_pretrained("HFLAY/switch_base_8") def get_tokenizer(self, **kwargs) -> SwitchTransformersTokenizer: return self.tokenizer_class.from_pretrained(self.tmpdirname, pad_token=None, **kwargs) @@ -382,6 +382,6 @@ def test_tokenizer_integration(self): self.tokenizer_integration_test_util( expected_encoding=expected_encoding, - model_name="switch_transformers-base", + model_name="switch_base_8", revision="5a7ff2d8f5117c194c7e32ec1ccbf04642cca99b", ) From ba8bf87b05b7acc693c143d6b93b50e7c652be9e Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 27 Oct 2022 17:26:59 +0200 Subject: [PATCH 045/102] fix test --- .../switch_transformers/test_modeling_switch_transformers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/switch_transformers/test_modeling_switch_transformers.py b/tests/models/switch_transformers/test_modeling_switch_transformers.py index 377fdbe3fe7be..4badfdcdd2023 100644 --- a/tests/models/switch_transformers/test_modeling_switch_transformers.py +++ b/tests/models/switch_transformers/test_modeling_switch_transformers.py @@ -866,6 +866,7 @@ def test_defaulting_to_symmetry(self): assert len(model.decoder.block) == len(model.encoder.block) == 2 +@require_torch class SwitchTransformerRouterTest(unittest.TestCase): r""" Switch Transformers has different blocks from classic transformer based models. From 41dadd187f43b5ed38294f4ec948a15e6e652e54 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 28 Oct 2022 09:37:21 +0200 Subject: [PATCH 046/102] fix loss output --- .../modeling_switch_transformers.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 46b05e377311a..035983dbe508c 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -1786,11 +1786,15 @@ def forward( if output_router_logits: # Compute the router loss (z_loss + auxiliary loss) for each router in the encoder and decoder - encoder_router_logits, encoder_expert_indexes = encoder_outputs.router_probs + encoder_router_logits, encoder_expert_indexes = self._unpack_router_logits( + encoder_outputs.router_probs + ) encoder_z_loss = router_z_loss_func(encoder_router_logits) encoder_aux_loss = load_balancing_loss_func(encoder_router_logits, encoder_expert_indexes) - decoder_router_logits, decoder_expert_indexes = decoder_outputs.router_probs + decoder_router_logits, decoder_expert_indexes = self._unpack_router_logits( + decoder_outputs.router_probs + ) decoder_z_loss = router_z_loss_func(decoder_router_logits) decoder_aux_loss = load_balancing_loss_func(decoder_router_logits, decoder_expert_indexes) @@ -1830,6 +1834,16 @@ def forward( decoder_router_logits=decoder_outputs.router_probs, ) + def _unpack_router_logits(self, router_outputs): + total_router_logits = [] + total_expert_indexes = [] + for router_output in router_outputs: + if router_output[0] is not None: + router_logits, expert_indexes = router_output + total_router_logits.append(router_logits) + total_expert_indexes.append(expert_indexes) + return torch.cat(total_router_logits, dim=1), torch.cat(total_expert_indexes, dim=1) + def prepare_inputs_for_generation( self, input_ids, From e8bff00b21200eecf4b391a83408b4e0742f177d Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 28 Oct 2022 07:43:58 +0000 Subject: [PATCH 047/102] update code for backward pass --- .../modeling_switch_transformers.py | 110 +++++------------- 1 file changed, 26 insertions(+), 84 deletions(-) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 46b05e377311a..b101af3b864b0 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -60,54 +60,16 @@ SWITCH_TRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST = [ "google/switch-base-8", "google/switch-base-16", + "google/switch-base-32", + "google/switch-base-64", + "google/switch-base-128", + "google/switch-base-256", + "google/switch-large-128", + "google/switch-xxl-128", + "google/switch-c-2048", # See all SwitchTransformers models at https://huggingface.co/models?filter=switch_transformers ] - -def _one_hot(tensor, num_classes, axis=-1, dtype=torch.bool): - r""" - This function mimics the behavior of jax.nn.functional.one_hot in PyTorch. It takes a tensor of indices, the number - of desired classes and the axis to apply the one-hot encoding. If any value is outside the range [0, num_classes), - it will be set to zeros. - - Args: - tensor (`torch.Tensor`): - Input tensor - num_classes (`int`): - Number of classes to process for one hot encoding - axis (`int`, *optional*): - The lookup axis to check for one-hot encoding - dtype (`torch.dtype`, *optional*): - Output `dtype`. The one hot encoded vector will be casted to this dtype - """ - if tensor.is_floating_point(): - raise "Input tensor for one hot encoding must be an `int32` or `int64`" - - if axis >= len(tensor.shape): - raise "Axis is out of bounds" - - if axis == -1: - axis = len(tensor.shape) - elif axis < -1: - raise "Axis must be greater than -1" - else: - axis = axis + 1 - - # Get the final output shape - output_shape = list(tensor.shape) - output_shape.insert(axis, num_classes) - - # Create an empty output of zeros - out = torch.zeros(tuple(output_shape), dtype=dtype) - - # Mask out the places where it is outside the range [0, num_classes) - # kudos to GitHub copilot for this line - mask = (tensor >= 0) & (tensor < num_classes) - out[mask, tensor[mask]] = 1 - - return out.to(tensor.device) - - # Router loss def router_z_loss_func(router_logits: torch.Tensor) -> float: r""" @@ -190,24 +152,24 @@ class RouterOutput: router_probs: torch.FloatTensor = None -class TokensChooseMaskedRouter(nn.Module): +class SwitchTransformersTop1Router(nn.Module): """ - Masked matmul router using tokens choose top-k experts assignment. + Masked matmul router using tokens choose top-1 experts assignment. This router uses the same mechanism as in Switch Transformer (https://arxiv.org/abs/2101.03961) and V-MoE (https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are sorted by router_probs and then routed to their choice of expert until the expert's expert_capacity is reached. There is no guarantee that each token is processed by an expert, or that each expert receives at least one token. - Attributes: - num_selected_experts (`int`): - Maximum number of experts to which each token is routed. Tokens may be routed to fewer experts if particular - experts are oversubscribed / reach capacity. - batch_prioritized_routing (`bool`): - Whether or not to use Batch Prioritized Routing (BPR), originally introduced in V-MoE - (https://arxiv.org/abs/2106.05974). With BPR, we prioritize routing those top-k tokens with the highest router - probability, rather than simply using each tokens left-to-right ordering in the batch. This prioritization is - important because the experts have limited capacity. + Parameters: + num_selected_experts (`int`): + Maximum number of experts to which each token is routed. Tokens may be routed to fewer experts if particular + experts are oversubscribed / reach capacity. + batch_prioritized_routing (`bool`): + Whether or not to use Batch Prioritized Routing (BPR), originally introduced in V-MoE + (https://arxiv.org/abs/2106.05974). With BPR, we prioritize routing those top-k tokens with the highest router + probability, rather than simply using each tokens left-to-right ordering in the batch. This prioritization is + important because the experts have limited capacity. """ def __init__(self, config, **kwargs): @@ -286,8 +248,9 @@ def forward(self, token_inputs: torch.Tensor, apply_jitter: bool = True, **kwarg Args: Computes dispatch and combine torch.Tensors for routing to experts. token_inputs: [num_groups, tokens_per_group, hidden_dim] inputs to send to experts. num_experts: - Number of experts. expert_capacity: Each group will send this many tokens to each expert. apply_jitter: If - true, apply jitter noise during routing. + Number of experts. expert_capacity: Each group will send this many tokens to each expert. + apply_jitter: + If true, apply jitter noise during routing. Returns: Router indices or mask torch.Tensors (depending on router type). """ @@ -306,32 +269,12 @@ def forward(self, token_inputs: torch.Tensor, apply_jitter: bool = True, **kwarg else: padding_mask = None - expert_index = self._compute_routing_instructions(router_probs, **kwargs) + expert_index = torch.argmax(router_probs, dim=-1) expert_index = torch.nn.functional.one_hot(expert_index, num_classes=self.num_experts) router_probs = torch.max(router_probs, dim=-1).values.unsqueeze(-1) - # router_z_loss = router_z_loss_func(router_logits) return expert_index, router_probs, router_logits - def _compute_routing_instructions( - self, - router_probs: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Computes masks for the top-k experts per token. - - Args: - router_probs (`torch.Tensor`): - Router raw probabilities tensor of shape [num_groups, tokens_per_group, num_experts] used to determine - the routing of tokens to the experts. - Returns: - Dispatch and combine arrays for routing with masked matmuls. - """ - # Top-k router probability and corresponding expert indices for each token. - # Shape: [num_groups, tokens_per_group, num_selected_experts]. - expert_index = torch.argmax(router_probs, dim=-1) - return expert_index - # Copied from transformers.models.t5.modeling_t5.T5LayerNorm with T5->SwitchTransformers class SwitchTransformersLayerNorm(nn.Module): @@ -470,9 +413,8 @@ def _get_router(self, config): In total the list of supported Routers are the following: """ - # TODO, use a ALL_ROUTER_TYPE map instead of havind all the ifs? then just if None raise error. if config.router_type.lower() == "tokens_masked": - return TokensChooseMaskedRouter(config) + return SwitchTransformersTop1Router(config) else: raise NotImplementedError( f"{config.router_type.lower()} not implemented ! Please chose a router in " @@ -496,7 +438,7 @@ def forward(self, hidden_states): router_mask, router_probs, router_logits = self.router(hidden_states) expert_index = torch.argmax(router_mask, dim=-1) - # next_states = torch.zeros_like(hidden_states) + next_states = hidden_states.clone() for idx, expert in enumerate(self.experts.values()): # 1. Get the index of the tokens that are routed to the current expert @@ -504,9 +446,9 @@ def forward(self, hidden_states): token_indices = router_mask[:, :, idx].bool() # 2. Update only the hidden states affected by the routing - hidden_states[token_indices] = expert(hidden_states[token_indices]) + next_states[token_indices] = expert(hidden_states[token_indices]) - hidden_states = router_probs * hidden_states + hidden_states = router_probs * next_states return hidden_states, (router_logits, expert_index) From 84a6447ec82ba8dc6bef46a004488616f2aeb424 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 28 Oct 2022 10:07:12 +0200 Subject: [PATCH 048/102] add loss support --- .../configuration_switch_transformers.py | 9 ++++++ .../modeling_switch_transformers.py | 28 ++++++++++++------- 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/switch_transformers/configuration_switch_transformers.py b/src/transformers/models/switch_transformers/configuration_switch_transformers.py index 02d99d099322a..722f6d8c67678 100644 --- a/src/transformers/models/switch_transformers/configuration_switch_transformers.py +++ b/src/transformers/models/switch_transformers/configuration_switch_transformers.py @@ -89,6 +89,10 @@ class SwitchTransformersConfig(PretrainedConfig): The ratio for all dropout layers. layer_norm_eps (`float`, *optional*, defaults to 1e-6): The epsilon used by the layer normalization layers. + router_z_loss_coef (`float`, *optional*, defaults to 0.001): + The z loss factor for the total loss. + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. initializer_factor (`float`, *optional*, defaults to 1): A factor for initializing all weight matrices (should be kept to 1, used internally for initialization testing). @@ -125,6 +129,8 @@ def __init__( relative_attention_num_buckets=32, relative_attention_max_distance=128, dropout_rate=0.1, + router_z_loss_coef=0.001, + router_aux_loss_coef=0.001, layer_norm_epsilon=1e-6, initializer_factor=1.0, feed_forward_proj="relu", @@ -185,6 +191,9 @@ def __init__( self.use_cache = use_cache self.add_router_probs = add_router_probs + self.router_z_loss_coef = router_z_loss_coef + self.router_aux_loss_coef = router_aux_loss_coef + act_info = self.feed_forward_proj.split("-") self.dense_act_fn = act_info[-1] self.is_gated_act = act_info[0] == "gated" diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index c6667f88d0011..3a09abbf8ab82 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -70,6 +70,7 @@ # See all SwitchTransformers models at https://huggingface.co/models?filter=switch_transformers ] + # Router loss def router_z_loss_func(router_logits: torch.Tensor) -> float: r""" @@ -163,13 +164,13 @@ class SwitchTransformersTop1Router(nn.Module): Parameters: num_selected_experts (`int`): - Maximum number of experts to which each token is routed. Tokens may be routed to fewer experts if particular - experts are oversubscribed / reach capacity. + Maximum number of experts to which each token is routed. Tokens may be routed to fewer experts if + particular experts are oversubscribed / reach capacity. batch_prioritized_routing (`bool`): Whether or not to use Batch Prioritized Routing (BPR), originally introduced in V-MoE - (https://arxiv.org/abs/2106.05974). With BPR, we prioritize routing those top-k tokens with the highest router - probability, rather than simply using each tokens left-to-right ordering in the batch. This prioritization is - important because the experts have limited capacity. + (https://arxiv.org/abs/2106.05974). With BPR, we prioritize routing those top-k tokens with the highest + router probability, rather than simply using each tokens left-to-right ordering in the batch. This + prioritization is important because the experts have limited capacity. """ def __init__(self, config, **kwargs): @@ -780,7 +781,7 @@ def forward( past_key_value=None, use_cache=False, output_attentions=False, - output_router_logits=False, + output_router_logits=True, return_dict=True, ): @@ -1022,7 +1023,7 @@ def forward( use_cache=None, output_attentions=None, output_hidden_states=None, - output_router_logits=None, + output_router_logits=True, return_dict=None, ): use_cache = use_cache if use_cache is not None else self.config.use_cache @@ -1575,6 +1576,9 @@ def __init__(self, config: SwitchTransformersConfig): self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) + self.router_z_loss_coef = config.router_z_loss_coef + self.router_aux_loss_coef = config.router_aux_loss_coef + # Initialize weights and apply final processing self.post_init() @@ -1620,7 +1624,7 @@ def forward( use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - output_router_logits: Optional[bool] = None, + output_router_logits: Optional[bool] = True, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqMoEOutput]: r""" @@ -1741,7 +1745,11 @@ def forward( decoder_aux_loss = load_balancing_loss_func(decoder_router_logits, decoder_expert_indexes) loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) - # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 + + if output_router_logits and labels is not None: + z_loss = self.router_z_loss_coef * (encoder_z_loss + decoder_z_loss) + aux_loss = self.router_aux_loss_coef * (encoder_aux_loss + decoder_aux_loss) + loss = loss + z_loss + aux_loss if not return_dict: output = (lm_logits,) @@ -1895,7 +1903,7 @@ def forward( inputs_embeds: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - output_router_logits: Optional[bool] = None, + output_router_logits: Optional[bool] = True, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.FloatTensor], MoEModelOutput]: r""" From 5890744990d8e47c1c2ce2ab45f43d74542b3ec6 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 28 Oct 2022 08:33:31 +0000 Subject: [PATCH 049/102] update documentation --- README.md | 2 +- README_es.md | 2 +- README_ko.md | 2 +- README_zh-hans.md | 2 +- README_zh-hant.md | 2 +- docs/source/en/index.mdx | 2 +- .../en/model_doc/switch_transformers.mdx | 32 +- .../models/switch_transformers/router_flax.py | 731 ------------------ .../tokenization_switch_transformers.py | 363 --------- .../tokenization_switch_transformers_fast.py | 260 ------- .../test_tokenization_switch_transformers.py | 387 ---------- 11 files changed, 24 insertions(+), 1761 deletions(-) delete mode 100644 src/transformers/models/switch_transformers/router_flax.py delete mode 100644 src/transformers/models/switch_transformers/tokenization_switch_transformers.py delete mode 100644 src/transformers/models/switch_transformers/tokenization_switch_transformers_fast.py delete mode 100644 tests/models/switch_transformers/test_tokenization_switch_transformers.py diff --git a/README.md b/README.md index cee378c477e56..1136feca5ded3 100644 --- a/README.md +++ b/README.md @@ -373,7 +373,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h 1. **[SqueezeBERT](https://huggingface.co/docs/transformers/model_doc/squeezebert)** (from Berkeley) released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer. 1. **[Swin Transformer](https://huggingface.co/docs/transformers/model_doc/swin)** (from Microsoft) released with the paper [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030) by Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, Baining Guo. 1. **[Swin Transformer V2](https://huggingface.co/docs/transformers/model_doc/swinv2)** (from Microsoft) released with the paper [Swin Transformer V2: Scaling Up Capacity and Resolution](https://arxiv.org/abs/2111.09883) by Ze Liu, Han Hu, Yutong Lin, Zhuliang Yao, Zhenda Xie, Yixuan Wei, Jia Ning, Yue Cao, Zheng Zhang, Li Dong, Furu Wei, Baining Guo. -1. **[SwitchTransformers](https://huggingface.co/docs/transformers/main/model_doc/switch_transformers)** (from ) released with the paper []() by . +1. **[SwitchTransformers](https://huggingface.co/docs/transformers/main/model_doc/switch_transformers)** (from Google) released with the paper [Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity](https://arxiv.org/abs/2101.03961) by William Fedus, Barret Zoph, Noam Shazeer. 1. **[T5](https://huggingface.co/docs/transformers/model_doc/t5)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[T5v1.1](https://huggingface.co/docs/transformers/model_doc/t5v1.1)** (from Google AI) released in the repository [google-research/text-to-text-transfer-transformer](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[TAPAS](https://huggingface.co/docs/transformers/model_doc/tapas)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos. diff --git a/README_es.md b/README_es.md index ca38c4c082a7b..1911bec625d90 100644 --- a/README_es.md +++ b/README_es.md @@ -373,7 +373,7 @@ Número actual de puntos de control: ![](https://img.shields.io/endpoint?url=htt 1. **[SqueezeBERT](https://huggingface.co/docs/transformers/model_doc/squeezebert)** (from Berkeley) released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer. 1. **[Swin Transformer](https://huggingface.co/docs/transformers/model_doc/swin)** (from Microsoft) released with the paper [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030) by Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, Baining Guo. 1. **[Swin Transformer V2](https://huggingface.co/docs/transformers/model_doc/swinv2)** (from Microsoft) released with the paper [Swin Transformer V2: Scaling Up Capacity and Resolution](https://arxiv.org/abs/2111.09883) by Ze Liu, Han Hu, Yutong Lin, Zhuliang Yao, Zhenda Xie, Yixuan Wei, Jia Ning, Yue Cao, Zheng Zhang, Li Dong, Furu Wei, Baining Guo. -1. **[SwitchTransformers](https://huggingface.co/docs/transformers/main/model_doc/switch_transformers)** (from ) released with the paper []() by . +1. **[SwitchTransformers](https://huggingface.co/docs/transformers/main/model_doc/switch_transformers)** (from Google) released with the paper [Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity](https://arxiv.org/abs/2101.03961) by William Fedus, Barret Zoph, Noam Shazeer. 1. **[T5](https://huggingface.co/docs/transformers/model_doc/t5)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[T5v1.1](https://huggingface.co/docs/transformers/model_doc/t5v1.1)** (from Google AI) released in the repository [google-research/text-to-text-transfer-transformer](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[TAPAS](https://huggingface.co/docs/transformers/model_doc/tapas)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos. diff --git a/README_ko.md b/README_ko.md index e1693faa9e660..488ad758e48ff 100644 --- a/README_ko.md +++ b/README_ko.md @@ -323,7 +323,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는 1. **[SqueezeBERT](https://huggingface.co/docs/transformers/model_doc/squeezebert)** (from Berkeley) released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer. 1. **[Swin Transformer](https://huggingface.co/docs/transformers/model_doc/swin)** (from Microsoft) released with the paper [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030) by Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, Baining Guo. 1. **[Swin Transformer V2](https://huggingface.co/docs/transformers/model_doc/swinv2)** (from Microsoft) released with the paper [Swin Transformer V2: Scaling Up Capacity and Resolution](https://arxiv.org/abs/2111.09883) by Ze Liu, Han Hu, Yutong Lin, Zhuliang Yao, Zhenda Xie, Yixuan Wei, Jia Ning, Yue Cao, Zheng Zhang, Li Dong, Furu Wei, Baining Guo. -1. **[SwitchTransformers](https://huggingface.co/docs/transformers/main/model_doc/switch_transformers)** (from ) released with the paper []() by . +1. **[SwitchTransformers](https://huggingface.co/docs/transformers/main/model_doc/switch_transformers)** (from Google) released with the paper [Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity](https://arxiv.org/abs/2101.03961) by William Fedus, Barret Zoph, Noam Shazeer. 1. **[T5](https://huggingface.co/docs/transformers/model_doc/t5)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[T5v1.1](https://huggingface.co/docs/transformers/model_doc/t5v1.1)** (from Google AI) released in the repository [google-research/text-to-text-transfer-transformer](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[TAPAS](https://huggingface.co/docs/transformers/model_doc/tapas)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos. diff --git a/README_zh-hans.md b/README_zh-hans.md index 923a018be2be8..be873b6627cdc 100644 --- a/README_zh-hans.md +++ b/README_zh-hans.md @@ -347,7 +347,7 @@ conda install -c huggingface transformers 1. **[SqueezeBERT](https://huggingface.co/docs/transformers/model_doc/squeezebert)** (来自 Berkeley) 伴随论文 [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) 由 Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer 发布。 1. **[Swin Transformer](https://huggingface.co/docs/transformers/model_doc/swin)** (来自 Microsoft) 伴随论文 [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030) 由 Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, Baining Guo 发布。 1. **[Swin Transformer V2](https://huggingface.co/docs/transformers/model_doc/swinv2)** (来自 Microsoft) 伴随论文 [Swin Transformer V2: Scaling Up Capacity and Resolution](https://arxiv.org/abs/2111.09883) 由 Ze Liu, Han Hu, Yutong Lin, Zhuliang Yao, Zhenda Xie, Yixuan Wei, Jia Ning, Yue Cao, Zheng Zhang, Li Dong, Furu Wei, Baining Guo 发布。 -1. **[SwitchTransformers](https://huggingface.co/docs/transformers/main/model_doc/switch_transformers)** (from ) released with the paper []() by . +1. **[SwitchTransformers](https://huggingface.co/docs/transformers/main/model_doc/switch_transformers)** (from Google) released with the paper [Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity](https://arxiv.org/abs/2101.03961) by William Fedus, Barret Zoph, Noam Shazeer. 1. **[T5](https://huggingface.co/docs/transformers/model_doc/t5)** (来自 Google AI) 伴随论文 [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) 由 Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu 发布。 1. **[T5v1.1](https://huggingface.co/docs/transformers/model_doc/t5v1.1)** (来自 Google AI) 伴随论文 [google-research/text-to-text-transfer-transformer](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) 由 Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu 发布。 1. **[TAPAS](https://huggingface.co/docs/transformers/model_doc/tapas)** (来自 Google AI) 伴随论文 [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) 由 Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos 发布。 diff --git a/README_zh-hant.md b/README_zh-hant.md index 6459d15509068..c5c104a1c7e75 100644 --- a/README_zh-hant.md +++ b/README_zh-hant.md @@ -359,7 +359,7 @@ conda install -c huggingface transformers 1. **[SqueezeBERT](https://huggingface.co/docs/transformers/model_doc/squeezebert)** (from Berkeley) released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer. 1. **[Swin Transformer](https://huggingface.co/docs/transformers/model_doc/swin)** (from Microsoft) released with the paper [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030) by Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, Baining Guo. 1. **[Swin Transformer V2](https://huggingface.co/docs/transformers/model_doc/swinv2)** (from Microsoft) released with the paper [Swin Transformer V2: Scaling Up Capacity and Resolution](https://arxiv.org/abs/2111.09883) by Ze Liu, Han Hu, Yutong Lin, Zhuliang Yao, Zhenda Xie, Yixuan Wei, Jia Ning, Yue Cao, Zheng Zhang, Li Dong, Furu Wei, Baining Guo. -1. **[SwitchTransformers](https://huggingface.co/docs/transformers/main/model_doc/switch_transformers)** (from ) released with the paper []() by . +1. **[SwitchTransformers](https://huggingface.co/docs/transformers/main/model_doc/switch_transformers)** (from Google) released with the paper [Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity](https://arxiv.org/abs/2101.03961) by William Fedus, Barret Zoph, Noam Shazeer. 1. **[T5](https://huggingface.co/docs/transformers/model_doc/t5)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[T5v1.1](https://huggingface.co/docs/transformers/model_doc/t5v1.1)** (from Google AI) released with the paper [google-research/text-to-text-transfer-transformer](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[TAPAS](https://huggingface.co/docs/transformers/model_doc/tapas)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos. diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index 1712dd9f11b6f..863c3ac281044 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -162,7 +162,7 @@ The documentation is organized into five sections: 1. **[SqueezeBERT](model_doc/squeezebert)** (from Berkeley) released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer. 1. **[Swin Transformer](model_doc/swin)** (from Microsoft) released with the paper [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030) by Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, Baining Guo. 1. **[Swin Transformer V2](model_doc/swinv2)** (from Microsoft) released with the paper [Swin Transformer V2: Scaling Up Capacity and Resolution](https://arxiv.org/abs/2111.09883) by Ze Liu, Han Hu, Yutong Lin, Zhuliang Yao, Zhenda Xie, Yixuan Wei, Jia Ning, Yue Cao, Zheng Zhang, Li Dong, Furu Wei, Baining Guo. -1. **[SwitchTransformers](model_doc/switch_transformers)** (from ) released with the paper []() by . +1. **[SwitchTransformers](model_doc/switch_transformers)** (from Google) released with the paper [Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity](https://arxiv.org/abs/2101.03961) by William Fedus, Barret Zoph, Noam Shazeer. 1. **[T5](model_doc/t5)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[T5v1.1](model_doc/t5v1.1)** (from Google AI) released in the repository [google-research/text-to-text-transfer-transformer](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[TAPAS](model_doc/tapas)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos. diff --git a/docs/source/en/model_doc/switch_transformers.mdx b/docs/source/en/model_doc/switch_transformers.mdx index 3603d25b9188e..ad60313f0d6e8 100644 --- a/docs/source/en/model_doc/switch_transformers.mdx +++ b/docs/source/en/model_doc/switch_transformers.mdx @@ -14,36 +14,40 @@ specific language governing permissions and limitations under the License. ## Overview -The SwitchTransformers model was proposed in []() by . - +The SwitchTransformers model was proposed in [Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity](https://arxiv.org/abs/2101.03961) by William Fedus, Barret Zoph, Noam Shazeer. + +The Switch Transformer model uses a sparse T5 encoder-decoder architure, where the MLP are replace by a Mixture of Expert (MOE). A routing mecanism (top 1 in this case) associates each token to one of the expert, where each expert is a dense MLP. While switch tranformers have a lot more weights than their equivalent dense models, the sparsity allows for better scaling. +During a forward pass, only a fraction of the weights are used. The routing mecanism allows the model to select relavant weights on the fly which increases the model capacity. #TODO add the intuition about moving the loss curve. + The abstract from the paper is the following: -** +*In deep learning, models typically reuse the same parameters for all inputs. Mixture of Experts (MoE) defies this and instead selects different parameters for each incoming example. The result is a sparsely-activated model -- with outrageous numbers of parameters -- but a constant computational cost. However, despite several notable successes of MoE, widespread adoption has been hindered by complexity, communication costs and training instability -- we address these with the Switch Transformer. We simplify the MoE routing algorithm and design intuitive improved models with reduced communication and computational costs. Our proposed training techniques help wrangle the instabilities and we show large sparse models may be trained, for the first time, with lower precision (bfloat16) formats. We design models based off T5-Base and T5-Large to obtain up to 7x increases in pre-training speed with the same computational resources. These improvements extend into multilingual settings where we measure gains over the mT5-Base version across all 101 languages. Finally, we advance the current scale of language models by pre-training up to trillion parameter models on the "Colossal Clean Crawled Corpus" and achieve a 4x speedup over the T5-XXL model.* Tips: - +- SwitchTransformers uses the T5Tokenizer, which can be loaded directly from each model's repository. +- The released weights are pretrained on English [Masked Language Modeling](What is is MLM blog or doc) task, and should be finetuned. +- The routers -This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/). -The original code can be found [here](). +This model was contributed by [Younes Belkada](https://huggingface.co/ybelkada) and [Arthur Zucker](https://huggingface.co/ArtZucker) . +The original code can be found [here](https://github.com/google/flaxformer/tree/main/flaxformer/architectures/moe). ## SwitchTransformersConfig [[autodoc]] SwitchTransformersConfig -## SwitchTransformersTokenizer +## SwitchTransformersTop1Router -[[autodoc]] SwitchTransformersTokenizer - - build_inputs_with_special_tokens - - get_special_tokens_mask - - create_token_type_ids_from_sequences - - save_vocabulary +[[autodoc]] SwitchTransformersTop1Router + - _compute_router_probabilities + - forward -## SwitchTransformersTokenizerFast +## SwitchTransformersSparseMLP -[[autodoc]] SwitchTransformersTokenizerFast +[[autodoc]] SwitchTransformersSparseMLP + - forward ## SwitchTransformersModel diff --git a/src/transformers/models/switch_transformers/router_flax.py b/src/transformers/models/switch_transformers/router_flax.py deleted file mode 100644 index 2a7480bb5ba35..0000000000000 --- a/src/transformers/models/switch_transformers/router_flax.py +++ /dev/null @@ -1,731 +0,0 @@ -# Copyright 2022 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Mixture of Experts routing mechanisms.""" - -from typing import Any, Iterable, Optional, Sequence, Tuple, Union - -import flax -import jax -import jax.numpy as jnp -from flax import linen as nn -from flax.linen import partitioning as flax_partitioning - - -# from flaxformer.components import dense -# from flaxformer.types import Array -# from flaxformer.types import DType -# from flaxformer.types import Initializer - -RouterOutput = Any -Array = Any -DType = Any -Initializer = Any - -# Switch Transformer (https://arxiv.org/abs/2101.03961) suggests using -# nn.initializers.variance_scaling(0.1, "fan_in", "truncated_normal") -# scaling throughout MoE models, but we find slightly better results adopting -# typical normally-distributed scaling for the router specifically. -default_kernel_init = nn.initializers.normal(stddev=2e-2) -default_bias_init = nn.initializers.zeros - - -@flax.struct.dataclass -class RouterIndices: - """Dispatch indices and combine weights for scatter/gather-based routing. - - Attributes: - dispatch_indices: [num_groups, tokens_per_group, - num_selected_experts, 2] dispatch indices indicating, for each token, its preferred expert and its priority in - that expert's buffer. - combine_weights: [num_groups, tokens_per_group, num_selected_experts] - combine weights used for scaling expert outputs with the router's dispatch probability/confidence. - auxiliary_loss: Load balancing loss for router. - router_z_loss: Router z-loss. Encourages router logits to remain small in an - effort to improve stability. - """ - - dispatch_indices: Array - combine_weights: Array - auxiliary_loss: float - router_z_loss: float = 0.0 - - -@flax.struct.dataclass -class RouterMask: - """Dispatch and combine arrays for expert routing with masked matmuls. - - Attributes: - dispatch_mask: [num_groups, tokens_per_group, num_experts, - expert_capacity] dispatch array that is 1 if the token gets routed to the corresponding expert, and 0 - otherwise. - combine_array: [num_groups, tokens_per_group, num_experts, - expert_capacity] combine array used for combining expert outputs and scaling with router probability. - auxiliary_loss: Load balancing loss for router. - router_z_loss: Router z-loss. Encourages router logits to remain small in an - effort to improve stability. - """ - - dispatch_mask: Array - combine_array: Array - auxiliary_loss: float - router_z_loss: float = 0.0 - - -def _favor_one_hot_slices() -> bool: - """Returns true iff running on TPUs.""" - return jax.default_backend() == "tpu" or jax.devices()[0].platform == "tpu" - - -def _take_along_axis(array: Array, indices: Array, axis: int) -> Array: - """Takes values from the input array by matching 1D index and data slices. - - This function serves the same purpose as jax.numpy.take_along_axis, except that it uses one-hot matrix - multiplications under the hood on TPUs: (1) On TPUs, we use one-hot matrix multiplications to select elements from - the - array; this is particularly helpful for avoiding erroneous all-gather ops when running under pjit. - (2) Otherwise, we fall back to jax.numpy.take_along_axis. - - Notes: - - To simplify matters in case (1), we only support slices along the second or last dimensions. - - We may wish to revisit (1) for very large arrays. - - Args: - array: Source array. - indices: Indices to take along each 1D slice of array. - axis: Axis along which to take 1D slices. - - Returns: - The indexed result. - """ - if array.ndim != indices.ndim: - raise ValueError( - f"indices and array must have the same number of dimensions; {indices.ndim} vs. {array.ndim}." - ) - - if ( - axis != -1 and axis != array.ndim - 1 and axis != 1 and axis != -array.ndim + 1 # Not last dimension - ): # Not second dimension - raise ValueError( - "Only slices along the second or last dimension are supported; " - f"array.ndim = {array.ndim}, while axis = {axis}." - ) - - if _favor_one_hot_slices(): - one_hot_length = array.shape[axis] - one_hot_indices = jax.nn.one_hot(indices, one_hot_length, axis=axis) - - if axis == -1 or array.ndim == 1: - # Take i elements from last dimension (s). - # We must use HIGHEST precision to accurately reproduce indexing - # operations with matrix multiplications. - result = jnp.einsum("...s,...is->...i", array, one_hot_indices, precision=jax.lax.Precision.HIGHEST) - else: - # Take i elements from second dimension (s). We assume here that we always - # want to slice along the second dimension. - # We must use HIGHEST precision to accurately reproduce indexing - # operations with matrix multiplications. - result = jnp.einsum("ns...,nis...->ni...", array, one_hot_indices, precision=jax.lax.Precision.HIGHEST) - return jax.lax.convert_element_type(result, array.dtype) - else: - return jnp.take_along_axis(array, indices, axis=axis) - - -def _top_k(array: Array, k: int) -> Tuple[Array, Array]: - """Returns top k values and their indices along the last axis of the array. - - This function serves the same purpose as jax.lax.top_k, but in a more XLA friendly manner for TPUs: (1) On TPUs, we - use one-hot matrix multiplications to select the top k values. - This convoluted way of obtaining the top k values is generally faster on TPUs, and, for pjit in particular, - avoids adding extra all-gather ops during backpropagation. - (2) Otherwise, we fall back to jax.lax.top_k (and its underlying scatter op). - - Args: - array: Source array. - k: Number of top values to select. - - Returns: - - Top k values - - Associated top k indices. - """ - if _favor_one_hot_slices(): - top_k_indices = jax.lax.top_k(array, k)[-1] - top_k_values = _take_along_axis(array, top_k_indices, axis=-1) - return top_k_values, top_k_indices - else: - return jax.lax.top_k(array, k) - - -class RouterWeights(nn.Module): - """Router module converting token inputs to router logits. - - Attributes: - use_bias: Whether or not to use the bias term in computing the logits. - dtype: Numerical float type for router logit computation. - kernel_init: Initialization scheme for kernel. - bias_init: Initialization scheme for bias. - precision: XLA precision for array computations. - axis: Axes along which to apply the dense router weights transformation. - Defaults to final axis (typically the "hidden dimension"). - kernel_axis_names: Logical axis names to use for kernel sharding. - reshape_kernel: Whether to reshape the kernel parameter to 2D for Adafactor. - """ - - use_bias: bool = True - dtype: DType = jnp.bfloat16 - kernel_init: Initializer = default_kernel_init - bias_init: Initializer = default_bias_init - precision: jax.lax.Precision = jax.lax.Precision.DEFAULT - axis: Union[Iterable[int], int] = -1 - kernel_axis_names: Sequence[str] = ("embed", "unmodeled") - reshape_kernel: bool = True - - @nn.compact - def __call__(self, token_inputs: Array, num_experts: int) -> Array: - """Applies RouterWeights module. - - Args: - token_inputs: Flattened batch of tokens with shape [num_groups, - group_size, hidden_dim]. - num_experts: Number of experts. - - Returns: - Router logits with shape [num_groups, group_size, num_experts]. - """ - # Flax code for reference - return nn.Dense( - features=num_experts, - axis=self.axis, - use_bias=self.use_bias, - dtype=self.dtype, - kernel_init=self.kernel_init, - bias_init=self.bias_init, - precision=self.precision, - kernel_axis_names=self.kernel_axis_names, - reshape_kernel=self.reshape_kernel, - name="w", - )(token_inputs) - # pass - - -class Router(nn.Module): - """Abstract base router class, defining router API and inner workings. - - Attributes: - router_weights: Configurable module used to compute router logits from token - inputs. - jitter_noise: Amplitude of jitter noise applied to router logits. - dtype: Numeric float type for returned combine array. All actual - computations are performed in float32 of the input for stability. - ignore_padding_tokens: Whether to ignore padding tokens during routing. Note - that some routers (e.g. TokensChooseMaskedRouter) will completely ignore padding tokens, while others (e.g. - TokensChooseScatterRouter and ExpertsChooseMaskedRouter) will simply down-weight the probability of selecting - padding tokens. - """ - - router_weights: RouterWeights - jitter_noise: float - dtype: jnp.dtype - ignore_padding_tokens: bool - - def __call__( - self, token_inputs: Array, num_experts: int, expert_capacity: int, apply_jitter: bool = True - ) -> RouterOutput: - """Computes dispatch and combine arrays for routing to experts. - - Args: - token_inputs: [num_groups, tokens_per_group, hidden_dim] inputs to - send to experts. - num_experts: Number of experts. - expert_capacity: Each group will send this many tokens to each expert. - apply_jitter: If true, apply jitter noise during routing. - - Returns: - Router indices or mask arrays (depending on router type). - """ - token_inputs = flax_partitioning.with_sharding_constraint(token_inputs, ("batch", "length", "embed")) - router_probs, router_logits = self._compute_router_probabilities(token_inputs, num_experts, apply_jitter) - router_probs = flax_partitioning.with_sharding_constraint(router_probs, ("batch", "length", "unmodeled")) - router_logits = flax_partitioning.with_sharding_constraint(router_logits, ("batch", "length", "unmodeled")) - - if self.ignore_padding_tokens: - # To identify non-padding tokens, we rely on the fact that padding tokens - # in the inputs have already been masked in the default T5 architecture. - # See - # https://github.com/google/flaxformer/blob/9712a16/flaxformer/architectures/t5/t5_architecture.py#L315 - # and - # https://github.com/google/flaxformer/blob/9712a16/flaxformer/architectures/t5/t5_architecture.py#L603. - padding_mask = jnp.array((jnp.sum(jnp.abs(token_inputs), axis=-1) > 0), dtype=token_inputs.dtype) - router_logits *= jnp.expand_dims(padding_mask, axis=-1) - else: - padding_mask = None - - instructions = self._compute_routing_instructions(router_probs, padding_mask, expert_capacity) - - return instructions.replace(router_z_loss=_router_z_loss(router_logits)) - - def _compute_router_probabilities( - self, token_inputs: Array, num_experts: int, apply_jitter: bool - ) -> Tuple[Array, Array]: - """Computes router probabilities from input tokens. - - Args: - token_inputs: [num_groups, tokens_per_group, hidden_dim] from which - router probabilities are computed. - num_experts: Number of experts. - apply_jitter: If true, apply jitter noise. - - Returns: - - [num_groups, tokens_per_group, num_experts] probabilities for each token and expert. Used for - routing tokens to experts. - - [num_groups, tokens_per_group, num_experts] raw router logits. Used for computing router z-loss. - """ - # For remainder of routing computation we use float32 to ensure stability. - # See the discussion of "selective precision" in - # https://arxiv.org/abs/2101.03961. - token_inputs = jax.lax.convert_element_type(token_inputs, jnp.float32) - - if apply_jitter and self.jitter_noise > 0: - token_inputs *= jax.random.uniform( - self.make_rng("jitter"), - token_inputs.shape, - token_inputs.dtype, - minval=1.0 - self.jitter_noise, - maxval=1.0 + self.jitter_noise, - ) - - # Shape: [num_groups, tokens_per_group, num_experts] - router_logits = self.router_weights(token_inputs, num_experts) - - router_probabilities = jax.nn.softmax(router_logits, axis=-1) - - return router_probabilities, router_logits - - def _compute_routing_instructions( - self, router_probs: Array, padding_mask: Optional[Array], expert_capacity: int - ) -> RouterOutput: - """Computes instructions for routing inputs to experts.""" - raise NotImplementedError("Router is an abstract class that should be subclassed.") - - -class ScatterRouter(Router): - """Abstract base router class for scatter dispatch routers. - - ScatterRouter(s) return RouterIndices containing dispatch indices and combine weights for sending token inputs (via - scatter) and receiving outputs (via gather) to and from experts. - - Scatter-based routing is generally faster than masked matmul routing on CPUs and GPUs. - """ - - def _compute_routing_instructions( - self, router_probs: Array, padding_mask: Optional[Array], expert_capacity: int - ) -> RouterIndices: - """Computes instructions for routing inputs to experts. - - Args: - router_probs: [num_groups, tokens_per_group, num_experts] - probabilities used to determine the routing of tokens to the experts. - padding_mask: [num_groups, tokens_per_group] padding logit mask - used to identify padding tokens that should be ignored by the router. - expert_capacity: Each group will send this many tokens to each expert. - - Returns: - Router indices containing dispatch indices and combine weights. - """ - raise NotImplementedError("ScatterRouter is an abstract class that should be subclassed.") - - -class MaskedRouter(Router): - """Abstract base router class for masked matmul dispatch routers. - - MaskedRouter(s) return RouterMask(s) containing a dispatch mask and combine array for sending and receiving (via - masked matmuls) inputs and outputs to and from experts. - - Routing using masked matmuls is generally faster than scatter-based routing on TPUs. - """ - - def _compute_routing_instructions( - self, router_probs: Array, padding_mask: Optional[Array], expert_capacity: int - ) -> RouterMask: - """Computes masks for the top-k experts per token. - - Args: - router_probs: [num_groups, tokens_per_group, num_experts] - probabilities used to determine the routing of tokens to the experts. - padding_mask: [num_groups, tokens_per_group] padding logit mask - used to identify padding tokens that should be ignored by the router. - expert_capacity: Each group will send this many tokens to each expert. - - Returns: - Router mask arrays. - """ - raise NotImplementedError("MaskedRouter is an abstract class that should be subclassed.") - - -class TokensChooseScatterRouter(ScatterRouter): - """Scatter router using tokens choose top-k experts assignment. - - This router uses the same mechanism as in Switch Transformer (https://arxiv.org/abs/2101.03961) and V-MoE - (https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are sorted by router_probs and then - routed to their choice of expert until the expert's expert_capacity is reached. There is no guarantee that each - token is processed by an expert, or that each expert receives at least one token. - - Attributes: - num_selected_experts: Maximum number of experts to which each token is - routed. Tokens may be routed to fewer experts if particular experts are oversubscribed / reach capacity. - batch_prioritized_routing: Whether or not to use Batch Prioritized Routing - (BPR), originally introduced in V-MoE (https://arxiv.org/abs/2106.05974). - With BPR, we prioritize routing those top-k tokens with the highest router probability, rather than simply - using each tokens left-to-right ordering in the batch. This prioritization is important because the expert's - have limited capacity. - """ - - num_selected_experts: int - batch_prioritized_routing: bool - - def _compute_routing_instructions( - self, router_probs: Array, padding_mask: Optional[Array], expert_capacity: int - ) -> RouterIndices: - """Computes dispatch indices and combine weights for the top-k experts. - - Args: - router_probs: [num_groups, tokens_per_group, num_experts] - probabilities used to determine the routing of tokens to the experts. - padding_mask: [num_groups, tokens_per_group] padding logit mask - used to identify padding tokens that should be down-weighted by the router. - expert_capacity: Each group will send this many tokens to each expert. - - Returns: - Dispatch indices and combine weights for scatter/gather-based routing. - """ - num_groups, tokens_per_group, num_experts = router_probs.shape - - if padding_mask is not None: - # Because `expert_indices` are directly used for scatter-based routing, we - # mask probabilities corresponding to tokens before the top-k operation. - # Note that, unlike for mask-based tokens-choose routing, the - # (down-weighted) padding tokens may still be selected. - router_probs *= jnp.expand_dims(padding_mask, axis=-1) - - # Top-k router probability and corresponding expert indices for each token. - # Shape: [num_groups, tokens_per_group, num_selected_experts]. - combine_weights, expert_indices = _top_k(router_probs, k=self.num_selected_experts) - - auxiliary_loss = _load_balancing_loss(router_probs, expert_indices) - - if self.batch_prioritized_routing: - # Sort tokens according to their routing probability per token group, so - # that the highest probability tokens are routed first. - token_ordering = jnp.argsort(-combine_weights[..., 0], axis=-1) - expert_indices = _take_along_axis(expert_indices, jnp.expand_dims(token_ordering, axis=-1), axis=-2) - - # Identify each token's preferred expert. - # Make num_selected_experts the leading axis to ensure that top-1 choices - # have priority over top-2 choices, which have priority over top-3 - # choices... - preferred_experts = jnp.swapaxes(expert_indices, 1, 2) - # Shape: [num_groups, num_selected_experts * tokens_per_group] - preferred_experts = preferred_experts.reshape(num_groups, -1) - - # Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts]. - expert_mask = jax.nn.one_hot(preferred_experts, num_experts, dtype=jnp.int32) - - # Experts have a fixed capacity that we cannot exceed. A token's priority - # within the expert's buffer is given by the masked, cumulative capacity of - # its target expert. - # Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts]. - token_priority = jnp.cumsum(expert_mask, axis=1) * expert_mask - 1.0 - # Shape: [num_groups, num_selected_experts, tokens_per_group, num_experts]. - token_priority = token_priority.reshape((num_groups, self.num_selected_experts, -1, num_experts)) - # Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts]. - token_priority = jnp.swapaxes(token_priority, 1, 2) - # For each token, across all experts, select the only non-negative - # (unmasked) priority. Shape: [num_groups, tokens_per_group, - # num_selected_experts]. - token_priority = jnp.max(token_priority, axis=-1) - - # Return to original index shape. - preferred_experts = preferred_experts.reshape(num_groups, self.num_selected_experts, tokens_per_group) - # Shape: [num_groups, tokens_per_group, num_selected_experts] - preferred_experts = jnp.swapaxes(preferred_experts, 1, 2) - - if self.batch_prioritized_routing: - # Place tokens in their original ordering. - inverse_token_ordering = jnp.argsort(token_ordering, axis=-1) - preferred_experts = _take_along_axis( - preferred_experts, jnp.expand_dims(inverse_token_ordering, axis=-1), axis=-2 - ) - token_priority = _take_along_axis( - token_priority, jnp.expand_dims(inverse_token_ordering, axis=-1), axis=-2 - ) - - # Mask out tokens that overflow the maximum expert capacities. - # Shape: [num_groups, tokens_per_group, num_selected_experts]. - combine_weights *= token_priority < expert_capacity - - # Expert index and priority within the expert capacity buffer. - # Shape: [num_groups, tokens_per_group, num_selected_experts, 2]. - dispatch_indices = jnp.stack([preferred_experts, token_priority], axis=-1) - - # Return to default dtype now that router computation is complete. - combine_weights = jax.lax.convert_element_type(combine_weights, self.dtype) - dispatch_indices = jax.lax.convert_element_type(dispatch_indices, jnp.int32) - - return RouterIndices(dispatch_indices, combine_weights, auxiliary_loss) - - -class ExpertsChooseMaskedRouter(MaskedRouter): - """Masked matmul router using experts choose tokens assignment. - - This router uses the same mechanism as in Mixture-of-Experts with Expert Choice (https://arxiv.org/abs/2202.09368): - each expert selects its top expert_capacity tokens. An individual token may be processed by multiple experts or - none at all. - - Note: "experts choose routing" should not be used in decoder blocks because it breaks the autoregressive behavior - -- the model will learn to cheat by using future token information to improve current token predictions. - """ - - def _compute_routing_instructions( - self, router_probs: Array, padding_mask: Optional[Array], expert_capacity: int - ) -> RouterMask: - """Computes masks for the highest probability token per expert. - - Args: - router_probs: [num_groups, tokens_per_group, num_experts] - probabilities used to determine the routing of tokens to the experts. - padding_mask: [num_groups, tokens_per_group] padding logit mask - used to identify padding tokens that should be down-weighted by the router. - expert_capacity: Each group will send this many tokens to each expert. - - Returns: - Dispatch and combine arrays for routing with masked matmuls. - """ - tokens_per_group = router_probs.shape[1] - - if padding_mask is not None: - # Because experts choose tokens, we mask probabilities corresponding to - # tokens before the top-k operation. Note that, unlike for masked-based - # tokens-choose routing, the experts here may still choose to select the - # (down-weighted) padding tokens. - router_probs *= jnp.expand_dims(padding_mask, axis=-1) - - # vmap over group dimension. - router_probs_t = jax.vmap(lambda m: m.transpose())(router_probs) - - # Top expert_capacity router probability and corresponding token indices for - # each expert. Shapes: [num_groups, num_experts, expert_capacity]. - expert_gate, expert_index = _top_k(router_probs_t, k=expert_capacity) - - # Convert to one-hot mask of expert indices for each token in each group. - # Shape: [num_groups, num_experts, expert_capacity, tokens_per_group]. - dispatch_mask = jax.nn.one_hot(expert_index, tokens_per_group, dtype=jnp.int32) - - # Move axes to conform with shape expected by MoeLayer API. - # Shape: [num_groups, tokens_per_group, num_experts, expert_capacity] - dispatch_mask = jnp.moveaxis(dispatch_mask, 3, 1) - - # The combine array will be used for combining expert outputs, scaled by the - # router probabilities. Shape: [num_groups, num_experts, tokens_per_group, - # expert_capacity]. - combine_array = jnp.einsum( - "...ec,...tec->...tec", expert_gate, dispatch_mask, precision=jax.lax.Precision.DEFAULT - ) - - # Return to default dtype now that router computation is complete. - combine_array = jax.lax.convert_element_type(combine_array, self.dtype) - - # Each expert is choosing tokens until it reaches full capacity, so we don't - # need an auxiliary loading balancing loss for expert choice routing. - auxiliary_loss = 0.0 - - return RouterMask(dispatch_mask, combine_array, auxiliary_loss) - - -class TokensChooseMaskedRouter(MaskedRouter): - """ - Masked matmul router using tokens choose top-k experts assignment. - - This router uses the same mechanism as in Switch Transformer (https://arxiv.org/abs/2101.03961) and V-MoE - (https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are sorted by router_probs and then - routed to their choice of expert until the expert's expert_capacity is reached. There is no guarantee that each - token is processed by an expert, or that each expert receives at least one token. - - Attributes: - num_selected_experts: Maximum number of experts to which each token is - routed. Tokens may be routed to fewer experts if particular experts are oversubscribed / reach capacity. - batch_prioritized_routing: Whether or not to use Batch Prioritized Routing - (BPR), originally introduced in V-MoE (https://arxiv.org/abs/2106.05974). With BPR, we prioritize routing those - top-k tokens with the highest router probability, rather than simply using each tokens left-to-right ordering - in the batch. This prioritization is important because the experts have limited capacity. - """ - - num_selected_experts: int - batch_prioritized_routing: bool - - def _compute_routing_instructions( - self, router_probs: Array, padding_mask: Optional[Array], expert_capacity: int - ) -> RouterMask: - """ - Computes masks for the top-k experts per token. - - Args: - router_probs: [num_groups, tokens_per_group, num_experts] - probabilities used to determine the routing of tokens to the experts. - padding_mask: [num_groups, tokens_per_group] padding logit mask - used to identify padding tokens that should be ignored by the router. - expert_capacity: Each group will send this many tokens to each expert. - - Returns: - Dispatch and combine arrays for routing with masked matmuls. - """ - num_groups, _, num_experts = router_probs.shape - - # Top-k router probability and corresponding expert indices for each token. - # Shape: [num_groups, tokens_per_group, num_selected_experts]. - expert_gate, expert_index = _top_k(router_probs, k=self.num_selected_experts) - - if padding_mask is not None: - # Mask applied to gate. Exclude choices corresponding to padding tokens. - gate_mask = jnp.expand_dims(padding_mask, axis=-1) - expert_gate *= gate_mask - - # Set `expert_index` elements corresponding to padding to negative - # numbers. Negative `expert_index` elements will ultimately be dropped in - # the one_hot conversion to the `expert_mask`. - # First convert nonzero padding elements to negative values. - expert_index *= 2 * gate_mask - 1.0 - # Handle zero padding elements by negatively shifting all padding. - expert_index += jnp.repeat(gate_mask - 1.0, self.num_selected_experts, axis=-1) - - # To correctly compute load balancing loss, we also mask out probs. - router_probs *= gate_mask - - auxiliary_loss = _load_balancing_loss(router_probs, expert_index) - - if self.batch_prioritized_routing: - # Sort tokens according to their routing probability per group, so that - # the highest probability tokens are routed first. - permutation = jnp.argsort(-expert_gate[..., 0], axis=-1) - # Shape: [num_groups, tokens_per_group, num_selected_experts] - expert_index = _take_along_axis(expert_index, jnp.expand_dims(permutation, axis=-1), axis=-2) - - # Make num_selected_experts the leading axis to ensure that top-1 choices - # have priority over top-2 choices, which have priority over top-3 choices, - # etc. - expert_index = jnp.swapaxes(expert_index, 1, 2) - # Shape: [num_groups, num_selected_experts * tokens_per_group] - expert_index = expert_index.reshape(num_groups, -1) - - # Create mask out of indices. - # Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts]. - expert_mask = jax.nn.one_hot(expert_index, num_experts, dtype=jnp.int32) - - # Experts have a fixed capacity that we cannot exceed. A token's priority - # within the expert's buffer is given by the masked, cumulative capacity of - # its target expert. - # Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts]. - token_priority = jnp.cumsum(expert_mask, axis=1) * expert_mask - 1.0 - # Shape: [num_groups, num_selected_experts, tokens_per_group, num_experts]. - token_priority = token_priority.reshape((num_groups, self.num_selected_experts, -1, num_experts)) - # Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts]. - token_priority = jnp.swapaxes(token_priority, 1, 2) - # For each token, across all selected experts, select the only non-negative - # (unmasked) priority. Now, for group G routing to expert E, token T has - # non-negative priority (i.e. token_priority[G,T,E] >= 0) if and only if E - # is its targeted expert. - # Shape: [num_groups, tokens_per_group, num_experts]. - token_priority = jnp.max(token_priority, axis=2) - - if self.batch_prioritized_routing: - # Place token priorities in original ordering of tokens. - inv_permutation = jnp.argsort(permutation, axis=-1) - token_priority = _take_along_axis(token_priority, jnp.expand_dims(inv_permutation, axis=-1), axis=-2) - - # Token T can only be routed to expert E if its priority is positive and - # less than the expert capacity. One-hot matrix will ignore indices outside - # the range [0, expert_capacity). - # Shape: [num_groups, tokens_per_group, num_experts, expert_capacity]. - dispatch_mask = jax.nn.one_hot(token_priority, expert_capacity, dtype=jnp.bool_) - - # The combine array will be used for combining expert outputs, scaled by the - # router probabilities. Shape: [num_groups, tokens_per_group, num_experts, - # expert_capacity]. - combine_array = jnp.einsum( - "...te,...tec->...tec", router_probs, dispatch_mask, precision=jax.lax.Precision.DEFAULT - ) - - # Return to default dtype now that router computation is complete. - combine_array = jax.lax.convert_element_type(combine_array, self.dtype) - - return RouterMask(dispatch_mask, combine_array, auxiliary_loss) - - -def _load_balancing_loss(router_probs: Array, expert_indices: Array) -> float: - """Computes auxiliary load balancing loss as in Switch Transformer. - - See Switch Transformer (https://arxiv.org/abs/2101.03961). This function implements the loss function presented in - equations (4) - (6). It aims to penalize those cases where the routing between experts is unbalanced. - - Args: - router_probs: Probability assigned to each expert per token. Shape: - [num_groups, tokens_per_group, num_experts]. - expert_indices: [num_groups, tokens_per_group, num_selected_experts] - indices identifying the top num_selected_experts for a given token. - - Returns: - The auxiliary loss. - """ - num_experts = router_probs.shape[-1] - - # Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts]. - expert_mask = jax.nn.one_hot(expert_indices, num_experts, dtype=jnp.int32) - # For a given token, determine if it was routed to a given expert. - # Shape: [num_groups, tokens_per_group, num_experts] - expert_mask = jnp.max(expert_mask, axis=-2) - - tokens_per_group_and_expert = jnp.mean(expert_mask, dtype=jnp.float32, axis=-2) - router_prob_per_group_and_expert = jnp.mean(router_probs, dtype=jnp.float32, axis=-2) - return ( - jnp.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert, dtype=jnp.float32) * num_experts**2 - ) - - -def _router_z_loss(router_logits: Array) -> float: - """Compute router z-loss. - - The router z-loss was introduced in Designing Effective Sparse Expert Models (https://arxiv.org/abs/2202.08906). - It encourages router logits to remain small in an effort to improve stability. - - Args: - router_logits: [num_groups, tokens_per_group, num_experts] router - logits. - - Returns: - Scalar router z-loss. - """ - num_groups, tokens_per_group, _ = router_logits.shape - log_z = jax.nn.logsumexp(router_logits, axis=-1) - z_loss = log_z**2 - return jnp.sum(z_loss, dtype=jnp.float32) / (num_groups * tokens_per_group) - - -# num_tokens = 5 -# num_experts = 2 -# num_selected_experts = 1 -# rng = jax.random.PRNGKey(0) - -# router_probs = jax.random.uniform(rng, (num_tokens, num_experts), minval=0, maxval=1) -# expert_indices = jax.random.randint(rng, (num_tokens, num_selected_experts), minval=0, maxval=2) - -# loss = _load_balancing_loss(router_probs, expert_indices) diff --git a/src/transformers/models/switch_transformers/tokenization_switch_transformers.py b/src/transformers/models/switch_transformers/tokenization_switch_transformers.py deleted file mode 100644 index 09d9a3c8d57e8..0000000000000 --- a/src/transformers/models/switch_transformers/tokenization_switch_transformers.py +++ /dev/null @@ -1,363 +0,0 @@ -# coding=utf-8 -# Copyright 2022 SwitchTransformers Authors and HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" Tokenization class for model SwitchTransformers.""" - - -import os -import re -import warnings -from shutil import copyfile -from typing import Any, Dict, List, Optional, Tuple - -import sentencepiece as spm - -from ...tokenization_utils import PreTrainedTokenizer -from ...utils import logging - - -logger = logging.get_logger(__name__) - -VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"} - -# PRETRAINED_VOCAB_FILES_MAP = { -# "vocab_file": { -# "google/switch-base-8": "https://huggingface.co/google/switch-base-8/resolve/main/spiece.model", -# "google/switch-base-16": "https://huggingface.co/switch-base-16/resolve/main/spiece.model", -# "google/switch-base-32": "https://huggingface.co/google/switch-base-32/resolve/main/spiece.model", -# "google/switch-base-64": "https://huggingface.co/google/switch-base-64/resolve/main/spiece.model", -# "google/switch-base-128": "https://huggingface.co/google/switch-base-128/resolve/main/spiece.model", -# "google/switch-base-256": "https://huggingface.co/google/switch-base-256/resolve/main/spiece.model", -# "google/switch-large-128": "https://huggingface.co/google/switch-large-128/resolve/main/spiece.model", -# "google/switch-xxl-128": "https://huggingface.co/google/switch-xxl-128/resolve/main/spiece.model", -# "google/switch-c-2048": "https://huggingface.co/google/switch-c-2048/resolve/main/spiece.model", -# } -# } - - -# # TODO(PVP) - this should be removed in Transformers v5 -# PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { -# "google/switch-base-8": 512, -# "google/switch-base-16": 512, -# "google/switch-base-32": 512, -# "google/switch-base-64": 512, -# "google/switch-base-128": 512, -# "google/switch-base-256": 512, -# "google/switch-large-128": 512, -# "google/switch-xxl-128": 512, -# "google/switch-c-2048": 512, -# } - -PRETRAINED_VOCAB_FILES_MAP = { - "vocab_file": { - "HFLAY/switch_base_8": "https://huggingface.co/HFLAY/switch_base_8/blob/main/spiece.model", - } -} - - -# TODO(PVP) - this should be removed in Transformers v5 -PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { - "HFLAY/switch_base_8": 512, -} - - -class SwitchTransformersTokenizer(PreTrainedTokenizer): - """ - Construct a SwitchTransformers tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece). - - This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to - this superclass for more information regarding those methods. - - Args: - vocab_file (`str`): - [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that - contains the vocabulary necessary to instantiate a tokenizer. - eos_token (`str`, *optional*, defaults to `""`): - The end of sequence token. - - - - When building a sequence using special tokens, this is not the token that is used for the end of sequence. - The token used is the `sep_token`. - - - - unk_token (`str`, *optional*, defaults to `""`): - The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this - token instead. - pad_token (`str`, *optional*, defaults to `""`): - The token used for padding, for example when batching sequences of different lengths. - extra_ids (`int`, *optional*, defaults to 100): - Add a number of extra ids added to the end of the vocabulary for use as sentinels. These tokens are - accessible as "" where "{%d}" is a number between 0 and extra_ids-1. Extra tokens are - indexed from the end of the vocabulary up to beginning ("" is the last token in the vocabulary - like in SwitchTransformers preprocessing see - [here](https://github.com/google-research/text-to-text-transfer-transformer/blob/9fd7b14a769417be33bc6c850f9598764913c833/switch_transformers/data/preprocessors.py#L2117)). - additional_special_tokens (`List[str]`, *optional*): - Additional special tokens used by the tokenizer. - sp_model_kwargs (`dict`, *optional*): - Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for - SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, - to set: - - - `enable_sampling`: Enable subword regularization. - - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. - - - `nbest_size = {0,1}`: No sampling is performed. - - `nbest_size > 1`: samples from the nbest_size results. - - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) - using forward-filtering-and-backward-sampling algorithm. - - - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for - BPE-dropout. - - Attributes: - sp_model (`SentencePieceProcessor`): - The *SentencePiece* processor that is used for every conversion (string, tokens and IDs). - """ - - vocab_files_names = VOCAB_FILES_NAMES - pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP - max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES - model_input_names = ["input_ids", "attention_mask"] - - def __init__( - self, - vocab_file, - eos_token="", - unk_token="", - pad_token="", - extra_ids=100, - additional_special_tokens=None, - sp_model_kwargs: Optional[Dict[str, Any]] = None, - **kwargs - ) -> None: - # Add extra_ids to the special token list - if extra_ids > 0 and additional_special_tokens is None: - additional_special_tokens = [f"" for i in range(extra_ids)] - elif extra_ids > 0 and additional_special_tokens is not None: - # Check that we have the right number of extra_id special tokens - extra_tokens = len(set(filter(lambda x: bool("extra_id" in str(x)), additional_special_tokens))) - if extra_tokens != extra_ids: - raise ValueError( - f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are" - " provided to SwitchTransformersTokenizer. In this case the additional_special_tokens must" - " include the extra_ids tokens" - ) - - self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs - - super().__init__( - eos_token=eos_token, - unk_token=unk_token, - pad_token=pad_token, - extra_ids=extra_ids, - additional_special_tokens=additional_special_tokens, - sp_model_kwargs=self.sp_model_kwargs, - **kwargs, - ) - - self.vocab_file = vocab_file - self._extra_ids = extra_ids - - self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) - self.sp_model.Load(vocab_file) - - @staticmethod - def _eventually_correct_switch_transformers_max_length( - pretrained_model_name_or_path, max_model_length, init_max_model_length - ): - if pretrained_model_name_or_path in SwitchTransformersTokenizer.max_model_input_sizes: - deprecated_max_model_length = SwitchTransformersTokenizer.max_model_input_sizes[ - pretrained_model_name_or_path - ] - if init_max_model_length is not None and init_max_model_length != max_model_length: - return init_max_model_length - elif init_max_model_length is None: - warnings.warn( - "This tokenizer was incorrectly instantiated with a model max length of" - f" {deprecated_max_model_length} which will be corrected in Transformers v5.\nFor now, this" - " behavior is kept to avoid breaking backwards compatibility when padding/encoding with" - " `truncation is True`.\n- Be aware that you SHOULD NOT rely on" - f" {pretrained_model_name_or_path} automatically truncating your input to" - f" {deprecated_max_model_length} when padding/encoding.\n- If you want to encode/pad to sequences" - f" longer than {deprecated_max_model_length} you can either instantiate this tokenizer with" - " `model_max_length` or pass `max_length` when encoding/padding.\n- To avoid this warning, please" - " instantiate this tokenizer with `model_max_length` set to your preferred value.", - FutureWarning, - ) - - return max_model_length - - @property - def vocab_size(self): - return self.sp_model.get_piece_size() + self._extra_ids - - def get_vocab(self): - vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} - vocab.update(self.added_tokens_encoder) - return vocab - - def get_special_tokens_mask( - self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False - ) -> List[int]: - """ - Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding - special tokens using the tokenizer `prepare_for_model` method. - - Args: - token_ids_0 (`List[int]`): - List of IDs. - token_ids_1 (`List[int]`, *optional*): - Optional second list of IDs for sequence pairs. - already_has_special_tokens (`bool`, *optional*, defaults to `False`): - Whether or not the token list is already formatted with special tokens for the model. - - Returns: - `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. - """ - if already_has_special_tokens: - return super().get_special_tokens_mask( - token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True - ) - - # normal case: some special tokens - if token_ids_1 is None: - return ([0] * len(token_ids_0)) + [1] - return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] - - def _add_eos_if_not_present(self, token_ids: List[int]) -> List[int]: - """Do not add eos again if user already added it.""" - if len(token_ids) > 0 and token_ids[-1] == self.eos_token_id: - warnings.warn( - f"This sequence already has {self.eos_token}. In future versions this behavior may lead to duplicated" - " eos tokens being added." - ) - return token_ids - else: - return token_ids + [self.eos_token_id] - - def create_token_type_ids_from_sequences( - self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None - ) -> List[int]: - """ - Create a mask from the two sequences passed to be used in a sequence-pair classification task. - SwitchTransformers does not make use of token type ids, therefore a list of zeros is returned. - - Args: - token_ids_0 (`List[int]`): - List of IDs. - token_ids_1 (`List[int]`, *optional*): - Optional second list of IDs for sequence pairs. - - Returns: - `List[int]`: List of zeros. - """ - eos = [self.eos_token_id] - - if token_ids_1 is None: - return len(token_ids_0 + eos) * [0] - return len(token_ids_0 + eos + token_ids_1 + eos) * [0] - - def build_inputs_with_special_tokens( - self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None - ) -> List[int]: - """ - Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and - adding special tokens. A sequence has the following format: - - - single sequence: `X ` - - pair of sequences: `A B ` - - Args: - token_ids_0 (`List[int]`): - List of IDs to which the special tokens will be added. - token_ids_1 (`List[int]`, *optional*): - Optional second list of IDs for sequence pairs. - - Returns: - `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. - """ - token_ids_0 = self._add_eos_if_not_present(token_ids_0) - if token_ids_1 is None: - return token_ids_0 - else: - token_ids_1 = self._add_eos_if_not_present(token_ids_1) - return token_ids_0 + token_ids_1 - - def __getstate__(self): - state = self.__dict__.copy() - state["sp_model"] = None - return state - - def __setstate__(self, d): - self.__dict__ = d - - # for backward compatibility - if not hasattr(self, "sp_model_kwargs"): - self.sp_model_kwargs = {} - - self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) - self.sp_model.Load(self.vocab_file) - - def _tokenize(self, text: str) -> List[str]: - """Take as input a string and return a list of strings (tokens) for words/sub-words""" - return self.sp_model.encode(text, out_type=str) - - def _convert_token_to_id(self, token): - """Converts a token (str) in an id using the vocab.""" - if token.startswith("", token) - num = int(match.group(1)) - return self.vocab_size - num - 1 - return self.sp_model.piece_to_id(token) - - def _convert_id_to_token(self, index): - """Converts an index (integer) in a token (str) using the vocab.""" - if index < self.sp_model.get_piece_size(): - token = self.sp_model.IdToPiece(index) - else: - token = f"" - return token - - def convert_tokens_to_string(self, tokens): - """Converts a sequence of tokens (string) in a single string.""" - current_sub_tokens = [] - out_string = "" - for token in tokens: - # make sure that special tokens are not decoded using sentencepiece model - if token in self.all_special_tokens: - out_string += self.sp_model.decode_pieces(current_sub_tokens) + token + " " - current_sub_tokens = [] - else: - current_sub_tokens.append(token) - out_string += self.sp_model.decode_pieces(current_sub_tokens) - return out_string.strip() - - def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: - if not os.path.isdir(save_directory): - logger.error(f"Vocabulary path ({save_directory}) should be a directory") - return - out_vocab_file = os.path.join( - save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] - ) - - if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): - copyfile(self.vocab_file, out_vocab_file) - elif not os.path.isfile(self.vocab_file): - with open(out_vocab_file, "wb") as fi: - content_spiece_model = self.sp_model.serialized_model_proto() - fi.write(content_spiece_model) - - return (out_vocab_file,) diff --git a/src/transformers/models/switch_transformers/tokenization_switch_transformers_fast.py b/src/transformers/models/switch_transformers/tokenization_switch_transformers_fast.py deleted file mode 100644 index c1a9ca9d44070..0000000000000 --- a/src/transformers/models/switch_transformers/tokenization_switch_transformers_fast.py +++ /dev/null @@ -1,260 +0,0 @@ -# coding=utf-8 -# Copyright 2022 SwitchTransformers Authors and HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" Tokenization class for model SwitchTransformers.""" - - -import os -import warnings -from shutil import copyfile -from typing import List, Optional, Tuple - -from ...tokenization_utils_fast import PreTrainedTokenizerFast -from ...utils import is_sentencepiece_available, logging - - -if is_sentencepiece_available(): - from .tokenization_switch_transformers import SwitchTransformersTokenizer -else: - SwitchTransformersTokenizer = None - - -logger = logging.get_logger(__name__) - -VOCAB_FILES_NAMES = {"vocab_file": "spiece.model", "tokenizer_file": "tokenizer.json"} - -# PRETRAINED_VOCAB_FILES_MAP = { -# "vocab_file": { -# "ybelkada/switch_transformers-base": ( -# "https://huggingface.co/ybelkada/switch_transformers-base/resolve/main/spiece.model" -# ), -# "switch_transformers-base": "https://huggingface.co/switch_transformers-base/resolve/main/spiece.model", -# "switch_transformers-large": "https://huggingface.co/switch_transformers-large/resolve/main/spiece.model", -# "switch_transformers-3b": "https://huggingface.co/switch_transformers-3b/resolve/main/spiece.model", -# "switch_transformers-11b": "https://huggingface.co/switch_transformers-11b/resolve/main/spiece.model", -# }, -# "tokenizer_file": { -# "ybelkada/switch_transformers-base": ( -# "https://huggingface.co/ybelkada/switch_transformers-base/resolve/main/tokenizer.json" -# ), -# "switch_transformers-base": "https://huggingface.co/switch_transformers-base/resolve/main/tokenizer.json", -# "switch_transformers-large": "https://huggingface.co/switch_transformers-large/resolve/main/tokenizer.json", -# "switch_transformers-3b": "https://huggingface.co/switch_transformers-3b/resolve/main/tokenizer.json", -# "switch_transformers-11b": "https://huggingface.co/switch_transformers-11b/resolve/main/tokenizer.json", -# }, -# } - - -# # TODO(PVP) - this should be removed in Transformers v5 -# PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { -# "ybelkada/switch_transformers-base": 512, -# "switch_transformers-base": 512, -# "switch_transformers-large": 512, -# "switch_transformers-3b": 512, -# "switch_transformers-11b": 512, -# } - -PRETRAINED_VOCAB_FILES_MAP = { - "vocab_file": { - "HFLAY/switch_base_8": "https://huggingface.co/HFLAY/switch_base_8/blob/main/spiece.model", - }, - "tokenizer_file": { - "HFLAY/switch_base_8": "https://huggingface.co/HFLAY/switch_base_8/blob/main/tokenizer.json", - }, -} - - -# TODO(PVP) - this should be removed in Transformers v5 -PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { - "HFLAY/switch_base_8": 512, -} - - -class SwitchTransformersTokenizerFast(PreTrainedTokenizerFast): - """ - Construct a "fast" SwitchTransformers tokenizer (backed by HuggingFace's *tokenizers* library). Based on - [Unigram](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=unigram#models). - - This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should - refer to this superclass for more information regarding those methods. - - Args: - vocab_file (`str`): - [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that - contains the vocabulary necessary to instantiate a tokenizer. - eos_token (`str`, *optional*, defaults to `""`): - The end of sequence token. - - - - When building a sequence using special tokens, this is not the token that is used for the end of sequence. - The token used is the `sep_token`. - - - - unk_token (`str`, *optional*, defaults to `""`): - The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this - token instead. - pad_token (`str`, *optional*, defaults to `""`): - The token used for padding, for example when batching sequences of different lengths. - extra_ids (`int`, *optional*, defaults to 100): - Add a number of extra ids added to the end of the vocabulary for use as sentinels. These tokens are - accessible as "" where "{%d}" is a number between 0 and extra_ids-1. Extra tokens are - indexed from the end of the vocabulary up to beginning ("" is the last token in the vocabulary - like in SwitchTransformers preprocessing see - [here](https://github.com/google-research/text-to-text-transfer-transformer/blob/9fd7b14a769417be33bc6c850f9598764913c833/switch_transformers/data/preprocessors.py#L2117)). - additional_special_tokens (`List[str]`, *optional*): - Additional special tokens used by the tokenizer. - """ - - vocab_files_names = VOCAB_FILES_NAMES - pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP - max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES - model_input_names = ["input_ids", "attention_mask"] - slow_tokenizer_class = SwitchTransformersTokenizer - - prefix_tokens: List[int] = [] - - def __init__( - self, - vocab_file=None, - tokenizer_file=None, - eos_token="", - unk_token="", - pad_token="", - extra_ids=100, - additional_special_tokens=None, - **kwargs - ): - # Add extra_ids to the special token list - if extra_ids > 0 and additional_special_tokens is None: - additional_special_tokens = [f"" for i in range(extra_ids)] - elif extra_ids > 0 and additional_special_tokens is not None: - # Check that we have the right number of extra special tokens - extra_tokens = len(set(filter(lambda x: bool("extra_id_" in str(x)), additional_special_tokens))) - if extra_tokens != extra_ids: - raise ValueError( - f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are" - " provided to SwitchTransformersTokenizer. In this case the additional_special_tokens must" - " include the extra_ids tokens" - ) - - super().__init__( - vocab_file, - tokenizer_file=tokenizer_file, - eos_token=eos_token, - unk_token=unk_token, - pad_token=pad_token, - extra_ids=extra_ids, - additional_special_tokens=additional_special_tokens, - **kwargs, - ) - - self.vocab_file = vocab_file - self.can_save_slow_tokenizer = False if not self.vocab_file else True - self._extra_ids = extra_ids - - @staticmethod - def _eventually_correct_switch_transformers_max_length( - pretrained_model_name_or_path, max_model_length, init_max_model_length - ): - if pretrained_model_name_or_path in SwitchTransformersTokenizerFast.max_model_input_sizes: - deprecated_max_model_length = SwitchTransformersTokenizerFast.max_model_input_sizes[ - pretrained_model_name_or_path - ] - if init_max_model_length is not None and init_max_model_length != max_model_length: - return init_max_model_length - elif init_max_model_length is None: - warnings.warn( - "This tokenizer was incorrectly instantiated with a model max length of" - f" {deprecated_max_model_length} which will be corrected in Transformers v5.\nFor now, this" - " behavior is kept to avoid breaking backwards compatibility when padding/encoding with" - " `truncation is True`.\n- Be aware that you SHOULD NOT rely on" - f" {pretrained_model_name_or_path} automatically truncating your input to" - f" {deprecated_max_model_length} when padding/encoding.\n- If you want to encode/pad to sequences" - f" longer than {deprecated_max_model_length} you can either instantiate this tokenizer with" - " `model_max_length` or pass `max_length` when encoding/padding.\n- To avoid this warning, please" - " instantiate this tokenizer with `model_max_length` set to your preferred value.", - FutureWarning, - ) - - return max_model_length - - def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: - if not self.can_save_slow_tokenizer: - raise ValueError( - "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow " - "tokenizer." - ) - - if not os.path.isdir(save_directory): - logger.error(f"Vocabulary path ({save_directory}) should be a directory") - return - out_vocab_file = os.path.join( - save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] - ) - - if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): - copyfile(self.vocab_file, out_vocab_file) - logger.info(f"Copy vocab file to {out_vocab_file}") - - return (out_vocab_file,) - - def build_inputs_with_special_tokens( - self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None - ) -> List[int]: - """ - Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and - adding special tokens. A sequence has the following format: - - - single sequence: `X ` - - pair of sequences: `A B ` - - Args: - token_ids_0 (`List[int]`): - List of IDs to which the special tokens will be added. - token_ids_1 (`List[int]`, *optional*): - Optional second list of IDs for sequence pairs. - - Returns: - `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. - """ - token_ids_0 = token_ids_0 + [self.eos_token_id] - if token_ids_1 is None: - return self.prefix_tokens + token_ids_0 - else: - token_ids_1 = token_ids_1 + [self.eos_token_id] - return self.prefix_tokens + token_ids_0 + token_ids_1 - - def create_token_type_ids_from_sequences( - self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None - ) -> List[int]: - """ - Create a mask from the two sequences passed to be used in a sequence-pair classification task. - SwitchTransformers does not make use of token type ids, therefore a list of zeros is returned. - - Args: - token_ids_0 (`List[int]`): - List of IDs. - token_ids_1 (`List[int]`, *optional*): - Optional second list of IDs for sequence pairs. - - Returns: - `List[int]`: List of zeros. - """ - eos = [self.eos_token_id] - - if token_ids_1 is None: - return len(token_ids_0 + eos) * [0] - return len(token_ids_0 + eos + token_ids_1 + eos) * [0] diff --git a/tests/models/switch_transformers/test_tokenization_switch_transformers.py b/tests/models/switch_transformers/test_tokenization_switch_transformers.py deleted file mode 100644 index 6ca18416a7d35..0000000000000 --- a/tests/models/switch_transformers/test_tokenization_switch_transformers.py +++ /dev/null @@ -1,387 +0,0 @@ -# coding=utf-8 -# Copyright 2022 Google SwitchTransformers Authors and HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import json -import os -import tempfile -import unittest - -from transformers import ( - SPIECE_UNDERLINE, - AddedToken, - BatchEncoding, - SwitchTransformersTokenizer, - SwitchTransformersTokenizerFast, -) -from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers, slow -from transformers.utils import cached_property, is_tf_available, is_torch_available - -from ...test_tokenization_common import TokenizerTesterMixin - - -SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model") - -if is_torch_available(): - FRAMEWORK = "pt" -elif is_tf_available(): - FRAMEWORK = "tf" -else: - FRAMEWORK = "jax" - - -@require_sentencepiece -@require_tokenizers -class SwitchTransformersTokenizationTest(TokenizerTesterMixin, unittest.TestCase): - - tokenizer_class = SwitchTransformersTokenizer - rust_tokenizer_class = SwitchTransformersTokenizerFast - test_rust_tokenizer = True - test_sentencepiece = True - - def setUp(self): - super().setUp() - - # We have a SentencePiece fixture for testing - tokenizer = SwitchTransformersTokenizer(SAMPLE_VOCAB) - tokenizer.save_pretrained(self.tmpdirname) - - def test_convert_token_and_id(self): - """Test ``_convert_token_to_id`` and ``_convert_id_to_token``.""" - token = "" - token_id = 1 - - self.assertEqual(self.get_tokenizer()._convert_token_to_id(token), token_id) - self.assertEqual(self.get_tokenizer()._convert_id_to_token(token_id), token) - - def test_get_vocab(self): - vocab_keys = list(self.get_tokenizer().get_vocab().keys()) - - self.assertEqual(vocab_keys[0], "") - self.assertEqual(vocab_keys[1], "") - self.assertEqual(vocab_keys[-1], "") - self.assertEqual(len(vocab_keys), 1_101) - - def test_vocab_size(self): - self.assertEqual(self.get_tokenizer().vocab_size, 1_100) - - def test_full_tokenizer(self): - tokenizer = SwitchTransformersTokenizer(SAMPLE_VOCAB) - - tokens = tokenizer.tokenize("This is a test") - self.assertListEqual(tokens, ["▁This", "▁is", "▁a", "▁t", "est"]) - - self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382]) - - tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.") - self.assertListEqual( - tokens, - [ - SPIECE_UNDERLINE + "I", - SPIECE_UNDERLINE + "was", - SPIECE_UNDERLINE + "b", - "or", - "n", - SPIECE_UNDERLINE + "in", - SPIECE_UNDERLINE + "", - "9", - "2", - "0", - "0", - "0", - ",", - SPIECE_UNDERLINE + "and", - SPIECE_UNDERLINE + "this", - SPIECE_UNDERLINE + "is", - SPIECE_UNDERLINE + "f", - "al", - "s", - "é", - ".", - ], - ) - ids = tokenizer.convert_tokens_to_ids(tokens) - self.assertListEqual(ids, [8, 21, 84, 55, 24, 19, 7, 0, 602, 347, 347, 347, 3, 12, 66, 46, 72, 80, 6, 0, 4]) - - back_tokens = tokenizer.convert_ids_to_tokens(ids) - self.assertListEqual( - back_tokens, - [ - SPIECE_UNDERLINE + "I", - SPIECE_UNDERLINE + "was", - SPIECE_UNDERLINE + "b", - "or", - "n", - SPIECE_UNDERLINE + "in", - SPIECE_UNDERLINE + "", - "", - "2", - "0", - "0", - "0", - ",", - SPIECE_UNDERLINE + "and", - SPIECE_UNDERLINE + "this", - SPIECE_UNDERLINE + "is", - SPIECE_UNDERLINE + "f", - "al", - "s", - "", - ".", - ], - ) - - @cached_property - def switch_transformers_base_tokenizer(self): - return SwitchTransformersTokenizer.from_pretrained("HFLAY/switch_base_8") - - @cached_property - def switch_transformers_base_tokenizer_fast(self): - return SwitchTransformersTokenizerFast.from_pretrained("HFLAY/switch_base_8") - - def get_tokenizer(self, **kwargs) -> SwitchTransformersTokenizer: - return self.tokenizer_class.from_pretrained(self.tmpdirname, pad_token=None, **kwargs) - - def get_rust_tokenizer(self, **kwargs) -> SwitchTransformersTokenizerFast: - return self.rust_tokenizer_class.from_pretrained(self.tmpdirname, pad_token=None, **kwargs) - - def test_rust_and_python_full_tokenizers(self): - if not self.test_rust_tokenizer: - return - - tokenizer = self.get_tokenizer() - rust_tokenizer = self.get_rust_tokenizer() - - sequence = "I was born in 92000, and this is falsé." - - tokens = tokenizer.tokenize(sequence) - rust_tokens = rust_tokenizer.tokenize(sequence) - self.assertListEqual(tokens, rust_tokens) - - ids = tokenizer.encode(sequence, add_special_tokens=False) - rust_ids = rust_tokenizer.encode(sequence, add_special_tokens=False) - self.assertListEqual(ids, rust_ids) - - rust_tokenizer = self.get_rust_tokenizer() - ids = tokenizer.encode(sequence) - rust_ids = rust_tokenizer.encode(sequence) - self.assertListEqual(ids, rust_ids) - - def test_eos_treatment(self): - tokenizer = self.switch_transformers_base_tokenizer - batch_with_eos_added = tokenizer(["hi", "I went to the gym", ""]) - batch_without_eos_added = tokenizer(["hi", "I went to the gym", ""]) - self.assertListEqual(batch_with_eos_added["input_ids"], batch_without_eos_added["input_ids"]) - - def test_prepare_batch(self): - tokenizer = self.switch_transformers_base_tokenizer - src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."] - expected_src_tokens = [71, 307, 8986, 21, 4505, 1635, 1707, 5, tokenizer.eos_token_id] - batch = tokenizer(src_text, padding=True, return_tensors=FRAMEWORK) - self.assertIsInstance(batch, BatchEncoding) - - if FRAMEWORK != "jax": - result = list(batch.input_ids.numpy()[0]) - else: - result = list(batch.input_ids.tolist()[0]) - - self.assertListEqual(expected_src_tokens, result) - - self.assertEqual((2, 9), batch.input_ids.shape) - self.assertEqual((2, 9), batch.attention_mask.shape) - - def test_empty_target_text(self): - tokenizer = self.switch_transformers_base_tokenizer - src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."] - batch = tokenizer(src_text, padding=True, return_tensors=FRAMEWORK) - # check if input_ids are returned and no decoder_input_ids - self.assertIn("input_ids", batch) - self.assertIn("attention_mask", batch) - self.assertNotIn("decoder_input_ids", batch) - self.assertNotIn("decoder_attention_mask", batch) - - def test_max_length(self): - tokenizer = self.switch_transformers_base_tokenizer - tgt_text = [ - "Summary of the text.", - "Another summary.", - ] - targets = tokenizer( - text_target=tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors=FRAMEWORK - ) - self.assertEqual(32, targets["input_ids"].shape[1]) - - def test_outputs_not_longer_than_maxlen(self): - tokenizer = self.switch_transformers_base_tokenizer - - batch = tokenizer( - ["I am a small frog" * 1000, "I am a small frog"], padding=True, truncation=True, return_tensors=FRAMEWORK - ) - self.assertIsInstance(batch, BatchEncoding) - # Since SwitchTransformers does NOT have a max input length, - # this test should be changed to the following in Transformers v5: - # self.assertEqual(batch.input_ids.shape, (2, 8001)) - self.assertEqual(batch.input_ids.shape, (2, 512)) - - def test_eos_in_input(self): - tokenizer = self.switch_transformers_base_tokenizer - src_text = ["A long paragraph for summarization. "] - tgt_text = ["Summary of the text. "] - expected_src_tokens = [71, 307, 8986, 21, 4505, 1635, 1707, 5, 1] - expected_tgt_tokens = [20698, 13, 8, 1499, 5, 1] - - batch = tokenizer(src_text, text_target=tgt_text) - - self.assertEqual(expected_src_tokens, batch["input_ids"][0]) - self.assertEqual(expected_tgt_tokens, batch["labels"][0]) - - def test_token_type_ids(self): - src_text_1 = ["A first paragraph for summarization."] - src_text_2 = ["A second paragraph for summarization."] - - fast_token_type_ids = self.switch_transformers_base_tokenizer_fast( - src_text_1, src_text_2, add_special_tokens=True, return_token_type_ids=True - ).token_type_ids - slow_token_type_ids = self.switch_transformers_base_tokenizer( - src_text_1, src_text_2, add_special_tokens=True, return_token_type_ids=True - ).token_type_ids - - self.assertEqual(slow_token_type_ids, fast_token_type_ids) - self.assertEqual(len(slow_token_type_ids[0]), 18) - - def test_fast_and_slow_same_result(self): - src_text = " Today is nice day " - tgt_ids = [0, 1960, 19, 2, 1245, 239, 1] - tgt_text = " Today is nice day" - - fast_ids = self.switch_transformers_base_tokenizer_fast(src_text, add_special_tokens=False).input_ids - slow_ids = self.switch_transformers_base_tokenizer(src_text, add_special_tokens=False).input_ids - self.assertEqual(tgt_ids, fast_ids) - self.assertEqual(tgt_ids, slow_ids) - - fast_text = self.switch_transformers_base_tokenizer_fast.decode(fast_ids) - slow_text = self.switch_transformers_base_tokenizer.decode(fast_ids) - self.assertEqual(tgt_text, fast_text) - self.assertEqual(tgt_text, slow_text) - - def test_special_tokens_initialization(self): - for tokenizer, pretrained_name, kwargs in self.tokenizers_list: - with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): - - added_tokens = [f"" for i in range(100)] + [AddedToken("", lstrip=True)] - - tokenizer_r = self.rust_tokenizer_class.from_pretrained( - pretrained_name, additional_special_tokens=added_tokens, **kwargs - ) - tokenizer_cr = self.rust_tokenizer_class.from_pretrained( - pretrained_name, additional_special_tokens=added_tokens, **kwargs, from_slow=True - ) - tokenizer_p = self.tokenizer_class.from_pretrained( - pretrained_name, additional_special_tokens=added_tokens, **kwargs - ) - - p_output = tokenizer_p.encode("Hey this is a token") - r_output = tokenizer_r.encode("Hey this is a token") - cr_output = tokenizer_cr.encode("Hey this is a token") - - special_token_id = tokenizer_r.encode("", add_special_tokens=False)[0] - - self.assertEqual(p_output, r_output) - self.assertEqual(cr_output, r_output) - self.assertTrue(special_token_id in p_output) - self.assertTrue(special_token_id in r_output) - self.assertTrue(special_token_id in cr_output) - - def test_special_tokens_initialization_with_non_empty_additional_special_tokens(self): - tokenizer_list = [] - if self.test_slow_tokenizer: - tokenizer_list.append((self.tokenizer_class, self.get_tokenizer())) - - if self.test_rust_tokenizer: - tokenizer_list.append((self.rust_tokenizer_class, self.get_rust_tokenizer())) - - for tokenizer_class, tokenizer_utils in tokenizer_list: - - with tempfile.TemporaryDirectory() as tmp_dir: - tokenizer_utils.save_pretrained(tmp_dir) - - with open(os.path.join(tmp_dir, "special_tokens_map.json"), encoding="utf-8") as json_file: - special_tokens_map = json.load(json_file) - - with open(os.path.join(tmp_dir, "tokenizer_config.json"), encoding="utf-8") as json_file: - tokenizer_config = json.load(json_file) - - added_tokens_extra_ids = [f"" for i in range(100)] - - special_tokens_map["additional_special_tokens"] = added_tokens_extra_ids + [ - "an_additional_special_token" - ] - tokenizer_config["additional_special_tokens"] = added_tokens_extra_ids + [ - "an_additional_special_token" - ] - - with open(os.path.join(tmp_dir, "special_tokens_map.json"), "w", encoding="utf-8") as outfile: - json.dump(special_tokens_map, outfile) - with open(os.path.join(tmp_dir, "tokenizer_config.json"), "w", encoding="utf-8") as outfile: - json.dump(tokenizer_config, outfile) - - # the following checks allow us to verify that our test works as expected, i.e. that the tokenizer takes - # into account the new value of additional_special_tokens given in the "tokenizer_config.json" and - # "special_tokens_map.json" files - tokenizer_without_change_in_init = tokenizer_class.from_pretrained( - tmp_dir, - ) - self.assertIn( - "an_additional_special_token", tokenizer_without_change_in_init.additional_special_tokens - ) - # self.assertIn("an_additional_special_token",tokenizer_without_change_in_init.get_vocab()) # BySwitchTransformersTokenization no vocab - self.assertEqual( - ["an_additional_special_token"], - tokenizer_without_change_in_init.convert_ids_to_tokens( - tokenizer_without_change_in_init.convert_tokens_to_ids(["an_additional_special_token"]) - ), - ) - - # Now we test that we can change the value of additional_special_tokens in the from_pretrained - new_added_tokens = added_tokens_extra_ids + [AddedToken("a_new_additional_special_token", lstrip=True)] - tokenizer = tokenizer_class.from_pretrained( - tmp_dir, - additional_special_tokens=new_added_tokens, - ) - - self.assertIn("a_new_additional_special_token", tokenizer.additional_special_tokens) - self.assertEqual( - ["a_new_additional_special_token"], - tokenizer.convert_ids_to_tokens( - tokenizer.convert_tokens_to_ids(["a_new_additional_special_token"]) - ), - ) - - # overwritten from `test_tokenization_common` since SwitchTransformers has no max length - def test_pretrained_model_lists(self): - # We should have at least one default checkpoint for each tokenizer - # We should specify the max input length as well (used in some part to list the pretrained checkpoints) - self.assertGreaterEqual(len(self.tokenizer_class.pretrained_vocab_files_map), 1) - self.assertGreaterEqual(len(list(self.tokenizer_class.pretrained_vocab_files_map.values())[0]), 1) - - @slow - def test_tokenizer_integration(self): - # fmt: off - expected_encoding = {'input_ids': [[31220, 7, 41, 14034, 801, 38, 3, 102, 63, 17, 127, 524, 18, 7031, 2032, 277, 11, 3, 102, 63, 17, 127, 524, 18, 2026, 17, 10761, 18, 7041, 61, 795, 879, 18, 19681, 4648, 7, 41, 12920, 382, 6, 350, 6383, 4949, 6, 2158, 12920, 382, 9, 6, 3, 4, 11160, 6, 2043, 17153, 279, 49, 17, 6, 3, 4, 434, 9688, 11439, 21, 6869, 10509, 17725, 41, 567, 9138, 61, 11, 6869, 10509, 11946, 41, 18207, 517, 61, 28, 147, 3538, 1220, 7140, 10761, 2250, 16, 910, 1220, 8024, 11, 1659, 1413, 32, 883, 2020, 344, 2215, 226, 6, 12901, 382, 127, 524, 11, 4738, 7, 127, 15390, 5, 1], [272, 24203, 19, 876, 12, 554, 18, 9719, 1659, 2647, 26352, 6497, 7, 45, 73, 9339, 400, 26, 1499, 57, 22801, 10760, 30, 321, 646, 11, 269, 2625, 16, 66, 7500, 5, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [37, 1704, 4216, 3, 20400, 4418, 7, 147, 8, 19743, 1782, 5, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]} # noqa: E501 - # fmt: on - - self.tokenizer_integration_test_util( - expected_encoding=expected_encoding, - model_name="switch_base_8", - revision="5a7ff2d8f5117c194c7e32ec1ccbf04642cca99b", - ) From 7c0fa4b584ee029394b6e759bb659aca65e2a706 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 28 Oct 2022 08:40:43 +0000 Subject: [PATCH 050/102] fix documentation, clean tokenizer --- docs/source/en/index.mdx | 2 +- src/transformers/__init__.py | 4 --- .../models/switch_transformers/__init__.py | 32 ------------------- .../modeling_switch_transformers.py | 29 +++++++++-------- .../utils/dummy_sentencepiece_objects.py | 7 ---- .../utils/dummy_tokenizers_objects.py | 7 ---- 6 files changed, 16 insertions(+), 65 deletions(-) diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index 863c3ac281044..2376672dc0f30 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -313,7 +313,7 @@ Flax), PyTorch, and/or TensorFlow. | SqueezeBERT | ✅ | ✅ | ✅ | ❌ | ❌ | | Swin Transformer | ❌ | ❌ | ✅ | ✅ | ❌ | | Swin Transformer V2 | ❌ | ❌ | ✅ | ❌ | ❌ | -| SwitchTransformers | ✅ | ✅ | ✅ | ❌ | ❌ | +| SwitchTransformers | ❌ | ❌ | ✅ | ❌ | ❌ | | T5 | ✅ | ✅ | ✅ | ✅ | ✅ | | TAPAS | ✅ | ❌ | ✅ | ✅ | ❌ | | Time Series Transformer | ❌ | ❌ | ✅ | ❌ | ❌ | diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 8f438933c2cfa..bcebe961fa5fa 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -544,7 +544,6 @@ _import_structure["models.rembert"].append("RemBertTokenizer") _import_structure["models.speech_to_text"].append("Speech2TextTokenizer") _import_structure["models.t5"].append("T5Tokenizer") - _import_structure["models.switch_transformers"].append("SwitchTransformersTokenizer") _import_structure["models.xglm"].append("XGLMTokenizer") _import_structure["models.xlm_prophetnet"].append("XLMProphetNetTokenizer") _import_structure["models.xlm_roberta"].append("XLMRobertaTokenizer") @@ -613,7 +612,6 @@ _import_structure["models.roformer"].append("RoFormerTokenizerFast") _import_structure["models.splinter"].append("SplinterTokenizerFast") _import_structure["models.squeezebert"].append("SqueezeBertTokenizerFast") - _import_structure["models.switch_transformers"].append("SwitchTransformersTokenizerFast") _import_structure["models.t5"].append("T5TokenizerFast") _import_structure["models.xglm"].append("XGLMTokenizerFast") _import_structure["models.xlm_roberta"].append("XLMRobertaTokenizerFast") @@ -3554,7 +3552,6 @@ from .models.reformer import ReformerTokenizer from .models.rembert import RemBertTokenizer from .models.speech_to_text import Speech2TextTokenizer - from .models.switch_transformers import SwitchTransformersTokenizer from .models.t5 import T5Tokenizer from .models.xglm import XGLMTokenizer from .models.xlm_prophetnet import XLMProphetNetTokenizer @@ -3617,7 +3614,6 @@ from .models.roformer import RoFormerTokenizerFast from .models.splinter import SplinterTokenizerFast from .models.squeezebert import SqueezeBertTokenizerFast - from .models.switch_transformers import SwitchTransformersTokenizerFast from .models.t5 import T5TokenizerFast from .models.xglm import XGLMTokenizerFast from .models.xlm_roberta import XLMRobertaTokenizerFast diff --git a/src/transformers/models/switch_transformers/__init__.py b/src/transformers/models/switch_transformers/__init__.py index e6fc32117cad9..1c5acf82b29ab 100644 --- a/src/transformers/models/switch_transformers/__init__.py +++ b/src/transformers/models/switch_transformers/__init__.py @@ -37,22 +37,6 @@ ] } -try: - if not is_sentencepiece_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - pass -else: - _import_structure["tokenization_switch_transformers"] = ["SwitchTransformersTokenizer"] - -try: - if not is_tokenizers_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - pass -else: - _import_structure["tokenization_switch_transformers_fast"] = ["SwitchTransformersTokenizerFast"] - try: if not is_torch_available(): raise OptionalDependencyNotAvailable() @@ -75,22 +59,6 @@ SwitchTransformersOnnxConfig, ) - try: - if not is_sentencepiece_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - pass - else: - from .tokenization_switch_transformers import SwitchTransformersTokenizer - - try: - if not is_tokenizers_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - pass - else: - from .tokenization_switch_transformers_fast import SwitchTransformersTokenizerFast - try: if not is_torch_available(): raise OptionalDependencyNotAvailable() diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index c6667f88d0011..87061dc8eada1 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -70,6 +70,7 @@ # See all SwitchTransformers models at https://huggingface.co/models?filter=switch_transformers ] + # Router loss def router_z_loss_func(router_logits: torch.Tensor) -> float: r""" @@ -163,13 +164,13 @@ class SwitchTransformersTop1Router(nn.Module): Parameters: num_selected_experts (`int`): - Maximum number of experts to which each token is routed. Tokens may be routed to fewer experts if particular - experts are oversubscribed / reach capacity. + Maximum number of experts to which each token is routed. Tokens may be routed to fewer experts if + particular experts are oversubscribed / reach capacity. batch_prioritized_routing (`bool`): Whether or not to use Batch Prioritized Routing (BPR), originally introduced in V-MoE - (https://arxiv.org/abs/2106.05974). With BPR, we prioritize routing those top-k tokens with the highest router - probability, rather than simply using each tokens left-to-right ordering in the batch. This prioritization is - important because the experts have limited capacity. + (https://arxiv.org/abs/2106.05974). With BPR, we prioritize routing those top-k tokens with the highest + router probability, rather than simply using each tokens left-to-right ordering in the batch. This + prioritization is important because the experts have limited capacity. """ def __init__(self, config, **kwargs): @@ -1228,7 +1229,7 @@ def custom_forward(*inputs): Indices of input sequence tokens in the vocabulary. SWITCH_TRANSFORMERS is a model with relative position embeddings so you should be able to pad the inputs on both the right and the left. - Indices can be obtained using [`SwitchTransformersTokenizer`]. See [`PreTrainedTokenizer.encode`] and + Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for detail. [What are input IDs?](../glossary#input-ids) @@ -1245,7 +1246,7 @@ def custom_forward(*inputs): decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): Indices of decoder input sequence tokens in the vocabulary. - Indices can be obtained using [`SwitchTransformersTokenizer`]. See [`PreTrainedTokenizer.encode`] and + Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are decoder input IDs?](../glossary#decoder-input-ids) @@ -1326,7 +1327,7 @@ def custom_forward(*inputs): Indices of input sequence tokens in the vocabulary. SWITCH_TRANSFORMERS is a model with relative position embeddings so you should be able to pad the inputs on both the right and the left. - Indices can be obtained using [`SwitchTransformersTokenizer`]. See [`PreTrainedTokenizer.encode`] and + Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for detail. To know more on how to prepare `input_ids` for pretraining take a look a [SWITCH_TRANSFORMERS @@ -1450,9 +1451,9 @@ def forward( Example: ```python - >>> from transformers import SwitchTransformersTokenizer, SwitchTransformersModel + >>> from transformers import T5Tokenizer, SwitchTransformersModel - >>> tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switch_transformers-base") + >>> tokenizer = T5Tokenizer.from_pretrained("ybelkada/switch_transformers-base") >>> model = SwitchTransformersModel.from_pretrained("ybelkada/switch_transformers-base") >>> input_ids = tokenizer( @@ -1634,9 +1635,9 @@ def forward( Examples: ```python - >>> from transformers import SwitchTransformersTokenizer, SwitchTransformersForConditionalGeneration + >>> from transformers import T5Tokenizer, SwitchTransformersForConditionalGeneration - >>> tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switch_transformers-base") + >>> tokenizer = T5Tokenizer.from_pretrained("ybelkada/switch_transformers-base") >>> model = SwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switch_transformers-base") >>> # training @@ -1904,9 +1905,9 @@ def forward( Example: ```python - >>> from transformers import SwitchTransformersTokenizer, SwitchTransformersEncoderModel + >>> from transformers import T5Tokenizer, SwitchTransformersEncoderModel - >>> tokenizer = SwitchTransformersTokenizer.from_pretrained("ybelkada/switch_transformers-base") + >>> tokenizer = T5Tokenizer.from_pretrained("ybelkada/switch_transformers-base") >>> model = SwitchTransformersEncoderModel.from_pretrained("ybelkada/switch_transformers-base") >>> input_ids = tokenizer( ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" diff --git a/src/transformers/utils/dummy_sentencepiece_objects.py b/src/transformers/utils/dummy_sentencepiece_objects.py index c73567d7fdc30..69f0bdcb7b1aa 100644 --- a/src/transformers/utils/dummy_sentencepiece_objects.py +++ b/src/transformers/utils/dummy_sentencepiece_objects.py @@ -157,13 +157,6 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["sentencepiece"]) -class SwitchTransformersTokenizer(metaclass=DummyObject): - _backends = ["sentencepiece"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["sentencepiece"]) - - class T5Tokenizer(metaclass=DummyObject): _backends = ["sentencepiece"] diff --git a/src/transformers/utils/dummy_tokenizers_objects.py b/src/transformers/utils/dummy_tokenizers_objects.py index 48695eefde1d4..8a24d9bea6b2c 100644 --- a/src/transformers/utils/dummy_tokenizers_objects.py +++ b/src/transformers/utils/dummy_tokenizers_objects.py @@ -360,13 +360,6 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["tokenizers"]) -class SwitchTransformersTokenizerFast(metaclass=DummyObject): - _backends = ["tokenizers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["tokenizers"]) - - class T5TokenizerFast(metaclass=DummyObject): _backends = ["tokenizers"] From 57acea7bd109e1a03ebb66e7babafc21edba898f Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 28 Oct 2022 08:44:17 +0000 Subject: [PATCH 051/102] more doc fix, cleanup example_switch --- example_switch.py | 15 --------------- src/transformers/convert_slow_tokenizer.py | 1 - src/transformers/models/auto/tokenization_auto.py | 7 ------- utils/documentation_tests.txt | 1 + 4 files changed, 1 insertion(+), 23 deletions(-) delete mode 100644 example_switch.py diff --git a/example_switch.py b/example_switch.py deleted file mode 100644 index f5a9256c2448c..0000000000000 --- a/example_switch.py +++ /dev/null @@ -1,15 +0,0 @@ -from transformers import AutoTokenizer, SwitchTransformersForConditionalGeneration - -tokenizer = AutoTokenizer.from_pretrained("t5-small") -text = "A walks into a bar a orders a with pinch of ." -model = SwitchTransformersForConditionalGeneration.from_pretrained("HFLAY/switch_base_8") - -input_ids = tokenizer(text, return_tensors="pt").input_ids -out = model.generate(input_ids, decoder_start_token_id=0, output_router_logits=True) -print(tokenizer.decode(out[0])) - -# Loss - -input_ids = tokenizer("The walks in park", return_tensors="pt").input_ids -labels = tokenizer(" cute dog the ", return_tensors="pt").input_ids -outputs = model(input_ids=input_ids, labels=labels, output_router_logits=True) \ No newline at end of file diff --git a/src/transformers/convert_slow_tokenizer.py b/src/transformers/convert_slow_tokenizer.py index 6520414e3821e..ce52ba3b3beba 100644 --- a/src/transformers/convert_slow_tokenizer.py +++ b/src/transformers/convert_slow_tokenizer.py @@ -1127,7 +1127,6 @@ def converted(self) -> Tokenizer: "RoFormerTokenizer": RoFormerConverter, "SqueezeBertTokenizer": BertConverter, "T5Tokenizer": T5Converter, - "SwitchTransformersTokenizer": T5Converter, "XLMRobertaTokenizer": XLMRobertaConverter, "XLNetTokenizer": XLNetConverter, "SplinterTokenizer": SplinterConverter, diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 13194b116dd4c..e29a5b19ddc07 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -241,13 +241,6 @@ "squeezebert", ("SqueezeBertTokenizer", "SqueezeBertTokenizerFast" if is_tokenizers_available() else None), ), - ( - "switch_transformers", - ( - "SwitchTransformersTokenizer" if is_sentencepiece_available() else None, - "SwitchTransformersTokenizerFast" if is_tokenizers_available() else None, - ), - ), ( "t5", ( diff --git a/utils/documentation_tests.txt b/utils/documentation_tests.txt index 7f28f750725c4..a50841f4a8751 100644 --- a/utils/documentation_tests.txt +++ b/utils/documentation_tests.txt @@ -5,6 +5,7 @@ docs/source/en/autoclass_tutorial.mdx docs/source/en/task_summary.mdx docs/source/en/model_doc/markuplm.mdx docs/source/en/model_doc/speech_to_text.mdx +docs/source/en/model_doc/switch_transformers.mdx docs/source/en/model_doc/t5.mdx docs/source/en/model_doc/t5v1.1.mdx docs/source/en/model_doc/byt5.mdx From 2e3d7b10b1fb6e508e49e07e2eb3413ae37958e4 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 28 Oct 2022 11:09:55 +0200 Subject: [PATCH 052/102] fix failing test --- .../test_modeling_switch_transformers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/models/switch_transformers/test_modeling_switch_transformers.py b/tests/models/switch_transformers/test_modeling_switch_transformers.py index 4badfdcdd2023..a6dcf8f8f3118 100644 --- a/tests/models/switch_transformers/test_modeling_switch_transformers.py +++ b/tests/models/switch_transformers/test_modeling_switch_transformers.py @@ -37,7 +37,7 @@ ) from transformers.models.switch_transformers.modeling_switch_transformers import ( SWITCH_TRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST, - TokensChooseMaskedRouter, + SwitchTransformersTop1Router, load_balancing_loss_func, router_z_loss_func, ) @@ -927,7 +927,7 @@ def test_equivalency_router_z_loss(self): def test_equivalency_token_chose_masked_router(self): r""" - This test tests the equivalency between the `TokensChooseMaskedRouter` + This test tests the equivalency between the `SwitchTransformersTop1Router` originally implemented from here: TODO: provide link """ hidden_dim = 4 @@ -959,7 +959,7 @@ def test_equivalency_token_chose_masked_router(self): expert_capacity=expert_capacity, batch_prioritized_routing=False, ) - model = TokensChooseMaskedRouter(config) + model = SwitchTransformersTop1Router(config) model.classifier.weight = torch.nn.Parameter( torch.Tensor( From d21f9e059697e143f3c0993d42e4d408794bbe6f Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 28 Oct 2022 11:31:36 +0200 Subject: [PATCH 053/102] fix test --- src/transformers/models/switch_transformers/__init__.py | 2 ++ .../switch_transformers/test_modeling_switch_transformers.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/switch_transformers/__init__.py b/src/transformers/models/switch_transformers/__init__.py index 1c5acf82b29ab..0119a9caa9087 100644 --- a/src/transformers/models/switch_transformers/__init__.py +++ b/src/transformers/models/switch_transformers/__init__.py @@ -49,6 +49,7 @@ "SwitchTransformersForConditionalGeneration", "SwitchTransformersModel", "SwitchTransformersPreTrainedModel", + "SwitchTransformersTop1Router", ] @@ -71,6 +72,7 @@ SwitchTransformersForConditionalGeneration, SwitchTransformersModel, SwitchTransformersPreTrainedModel, + SwitchTransformersTop1Router, ) diff --git a/tests/models/switch_transformers/test_modeling_switch_transformers.py b/tests/models/switch_transformers/test_modeling_switch_transformers.py index a6dcf8f8f3118..96810374e98a4 100644 --- a/tests/models/switch_transformers/test_modeling_switch_transformers.py +++ b/tests/models/switch_transformers/test_modeling_switch_transformers.py @@ -34,10 +34,10 @@ SwitchTransformersEncoderModel, SwitchTransformersForConditionalGeneration, SwitchTransformersModel, + SwitchTransformersTop1Router, ) from transformers.models.switch_transformers.modeling_switch_transformers import ( SWITCH_TRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST, - SwitchTransformersTop1Router, load_balancing_loss_func, router_z_loss_func, ) From de6017284dd5052c01a2522749087c20ddcd1443 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 28 Oct 2022 11:38:22 +0200 Subject: [PATCH 054/102] fix test --- src/transformers/__init__.py | 2 ++ .../switch_transformers/modeling_switch_transformers.py | 4 ++-- src/transformers/utils/dummy_pt_objects.py | 7 +++++++ 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index bcebe961fa5fa..ffdc691441f2d 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1956,6 +1956,7 @@ "SwitchTransformersForConditionalGeneration", "SwitchTransformersModel", "SwitchTransformersPreTrainedModel", + "SwitchTransformersTop1Router", ] ) _import_structure["models.trajectory_transformer"].extend( @@ -4682,6 +4683,7 @@ SwitchTransformersForConditionalGeneration, SwitchTransformersModel, SwitchTransformersPreTrainedModel, + SwitchTransformersTop1Router, ) from .models.t5 import ( T5_PRETRAINED_MODEL_ARCHIVE_LIST, diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index ba990ae169c77..7366f094ecc7e 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -1736,13 +1736,13 @@ def forward( encoder_outputs.router_probs ) encoder_z_loss = router_z_loss_func(encoder_router_logits) - encoder_aux_loss = load_balancing_loss_func(encoder_router_logits, encoder_expert_indexes) + encoder_aux_loss = abs(load_balancing_loss_func(encoder_router_logits, encoder_expert_indexes)) decoder_router_logits, decoder_expert_indexes = self._unpack_router_logits( decoder_outputs.router_probs ) decoder_z_loss = router_z_loss_func(decoder_router_logits) - decoder_aux_loss = load_balancing_loss_func(decoder_router_logits, decoder_expert_indexes) + decoder_aux_loss = abs(load_balancing_loss_func(decoder_router_logits, decoder_expert_indexes)) loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index e66cd1c34d4ba..57a07d78de4b4 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -4886,6 +4886,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class SwitchTransformersTop1Router(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + T5_PRETRAINED_MODEL_ARCHIVE_LIST = None From bae848ae4922156be34d8bea4b62eef2ae3ab9a7 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 28 Oct 2022 11:54:56 +0200 Subject: [PATCH 055/102] fix loss issue --- .../switch_transformers/modeling_switch_transformers.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 7366f094ecc7e..18a9c41738cc4 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -1736,13 +1736,15 @@ def forward( encoder_outputs.router_probs ) encoder_z_loss = router_z_loss_func(encoder_router_logits) - encoder_aux_loss = abs(load_balancing_loss_func(encoder_router_logits, encoder_expert_indexes)) + encoder_router_probs = nn.Softmax(dim=-1)(encoder_router_logits) + encoder_aux_loss = load_balancing_loss_func(encoder_router_probs, encoder_expert_indexes) decoder_router_logits, decoder_expert_indexes = self._unpack_router_logits( decoder_outputs.router_probs ) decoder_z_loss = router_z_loss_func(decoder_router_logits) - decoder_aux_loss = abs(load_balancing_loss_func(decoder_router_logits, decoder_expert_indexes)) + decoder_router_probs = nn.Softmax(dim=-1)(decoder_router_logits) + decoder_aux_loss = load_balancing_loss_func(decoder_router_probs, decoder_expert_indexes) loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) From 1f6b91a03e69297d6f4c9fbfd6092eee16702da8 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 28 Oct 2022 11:08:21 +0000 Subject: [PATCH 056/102] move layer --- .../en/model_doc/switch_transformers.mdx | 3 +- .../modeling_switch_transformers.py | 93 ++++++++++--------- 2 files changed, 49 insertions(+), 47 deletions(-) diff --git a/docs/source/en/model_doc/switch_transformers.mdx b/docs/source/en/model_doc/switch_transformers.mdx index ad60313f0d6e8..ca2e7d425a395 100644 --- a/docs/source/en/model_doc/switch_transformers.mdx +++ b/docs/source/en/model_doc/switch_transformers.mdx @@ -27,8 +27,7 @@ The abstract from the paper is the following: Tips: - SwitchTransformers uses the T5Tokenizer, which can be loaded directly from each model's repository. -- The released weights are pretrained on English [Masked Language Modeling](What is is MLM blog or doc) task, and should be finetuned. -- The routers +- The released weights are pretrained on English [Masked Language Modeling](https://moon-ci-docs.huggingface.co/docs/transformers/pr_19323/en/glossary#general-terms) task, and should be finetuned. This model was contributed by [Younes Belkada](https://huggingface.co/ybelkada) and [Arthur Zucker](https://huggingface.co/ArtZucker) . The original code can be found [here](https://github.com/google/flaxformer/tree/main/flaxformer/architectures/moe). diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index ba990ae169c77..6e14512f1ed6c 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -344,51 +344,6 @@ def forward(self, hidden_states): return hidden_states -class SwitchTransformersLayerFF(nn.Module): - r""" - Switch Transformers Feed Forward layer module. This is a wrapper around the Mixture of Experts module. - - Attributes: - is_sparse (`bool`): - Whether the MLP layer is a `Sparse` layer (contains a Mixture of Experts) or not - mlp (`torch.nn.Module`): - The MLP layer of the Feed Forward layer - layer_norm (`torch.nn.Module`): - The pre-MLP layer norm. This module is equivalent to the `pre_mlp_layer_norm` in `t5x` - dropout (`torch.nn.Module`): - Post-MLP dropout layer. - """ - - def __init__(self, config: SwitchTransformersConfig, is_sparse=False): - super().__init__() - self.is_sparse = is_sparse - - # Check if it is a sparse layer, if not then it is a dense layer - if not self.is_sparse: - self.mlp = SwitchTransformersDenseActDense(config) - else: - self.mlp = SwitchTransformersSparseMLP(config) - - self.layer_norm = SwitchTransformersLayerNorm(config.d_model, eps=config.layer_norm_epsilon) - self.dropout = nn.Dropout(config.dropout_rate) - - def forward(self, hidden_states, output_router_logits): - forwarded_states = self.layer_norm(hidden_states) - forwarded_states = self.mlp(forwarded_states) - - if isinstance(forwarded_states, tuple): - forwarded_states, router_tuple = forwarded_states - else: - router_tuple = None - - output = hidden_states + self.dropout(forwarded_states) - - if output_router_logits and router_tuple is not None: - output = (output, router_tuple) - - return output - - class SwitchTransformersSparseMLP(nn.Module): r""" Implementation of the Switch Transformers Sparse MLP module. TODO: Add a LOT of details here @@ -439,6 +394,9 @@ def forward(self, hidden_states): router_mask, router_probs, router_logits = self.router(hidden_states) expert_index = torch.argmax(router_mask, dim=-1) + # The routers introduced might not always map all the tokens, to a router, which means that some hidden states + # can be unchanged from one layer to another. That is why the hidden states are cloned before updating only the seleced ones. + next_states = hidden_states.clone() for idx, expert in enumerate(self.experts.values()): @@ -453,6 +411,51 @@ def forward(self, hidden_states): return hidden_states, (router_logits, expert_index) +class SwitchTransformersLayerFF(nn.Module): + r""" + Switch Transformers Feed Forward layer module. This is a wrapper around the Mixture of Experts module. + + Attributes: + is_sparse (`bool`): + Whether the MLP layer is a `Sparse` layer (contains a Mixture of Experts) or not + mlp (`torch.nn.Module`): + The MLP layer of the Feed Forward layer + layer_norm (`torch.nn.Module`): + The pre-MLP layer norm. This module is equivalent to the `pre_mlp_layer_norm` in `t5x` + dropout (`torch.nn.Module`): + Post-MLP dropout layer. + """ + + def __init__(self, config: SwitchTransformersConfig, is_sparse=False): + super().__init__() + self.is_sparse = is_sparse + + # Check if it is a sparse layer, if not then it is a dense layer + if not self.is_sparse: + self.mlp = SwitchTransformersDenseActDense(config) + else: + self.mlp = SwitchTransformersSparseMLP(config) + + self.layer_norm = SwitchTransformersLayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states, output_router_logits): + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.mlp(forwarded_states) + + if isinstance(forwarded_states, tuple): + forwarded_states, router_tuple = forwarded_states + else: + router_tuple = None + + output = hidden_states + self.dropout(forwarded_states) + + if output_router_logits and router_tuple is not None: + output = (output, router_tuple) + + return output + + # Copied from transformers.models.t5.modeling_t5.T5Attention with T5->SwitchTransformers class SwitchTransformersAttention(nn.Module): def __init__(self, config: SwitchTransformersConfig, has_relative_attention_bias=False): From cdb7768849d80af17ac9464efbd152355062ff68 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 28 Oct 2022 14:03:28 +0000 Subject: [PATCH 057/102] update doc and fix router capacity usage --- .../configuration_switch_transformers.py | 9 -- .../modeling_switch_transformers.py | 144 ++++++------------ .../test_modeling_switch_transformers.py | 41 +++-- 3 files changed, 74 insertions(+), 120 deletions(-) diff --git a/src/transformers/models/switch_transformers/configuration_switch_transformers.py b/src/transformers/models/switch_transformers/configuration_switch_transformers.py index 722f6d8c67678..b1abb13576efc 100644 --- a/src/transformers/models/switch_transformers/configuration_switch_transformers.py +++ b/src/transformers/models/switch_transformers/configuration_switch_transformers.py @@ -75,12 +75,8 @@ class SwitchTransformersConfig(PretrainedConfig): router_dtype (`str`, *optional*, default to `float32`): The `dtype` used for the routers. It is preferable to keep the `dtype` to `float32` as specified in the "selective precision" discussion in https://arxiv.org/abs/2101.03961. - batch_prioritized_routing (`bool`, *optional*, defaults to `False`): - Whether to use batch prioritized routing. add_router_probs (`bool`, *optional*, defaults to `False`): Whether to output router probabilities to compute router auxiliary loss. - num_selected_experts (`int`, *optional*, defaults to 2): - Number of experts to select for each token. relative_attention_num_buckets (`int`, *optional*, defaults to 32): The number of buckets to use for each attention layer. relative_attention_max_distance (`int`, *optional*, defaults to 128): @@ -123,9 +119,7 @@ def __init__( router_bias=False, router_jitter_noise=0.01, router_dtype="float32", - num_selected_experts=2, router_ignore_padding_tokens=False, - batch_prioritized_routing=False, relative_attention_num_buckets=32, relative_attention_max_distance=128, dropout_rate=0.1, @@ -180,9 +174,6 @@ def __init__( self.router_ignore_padding_tokens = router_ignore_padding_tokens self.relative_attention_num_buckets = relative_attention_num_buckets self.relative_attention_max_distance = relative_attention_max_distance - self.batch_prioritized_routing = batch_prioritized_routing - - self.num_selected_experts = num_selected_experts self.dropout_rate = dropout_rate self.layer_norm_epsilon = layer_norm_epsilon diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index a258dab12884f..80a2501fc1c52 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2022 Mesh TensorFlow authors, SwitchTransformers Authors and HuggingFace Inc. team. +# Copyright 2022 SwitchTransformers Authors and HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -71,17 +71,16 @@ ] -# Router loss def router_z_loss_func(router_logits: torch.Tensor) -> float: r""" - Compute router z-loss implemented in PyTorch. + Compute the router z-loss implemented in PyTorch. - The router z-loss was introduced in Designing Effective Sparse Expert Models (https://arxiv.org/abs/2202.08906). It + The router z-loss was introduced in [Designing Effective Sparse Expert Models](https://arxiv.org/abs/2202.08906). It encourages router logits to remain small in an effort to improve stability. Args: router_logits (`float`): - Input logits of shape [num_groups, tokens_per_group, num_experts] + Input logits of shape [batch_size, sequence_length, num_experts] Returns: Scalar router z-loss. @@ -92,27 +91,25 @@ def router_z_loss_func(router_logits: torch.Tensor) -> float: return torch.sum(z_loss) / (num_groups * tokens_per_group) -# aux loss function def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.Tensor) -> float: r""" Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. - See Switch Transformer (https://arxiv.org/abs/2101.03961). This function implements the loss function presented in - equations (4) - (6). It aims to penalize those cases where the routing between experts is unbalanced. + See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the + loss function presented in equations (4) - (6) of the paper. + It aims at penalizing cases where the routing between experts is too unbalanced. Args: router_probs (`torch.Tensor`): - Probability assigned to each expert per token. Shape: [num_groups, tokens_per_group, num_experts]. + Probability assigned to each expert per token. Shape: [batch_size, seqeunce_length, num_experts]. expert_indices (`torch.Tensor`): - Indices tensor of shape [num_groups, tokens_per_group, num_selected_experts] identifying the top - num_selected_experts for a given token. + Indices tensor of shape [batch_size, seqeunce_length] identifying the selected expert for a given token. Returns: The auxiliary loss. """ num_experts = router_probs.shape[-1] - # Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts]. # cast the expert indices to int64, otherwise one-hot encoding will fail if expert_indices.dtype != torch.int64: expert_indices = expert_indices.to(torch.int64) @@ -123,7 +120,6 @@ def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.T expert_mask = torch.nn.functional.one_hot(expert_indices, num_experts) # For a given token, determine if it was routed to a given expert. - # Shape: [num_groups, tokens_per_group, num_experts] expert_mask = torch.max(expert_mask, axis=-2).values # cast to float32 otherwise mean will fail @@ -134,112 +130,75 @@ def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.T return torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert) * (num_experts**2) -@dataclass -class RouterOutput: - """ - Base class for MoE Routers outputs, with expert indices, together with router probabilities. - - Args: - expert_indices (`torch.LongTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - router_probs (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. - - Raw router probabilities that are computed by MoE routers, these terms are used to compute the auxiliary - loss for Mixture of Experts models. - """ - - expert_indices: torch.LongTensor = None - router_probs: torch.FloatTensor = None - class SwitchTransformersTop1Router(nn.Module): """ - Masked matmul router using tokens choose top-1 experts assignment. + Router using tokens choose top-1 experts assignment. This router uses the same mechanism as in Switch Transformer (https://arxiv.org/abs/2101.03961) and V-MoE (https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are sorted by router_probs and then - routed to their choice of expert until the expert's expert_capacity is reached. There is no guarantee that each - token is processed by an expert, or that each expert receives at least one token. + routed to their choice of expert until the expert's expert_capacity is reached. **There is no guarantee that each + token is processed by an expert**, or that each expert receives at least one token. - Parameters: - num_selected_experts (`int`): - Maximum number of experts to which each token is routed. Tokens may be routed to fewer experts if - particular experts are oversubscribed / reach capacity. - batch_prioritized_routing (`bool`): - Whether or not to use Batch Prioritized Routing (BPR), originally introduced in V-MoE - (https://arxiv.org/abs/2106.05974). With BPR, we prioritize routing those top-k tokens with the highest - router probability, rather than simply using each tokens left-to-right ordering in the batch. This - prioritization is important because the experts have limited capacity. """ - def __init__(self, config, **kwargs): + def __init__(self, config: SwitchTransformersConfig): super().__init__() self.num_experts = config.num_experts - self.batch_prioritized_routing = config.batch_prioritized_routing - self.num_experts = config.num_experts + self.expert_capacity = config.expert_capacity self.classifier = nn.Linear(config.hidden_size, self.num_experts, bias=config.router_bias) self.jitter_noise = config.router_jitter_noise self.ignore_padding_tokens = config.router_ignore_padding_tokens self.dtype = getattr(torch, config.router_dtype) - def _compute_router_probabilities( - self, token_inputs: torch.Tensor, num_experts: int, apply_jitter: bool - ) -> Tuple[torch.Tensor, torch.Tensor]: + def _compute_router_probabilities(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: r""" - Computes router probabilities from input tokens. + Computes router probabilities from input hidden states. Args: - token_inputs (`torch.Tensor`): - [num_groups, tokens_per_group, hidden_dim] from which router probabilities are computed. - num_experts (`int`): - Number of experts. - apply_jitter (`bool`): - If true, apply jitter noise. - + hidden_states (`torch.Tensor`): + (batch_size, sequence_length, hidden_dim) from which router probabilities are computed. Returns: router_probabilities (`torch.Tensor`): - Tensor of shape [num_groups, tokens_per_group, num_experts] corresponding to the probabilities for each + Tensor of shape (batch_size, sequence_length, num_experts) corresponding to the probabilities for each token and expert. Used for routing tokens to experts. router_logits (`torch.Tensor`): - Logits tensor of shape [num_groups, tokens_per_group, num_experts] corresponding to raw router logits. + Logits tensor of shape (batch_size, sequence_length, num_experts) corresponding to raw router logits. This is used later for computing router z-loss. """ - # For remainder of routing computation we use float32 to ensure stability. - # See the discussion of "selective precision" in + # float32 is used to ensure stability. See the discussion of "selective precision" in # https://arxiv.org/abs/2101.03961. # We also store the previous dtype to cast back the output to the previous dtype - self.input_tokens_dtype = token_inputs.dtype - token_inputs = token_inputs.to(self.dtype) + self.input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(self.dtype) - if apply_jitter and self.jitter_noise > 0: + if self.jitter_noise > 0: # Get the lower and upper bound of the uniform distribution # Adapted from: https://stackoverflow.com/questions/44328530/how-to-get-a-uniform-distribution-in-a-range-r1-r2-in-pytorch distrib_lower_bound = 1.0 - self.jitter_noise distrib_upper_bound = 1.0 + self.jitter_noise - uniform_distrib = ( - torch.rand(token_inputs.shape, device=token_inputs.device, dtype=self.dtype) - * (distrib_lower_bound - distrib_upper_bound) - ) + distrib_upper_bound + uniform_distrib = torch.rand(hidden_states.shape, device=hidden_states.device, dtype=self.dtype) + uniform_distrib = uniform_distrib * (distrib_lower_bound - distrib_upper_bound) + uniform_distrib = uniform_distrib + distrib_upper_bound # Multiply the token inputs by the uniform distribution - adding some noise - token_inputs *= uniform_distrib + hidden_states *= uniform_distrib # Shape: [num_groups, tokens_per_group, num_experts] self.classifier = self.classifier.to(self.dtype) - router_logits = self.classifier(token_inputs) + router_logits = self.classifier(hidden_states) # Apply Softmax and cast back to the original `dtype` router_probabilities = nn.functional.softmax(router_logits, dim=-1, dtype=self.dtype).to( - self.input_tokens_dtype + self.input_dtype ) return router_probabilities, router_logits - def forward(self, token_inputs: torch.Tensor, apply_jitter: bool = True, **kwargs) -> Tuple: + def forward(self, hidden_states: torch.Tensor) -> Tuple: r""" Generic forward function for every Router class. Each Router expects to have the same input hidden states - (`token_inputs`) corresponding to the hidden states for each token, the `expert_capacity` corresponding to the + (`hidden_states`) corresponding to the hidden states for each token, the `expert_capacity` corresponding to the number of tokens the Router will send to each expert, some Routers can send up to few tokens to each expert. Each Router works as the following: it expects the hidden states for each token, gets the `router_probs` and @@ -247,17 +206,14 @@ def forward(self, token_inputs: torch.Tensor, apply_jitter: bool = True, **kwarg to an expert. Then each Router class will have to define its own `_compute_routing_instructions`. Args: - Computes dispatch and combine torch.Tensors for routing to experts. - token_inputs: [num_groups, tokens_per_group, hidden_dim] inputs to send to experts. num_experts: - Number of experts. expert_capacity: Each group will send this many tokens to each expert. - apply_jitter: - If true, apply jitter noise during routing. + hidden_states (`torch.Tensor`) : + [num_groups, tokens_per_group, hidden_dim] inputs to send to experts. Returns: Router indices or mask torch.Tensors (depending on router type). """ - router_probs, router_logits = self._compute_router_probabilities(token_inputs, self.num_experts, apply_jitter) + router_probs, router_logits = self._compute_router_probabilities(hidden_states) - # Flax code for reference + # Flax code for reference TODO check what happens with padded inputs here if self.ignore_padding_tokens: # To identify non-padding tokens, we rely on the fact that padding tokens # in the inputs have already been masked in the default T5 architecture. @@ -265,7 +221,7 @@ def forward(self, token_inputs: torch.Tensor, apply_jitter: bool = True, **kwarg # https://github.com/google/flaxformer/blob/9712a16/flaxformer/architectures/t5/t5_architecture.py#L315 # and # https://github.com/google/flaxformer/blob/9712a16/flaxformer/architectures/t5/t5_architecture.py#L603. - padding_mask = torch.Tensor((torch.sum(torch.abs(token_inputs), axis=-1) > 0)).to(token_inputs.dtype) + padding_mask = torch.Tensor((torch.sum(torch.abs(hidden_states), axis=-1) > 0)).to(hidden_states.dtype) router_logits *= padding_mask.unsqueeze(-1) else: padding_mask = None @@ -273,6 +229,12 @@ def forward(self, token_inputs: torch.Tensor, apply_jitter: bool = True, **kwarg expert_index = torch.argmax(router_probs, dim=-1) expert_index = torch.nn.functional.one_hot(expert_index, num_classes=self.num_experts) + # Mask tokens outside expert capacity. Sum over the sequence + token_priority = torch.cumsum(expert_index, dim =-2) + # mask if the token routed to to the expert will overflow + expert_capacity_mask = (token_priority <= self.expert_capacity) + expert_index = expert_index * expert_capacity_mask + router_probs = torch.max(router_probs, dim=-1).values.unsqueeze(-1) return expert_index, router_probs, router_logits @@ -303,8 +265,6 @@ def forward(self, hidden_states): return self.weight * hidden_states - -# TODO: do we need this? No let's just import ALL_LAYERNORM_LAYERS. ALL_LAYERNORM_LAYERS.append(SwitchTransformersLayerNorm) @@ -346,7 +306,7 @@ def forward(self, hidden_states): class SwitchTransformersSparseMLP(nn.Module): r""" - Implementation of the Switch Transformers Sparse MLP module. TODO: Add a LOT of details here + Implementation of the Switch Transformers Sparse MLP module. """ def __init__(self, config: SwitchTransformersConfig, expert_class: nn.Module = SwitchTransformersDenseActDense): @@ -359,8 +319,6 @@ def __init__(self, config: SwitchTransformersConfig, expert_class: nn.Module = S for idx in range(config.num_experts): self.experts[f"expert_{idx}"] = expert_class(config) - self.expert_capacity = config.expert_capacity - def _get_router(self, config): r""" For now two types of Router are supported: @@ -386,8 +344,7 @@ def forward(self, hidden_states): hidden states since the probabilities will be broadcasted to the hidden states values (they can be interpreted as a scaling factor). - 2- TODO: explain @ArthurZucker - + 2- Dispatch the tokens to the experts. """ # Step 1: Get the router_mask from the router as wel as the probabilities @@ -415,15 +372,12 @@ class SwitchTransformersLayerFF(nn.Module): r""" Switch Transformers Feed Forward layer module. This is a wrapper around the Mixture of Experts module. - Attributes: + Parameters: + config : ([`SwitchTransformersConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. is_sparse (`bool`): Whether the MLP layer is a `Sparse` layer (contains a Mixture of Experts) or not - mlp (`torch.nn.Module`): - The MLP layer of the Feed Forward layer - layer_norm (`torch.nn.Module`): - The pre-MLP layer norm. This module is equivalent to the `pre_mlp_layer_norm` in `t5x` - dropout (`torch.nn.Module`): - Post-MLP dropout layer. """ def __init__(self, config: SwitchTransformersConfig, is_sparse=False): diff --git a/tests/models/switch_transformers/test_modeling_switch_transformers.py b/tests/models/switch_transformers/test_modeling_switch_transformers.py index 96810374e98a4..eb6e31f726032 100644 --- a/tests/models/switch_transformers/test_modeling_switch_transformers.py +++ b/tests/models/switch_transformers/test_modeling_switch_transformers.py @@ -875,7 +875,13 @@ class SwitchTransformerRouterTest(unittest.TestCase): Original implementation of the routers here: """ - + config = SwitchTransformersConfig( + num_experts=2, + hidden_size=8, + d_ff=16, + router_jitter_noise=0, + expert_capacity=4, + ) def test_equivalency_balancy_loss(self): r""" This test checks if the balancy loss is correctly implemented @@ -930,11 +936,6 @@ def test_equivalency_token_chose_masked_router(self): This test tests the equivalency between the `SwitchTransformersTop1Router` originally implemented from here: TODO: provide link """ - hidden_dim = 4 - num_experts = 2 - num_selected_experts = 1 # Switch routing case - expert_capacity = 1 # Total capacity = 2*2*1 = 4 < num_tokens - jitter_noise = 0.0 input_tokens = torch.Tensor( [ @@ -951,15 +952,8 @@ def test_equivalency_token_chose_masked_router(self): ] ) - config = SwitchTransformersConfig( - num_experts=num_experts, - hidden_size=hidden_dim, - num_selected_experts=num_selected_experts, - router_jitter_noise=jitter_noise, - expert_capacity=expert_capacity, - batch_prioritized_routing=False, - ) - model = SwitchTransformersTop1Router(config) + + model = SwitchTransformersTop1Router(self.config) model.classifier.weight = torch.nn.Parameter( torch.Tensor( @@ -983,6 +977,21 @@ def test_equivalency_token_chose_masked_router(self): # self.assertTrue(torch.allclose(expert_index.bool().unsqueeze(-1), expected_dispatch_mask)) + def test_max_routing_capacity(self): + model = SwitchTransformersTop1Router(self.config) + seq_len = 128 + batch_size = 4 + hidden_states = torch.stack(batch_size*[torch.rand((seq_len,self.config.hidden_size))]) + + router_probs, router_logits = model._compute_router_probabilities(hidden_states) + expert_index = torch.argmax(router_probs, dim=-1) + expert_index = torch.nn.functional.one_hot(expert_index, num_classes=self.config.num_experts) + + token_priority = torch.cumsum(expert_index, dim =-2) + expert_capacity_mask = (token_priority <= self.config.expert_capacity) + expert_index = expert_index * expert_capacity_mask + + assert(torch.sum(expert_index) <= batch_size * self.config.num_experts * self.config.expert_capacity) @slow @require_torch @@ -994,7 +1003,7 @@ def test_small_logits(self): and `transformers` implementation of Switch-C transformers. We only check the logits of the first batch. """ - model = SwitchTransformersModel.from_pretrained("HFLAY/switch_base_8", torch_dtype=torch.bfloat16).eval() + model = SwitchTransformersModel.from_pretrained("HFLAY/switch_base_8",expert_capacity = 65, torch_dtype=torch.bfloat16).eval() input_ids = torch.ones((32, 64), dtype=torch.long) decoder_input_ids = torch.ones((32, 64), dtype=torch.long) From aac7137cc71158f24d33ca1115ad51f8a5cf9e5b Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 28 Oct 2022 14:04:08 +0000 Subject: [PATCH 058/102] fixup --- .../modeling_switch_transformers.py | 23 +++++++--------- .../test_modeling_switch_transformers.py | 27 ++++++++++--------- 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 80a2501fc1c52..5748855025e5e 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -18,7 +18,6 @@ import copy import math import warnings -from dataclasses import dataclass from typing import Optional, Tuple, Union import torch @@ -75,8 +74,8 @@ def router_z_loss_func(router_logits: torch.Tensor) -> float: r""" Compute the router z-loss implemented in PyTorch. - The router z-loss was introduced in [Designing Effective Sparse Expert Models](https://arxiv.org/abs/2202.08906). It - encourages router logits to remain small in an effort to improve stability. + The router z-loss was introduced in [Designing Effective Sparse Expert Models](https://arxiv.org/abs/2202.08906). + It encourages router logits to remain small in an effort to improve stability. Args: router_logits (`float`): @@ -95,9 +94,9 @@ def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.T r""" Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. - See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the - loss function presented in equations (4) - (6) of the paper. - It aims at penalizing cases where the routing between experts is too unbalanced. + See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. Args: router_probs (`torch.Tensor`): @@ -130,7 +129,6 @@ def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.T return torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert) * (num_experts**2) - class SwitchTransformersTop1Router(nn.Module): """ Router using tokens choose top-1 experts assignment. @@ -190,9 +188,7 @@ def _compute_router_probabilities(self, hidden_states: torch.Tensor) -> Tuple[to router_logits = self.classifier(hidden_states) # Apply Softmax and cast back to the original `dtype` - router_probabilities = nn.functional.softmax(router_logits, dim=-1, dtype=self.dtype).to( - self.input_dtype - ) + router_probabilities = nn.functional.softmax(router_logits, dim=-1, dtype=self.dtype).to(self.input_dtype) return router_probabilities, router_logits def forward(self, hidden_states: torch.Tensor) -> Tuple: @@ -230,10 +226,10 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple: expert_index = torch.nn.functional.one_hot(expert_index, num_classes=self.num_experts) # Mask tokens outside expert capacity. Sum over the sequence - token_priority = torch.cumsum(expert_index, dim =-2) + token_priority = torch.cumsum(expert_index, dim=-2) # mask if the token routed to to the expert will overflow - expert_capacity_mask = (token_priority <= self.expert_capacity) - expert_index = expert_index * expert_capacity_mask + expert_capacity_mask = token_priority <= self.expert_capacity + expert_index = expert_index * expert_capacity_mask router_probs = torch.max(router_probs, dim=-1).values.unsqueeze(-1) return expert_index, router_probs, router_logits @@ -265,6 +261,7 @@ def forward(self, hidden_states): return self.weight * hidden_states + ALL_LAYERNORM_LAYERS.append(SwitchTransformersLayerNorm) diff --git a/tests/models/switch_transformers/test_modeling_switch_transformers.py b/tests/models/switch_transformers/test_modeling_switch_transformers.py index eb6e31f726032..f614e98649d33 100644 --- a/tests/models/switch_transformers/test_modeling_switch_transformers.py +++ b/tests/models/switch_transformers/test_modeling_switch_transformers.py @@ -876,12 +876,13 @@ class SwitchTransformerRouterTest(unittest.TestCase): """ config = SwitchTransformersConfig( - num_experts=2, - hidden_size=8, - d_ff=16, - router_jitter_noise=0, - expert_capacity=4, + num_experts=2, + hidden_size=8, + d_ff=16, + router_jitter_noise=0, + expert_capacity=4, ) + def test_equivalency_balancy_loss(self): r""" This test checks if the balancy loss is correctly implemented @@ -952,7 +953,6 @@ def test_equivalency_token_chose_masked_router(self): ] ) - model = SwitchTransformersTop1Router(self.config) model.classifier.weight = torch.nn.Parameter( @@ -981,17 +981,18 @@ def test_max_routing_capacity(self): model = SwitchTransformersTop1Router(self.config) seq_len = 128 batch_size = 4 - hidden_states = torch.stack(batch_size*[torch.rand((seq_len,self.config.hidden_size))]) + hidden_states = torch.stack(batch_size * [torch.rand((seq_len, self.config.hidden_size))]) router_probs, router_logits = model._compute_router_probabilities(hidden_states) expert_index = torch.argmax(router_probs, dim=-1) expert_index = torch.nn.functional.one_hot(expert_index, num_classes=self.config.num_experts) - token_priority = torch.cumsum(expert_index, dim =-2) - expert_capacity_mask = (token_priority <= self.config.expert_capacity) - expert_index = expert_index * expert_capacity_mask + token_priority = torch.cumsum(expert_index, dim=-2) + expert_capacity_mask = token_priority <= self.config.expert_capacity + expert_index = expert_index * expert_capacity_mask + + assert torch.sum(expert_index) <= batch_size * self.config.num_experts * self.config.expert_capacity - assert(torch.sum(expert_index) <= batch_size * self.config.num_experts * self.config.expert_capacity) @slow @require_torch @@ -1003,7 +1004,9 @@ def test_small_logits(self): and `transformers` implementation of Switch-C transformers. We only check the logits of the first batch. """ - model = SwitchTransformersModel.from_pretrained("HFLAY/switch_base_8",expert_capacity = 65, torch_dtype=torch.bfloat16).eval() + model = SwitchTransformersModel.from_pretrained( + "HFLAY/switch_base_8", expert_capacity=65, torch_dtype=torch.bfloat16 + ).eval() input_ids = torch.ones((32, 64), dtype=torch.long) decoder_input_ids = torch.ones((32, 64), dtype=torch.long) From 7129478e2bd0a9e64c0704ed183327c71461fb90 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 28 Oct 2022 14:12:30 +0000 Subject: [PATCH 059/102] add sparse mlp index for documentation on hub --- src/transformers/__init__.py | 2 ++ src/transformers/models/switch_transformers/__init__.py | 2 ++ .../configuration_switch_transformers.py | 4 ++-- .../switch_transformers/modeling_switch_transformers.py | 4 ++-- src/transformers/utils/dummy_pt_objects.py | 7 +++++++ 5 files changed, 15 insertions(+), 4 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index ffdc691441f2d..beff34639cf81 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1957,6 +1957,7 @@ "SwitchTransformersModel", "SwitchTransformersPreTrainedModel", "SwitchTransformersTop1Router", + "SwitchTransformersSparseMLP", ] ) _import_structure["models.trajectory_transformer"].extend( @@ -4683,6 +4684,7 @@ SwitchTransformersForConditionalGeneration, SwitchTransformersModel, SwitchTransformersPreTrainedModel, + SwitchTransformersSparseMLP, SwitchTransformersTop1Router, ) from .models.t5 import ( diff --git a/src/transformers/models/switch_transformers/__init__.py b/src/transformers/models/switch_transformers/__init__.py index 0119a9caa9087..9352b14d9feee 100644 --- a/src/transformers/models/switch_transformers/__init__.py +++ b/src/transformers/models/switch_transformers/__init__.py @@ -50,6 +50,7 @@ "SwitchTransformersModel", "SwitchTransformersPreTrainedModel", "SwitchTransformersTop1Router", + "SwitchTransformersSparseMLP", ] @@ -72,6 +73,7 @@ SwitchTransformersForConditionalGeneration, SwitchTransformersModel, SwitchTransformersPreTrainedModel, + SwitchTransformersSparseMLP, SwitchTransformersTop1Router, ) diff --git a/src/transformers/models/switch_transformers/configuration_switch_transformers.py b/src/transformers/models/switch_transformers/configuration_switch_transformers.py index b1abb13576efc..b6f001da3600c 100644 --- a/src/transformers/models/switch_transformers/configuration_switch_transformers.py +++ b/src/transformers/models/switch_transformers/configuration_switch_transformers.py @@ -49,7 +49,7 @@ class SwitchTransformersConfig(PretrainedConfig): num_heads`. d_ff (`int`, *optional*, defaults to 2048): Size of the intermediate feed forward layer in each `SwitchTransformersBlock`. - expert_capacity (`int`, *optional*, defaults to 1): + expert_capacity (`int`, *optional*, defaults to 64): Number of tokens that can be stored in each expert. If set to 1, the model will behave like a regular Transformer. num_layers (`int`, *optional*, defaults to 12): @@ -114,7 +114,7 @@ def __init__( num_sparse_decoder_layers=3, num_heads=12, num_experts=8, - expert_capacity=1, + expert_capacity=64, router_type="tokens_masked", router_bias=False, router_jitter_noise=0.01, diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 5748855025e5e..a3d9f462753c0 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -225,7 +225,7 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple: expert_index = torch.argmax(router_probs, dim=-1) expert_index = torch.nn.functional.one_hot(expert_index, num_classes=self.num_experts) - # Mask tokens outside expert capacity. Sum over the sequence + # Mask tokens outside expert capacity. Sum over each sequence token_priority = torch.cumsum(expert_index, dim=-2) # mask if the token routed to to the expert will overflow expert_capacity_mask = token_priority <= self.expert_capacity @@ -839,7 +839,7 @@ class SwitchTransformersPreTrainedModel(PreTrainedModel): """ config_class = SwitchTransformersConfig - base_model_prefix = "transformer" + base_model_prefix = "switch_transformers" supports_gradient_checkpointing = True _no_split_modules = ["SwitchTransformersBlock"] diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 57a07d78de4b4..5955ad46465e4 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -4893,6 +4893,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class SwitchTransformersSparseML(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + T5_PRETRAINED_MODEL_ARCHIVE_LIST = None From 56dd559fe40d520516d0d6522c34d648844e78e4 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 28 Oct 2022 14:13:06 +0000 Subject: [PATCH 060/102] fixup --- src/transformers/utils/dummy_pt_objects.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 5955ad46465e4..16ae2b5f640bd 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -4886,14 +4886,14 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) -class SwitchTransformersTop1Router(metaclass=DummyObject): +class SwitchTransformersSparseMLP(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) -class SwitchTransformersSparseML(metaclass=DummyObject): +class SwitchTransformersTop1Router(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): From 00186a89773352f41a7c18ca267e63b1f6c54d2e Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 28 Oct 2022 14:41:55 +0000 Subject: [PATCH 061/102] test sparse mix architecture --- .../test_modeling_switch_transformers.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/models/switch_transformers/test_modeling_switch_transformers.py b/tests/models/switch_transformers/test_modeling_switch_transformers.py index f614e98649d33..1f0e99ccddb41 100644 --- a/tests/models/switch_transformers/test_modeling_switch_transformers.py +++ b/tests/models/switch_transformers/test_modeling_switch_transformers.py @@ -65,9 +65,10 @@ def __init__( eos_token_id=1, pad_token_id=0, decoder_start_token_id=0, - scope=None, decoder_layers=None, sparse_step=1, + num_sparse_decoder_layers=2, + num_sparse_encoder_layers=0, ): self.parent = parent @@ -93,6 +94,8 @@ def __init__( self.scope = None self.decoder_layers = decoder_layers self.sparse_step = sparse_step + self.num_sparse_decoder_layers=num_sparse_decoder_layers + self.num_sparse_encoder_layers=num_sparse_encoder_layers def get_large_model_config(self): return SwitchTransformersConfig.from_pretrained("HFLAY/switch_base_8") @@ -157,6 +160,8 @@ def get_config(self): pad_token_id=self.pad_token_id, decoder_start_token_id=self.decoder_start_token_id, sparse_step=self.sparse_step, + num_sparse_encoder_layers=self.num_sparse_encoder_layers, + num_sparse_decoder_layers=self.num_sparse_decoder_layers ) def check_prepare_lm_labels_via_shift_left( @@ -1005,7 +1010,7 @@ def test_small_logits(self): of the first batch. """ model = SwitchTransformersModel.from_pretrained( - "HFLAY/switch_base_8", expert_capacity=65, torch_dtype=torch.bfloat16 + "HFLAY/switch_base_8", torch_dtype=torch.bfloat16 ).eval() input_ids = torch.ones((32, 64), dtype=torch.long) decoder_input_ids = torch.ones((32, 64), dtype=torch.long) From 9d62a07a876d0c4b3f1ae408c0a353e367ac0861 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Fri, 28 Oct 2022 16:56:38 +0200 Subject: [PATCH 062/102] Apply suggestions from code review --- .../switch_transformers/modeling_switch_transformers.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index a3d9f462753c0..edd36309f795f 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -205,7 +205,7 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple: hidden_states (`torch.Tensor`) : [num_groups, tokens_per_group, hidden_dim] inputs to send to experts. Returns: - Router indices or mask torch.Tensors (depending on router type). + Tuple[`torch.Tensor`, `torch.Tensor`, `torch.Tensor`] Tuple containing the expert index, the router probs and the router logits. The router probabilities and logits are required to compute the loss. """ router_probs, router_logits = self._compute_router_probabilities(hidden_states) @@ -829,7 +829,7 @@ def forward( outputs = outputs + (router_tuple,) - return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights), (router_tuple) class SwitchTransformersPreTrainedModel(PreTrainedModel): @@ -1309,6 +1309,9 @@ def custom_forward(*inputs): output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, and + should not be returned during inference. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ From 7827d08acacbf5842205e5456fc500d5ba0728d7 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Fri, 28 Oct 2022 16:57:02 +0200 Subject: [PATCH 063/102] Update docs/source/en/model_doc/switch_transformers.mdx --- docs/source/en/model_doc/switch_transformers.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/model_doc/switch_transformers.mdx b/docs/source/en/model_doc/switch_transformers.mdx index ca2e7d425a395..0a4d4e5ec930c 100644 --- a/docs/source/en/model_doc/switch_transformers.mdx +++ b/docs/source/en/model_doc/switch_transformers.mdx @@ -17,7 +17,7 @@ specific language governing permissions and limitations under the License. The SwitchTransformers model was proposed in [Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity](https://arxiv.org/abs/2101.03961) by William Fedus, Barret Zoph, Noam Shazeer. The Switch Transformer model uses a sparse T5 encoder-decoder architure, where the MLP are replace by a Mixture of Expert (MOE). A routing mecanism (top 1 in this case) associates each token to one of the expert, where each expert is a dense MLP. While switch tranformers have a lot more weights than their equivalent dense models, the sparsity allows for better scaling. -During a forward pass, only a fraction of the weights are used. The routing mecanism allows the model to select relavant weights on the fly which increases the model capacity. #TODO add the intuition about moving the loss curve. +During a forward pass, only a fraction of the weights are used. The routing mecanism allows the model to select relevant weights on the fly which increases the model capacity without increasing the number of operations. The abstract from the paper is the following: From a2f725d2b46529de2f1f91ae86eea07acc5d8e41 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 28 Oct 2022 14:58:55 +0000 Subject: [PATCH 064/102] fixup on update --- .../modeling_switch_transformers.py | 3 ++- .../test_modeling_switch_transformers.py | 10 ++++------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index edd36309f795f..248aab2498569 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -205,7 +205,8 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple: hidden_states (`torch.Tensor`) : [num_groups, tokens_per_group, hidden_dim] inputs to send to experts. Returns: - Tuple[`torch.Tensor`, `torch.Tensor`, `torch.Tensor`] Tuple containing the expert index, the router probs and the router logits. The router probabilities and logits are required to compute the loss. + Tuple[`torch.Tensor`, `torch.Tensor`, `torch.Tensor`] Tuple containing the expert index, the router probs + and the router logits. The router probabilities and logits are required to compute the loss. """ router_probs, router_logits = self._compute_router_probabilities(hidden_states) diff --git a/tests/models/switch_transformers/test_modeling_switch_transformers.py b/tests/models/switch_transformers/test_modeling_switch_transformers.py index 1f0e99ccddb41..99ecac224b860 100644 --- a/tests/models/switch_transformers/test_modeling_switch_transformers.py +++ b/tests/models/switch_transformers/test_modeling_switch_transformers.py @@ -94,8 +94,8 @@ def __init__( self.scope = None self.decoder_layers = decoder_layers self.sparse_step = sparse_step - self.num_sparse_decoder_layers=num_sparse_decoder_layers - self.num_sparse_encoder_layers=num_sparse_encoder_layers + self.num_sparse_decoder_layers = num_sparse_decoder_layers + self.num_sparse_encoder_layers = num_sparse_encoder_layers def get_large_model_config(self): return SwitchTransformersConfig.from_pretrained("HFLAY/switch_base_8") @@ -161,7 +161,7 @@ def get_config(self): decoder_start_token_id=self.decoder_start_token_id, sparse_step=self.sparse_step, num_sparse_encoder_layers=self.num_sparse_encoder_layers, - num_sparse_decoder_layers=self.num_sparse_decoder_layers + num_sparse_decoder_layers=self.num_sparse_decoder_layers, ) def check_prepare_lm_labels_via_shift_left( @@ -1009,9 +1009,7 @@ def test_small_logits(self): and `transformers` implementation of Switch-C transformers. We only check the logits of the first batch. """ - model = SwitchTransformersModel.from_pretrained( - "HFLAY/switch_base_8", torch_dtype=torch.bfloat16 - ).eval() + model = SwitchTransformersModel.from_pretrained("HFLAY/switch_base_8", torch_dtype=torch.bfloat16).eval() input_ids = torch.ones((32, 64), dtype=torch.long) decoder_input_ids = torch.ones((32, 64), dtype=torch.long) From 20b076ea5733c7e3ddafe3d654e0c7e969c85026 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 28 Oct 2022 17:04:00 +0200 Subject: [PATCH 065/102] fix tests --- .../test_modeling_switch_transformers.py | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/tests/models/switch_transformers/test_modeling_switch_transformers.py b/tests/models/switch_transformers/test_modeling_switch_transformers.py index 1f0e99ccddb41..16658425c3612 100644 --- a/tests/models/switch_transformers/test_modeling_switch_transformers.py +++ b/tests/models/switch_transformers/test_modeling_switch_transformers.py @@ -68,7 +68,7 @@ def __init__( decoder_layers=None, sparse_step=1, num_sparse_decoder_layers=2, - num_sparse_encoder_layers=0, + num_sparse_encoder_layers=2, ): self.parent = parent @@ -249,7 +249,7 @@ def create_and_check_with_lm_head( decoder_attention_mask=decoder_attention_mask, labels=lm_labels, ) - self.parent.assertEqual(len(outputs), 4) + self.parent.assertEqual(len(outputs), 10) self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, self.decoder_seq_length, self.vocab_size)) self.parent.assertEqual(outputs["loss"].size(), ()) @@ -264,9 +264,9 @@ def create_and_check_decoder_model_past( ): model = SwitchTransformersModel(config=config).get_decoder().to(torch_device).eval() # first forward pass - outputs = model(input_ids, use_cache=True) - outputs_use_cache_conf = model(input_ids) - outputs_no_past = model(input_ids, use_cache=False) + outputs = model(input_ids, use_cache=True, output_router_logits=False) + outputs_use_cache_conf = model(input_ids, output_router_logits=False) + outputs_no_past = model(input_ids, use_cache=False, output_router_logits=False) self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf)) self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1) @@ -279,8 +279,8 @@ def create_and_check_decoder_model_past( # append to next input_ids and next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) - output_from_no_past = model(next_input_ids)["last_hidden_state"] - output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"] + output_from_no_past = model(next_input_ids, output_router_logits=False)["last_hidden_state"] + output_from_past = model(next_tokens, past_key_values=past_key_values, output_router_logits=False)["last_hidden_state"] # select random slice random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() @@ -310,7 +310,7 @@ def create_and_check_decoder_model_attention_mask_past( attn_mask[:, half_seq_length:] = 0 # first forward pass - output, past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True).to_tuple() + output, past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True, output_router_logits=False).to_tuple() # create hypothetical next token and extent to next_input_ids next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) @@ -328,8 +328,8 @@ def create_and_check_decoder_model_attention_mask_past( ) # get two different outputs - output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"] - output_from_past = model(next_tokens, past_key_values=past_key_values, attention_mask=attn_mask)[ + output_from_no_past = model(next_input_ids, attention_mask=attn_mask, output_router_logits=False)["last_hidden_state"] + output_from_past = model(next_tokens, past_key_values=past_key_values, attention_mask=attn_mask, output_router_logits=False)[ "last_hidden_state" ] @@ -352,7 +352,7 @@ def create_and_check_decoder_model_past_large_inputs( ): model = SwitchTransformersModel(config=config).get_decoder().to(torch_device).eval() # first forward pass - outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) + outputs = model(input_ids, attention_mask=attention_mask, use_cache=True, output_router_logits=False) output, past_key_values = outputs.to_tuple() @@ -364,8 +364,8 @@ def create_and_check_decoder_model_past_large_inputs( next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) next_attention_mask = torch.cat([attention_mask, next_mask], dim=-1) - output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"] - output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[ + output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask, output_router_logits=False)["last_hidden_state"] + output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values, output_router_logits=False)[ "last_hidden_state" ] @@ -517,6 +517,7 @@ def prepare_config_and_inputs_for_common(self): "decoder_input_ids": decoder_input_ids, "decoder_attention_mask": decoder_attention_mask, "use_cache": False, + "output_router_logits": False, } return config, inputs_dict @@ -852,6 +853,7 @@ def build_model_and_check_forward_pass(self, **kwargs): decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, labels=lm_labels, + output_router_logits=False, ) # outputs = model(*inputs) assert len(outputs) == 4 From b19e3924ca583a20cc0e0eb510992d7434fd1a8f Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 28 Oct 2022 17:12:46 +0200 Subject: [PATCH 066/102] fix another test --- .../modeling_switch_transformers.py | 16 +++++++++++++ .../test_modeling_switch_transformers.py | 23 ++++++++++++++----- 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 248aab2498569..c54674de55906 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -901,6 +901,22 @@ def _init_weights(self, module): module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) if module.has_relative_attention_bias: module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) + elif isinstance(module, SwitchTransformersSparseMLP): + # Mesh TensorFlow attention initialization to avoid scaling before softmax + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 + d_model = self.config.d_model + key_value_proj_dim = self.config.d_kv + n_heads = self.config.num_heads + module.router.classifier.weight.data.normal_( + mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5) + ) + for idx in range(self.config.num_experts): + module.experts[f"expert_{idx}"].wi.weight.data.normal_( + mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5) + ) + module.experts[f"expert_{idx}"].wo.weight.data.normal_( + mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5) + ) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (SwitchTransformersAttention, SwitchTransformersStack)): diff --git a/tests/models/switch_transformers/test_modeling_switch_transformers.py b/tests/models/switch_transformers/test_modeling_switch_transformers.py index bc0be53fac661..99673e53900c0 100644 --- a/tests/models/switch_transformers/test_modeling_switch_transformers.py +++ b/tests/models/switch_transformers/test_modeling_switch_transformers.py @@ -280,7 +280,9 @@ def create_and_check_decoder_model_past( next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) output_from_no_past = model(next_input_ids, output_router_logits=False)["last_hidden_state"] - output_from_past = model(next_tokens, past_key_values=past_key_values, output_router_logits=False)["last_hidden_state"] + output_from_past = model(next_tokens, past_key_values=past_key_values, output_router_logits=False)[ + "last_hidden_state" + ] # select random slice random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() @@ -310,7 +312,9 @@ def create_and_check_decoder_model_attention_mask_past( attn_mask[:, half_seq_length:] = 0 # first forward pass - output, past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True, output_router_logits=False).to_tuple() + output, past_key_values = model( + input_ids, attention_mask=attn_mask, use_cache=True, output_router_logits=False + ).to_tuple() # create hypothetical next token and extent to next_input_ids next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) @@ -328,10 +332,12 @@ def create_and_check_decoder_model_attention_mask_past( ) # get two different outputs - output_from_no_past = model(next_input_ids, attention_mask=attn_mask, output_router_logits=False)["last_hidden_state"] - output_from_past = model(next_tokens, past_key_values=past_key_values, attention_mask=attn_mask, output_router_logits=False)[ + output_from_no_past = model(next_input_ids, attention_mask=attn_mask, output_router_logits=False)[ "last_hidden_state" ] + output_from_past = model( + next_tokens, past_key_values=past_key_values, attention_mask=attn_mask, output_router_logits=False + )["last_hidden_state"] # select random slice random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() @@ -364,10 +370,15 @@ def create_and_check_decoder_model_past_large_inputs( next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) next_attention_mask = torch.cat([attention_mask, next_mask], dim=-1) - output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask, output_router_logits=False)["last_hidden_state"] - output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values, output_router_logits=False)[ + output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask, output_router_logits=False)[ "last_hidden_state" ] + output_from_past = model( + next_tokens, + attention_mask=next_attention_mask, + past_key_values=past_key_values, + output_router_logits=False, + )["last_hidden_state"] # select random slice random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() From 26f5387e8499f437fe973091d0f8d0174f20c3e9 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 28 Oct 2022 17:20:40 +0200 Subject: [PATCH 067/102] attempt fix --- .../switch_transformers/modeling_switch_transformers.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index c54674de55906..1433dbf2eab69 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -1118,6 +1118,10 @@ def custom_forward(*inputs): output_router_logits=output_router_logits, ) + if output_router_logits: + router_probs = layer_outputs[-1] + layer_outputs = layer_outputs[:-1] + # layer_outputs is a tuple with: # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) if use_cache is False: @@ -1141,7 +1145,7 @@ def custom_forward(*inputs): all_cross_attentions = all_cross_attentions + (layer_outputs[5],) if output_router_logits: - all_router_probs = all_router_probs + (layer_outputs[-1],) + all_router_probs = all_router_probs + (router_probs,) hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) From 44446888a8da279c8bc0e29fcaab6fffed83f42e Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Fri, 28 Oct 2022 17:33:20 +0200 Subject: [PATCH 068/102] Update src/transformers/models/switch_transformers/configuration_switch_transformers.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- .../switch_transformers/configuration_switch_transformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/switch_transformers/configuration_switch_transformers.py b/src/transformers/models/switch_transformers/configuration_switch_transformers.py index b6f001da3600c..c57e99f06fd5b 100644 --- a/src/transformers/models/switch_transformers/configuration_switch_transformers.py +++ b/src/transformers/models/switch_transformers/configuration_switch_transformers.py @@ -23,7 +23,7 @@ logger = logging.get_logger(__name__) SWITCH_TRANSFORMERS_PRETRAINED_CONFIG_ARCHIVE_MAP = { - "HFLAY/switch_base_8": "https://huggingface.co/HFLAY/switch_base_8/blob/main/config.json", + "google/switch-base-8": "https://huggingface.co/google/switch-base-8/blob/main/config.json", } From 44b8a819fb58ede253b7dfe43f77df1988f11593 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Fri, 28 Oct 2022 17:33:58 +0200 Subject: [PATCH 069/102] Update src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- ...ers_original_flax_checkpoint_to_pytorch.py | 29 ------------------- 1 file changed, 29 deletions(-) diff --git a/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py b/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py index a6d8cac3e371f..353b00cc974ad 100644 --- a/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py +++ b/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py @@ -28,35 +28,6 @@ logging.set_verbosity_info() -MODEL_MAPPING = { - "switch_base_8": [ - "https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin" - ], - "switch_base_16": [ - "https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin" - ], - "switch_base_32": [ - "https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin" - ], - "switch_base_64": [ - "https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin" - ], - "switch_base_128": [ - "https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin" - ], - "switch_base_256": [ - "https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin" - ], - "switch_base_512": [ - "https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin" - ], - "switch_base_1024": [ - "https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin" - ], - "switch_base_2048": [ - "https://github.com/google/flaxformer/tree/main/flaxformer/t5x/configs/moe/models/switch_base.gin" - ], -} # should not include what is already done by the `from_pt` argument MOE_LAYER_NAME_MAPPING = { From 32903e32371c44d1b8d95dc32de681086bf35fe1 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 28 Oct 2022 18:52:26 +0200 Subject: [PATCH 070/102] try --- .../configuration_switch_transformers.py | 2 +- ...ransformers_original_flax_checkpoint_to_pytorch.py | 1 - .../modeling_switch_transformers.py | 11 ++++------- .../test_modeling_switch_transformers.py | 3 +++ 4 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/switch_transformers/configuration_switch_transformers.py b/src/transformers/models/switch_transformers/configuration_switch_transformers.py index c57e99f06fd5b..f1c6e364342a2 100644 --- a/src/transformers/models/switch_transformers/configuration_switch_transformers.py +++ b/src/transformers/models/switch_transformers/configuration_switch_transformers.py @@ -23,7 +23,7 @@ logger = logging.get_logger(__name__) SWITCH_TRANSFORMERS_PRETRAINED_CONFIG_ARCHIVE_MAP = { - "google/switch-base-8": "https://huggingface.co/google/switch-base-8/blob/main/config.json", + "google/switch-base-8": "https://huggingface.co/google/switch-base-8/blob/main/config.json", } diff --git a/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py b/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py index 353b00cc974ad..f0f419752706f 100644 --- a/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py +++ b/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py @@ -28,7 +28,6 @@ logging.set_verbosity_info() - # should not include what is already done by the `from_pt` argument MOE_LAYER_NAME_MAPPING = { "/attention/": "/0/SelfAttention/", diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 1433dbf2eab69..81b030c6e6c1d 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -824,11 +824,9 @@ def forward( outputs = (hidden_states,) if use_cache: - outputs = outputs + (present_key_value_state,) + attention_outputs + outputs = outputs + (present_key_value_state,) + attention_outputs + (router_tuple,) else: - outputs = outputs + attention_outputs - - outputs = outputs + (router_tuple,) + outputs = outputs + attention_outputs + (router_tuple,) return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights), (router_tuple) @@ -1118,9 +1116,8 @@ def custom_forward(*inputs): output_router_logits=output_router_logits, ) - if output_router_logits: - router_probs = layer_outputs[-1] - layer_outputs = layer_outputs[:-1] + router_probs = layer_outputs[-1] + layer_outputs = layer_outputs[:-1] # layer_outputs is a tuple with: # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) diff --git a/tests/models/switch_transformers/test_modeling_switch_transformers.py b/tests/models/switch_transformers/test_modeling_switch_transformers.py index 99673e53900c0..f2ac208ecf4e8 100644 --- a/tests/models/switch_transformers/test_modeling_switch_transformers.py +++ b/tests/models/switch_transformers/test_modeling_switch_transformers.py @@ -69,6 +69,7 @@ def __init__( sparse_step=1, num_sparse_decoder_layers=2, num_sparse_encoder_layers=2, + expert_capacity=100, ): self.parent = parent @@ -96,6 +97,7 @@ def __init__( self.sparse_step = sparse_step self.num_sparse_decoder_layers = num_sparse_decoder_layers self.num_sparse_encoder_layers = num_sparse_encoder_layers + self.expert_capacity = expert_capacity def get_large_model_config(self): return SwitchTransformersConfig.from_pretrained("HFLAY/switch_base_8") @@ -141,6 +143,7 @@ def get_pipeline_config(self): bos_token_id=self.pad_token_id, pad_token_id=self.pad_token_id, decoder_start_token_id=self.decoder_start_token_id, + expert_capacity=self.expert_capacity, ) def get_config(self): From bcff9e4a1490cd23d91fcf8af4fc900a18965d2b Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 28 Oct 2022 19:43:17 +0200 Subject: [PATCH 071/102] all tests pass --- .../modeling_switch_transformers.py | 12 +- .../test_modeling_switch_transformers.py | 104 +++++++++++++++++- 2 files changed, 106 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 81b030c6e6c1d..fd5d1e600674b 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -905,16 +905,10 @@ def _init_weights(self, module): d_model = self.config.d_model key_value_proj_dim = self.config.d_kv n_heads = self.config.num_heads - module.router.classifier.weight.data.normal_( - mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5) - ) + module.router.classifier.weight.data.normal_(mean=0.0, std=factor * 1) for idx in range(self.config.num_experts): - module.experts[f"expert_{idx}"].wi.weight.data.normal_( - mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5) - ) - module.experts[f"expert_{idx}"].wo.weight.data.normal_( - mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5) - ) + module.experts[f"expert_{idx}"].wi.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.experts[f"expert_{idx}"].wo.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (SwitchTransformersAttention, SwitchTransformersStack)): diff --git a/tests/models/switch_transformers/test_modeling_switch_transformers.py b/tests/models/switch_transformers/test_modeling_switch_transformers.py index f2ac208ecf4e8..6cac096600302 100644 --- a/tests/models/switch_transformers/test_modeling_switch_transformers.py +++ b/tests/models/switch_transformers/test_modeling_switch_transformers.py @@ -36,6 +36,7 @@ SwitchTransformersModel, SwitchTransformersTop1Router, ) + from transformers.generation_utils import BeamSampleDecoderOnlyOutput, BeamSampleEncoderDecoderOutput from transformers.models.switch_transformers.modeling_switch_transformers import ( SWITCH_TRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST, load_balancing_loss_func, @@ -393,6 +394,7 @@ def create_and_check_decoder_model_past_large_inputs( # test that outputs are equal for slice self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + @slow def create_and_check_generate_with_past_key_values( self, config, @@ -402,7 +404,12 @@ def create_and_check_generate_with_past_key_values( decoder_attention_mask, lm_labels, ): - model = SwitchTransformersForConditionalGeneration(config=config).to(torch_device).eval() + r""" + This test does not pass for small models due to precision errors. It is therefore only run for slightly larger models. + """ + model = ( + SwitchTransformersForConditionalGeneration.from_pretrained("HFLAY/switch_base_8").to(torch_device).eval() + ) torch.manual_seed(0) output_without_past_cache = model.generate( input_ids[:1], num_beams=2, max_length=5, do_sample=True, use_cache=False @@ -595,6 +602,101 @@ def test_decoder_model_past_with_attn_mask(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs) + @slow + def test_beam_sample_generate_dict_output(self): + r""" + This test needs to be overriden with a larger model since it fails for very small models due to precision issues. + """ + for model_class in self.all_generative_model_classes: + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + + # disable cache + config.use_cache = False + + # It is important set set the eos_token_id to None to ensure that no sequences + # shorter than `max_length` can be generated which could lead to flaky circle ci + # failures if the top `num_return_sequences` beams are all shorter than the longest beam + config.eos_token_id = None + config.forced_eos_token_id = None + + model = model_class.from_pretrained("HFLAY/switch_base_8").to(torch_device).eval() + logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) + + num_return_sequences = 2 + if model.config.is_encoder_decoder: + max_length = 4 + beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs( + input_ids.shape[0] * num_return_sequences, max_length + ) + beam_kwargs["num_return_sequences"] = num_return_sequences + + output_beam_sample, output_generate = self._beam_sample_generate( + model=model, + input_ids=input_ids, + attention_mask=attention_mask, + max_length=max_length, + num_return_sequences=num_return_sequences, + beam_scorer=beam_scorer, + beam_kwargs=beam_kwargs, + logits_warper=logits_warper, + logits_warper_kwargs=logits_warper_kwargs, + output_scores=True, + output_hidden_states=True, + output_attentions=True, + return_dict_in_generate=True, + ) + + if model.config.is_encoder_decoder: + self.assertIsInstance(output_beam_sample, BeamSampleEncoderDecoderOutput) + self.assertIsInstance(output_generate, BeamSampleEncoderDecoderOutput) + else: + self.assertIsInstance(output_beam_sample, BeamSampleDecoderOnlyOutput) + self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput) + + self.assertListEqual(output_generate.sequences.tolist(), output_beam_sample.sequences.tolist()) + + @slow + def test_beam_sample_generate(self): + r""" + This test needs to be overriden with a larger model since it fails for very small models due to precision issues. + """ + for model_class in self.all_generative_model_classes: + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + + # It is important set set the eos_token_id to None to ensure that no sequences + # shorter than `max_length` can be generated which could lead to flaky circle ci + # failures if the top `num_return_sequences` beams are all shorter than the longest beam + config.eos_token_id = None + config.forced_eos_token_id = None + + logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) + + model = model_class.from_pretrained("HFLAY/switch_base_8").to(torch_device).eval() + + # check `generate()` and `beam_search()` are equal + # change `num_return_sequences = 2` but not for `beam_scorer` + num_return_sequences = 2 + if model.config.is_encoder_decoder: + max_length = 4 + beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs( + input_ids.shape[0] * num_return_sequences, max_length + ) + beam_kwargs["num_return_sequences"] = num_return_sequences + + output_generate, output_beam_sample = self._beam_sample_generate( + model=model, + input_ids=input_ids, + attention_mask=attention_mask, + max_length=max_length, + num_return_sequences=num_return_sequences, + beam_scorer=beam_scorer, + beam_kwargs=beam_kwargs, + logits_warper=logits_warper, + logits_warper_kwargs=logits_warper_kwargs, + ) + + self.assertListEqual(output_generate.tolist(), output_beam_sample.tolist()) + def test_decoder_model_past_with_3d_attn_mask(self): ( config, From da4800077a19f842871ddfbc387b7bef7f91d25d Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 28 Oct 2022 19:53:38 +0200 Subject: [PATCH 072/102] fix jitter noise --- .../switch_transformers/test_modeling_switch_transformers.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/models/switch_transformers/test_modeling_switch_transformers.py b/tests/models/switch_transformers/test_modeling_switch_transformers.py index 6cac096600302..21d4e7dc8c3c8 100644 --- a/tests/models/switch_transformers/test_modeling_switch_transformers.py +++ b/tests/models/switch_transformers/test_modeling_switch_transformers.py @@ -71,6 +71,7 @@ def __init__( num_sparse_decoder_layers=2, num_sparse_encoder_layers=2, expert_capacity=100, + router_jitter_noise=0.0, ): self.parent = parent @@ -99,6 +100,7 @@ def __init__( self.num_sparse_decoder_layers = num_sparse_decoder_layers self.num_sparse_encoder_layers = num_sparse_encoder_layers self.expert_capacity = expert_capacity + self.router_jitter_noise = router_jitter_noise def get_large_model_config(self): return SwitchTransformersConfig.from_pretrained("HFLAY/switch_base_8") @@ -145,6 +147,7 @@ def get_pipeline_config(self): pad_token_id=self.pad_token_id, decoder_start_token_id=self.decoder_start_token_id, expert_capacity=self.expert_capacity, + router_jitter_noise=self.router_jitter_noise, ) def get_config(self): From fe9c6b95d22404930d7054797c7413eb612c4306 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Mon, 31 Oct 2022 10:33:46 +0100 Subject: [PATCH 073/102] Apply suggestions from code review --- .../en/model_doc/switch_transformers.mdx | 2 +- src/transformers/modeling_outputs.py | 8 +++---- .../configuration_switch_transformers.py | 3 +-- .../modeling_switch_transformers.py | 21 ++++++++----------- 4 files changed, 15 insertions(+), 19 deletions(-) diff --git a/docs/source/en/model_doc/switch_transformers.mdx b/docs/source/en/model_doc/switch_transformers.mdx index 0a4d4e5ec930c..5a7a45f3e6400 100644 --- a/docs/source/en/model_doc/switch_transformers.mdx +++ b/docs/source/en/model_doc/switch_transformers.mdx @@ -16,7 +16,7 @@ specific language governing permissions and limitations under the License. The SwitchTransformers model was proposed in [Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity](https://arxiv.org/abs/2101.03961) by William Fedus, Barret Zoph, Noam Shazeer. -The Switch Transformer model uses a sparse T5 encoder-decoder architure, where the MLP are replace by a Mixture of Expert (MOE). A routing mecanism (top 1 in this case) associates each token to one of the expert, where each expert is a dense MLP. While switch tranformers have a lot more weights than their equivalent dense models, the sparsity allows for better scaling. +The Switch Transformer model uses a sparse T5 encoder-decoder architure, where the MLP are replace by a Mixture of Experts (MoE). A routing mecanism (top 1 in this case) associates each token to one of the expert, where each expert is a dense MLP. While switch tranformers have a lot more weights than their equivalent dense models, the sparsity allows better scaling and better finetuning performance at scale. During a forward pass, only a fraction of the weights are used. The routing mecanism allows the model to select relevant weights on the fly which increases the model capacity without increasing the number of operations. diff --git a/src/transformers/modeling_outputs.py b/src/transformers/modeling_outputs.py index 5325cc550c781..0b430df244482 100644 --- a/src/transformers/modeling_outputs.py +++ b/src/transformers/modeling_outputs.py @@ -309,7 +309,7 @@ class MoEModelOutput(ModelOutput): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. Raw router probabilities that are computed by MoE routers, these terms are used to compute the auxiliary - loss for Mixture of Experts models. + loss and the z_loss for Mixture of Experts models. """ last_hidden_state: torch.FloatTensor = None @@ -360,7 +360,7 @@ class MoEModelOutputWithPastAndCrossAttentions(ModelOutput): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. Raw router probabilities that are computed by MoE routers, these terms are used to compute the auxiliary - loss for Mixture of Experts models. + loss and the z_loss for Mixture of Experts models. """ last_hidden_state: torch.FloatTensor = None @@ -488,7 +488,7 @@ class Seq2SeqMoEModelOutput(ModelOutput): encoder_router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. - Router logits of the encoder model, useful to compute the auxiliary loss for Mixture of Experts models. + Router logits of the encoder model, useful to compute the auxiliary loss and the z_loss for the sparse modules. """ last_hidden_state: torch.FloatTensor = None @@ -791,7 +791,7 @@ class Seq2SeqMoEOutput(ModelOutput): encoder_router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. - Router logits of the encoder model, useful to compute the auxiliary loss for Mixture of Experts models. + Router logits of the encoder model, useful to compute the auxiliary loss and z_loss for Mixture of Experts models. """ loss: Optional[torch.FloatTensor] = None diff --git a/src/transformers/models/switch_transformers/configuration_switch_transformers.py b/src/transformers/models/switch_transformers/configuration_switch_transformers.py index f1c6e364342a2..6ab55e750a1e6 100644 --- a/src/transformers/models/switch_transformers/configuration_switch_transformers.py +++ b/src/transformers/models/switch_transformers/configuration_switch_transformers.py @@ -40,8 +40,7 @@ class SwitchTransformersConfig(PretrainedConfig): Arguments: vocab_size (`int`, *optional*, defaults to 32128): Vocabulary size of the SwitchTransformers model. Defines the number of different tokens that can be - represented by the `inputs_ids` passed when calling [`SwitchTransformersModel`] or - [`FlaxSwitchTransformersModel`]. + represented by the `inputs_ids` passed when calling [`SwitchTransformersModel`]. d_model (`int`, *optional*, defaults to 512): Size of the encoder layers and the pooler layer. d_kv (`int`, *optional*, defaults to 64): diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index fd5d1e600674b..94cb3f90d1ff3 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -322,7 +322,6 @@ def _get_router(self, config): For now two types of Router are supported: - Masked Routers - Scatter Routers - In total the list of supported Routers are the following: """ if config.router_type.lower() == "tokens_masked": @@ -342,7 +341,8 @@ def forward(self, hidden_states): hidden states since the probabilities will be broadcasted to the hidden states values (they can be interpreted as a scaling factor). - 2- Dispatch the tokens to the experts. + 2- Dispatch the tokens to the experts. We do a classic for loop over the experts and assign for each expert + the corresponding hidden state """ # Step 1: Get the router_mask from the router as wel as the probabilities @@ -1170,10 +1170,7 @@ def custom_forward(*inputs): SWITCH_TRANSFORMERS_START_DOCSTRING = r""" - The SWITCH_TRANSFORMERS model was proposed in [Exploring the Limits of Transfer Learning with a Unified - Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine - Lee, Sharan Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer - pre-trained in a text-to-text denoising generative setting. + The SWITCH_TRANSFORMERS model was proposed in [Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity](https://arxiv.org/abs/2101.03961) by [William Fedus](https://arxiv.org/search/cs?searchtype=author&query=Fedus%2C+W), [Barret Zoph](https://arxiv.org/search/cs?searchtype=author&query=Zoph%2C+B), and [Noam Shazeer](https://arxiv.org/search/cs?searchtype=author&query=Shazeer%2C+N). It's an encoder-decoder T5-like model with sparse Feed Forward that stands for Mixture of Experts (MoE) architecture. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads @@ -1422,8 +1419,8 @@ def forward( ```python >>> from transformers import T5Tokenizer, SwitchTransformersModel - >>> tokenizer = T5Tokenizer.from_pretrained("ybelkada/switch_transformers-base") - >>> model = SwitchTransformersModel.from_pretrained("ybelkada/switch_transformers-base") + >>> tokenizer = T5Tokenizer.from_pretrained("google/switch-base-8") + >>> model = SwitchTransformersModel.from_pretrained("google/switch-base-8") >>> input_ids = tokenizer( ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" @@ -1609,8 +1606,8 @@ def forward( ```python >>> from transformers import T5Tokenizer, SwitchTransformersForConditionalGeneration - >>> tokenizer = T5Tokenizer.from_pretrained("ybelkada/switch_transformers-base") - >>> model = SwitchTransformersForConditionalGeneration.from_pretrained("ybelkada/switch_transformers-base") + >>> tokenizer = T5Tokenizer.from_pretrained("google/switch-base-8") + >>> model = SwitchTransformersForConditionalGeneration.from_pretrained("google/switch-base-8") >>> # training >>> input_ids = tokenizer("The walks in park", return_tensors="pt").input_ids @@ -1885,8 +1882,8 @@ def forward( ```python >>> from transformers import T5Tokenizer, SwitchTransformersEncoderModel - >>> tokenizer = T5Tokenizer.from_pretrained("ybelkada/switch_transformers-base") - >>> model = SwitchTransformersEncoderModel.from_pretrained("ybelkada/switch_transformers-base") + >>> tokenizer = T5Tokenizer.from_pretrained("google/switch-base-8") + >>> model = SwitchTransformersEncoderModel.from_pretrained("google/switch-base-8") >>> input_ids = tokenizer( ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" ... ).input_ids # Batch size 1 From 0f8139e25ddcd4204fcd80960b7c7bbde3a02171 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 31 Oct 2022 11:32:29 +0100 Subject: [PATCH 074/102] doc tests pass --- src/transformers/modeling_outputs.py | 6 ++-- .../modeling_switch_transformers.py | 31 +++++++++++-------- 2 files changed, 22 insertions(+), 15 deletions(-) diff --git a/src/transformers/modeling_outputs.py b/src/transformers/modeling_outputs.py index 0b430df244482..606476e964561 100644 --- a/src/transformers/modeling_outputs.py +++ b/src/transformers/modeling_outputs.py @@ -488,7 +488,8 @@ class Seq2SeqMoEModelOutput(ModelOutput): encoder_router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. - Router logits of the encoder model, useful to compute the auxiliary loss and the z_loss for the sparse modules. + Router logits of the encoder model, useful to compute the auxiliary loss and the z_loss for the sparse + modules. """ last_hidden_state: torch.FloatTensor = None @@ -791,7 +792,8 @@ class Seq2SeqMoEOutput(ModelOutput): encoder_router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. - Router logits of the encoder model, useful to compute the auxiliary loss and z_loss for Mixture of Experts models. + Router logits of the encoder model, useful to compute the auxiliary loss and z_loss for Mixture of Experts + models. """ loss: Optional[torch.FloatTensor] = None diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 94cb3f90d1ff3..b9231fc682a85 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -50,14 +50,14 @@ _CONFIG_FOR_DOC = "SwitchTransformersConfig" _TOKENIZER_FOR_DOC = "T5Tokenizer" -_CHECKPOINT_FOR_DOC = "google/switch-base-8" +_CHECKPOINT_FOR_DOC = "HFLAY/switch_base_8" #################################################### # This dict contains ids and associated url # for the pretrained weights provided with the models #################################################### SWITCH_TRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "google/switch-base-8", + "HFLAY/switch_base_8", "google/switch-base-16", "google/switch-base-32", "google/switch-base-64", @@ -341,8 +341,8 @@ def forward(self, hidden_states): hidden states since the probabilities will be broadcasted to the hidden states values (they can be interpreted as a scaling factor). - 2- Dispatch the tokens to the experts. We do a classic for loop over the experts and assign for each expert - the corresponding hidden state + 2- Dispatch the tokens to the experts. We do a classic for loop over the experts and assign for each expert the + corresponding hidden state """ # Step 1: Get the router_mask from the router as wel as the probabilities @@ -1170,7 +1170,12 @@ def custom_forward(*inputs): SWITCH_TRANSFORMERS_START_DOCSTRING = r""" - The SWITCH_TRANSFORMERS model was proposed in [Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity](https://arxiv.org/abs/2101.03961) by [William Fedus](https://arxiv.org/search/cs?searchtype=author&query=Fedus%2C+W), [Barret Zoph](https://arxiv.org/search/cs?searchtype=author&query=Zoph%2C+B), and [Noam Shazeer](https://arxiv.org/search/cs?searchtype=author&query=Shazeer%2C+N). It's an encoder-decoder T5-like model with sparse Feed Forward that stands for Mixture of Experts (MoE) architecture. + The SWITCH_TRANSFORMERS model was proposed in [Switch Transformers: Scaling to Trillion Parameter Models with + Simple and Efficient Sparsity](https://arxiv.org/abs/2101.03961) by [William + Fedus](https://arxiv.org/search/cs?searchtype=author&query=Fedus%2C+W), [Barret + Zoph](https://arxiv.org/search/cs?searchtype=author&query=Zoph%2C+B), and [Noam + Shazeer](https://arxiv.org/search/cs?searchtype=author&query=Shazeer%2C+N). It's an encoder-decoder T5-like model + with sparse Feed Forward that stands for Mixture of Experts (MoE) architecture. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads @@ -1419,8 +1424,8 @@ def forward( ```python >>> from transformers import T5Tokenizer, SwitchTransformersModel - >>> tokenizer = T5Tokenizer.from_pretrained("google/switch-base-8") - >>> model = SwitchTransformersModel.from_pretrained("google/switch-base-8") + >>> tokenizer = T5Tokenizer.from_pretrained("HFLAY/switch_base_8") + >>> model = SwitchTransformersModel.from_pretrained("HFLAY/switch_base_8") >>> input_ids = tokenizer( ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" @@ -1606,8 +1611,8 @@ def forward( ```python >>> from transformers import T5Tokenizer, SwitchTransformersForConditionalGeneration - >>> tokenizer = T5Tokenizer.from_pretrained("google/switch-base-8") - >>> model = SwitchTransformersForConditionalGeneration.from_pretrained("google/switch-base-8") + >>> tokenizer = T5Tokenizer.from_pretrained("HFLAY/switch_base_8") + >>> model = SwitchTransformersForConditionalGeneration.from_pretrained("HFLAY/switch_base_8") >>> # training >>> input_ids = tokenizer("The walks in park", return_tensors="pt").input_ids @@ -1621,8 +1626,8 @@ def forward( ... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt" ... ).input_ids # Batch size 1 >>> outputs = model.generate(input_ids) - >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) - >>> # studies have shown that owning a dog is good for you. + >>> # . To, let’s say you have a dog. To summarize: + >>> # Since the model has been trained on MLM, this will output gibberish ```""" use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -1882,8 +1887,8 @@ def forward( ```python >>> from transformers import T5Tokenizer, SwitchTransformersEncoderModel - >>> tokenizer = T5Tokenizer.from_pretrained("google/switch-base-8") - >>> model = SwitchTransformersEncoderModel.from_pretrained("google/switch-base-8") + >>> tokenizer = T5Tokenizer.from_pretrained("HFLAY/switch_base_8") + >>> model = SwitchTransformersEncoderModel.from_pretrained("HFLAY/switch_base_8") >>> input_ids = tokenizer( ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" ... ).input_ids # Batch size 1 From 7ad488dceaf1099479a657aa8321c841c69137e7 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Wed, 2 Nov 2022 16:37:08 +0100 Subject: [PATCH 075/102] Update src/transformers/models/switch_transformers/modeling_switch_transformers.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- .../switch_transformers/modeling_switch_transformers.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index b9231fc682a85..91313f174c8d3 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -336,13 +336,12 @@ def forward(self, hidden_states): r""" Hold on, this will be slightly tricky to understand In the correct order, a MoE layer does the following: - 1- Gets the `router_mask` from the router. This mask will contain the indices of the routed tokens. Also - retrieve the probabilities (max prob) for each token. The probabilities are needed in the computation of the - hidden states since the probabilities will be broadcasted to the hidden states values (they can be interpreted + 1- Gets the `router_mask` from the router. The shape of the mask is `(batch_size, sequence_length, num_expert)` and corresponds to the argmax of the `router_probs`. The probabilities are needed in the computation of the + hidden states : they are broadcasted to the hidden states values (can be interpreted as a scaling factor). - 2- Dispatch the tokens to the experts. We do a classic for loop over the experts and assign for each expert the - corresponding hidden state + 2- Dispatch the tokens to its associated experts. We do a classic for loop over the experts and assign for each expert the + corresponding hidden states. """ # Step 1: Get the router_mask from the router as wel as the probabilities From deb2b475e5eb39049dca51e2387e3a4ea1ea1d6f Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Wed, 2 Nov 2022 16:37:14 +0100 Subject: [PATCH 076/102] Update src/transformers/models/switch_transformers/modeling_switch_transformers.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- .../switch_transformers/modeling_switch_transformers.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 91313f174c8d3..8884713ad35ff 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -354,11 +354,7 @@ def forward(self, hidden_states): next_states = hidden_states.clone() for idx, expert in enumerate(self.experts.values()): - # 1. Get the index of the tokens that are routed to the current expert - # masked_indices has a shape of `batch_size`, `seq_len`, `num_experts` token_indices = router_mask[:, :, idx].bool() - - # 2. Update only the hidden states affected by the routing next_states[token_indices] = expert(hidden_states[token_indices]) hidden_states = router_probs * next_states From 88e68b5ef8b35f693cb25a042276e1390e1eb1a4 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 2 Nov 2022 16:41:24 +0100 Subject: [PATCH 077/102] remove assert --- .../modeling_switch_transformers.py | 22 +++++++++++++------ 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 8884713ad35ff..83cf128928520 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -336,12 +336,12 @@ def forward(self, hidden_states): r""" Hold on, this will be slightly tricky to understand In the correct order, a MoE layer does the following: - 1- Gets the `router_mask` from the router. The shape of the mask is `(batch_size, sequence_length, num_expert)` and corresponds to the argmax of the `router_probs`. The probabilities are needed in the computation of the - hidden states : they are broadcasted to the hidden states values (can be interpreted - as a scaling factor). + 1- Gets the `router_mask` from the router. The shape of the mask is `(batch_size, sequence_length, num_expert)` + and corresponds to the argmax of the `router_probs`. The probabilities are needed in the computation of the + hidden states : they are broadcasted to the hidden states values (can be interpreted as a scaling factor). - 2- Dispatch the tokens to its associated experts. We do a classic for loop over the experts and assign for each expert the - corresponding hidden states. + 2- Dispatch the tokens to its associated experts. We do a classic for loop over the experts and assign for each + expert the corresponding hidden states. """ # Step 1: Get the router_mask from the router as wel as the probabilities @@ -1811,8 +1811,16 @@ def _reorder_cache(self, past, beam_idx): layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)), ) - assert reordered_layer_past_states[0].shape == layer_past_states[0].shape - assert len(reordered_layer_past_states) == len(layer_past_states) + if reordered_layer_past_states[0].shape != layer_past_states[0].shape: + raise ValueError( + "expected reordered_layer_past_states to have the same shape than layer_past_states" + f"but got {reordered_layer_past_states[0].shape} and {layer_past_states[0].shape}" + ) + if len(reordered_layer_past_states) != len(layer_past_states): + raise ValueError( + "expected layer_past_states to have the same length as reordered_layer_past_states" + f"got {len(layer_past_states)} and {len(reordered_layer_past_states)}" + ) reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) return reordered_decoder_past From 16e23c4bd4456ffbdf122f9fce160086514e5864 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 2 Nov 2022 16:45:11 +0100 Subject: [PATCH 078/102] change config order --- .../configuration_switch_transformers.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/switch_transformers/configuration_switch_transformers.py b/src/transformers/models/switch_transformers/configuration_switch_transformers.py index 6ab55e750a1e6..f1c91ac8f80f8 100644 --- a/src/transformers/models/switch_transformers/configuration_switch_transformers.py +++ b/src/transformers/models/switch_transformers/configuration_switch_transformers.py @@ -69,13 +69,11 @@ class SwitchTransformersConfig(PretrainedConfig): Whether to add a bias to the router. router_jitter_noise (`float`, *optional*, defaults to 0.1): Amount of noise to add to the router. - router_ignore_padding_tokens (`bool`, *optional*, defaults to `False`): - Whether to ignore padding tokens when routing. router_dtype (`str`, *optional*, default to `float32`): The `dtype` used for the routers. It is preferable to keep the `dtype` to `float32` as specified in the "selective precision" discussion in https://arxiv.org/abs/2101.03961. - add_router_probs (`bool`, *optional*, defaults to `False`): - Whether to output router probabilities to compute router auxiliary loss. + router_ignore_padding_tokens (`bool`, *optional*, defaults to `False`): + Whether to ignore padding tokens when routing. relative_attention_num_buckets (`int`, *optional*, defaults to 32): The number of buckets to use for each attention layer. relative_attention_max_distance (`int`, *optional*, defaults to 128): @@ -94,6 +92,8 @@ class SwitchTransformersConfig(PretrainedConfig): feed_forward_proj (`string`, *optional*, defaults to `"relu"`): Type of feed forward layer to be used. Should be one of `"relu"` or `"gated-gelu"`. SwitchTransformersv1.1 uses the `"gated-gelu"` feed forward projection. Original SwitchTransformers uses `"relu"`. + add_router_probs (`bool`, *optional*, defaults to `False`): + Whether to output router probabilities to compute router auxiliary loss. use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). """ @@ -107,13 +107,13 @@ def __init__( d_model=768, d_kv=64, d_ff=2048, + expert_capacity=64, num_layers=12, num_sparse_encoder_layers=3, num_decoder_layers=12, num_sparse_decoder_layers=3, num_heads=12, num_experts=8, - expert_capacity=64, router_type="tokens_masked", router_bias=False, router_jitter_noise=0.01, @@ -122,9 +122,9 @@ def __init__( relative_attention_num_buckets=32, relative_attention_max_distance=128, dropout_rate=0.1, + layer_norm_epsilon=1e-6, router_z_loss_coef=0.001, router_aux_loss_coef=0.001, - layer_norm_epsilon=1e-6, initializer_factor=1.0, feed_forward_proj="relu", is_encoder_decoder=True, From 1231e2ba3a80a7c00df3e16db656ecbdeaa169c9 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 2 Nov 2022 16:53:59 +0100 Subject: [PATCH 079/102] fix readme japanese --- README_ja.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README_ja.md b/README_ja.md index eed7d204f8368..a397130662acb 100644 --- a/README_ja.md +++ b/README_ja.md @@ -410,6 +410,7 @@ Flax、PyTorch、TensorFlowをcondaでインストールする方法は、それ 1. **[SqueezeBERT](https://huggingface.co/docs/transformers/model_doc/squeezebert)** (from Berkeley) released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer. 1. **[Swin Transformer](https://huggingface.co/docs/transformers/model_doc/swin)** (from Microsoft) released with the paper [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030) by Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, Baining Guo. 1. **[Swin Transformer V2](https://huggingface.co/docs/transformers/model_doc/swinv2)** (from Microsoft) released with the paper [Swin Transformer V2: Scaling Up Capacity and Resolution](https://arxiv.org/abs/2111.09883) by Ze Liu, Han Hu, Yutong Lin, Zhuliang Yao, Zhenda Xie, Yixuan Wei, Jia Ning, Yue Cao, Zheng Zhang, Li Dong, Furu Wei, Baining Guo. +1. **[SwitchTransformers](https://huggingface.co/docs/transformers/main/model_doc/switch_transformers)** (from Google) released with the paper [Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity](https://arxiv.org/abs/2101.03961) by William Fedus, Barret Zoph, Noam Shazeer. 1. **[T5](https://huggingface.co/docs/transformers/model_doc/t5)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[T5v1.1](https://huggingface.co/docs/transformers/model_doc/t5v1.1)** (from Google AI) released in the repository [google-research/text-to-text-transfer-transformer](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[Table Transformer](https://huggingface.co/docs/transformers/model_doc/table-transformer)** (from Microsoft Research) released with the paper [PubTables-1M: Towards Comprehensive Table Extraction From Unstructured Documents](https://arxiv.org/abs/2110.00061) by Brandon Smock, Rohith Pesala, Robin Abraham. From b4360d2334cc8f585bd75a3c2a866c006fa73a0c Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Thu, 3 Nov 2022 23:45:40 +0100 Subject: [PATCH 080/102] Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- docs/source/en/model_doc/switch_transformers.mdx | 4 ++-- .../configuration_switch_transformers.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/source/en/model_doc/switch_transformers.mdx b/docs/source/en/model_doc/switch_transformers.mdx index 5a7a45f3e6400..348c831a0e985 100644 --- a/docs/source/en/model_doc/switch_transformers.mdx +++ b/docs/source/en/model_doc/switch_transformers.mdx @@ -16,7 +16,7 @@ specific language governing permissions and limitations under the License. The SwitchTransformers model was proposed in [Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity](https://arxiv.org/abs/2101.03961) by William Fedus, Barret Zoph, Noam Shazeer. -The Switch Transformer model uses a sparse T5 encoder-decoder architure, where the MLP are replace by a Mixture of Experts (MoE). A routing mecanism (top 1 in this case) associates each token to one of the expert, where each expert is a dense MLP. While switch tranformers have a lot more weights than their equivalent dense models, the sparsity allows better scaling and better finetuning performance at scale. +The Switch Transformer model uses a sparse T5 encoder-decoder architecure, where the MLP are replaced by a Mixture of Experts (MoE). A routing mechanism (top 1 in this case) associates each token to one of the expert, where each expert is a dense MLP. While switch transformers have a lot more weights than their equivalent dense models, the sparsity allows better scaling and better finetuning performance at scale. During a forward pass, only a fraction of the weights are used. The routing mecanism allows the model to select relevant weights on the fly which increases the model capacity without increasing the number of operations. @@ -26,7 +26,7 @@ The abstract from the paper is the following: Tips: -- SwitchTransformers uses the T5Tokenizer, which can be loaded directly from each model's repository. +- SwitchTransformers uses the [`T5Tokenizer`], which can be loaded directly from each model's repository. - The released weights are pretrained on English [Masked Language Modeling](https://moon-ci-docs.huggingface.co/docs/transformers/pr_19323/en/glossary#general-terms) task, and should be finetuned. This model was contributed by [Younes Belkada](https://huggingface.co/ybelkada) and [Arthur Zucker](https://huggingface.co/ArtZucker) . diff --git a/src/transformers/models/switch_transformers/configuration_switch_transformers.py b/src/transformers/models/switch_transformers/configuration_switch_transformers.py index f1c91ac8f80f8..625b51a2241ac 100644 --- a/src/transformers/models/switch_transformers/configuration_switch_transformers.py +++ b/src/transformers/models/switch_transformers/configuration_switch_transformers.py @@ -63,15 +63,15 @@ class SwitchTransformersConfig(PretrainedConfig): Number of attention heads for each attention layer in the Transformer encoder. num_experts (`int`, *optional*, defaults to 8): Number of experts for each SwitchTransformer layer. - router_type (`str`, *optional*, defaults to `tokens_masked`): - Router type - choice between `tokens_masked` and `tokens_scatter`, `experts_masked`. + router_type (`str`, *optional*, defaults to `"tokens_masked"`): + Router type - choose between `"tokens_masked", `"tokens_scatter"` and `"experts_masked"`. router_bias (`bool`, *optional*, defaults to `True`): Whether to add a bias to the router. router_jitter_noise (`float`, *optional*, defaults to 0.1): Amount of noise to add to the router. - router_dtype (`str`, *optional*, default to `float32`): - The `dtype` used for the routers. It is preferable to keep the `dtype` to `float32` as specified in the - "selective precision" discussion in https://arxiv.org/abs/2101.03961. + router_dtype (`str`, *optional*, default to `"float32"`): + The `dtype` used for the routers. It is preferable to keep the `dtype` to `"float32"` as specified in the + *selective precision* discussion in [the paper](https://arxiv.org/abs/2101.03961). router_ignore_padding_tokens (`bool`, *optional*, defaults to `False`): Whether to ignore padding tokens when routing. relative_attention_num_buckets (`int`, *optional*, defaults to 32): From 0240906ec6148c9996d535d127626819e8e2e2ca Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 3 Nov 2022 23:50:09 +0100 Subject: [PATCH 081/102] remove parallelizable tests + add one liners --- .../test_modeling_switch_transformers.py | 30 +++---------------- 1 file changed, 4 insertions(+), 26 deletions(-) diff --git a/tests/models/switch_transformers/test_modeling_switch_transformers.py b/tests/models/switch_transformers/test_modeling_switch_transformers.py index 21d4e7dc8c3c8..01e8886b2ea23 100644 --- a/tests/models/switch_transformers/test_modeling_switch_transformers.py +++ b/tests/models/switch_transformers/test_modeling_switch_transformers.py @@ -553,9 +553,6 @@ class SwitchTransformersModelTest(ModelTesterMixin, GenerationTesterMixin, unitt (SwitchTransformersModel, SwitchTransformersForConditionalGeneration) if is_torch_available() else () ) all_generative_model_classes = (SwitchTransformersForConditionalGeneration,) if is_torch_available() else () - all_parallelizable_model_classes = ( - (SwitchTransformersModel, SwitchTransformersForConditionalGeneration) if is_torch_available() else () - ) fx_compatible = False test_pruning = False test_resize_embeddings = True @@ -877,18 +874,9 @@ def prepare_config_and_inputs(self): is_encoder_decoder=self.is_encoder_decoder, ) - return ( - config, - input_ids, - attention_mask, - ) + return config, input_ids, attention_mask - def create_and_check_model( - self, - config, - input_ids, - attention_mask, - ): + def create_and_check_model(self, config, input_ids, attention_mask): model = SwitchTransformersEncoderModel(config=config) model.to(torch_device) model.eval() @@ -901,23 +889,14 @@ def create_and_check_model( self.parent.assertEqual(encoder_output.size(), (self.batch_size, self.encoder_seq_length, self.hidden_size)) - def create_and_check_model_fp16_forward( - self, - config, - input_ids, - attention_mask, - ): + def create_and_check_model_fp16_forward(self, config, input_ids, attention_mask): model = SwitchTransformersEncoderModel(config=config).to(torch_device).half().eval() output = model(input_ids, attention_mask=attention_mask)["last_hidden_state"] self.parent.assertFalse(torch.isnan(output).any().item()) def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() - ( - config, - input_ids, - attention_mask, - ) = config_and_inputs + config, input_ids, attention_mask = config_and_inputs inputs_dict = { "input_ids": input_ids, @@ -931,7 +910,6 @@ class SwitchTransformersEncoderOnlyModelTest(ModelTesterMixin, unittest.TestCase test_pruning = False test_resize_embeddings = False test_model_parallel = True - all_parallelizable_model_classes = (SwitchTransformersEncoderModel,) if is_torch_available() else () def setUp(self): self.model_tester = SwitchTransformersEncoderOnlyModelTester(self) From 09788717502af1cbd8dea7562c16e63d82db6644 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 3 Nov 2022 23:52:57 +0100 Subject: [PATCH 082/102] remove ONNX config --- .../configuration_switch_transformers.py | 28 ------------------- 1 file changed, 28 deletions(-) diff --git a/src/transformers/models/switch_transformers/configuration_switch_transformers.py b/src/transformers/models/switch_transformers/configuration_switch_transformers.py index 625b51a2241ac..95560cceeeb8a 100644 --- a/src/transformers/models/switch_transformers/configuration_switch_transformers.py +++ b/src/transformers/models/switch_transformers/configuration_switch_transformers.py @@ -13,10 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ Switch Transformers model configuration""" -from typing import Mapping - from ...configuration_utils import PretrainedConfig -from ...onnx import OnnxSeq2SeqConfigWithPast from ...utils import logging @@ -205,28 +202,3 @@ def __init__( is_encoder_decoder=is_encoder_decoder, **kwargs, ) - - -class SwitchTransformersOnnxConfig(OnnxSeq2SeqConfigWithPast): - @property - def inputs(self) -> Mapping[str, Mapping[int, str]]: - common_inputs = { - "input_ids": {0: "batch", 1: "encoder_sequence"}, - "attention_mask": {0: "batch", 1: "encoder_sequence"}, - } - if self.use_past: - common_inputs["attention_mask"][1] = "past_encoder_sequence + sequence" - common_inputs["decoder_input_ids"] = {0: "batch"} - common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"} - else: - common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"} - common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"} - - if self.use_past: - self.fill_with_past_key_values_(common_inputs, direction="inputs") - - return common_inputs - - @property - def default_onnx_opset(self) -> int: - return 13 From ccc28a9f1652aaedf9590842b1d624cc697acd58 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 3 Nov 2022 23:57:02 +0100 Subject: [PATCH 083/102] fix nits - add `T5Tokenizer` in auto mapping - remove `Switch Transformers` from ONNX supported models --- docs/source/en/serialization.mdx | 1 - src/transformers/models/auto/tokenization_auto.py | 7 +++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/docs/source/en/serialization.mdx b/docs/source/en/serialization.mdx index 7a6650e2e79d4..1cbc1237f286b 100644 --- a/docs/source/en/serialization.mdx +++ b/docs/source/en/serialization.mdx @@ -96,7 +96,6 @@ Ready-made configurations include the following architectures: - SegFormer - SqueezeBERT - Swin Transformer -- SwitchTransformers - T5 - Table Transformer - Vision Encoder decoder diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 46e57ac58bd4c..76886bd493ae8 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -249,6 +249,13 @@ "squeezebert", ("SqueezeBertTokenizer", "SqueezeBertTokenizerFast" if is_tokenizers_available() else None), ), + ( + "switch_transformers", + ( + "T5Tokenizer" if is_sentencepiece_available() else None, + "T5TokenizerFast" if is_tokenizers_available() else None, + ), + ), ( "t5", ( From ef1fa19beb6687ffb913bf560b8eb7b7508936ab Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 3 Nov 2022 23:59:34 +0100 Subject: [PATCH 084/102] remove `_get_router` --- .../modeling_switch_transformers.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 83cf128928520..6ab5c1c11297a 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -310,28 +310,13 @@ class SwitchTransformersSparseMLP(nn.Module): def __init__(self, config: SwitchTransformersConfig, expert_class: nn.Module = SwitchTransformersDenseActDense): super().__init__() # Step 1: Get the correct router according to its class - self.router = self._get_router(config) + self.router = SwitchTransformersTop1Router(config) # Step 2: Get the experts self.experts = nn.ModuleDict() for idx in range(config.num_experts): self.experts[f"expert_{idx}"] = expert_class(config) - def _get_router(self, config): - r""" - For now two types of Router are supported: - - Masked Routers - - Scatter Routers - - """ - if config.router_type.lower() == "tokens_masked": - return SwitchTransformersTop1Router(config) - else: - raise NotImplementedError( - f"{config.router_type.lower()} not implemented ! Please chose a router in " - "`{'tokens_masked','experts_masked'}`" - ) - def forward(self, hidden_states): r""" Hold on, this will be slightly tricky to understand In the correct order, a MoE layer does the following: From e2dc2b6a417335fdf14c9b4ed5322a415ba4aa74 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 4 Nov 2022 00:02:02 +0100 Subject: [PATCH 085/102] remove asserts --- .../modeling_switch_transformers.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 6ab5c1c11297a..ca29fdcbf59b9 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -898,10 +898,11 @@ def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id pad_token_id = self.config.pad_token_id - assert decoder_start_token_id is not None, ( - "self.model.config.decoder_start_token_id has to be defined. In SwitchTransformers it is usually set to" - " the pad_token_id. See SwitchTransformers docs for more information" - ) + if decoder_start_token_id is None: + raise ValueError( + "self.model.config.decoder_start_token_id has to be defined. In SwitchTransformers it is usually set" + " to the pad_token_id. See SwitchTransformers docs for more information" + ) # shift inputs to the right if is_torch_fx_proxy(input_ids): @@ -913,7 +914,8 @@ def _shift_right(self, input_ids): shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() shifted_input_ids[..., 0] = decoder_start_token_id - assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") # replace possible -100 values in labels by `pad_token_id` shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) From a2d786b4b3f24409bb596ee1868d2ec9cb3698f6 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 4 Nov 2022 00:03:49 +0100 Subject: [PATCH 086/102] add check in test for `router_dtype` --- .../switch_transformers/configuration_switch_transformers.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/models/switch_transformers/configuration_switch_transformers.py b/src/transformers/models/switch_transformers/configuration_switch_transformers.py index 95560cceeeb8a..20392cea4de54 100644 --- a/src/transformers/models/switch_transformers/configuration_switch_transformers.py +++ b/src/transformers/models/switch_transformers/configuration_switch_transformers.py @@ -162,6 +162,8 @@ def __init__( self.expert_capacity = expert_capacity self.router_bias = router_bias self.router_jitter_noise = router_jitter_noise + if router_dtype not in ["float32", "float16", "bfloat16"]: + raise ValueError(f"`router_dtype` must be one of 'float32', 'float16' or 'bfloat16', got {router_dtype}") self.router_dtype = router_dtype if router_dtype not in ["float16", "float32", "bfloat16"]: From ab67f48078be72c78200604c2023f4590c32e3e8 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 4 Nov 2022 00:22:44 +0100 Subject: [PATCH 087/102] add `SwitchTransformersConfig` in `run_pipeline_test` --- tests/pipelines/test_pipelines_summarization.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/pipelines/test_pipelines_summarization.py b/tests/pipelines/test_pipelines_summarization.py index 50e8315a5f1e0..6a7cda8bec8f5 100644 --- a/tests/pipelines/test_pipelines_summarization.py +++ b/tests/pipelines/test_pipelines_summarization.py @@ -20,6 +20,7 @@ LEDConfig, LongT5Config, SummarizationPipeline, + SwitchTransformersConfig, T5Config, pipeline, ) @@ -54,7 +55,7 @@ def run_pipeline_test(self, summarizer, _): ) self.assertEqual(outputs, [{"summary_text": ANY(str)}]) - if not isinstance(model.config, (T5Config, LongT5Config, LEDConfig)): + if not isinstance(model.config, (SwitchTransformersConfig, T5Config, LongT5Config, LEDConfig)): # LED, T5, LongT5 can handle it. # Too long. with self.assertRaises(Exception): From 7c3e5aabf6cd72b84e24af34d377fc1310cdca8f Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Fri, 4 Nov 2022 15:35:44 +0100 Subject: [PATCH 088/102] Update tests/pipelines/test_pipelines_summarization.py --- tests/pipelines/test_pipelines_summarization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/test_pipelines_summarization.py b/tests/pipelines/test_pipelines_summarization.py index 6a7cda8bec8f5..c4c646cee96bd 100644 --- a/tests/pipelines/test_pipelines_summarization.py +++ b/tests/pipelines/test_pipelines_summarization.py @@ -56,7 +56,7 @@ def run_pipeline_test(self, summarizer, _): self.assertEqual(outputs, [{"summary_text": ANY(str)}]) if not isinstance(model.config, (SwitchTransformersConfig, T5Config, LongT5Config, LEDConfig)): - # LED, T5, LongT5 can handle it. + # Switch Transformers, LED, T5, LongT5 can handle it. # Too long. with self.assertRaises(Exception): outputs = summarizer("This " * 1000) From 1326126caa66a200b7dc94051f4b976a90fe460c Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 4 Nov 2022 16:48:48 +0000 Subject: [PATCH 089/102] add huge model conversion script --- .../switch_transformers/convert_big_switch.py | 168 ++++++++++++++++++ ...ers_original_flax_checkpoint_to_pytorch.py | 14 +- 2 files changed, 176 insertions(+), 6 deletions(-) create mode 100644 src/transformers/models/switch_transformers/convert_big_switch.py diff --git a/src/transformers/models/switch_transformers/convert_big_switch.py b/src/transformers/models/switch_transformers/convert_big_switch.py new file mode 100644 index 0000000000000..16d114f9c3fc2 --- /dev/null +++ b/src/transformers/models/switch_transformers/convert_big_switch.py @@ -0,0 +1,168 @@ +from typing import Dict, Union +from sqlalchemy import false +import torch +from transformers.utils.hub import convert_file_size_to_int +from transformers.utils import WEIGHTS_NAME, WEIGHTS_INDEX_NAME +from transformers.modeling_utils import dtype_byte_size +import os +from transformers.models.switch_transformers.convert_switch_transformers_original_flax_checkpoint_to_pytorch import rename_keys +from flax.traverse_util import flatten_dict, unflatten_dict +from tensorflow.io import gfile +import tensorstore as ts +from flax import serialization +import json +import argparse + + +def rename_base_flax_keys(flax_key_tuple, flax_tensor): + """ + Post renaming of basic JAX keys to pytorch. + """ + if flax_key_tuple[-1] == "kernel" and flax_tensor.ndim == 3: + # expert layer + flax_key_tuple = flax_key_tuple[:-1] + ("weight",) + flax_tensor = torch.permute(flax_tensor, ( 0, 2, 1)) + elif flax_key_tuple[-1] == "kernel" and ".".join(flax_key_tuple): + # linear layer + flax_key_tuple = flax_key_tuple[:-1] + ("weight",) + flax_tensor = flax_tensor.T + elif flax_key_tuple[-1] in ["scale", "embedding"]: + flax_key_tuple = flax_key_tuple[:-1] + ("weight",) + + return flax_key_tuple, flax_tensor + + +def get_key_and_tensorstore_dict(layer, checkpoint_info, switch_checkpoint_path): + if "metadata" in layer : + split_layer = layer.split("metadata") + curr_real_layer_name = "".join(split_layer[0])[:-1] + split_layer = [tuple(("metadata"+ split_layer[1]).split("/"))] + elif "kvstore" in layer : + split_layer = layer.split("kvstore") + curr_real_layer_name = "".join(split_layer[0])[:-1] + split_layer = [tuple(("kvstore"+ split_layer[1]).split("/"))] + + else: + split_layer = layer.split("/") + curr_real_layer_name = "/".join(split_layer[:-1]) + split_layer[-1] = (split_layer[-1],) + + if "kvstore/path" in layer: + content = f"{switch_checkpoint_path}/{checkpoint_info[layer]}" + elif "kvstore/driver" in layer: + content = "file" + else : + content = checkpoint_info[layer] + + return curr_real_layer_name, split_layer, content + +def rename_and_save_block(current_block, save_path): + current_block = rename_keys(current_block) + new_current_block = {} + for k,v in current_block.items(): + new_current_block[k.replace("/",".")] = v + current_block = new_current_block + torch.save(current_block, save_path) + + +def shard_on_the_fly(switch_checkpoint_path, dump_path, max_shard_size, dtype, weights_name: str = WEIGHTS_NAME): + max_shard_size = convert_file_size_to_int(max_shard_size) + sharded_state_dicts = [] + current_block = {} + current_block_size = 0 + total_size = 0 + with gfile.GFile(switch_checkpoint_path+"/checkpoint",'rb') as fp: + checkpoint_info = serialization.msgpack_restore(fp.read())["optimizer"]["target"] + checkpoint_info = flatten_dict(checkpoint_info, sep="/") + + all_layers = {} + for layer in checkpoint_info.keys(): + curr_real_layer_name, split_layer, content = get_key_and_tensorstore_dict(layer, checkpoint_info, switch_checkpoint_path) + if curr_real_layer_name in all_layers : + all_layers[curr_real_layer_name][split_layer[-1]] = content + else : + all_layers[curr_real_layer_name] = {split_layer[-1]: content} + + for key in all_layers.keys(): + # open tensorstore file + raw_weights = ts.open(unflatten_dict(all_layers[key])).result().read().result() + raw_weights = torch.tensor(raw_weights) + weight_size = raw_weights.numel() * dtype_byte_size(raw_weights.dtype) + + # use the renaming pattern from the small conversion scripts + key, raw_weights = rename_base_flax_keys(tuple(key.split("/")), raw_weights) + key = "/".join(key) + + + # If this weight is going to tip up over the maximal size, we split. + if current_block_size + weight_size > max_shard_size: + save_path = os.path.join(dump_path, weights_name.replace(".bin", f"-{len(sharded_state_dicts)+1:05d}-of-???.bin")) + rename_and_save_block(current_block, save_path) + sharded_state_dicts.append(current_block.keys()) + del current_block + current_block = {} + current_block_size = 0 + + current_block[key] = raw_weights.to(getattr(torch,dtype)) + current_block_size += weight_size + total_size += weight_size + + # Add the last block + save_path = os.path.join(dump_path, weights_name.replace(".bin", f"-{len(sharded_state_dicts)+1:05d}-of-???.bin")) + rename_and_save_block(current_block, save_path) + sharded_state_dicts.append(current_block.keys()) + + # If we only have one shard, we return it + if len(sharded_state_dicts) == 1: + return {weights_name: sharded_state_dicts[0]}, None + + # Otherwise, let's build the index + weight_map = {} + shards = {} + for idx, shard in enumerate(sharded_state_dicts): + shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin") # len(sharded_state_dicts):05d} + temp_filename = os.path.join(dump_path,weights_name.replace(".bin", f"-{idx+1:05d}-of-???.bin")) + os.rename(temp_filename, os.path.join(dump_path,shard_file)) + shards[shard_file] = shard + for key in shard: + weight_map[key] = shard_file + + # Add the metadata + metadata = {"total_size": total_size} + index = {"metadata": metadata, "weight_map": weight_map} + + with open(os.path.join(dump_path,WEIGHTS_INDEX_NAME), "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + + return metadata, index + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument("--switch_t5x_checkpoint_path",default="/home/younes_huggingface_co/convert_switch/switch-base-8/checkpoint_500100",type=str,required=False,help=("Path to a directory containing a folder per layer. Follows the original Google format."),) + parser.add_argument("--max_shard_size", default="10GB", required=False, help="Max shard size") + parser.add_argument("--dtype", default="bfloat16", type=str, required=False, help="dtype of the saved model") + parser.add_argument("--pytorch_dump_folder_path", default="/home/arthur_huggingface_co/transformers/switch_converted", type=str, required=False, help="Path to the output pytorch model.") + args = parser.parse_args() + shard_on_the_fly( + args.switch_t5x_checkpoint_path, + args.pytorch_dump_folder_path, + args.max_shard_size, + args.dtype, + ) + + +def sanity_check(): + from transformers import SwitchTransformersForConditionalGeneration, SwitchTransformersConfig, T5Tokenizer + config = SwitchTransformersConfig.from_pretrained("google/switch-base-8") + config.save_pretrained("/home/arthur_huggingface_co/transformers/switch_converted") + model = SwitchTransformersForConditionalGeneration.from_pretrained("/home/arthur_huggingface_co/transformers/switch_converted", device_map = "auto") + + tokenizer = T5Tokenizer.from_pretrained("t5-small") + text = "A walks into a bar a orders a with pinch of ." + + + input_ids = tokenizer(text, return_tensors="pt").input_ids + out = model.generate(input_ids, decoder_start_token_id=0) + print(tokenizer.decode(out[0])) \ No newline at end of file diff --git a/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py b/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py index f0f419752706f..45cd63e474333 100644 --- a/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py +++ b/src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py @@ -80,12 +80,14 @@ def rename_keys(s_dict): print(f"{key} -> {new_key}") s_dict[new_key] = s_dict.pop(key) - s_dict["encoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight"] = s_dict[ - "encoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight" - ].T - s_dict["decoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight"] = s_dict[ - "decoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight" - ].T + if "encoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight" in s_dict: + s_dict["encoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight"] = s_dict[ + "encoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight" + ].T + if "decoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight" in s_dict: + s_dict["decoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight"] = s_dict[ + "decoder/block/0/layer/0/SelfAttention/relative_attention_bias/weight" + ].T # 3. Take extra care of the EXPERTS layer for key in list(s_dict.keys()): From c7dec49630e30239667975211ee680f7f1b6a768 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 4 Nov 2022 17:09:30 +0000 Subject: [PATCH 090/102] fix slow tests - add better casting for `Linear8bitLt` - remove `torchscript` tests --- .../modeling_switch_transformers.py | 16 ++++++++++++++-- .../test_modeling_switch_transformers.py | 18 ++++++++++-------- 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index ca29fdcbf59b9..f548f36e012a2 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -184,13 +184,21 @@ def _compute_router_probabilities(self, hidden_states: torch.Tensor) -> Tuple[to hidden_states *= uniform_distrib # Shape: [num_groups, tokens_per_group, num_experts] - self.classifier = self.classifier.to(self.dtype) + self._cast_classifier() router_logits = self.classifier(hidden_states) # Apply Softmax and cast back to the original `dtype` router_probabilities = nn.functional.softmax(router_logits, dim=-1, dtype=self.dtype).to(self.input_dtype) return router_probabilities, router_logits + def _cast_classifier(self): + r""" + `bitsandbytes` `Linear8bitLt` layers does not support manual casting Therefore we need to check if they are an + instance of the `Linear8bitLt` class by checking special attributes. + """ + if not (hasattr(self.classifier, "SCB") or hasattr(self.classifier, "CB")): + self.classifier = self.classifier.to(self.dtype) + def forward(self, hidden_states: torch.Tensor) -> Tuple: r""" Generic forward function for every Router class. Each Router expects to have the same input hidden states @@ -926,7 +934,11 @@ class SwitchTransformersStack(SwitchTransformersPreTrainedModel): def __init__(self, config, embed_tokens=None): super().__init__(config) - self.embed_tokens = embed_tokens + self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model) + + if embed_tokens is not None: + self.embed_tokens.weight = embed_tokens.weight + self.is_decoder = config.is_decoder sparse_step = config.decoder_sparse_step if self.is_decoder else config.encoder_sparse_step diff --git a/tests/models/switch_transformers/test_modeling_switch_transformers.py b/tests/models/switch_transformers/test_modeling_switch_transformers.py index 01e8886b2ea23..67411663c2cba 100644 --- a/tests/models/switch_transformers/test_modeling_switch_transformers.py +++ b/tests/models/switch_transformers/test_modeling_switch_transformers.py @@ -103,7 +103,7 @@ def __init__( self.router_jitter_noise = router_jitter_noise def get_large_model_config(self): - return SwitchTransformersConfig.from_pretrained("HFLAY/switch_base_8") + return SwitchTransformersConfig.from_pretrained("google/switch-base-8") def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size) @@ -411,7 +411,7 @@ def create_and_check_generate_with_past_key_values( This test does not pass for small models due to precision errors. It is therefore only run for slightly larger models. """ model = ( - SwitchTransformersForConditionalGeneration.from_pretrained("HFLAY/switch_base_8").to(torch_device).eval() + SwitchTransformersForConditionalGeneration.from_pretrained("google/switch-base-8").to(torch_device).eval() ) torch.manual_seed(0) output_without_past_cache = model.generate( @@ -558,6 +558,7 @@ class SwitchTransformersModelTest(ModelTesterMixin, GenerationTesterMixin, unitt test_resize_embeddings = True test_model_parallel = False is_encoder_decoder = True + test_torchscript = False # The small SWITCH_TRANSFORMERS model needs higher percentages for CPU/MP tests model_split_percents = [0.8, 0.9] @@ -619,7 +620,7 @@ def test_beam_sample_generate_dict_output(self): config.eos_token_id = None config.forced_eos_token_id = None - model = model_class.from_pretrained("HFLAY/switch_base_8").to(torch_device).eval() + model = model_class.from_pretrained("google/switch-base-8").to(torch_device).eval() logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) num_return_sequences = 2 @@ -671,7 +672,7 @@ def test_beam_sample_generate(self): logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) - model = model_class.from_pretrained("HFLAY/switch_base_8").to(torch_device).eval() + model = model_class.from_pretrained("google/switch-base-8").to(torch_device).eval() # check `generate()` and `beam_search()` are equal # change `num_return_sequences = 2` but not for `beam_scorer` @@ -909,7 +910,8 @@ class SwitchTransformersEncoderOnlyModelTest(ModelTesterMixin, unittest.TestCase all_model_classes = (SwitchTransformersEncoderModel,) if is_torch_available() else () test_pruning = False test_resize_embeddings = False - test_model_parallel = True + test_model_parallel = False + test_torchscript = False def setUp(self): self.model_tester = SwitchTransformersEncoderOnlyModelTester(self) @@ -1108,7 +1110,7 @@ def test_small_logits(self): and `transformers` implementation of Switch-C transformers. We only check the logits of the first batch. """ - model = SwitchTransformersModel.from_pretrained("HFLAY/switch_base_8", torch_dtype=torch.bfloat16).eval() + model = SwitchTransformersModel.from_pretrained("google/switch-base-8", torch_dtype=torch.bfloat16).eval() input_ids = torch.ones((32, 64), dtype=torch.long) decoder_input_ids = torch.ones((32, 64), dtype=torch.long) @@ -1134,7 +1136,7 @@ def test_small_generate(self): # Generate test using the smalled switch-C model. model = SwitchTransformersForConditionalGeneration.from_pretrained( - "HFLAY/switch_base_8", torch_dtype=torch.bfloat16 + "google/switch-base-8", torch_dtype=torch.bfloat16 ).eval() tokenizer = AutoTokenizer.from_pretrained("t5-small") model = model.to(torch_device) @@ -1159,7 +1161,7 @@ def test_small_generate(self): def test_small_batch_generate(self): BATCH_SIZE = 4 model = SwitchTransformersForConditionalGeneration.from_pretrained( - "HFLAY/switch_base_8", torch_dtype=torch.bfloat16 + "google/switch-base-8", torch_dtype=torch.bfloat16 ).eval() tokenizer = AutoTokenizer.from_pretrained("t5-small") From edc772391c3cd6c8be9d507fe3dffbdb67af408e Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 4 Nov 2022 17:14:24 +0000 Subject: [PATCH 091/102] add make dir --- .../models/switch_transformers/convert_big_switch.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/switch_transformers/convert_big_switch.py b/src/transformers/models/switch_transformers/convert_big_switch.py index 16d114f9c3fc2..118bf3e98d924 100644 --- a/src/transformers/models/switch_transformers/convert_big_switch.py +++ b/src/transformers/models/switch_transformers/convert_big_switch.py @@ -71,6 +71,8 @@ def shard_on_the_fly(switch_checkpoint_path, dump_path, max_shard_size, dtype, current_block = {} current_block_size = 0 total_size = 0 + + os.makedirs(dump_path,exist_ok=True) with gfile.GFile(switch_checkpoint_path+"/checkpoint",'rb') as fp: checkpoint_info = serialization.msgpack_restore(fp.read())["optimizer"]["target"] checkpoint_info = flatten_dict(checkpoint_info, sep="/") @@ -140,10 +142,10 @@ def shard_on_the_fly(switch_checkpoint_path, dump_path, max_shard_size, dtype, if __name__ == "__main__": parser = argparse.ArgumentParser() # Required parameters - parser.add_argument("--switch_t5x_checkpoint_path",default="/home/younes_huggingface_co/convert_switch/switch-base-8/checkpoint_500100",type=str,required=False,help=("Path to a directory containing a folder per layer. Follows the original Google format."),) + parser.add_argument("--switch_t5x_checkpoint_path",default="/mnt/disks/disk_switch/original_checkpoints/switch-xxl-128/checkpoint_634600",type=str,required=False,help=("Path to a directory containing a folder per layer. Follows the original Google format."),) parser.add_argument("--max_shard_size", default="10GB", required=False, help="Max shard size") parser.add_argument("--dtype", default="bfloat16", type=str, required=False, help="dtype of the saved model") - parser.add_argument("--pytorch_dump_folder_path", default="/home/arthur_huggingface_co/transformers/switch_converted", type=str, required=False, help="Path to the output pytorch model.") + parser.add_argument("--pytorch_dump_folder_path", default="/mnt/disks/disk_switch/original_checkpoints/switch-xxl-128-converted", type=str, required=False, help="Path to the output pytorch model.") args = parser.parse_args() shard_on_the_fly( args.switch_t5x_checkpoint_path, From d3a7795412368982f76d4f63d04ed0e9e8bfce96 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 4 Nov 2022 17:15:26 +0000 Subject: [PATCH 092/102] style on new script --- .../switch_transformers/convert_big_switch.py | 129 +++++++++++------- 1 file changed, 76 insertions(+), 53 deletions(-) diff --git a/src/transformers/models/switch_transformers/convert_big_switch.py b/src/transformers/models/switch_transformers/convert_big_switch.py index 118bf3e98d924..aa44f9a2190d9 100644 --- a/src/transformers/models/switch_transformers/convert_big_switch.py +++ b/src/transformers/models/switch_transformers/convert_big_switch.py @@ -1,111 +1,117 @@ -from typing import Dict, Union -from sqlalchemy import false -import torch -from transformers.utils.hub import convert_file_size_to_int -from transformers.utils import WEIGHTS_NAME, WEIGHTS_INDEX_NAME -from transformers.modeling_utils import dtype_byte_size +import argparse +import json import os -from transformers.models.switch_transformers.convert_switch_transformers_original_flax_checkpoint_to_pytorch import rename_keys -from flax.traverse_util import flatten_dict, unflatten_dict + +import torch from tensorflow.io import gfile + import tensorstore as ts from flax import serialization -import json -import argparse +from flax.traverse_util import flatten_dict, unflatten_dict +from transformers.modeling_utils import dtype_byte_size +from transformers.models.switch_transformers.convert_switch_transformers_original_flax_checkpoint_to_pytorch import ( + rename_keys, +) +from transformers.utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME +from transformers.utils.hub import convert_file_size_to_int def rename_base_flax_keys(flax_key_tuple, flax_tensor): """ - Post renaming of basic JAX keys to pytorch. + Post renaming of basic JAX keys to pytorch. """ if flax_key_tuple[-1] == "kernel" and flax_tensor.ndim == 3: # expert layer flax_key_tuple = flax_key_tuple[:-1] + ("weight",) - flax_tensor = torch.permute(flax_tensor, ( 0, 2, 1)) + flax_tensor = torch.permute(flax_tensor, (0, 2, 1)) elif flax_key_tuple[-1] == "kernel" and ".".join(flax_key_tuple): # linear layer flax_key_tuple = flax_key_tuple[:-1] + ("weight",) - flax_tensor = flax_tensor.T + flax_tensor = flax_tensor.T elif flax_key_tuple[-1] in ["scale", "embedding"]: flax_key_tuple = flax_key_tuple[:-1] + ("weight",) - + return flax_key_tuple, flax_tensor def get_key_and_tensorstore_dict(layer, checkpoint_info, switch_checkpoint_path): - if "metadata" in layer : + if "metadata" in layer: split_layer = layer.split("metadata") curr_real_layer_name = "".join(split_layer[0])[:-1] - split_layer = [tuple(("metadata"+ split_layer[1]).split("/"))] - elif "kvstore" in layer : + split_layer = [tuple(("metadata" + split_layer[1]).split("/"))] + elif "kvstore" in layer: split_layer = layer.split("kvstore") curr_real_layer_name = "".join(split_layer[0])[:-1] - split_layer = [tuple(("kvstore"+ split_layer[1]).split("/"))] - + split_layer = [tuple(("kvstore" + split_layer[1]).split("/"))] + else: split_layer = layer.split("/") curr_real_layer_name = "/".join(split_layer[:-1]) split_layer[-1] = (split_layer[-1],) - - if "kvstore/path" in layer: + + if "kvstore/path" in layer: content = f"{switch_checkpoint_path}/{checkpoint_info[layer]}" - elif "kvstore/driver" in layer: + elif "kvstore/driver" in layer: content = "file" - else : + else: content = checkpoint_info[layer] - + return curr_real_layer_name, split_layer, content + def rename_and_save_block(current_block, save_path): current_block = rename_keys(current_block) - new_current_block = {} - for k,v in current_block.items(): - new_current_block[k.replace("/",".")] = v + new_current_block = {} + for k, v in current_block.items(): + new_current_block[k.replace("/", ".")] = v current_block = new_current_block torch.save(current_block, save_path) - -def shard_on_the_fly(switch_checkpoint_path, dump_path, max_shard_size, dtype, weights_name: str = WEIGHTS_NAME): + +def shard_on_the_fly(switch_checkpoint_path, dump_path, max_shard_size, dtype, weights_name: str = WEIGHTS_NAME): max_shard_size = convert_file_size_to_int(max_shard_size) sharded_state_dicts = [] current_block = {} current_block_size = 0 total_size = 0 - - os.makedirs(dump_path,exist_ok=True) - with gfile.GFile(switch_checkpoint_path+"/checkpoint",'rb') as fp: + + os.makedirs(dump_path, exist_ok=True) + with gfile.GFile(switch_checkpoint_path + "/checkpoint", "rb") as fp: checkpoint_info = serialization.msgpack_restore(fp.read())["optimizer"]["target"] checkpoint_info = flatten_dict(checkpoint_info, sep="/") all_layers = {} for layer in checkpoint_info.keys(): - curr_real_layer_name, split_layer, content = get_key_and_tensorstore_dict(layer, checkpoint_info, switch_checkpoint_path) - if curr_real_layer_name in all_layers : + curr_real_layer_name, split_layer, content = get_key_and_tensorstore_dict( + layer, checkpoint_info, switch_checkpoint_path + ) + if curr_real_layer_name in all_layers: all_layers[curr_real_layer_name][split_layer[-1]] = content - else : + else: all_layers[curr_real_layer_name] = {split_layer[-1]: content} - + for key in all_layers.keys(): # open tensorstore file - raw_weights = ts.open(unflatten_dict(all_layers[key])).result().read().result() - raw_weights = torch.tensor(raw_weights) + raw_weights = ts.open(unflatten_dict(all_layers[key])).result().read().result() + raw_weights = torch.tensor(raw_weights) weight_size = raw_weights.numel() * dtype_byte_size(raw_weights.dtype) # use the renaming pattern from the small conversion scripts key, raw_weights = rename_base_flax_keys(tuple(key.split("/")), raw_weights) key = "/".join(key) - # If this weight is going to tip up over the maximal size, we split. if current_block_size + weight_size > max_shard_size: - save_path = os.path.join(dump_path, weights_name.replace(".bin", f"-{len(sharded_state_dicts)+1:05d}-of-???.bin")) + save_path = os.path.join( + dump_path, weights_name.replace(".bin", f"-{len(sharded_state_dicts)+1:05d}-of-???.bin") + ) rename_and_save_block(current_block, save_path) sharded_state_dicts.append(current_block.keys()) del current_block current_block = {} current_block_size = 0 - current_block[key] = raw_weights.to(getattr(torch,dtype)) + current_block[key] = raw_weights.to(getattr(torch, dtype)) current_block_size += weight_size total_size += weight_size @@ -122,9 +128,11 @@ def shard_on_the_fly(switch_checkpoint_path, dump_path, max_shard_size, dtype, weight_map = {} shards = {} for idx, shard in enumerate(sharded_state_dicts): - shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin") # len(sharded_state_dicts):05d} - temp_filename = os.path.join(dump_path,weights_name.replace(".bin", f"-{idx+1:05d}-of-???.bin")) - os.rename(temp_filename, os.path.join(dump_path,shard_file)) + shard_file = weights_name.replace( + ".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin" + ) # len(sharded_state_dicts):05d} + temp_filename = os.path.join(dump_path, weights_name.replace(".bin", f"-{idx+1:05d}-of-???.bin")) + os.rename(temp_filename, os.path.join(dump_path, shard_file)) shards[shard_file] = shard for key in shard: weight_map[key] = shard_file @@ -132,20 +140,33 @@ def shard_on_the_fly(switch_checkpoint_path, dump_path, max_shard_size, dtype, # Add the metadata metadata = {"total_size": total_size} index = {"metadata": metadata, "weight_map": weight_map} - - with open(os.path.join(dump_path,WEIGHTS_INDEX_NAME), "w", encoding="utf-8") as f: + + with open(os.path.join(dump_path, WEIGHTS_INDEX_NAME), "w", encoding="utf-8") as f: content = json.dumps(index, indent=2, sort_keys=True) + "\n" f.write(content) - + return metadata, index + if __name__ == "__main__": parser = argparse.ArgumentParser() # Required parameters - parser.add_argument("--switch_t5x_checkpoint_path",default="/mnt/disks/disk_switch/original_checkpoints/switch-xxl-128/checkpoint_634600",type=str,required=False,help=("Path to a directory containing a folder per layer. Follows the original Google format."),) + parser.add_argument( + "--switch_t5x_checkpoint_path", + default="/mnt/disks/disk_switch/original_checkpoints/switch-xxl-128/checkpoint_634600", + type=str, + required=False, + help="Path to a directory containing a folder per layer. Follows the original Google format.", + ) parser.add_argument("--max_shard_size", default="10GB", required=False, help="Max shard size") parser.add_argument("--dtype", default="bfloat16", type=str, required=False, help="dtype of the saved model") - parser.add_argument("--pytorch_dump_folder_path", default="/mnt/disks/disk_switch/original_checkpoints/switch-xxl-128-converted", type=str, required=False, help="Path to the output pytorch model.") + parser.add_argument( + "--pytorch_dump_folder_path", + default="/mnt/disks/disk_switch/original_checkpoints/switch-xxl-128-converted", + type=str, + required=False, + help="Path to the output pytorch model.", + ) args = parser.parse_args() shard_on_the_fly( args.switch_t5x_checkpoint_path, @@ -156,15 +177,17 @@ def shard_on_the_fly(switch_checkpoint_path, dump_path, max_shard_size, dtype, def sanity_check(): - from transformers import SwitchTransformersForConditionalGeneration, SwitchTransformersConfig, T5Tokenizer + from transformers import SwitchTransformersConfig, SwitchTransformersForConditionalGeneration, T5Tokenizer + config = SwitchTransformersConfig.from_pretrained("google/switch-base-8") config.save_pretrained("/home/arthur_huggingface_co/transformers/switch_converted") - model = SwitchTransformersForConditionalGeneration.from_pretrained("/home/arthur_huggingface_co/transformers/switch_converted", device_map = "auto") + model = SwitchTransformersForConditionalGeneration.from_pretrained( + "/home/arthur_huggingface_co/transformers/switch_converted", device_map="auto" + ) tokenizer = T5Tokenizer.from_pretrained("t5-small") text = "A walks into a bar a orders a with pinch of ." - input_ids = tokenizer(text, return_tensors="pt").input_ids out = model.generate(input_ids, decoder_start_token_id=0) - print(tokenizer.decode(out[0])) \ No newline at end of file + print(tokenizer.decode(out[0])) From 437eef7acbce56c1d70774c22474775ceda075a8 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 4 Nov 2022 17:56:19 +0000 Subject: [PATCH 093/102] fix nits - doctest - remove `_keys_to_ignore_on_load_unexpected` --- .../modeling_switch_transformers.py | 22 +++++++------------ 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index f548f36e012a2..4f083623d61bd 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -50,14 +50,14 @@ _CONFIG_FOR_DOC = "SwitchTransformersConfig" _TOKENIZER_FOR_DOC = "T5Tokenizer" -_CHECKPOINT_FOR_DOC = "HFLAY/switch_base_8" +_CHECKPOINT_FOR_DOC = "google/switch-base-8" #################################################### # This dict contains ids and associated url # for the pretrained weights provided with the models #################################################### SWITCH_TRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "HFLAY/switch_base_8", + "google/switch-base-8", "google/switch-base-16", "google/switch-base-32", "google/switch-base-64", @@ -1342,9 +1342,6 @@ class SwitchTransformersModel(SwitchTransformersPreTrainedModel): r"encoder.embed_tokens.weight", r"decoder.embed_tokens.weight", ] - _keys_to_ignore_on_load_unexpected = [ - r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", - ] def __init__(self, config: SwitchTransformersConfig): super().__init__(config) @@ -1418,8 +1415,8 @@ def forward( ```python >>> from transformers import T5Tokenizer, SwitchTransformersModel - >>> tokenizer = T5Tokenizer.from_pretrained("HFLAY/switch_base_8") - >>> model = SwitchTransformersModel.from_pretrained("HFLAY/switch_base_8") + >>> tokenizer = T5Tokenizer.from_pretrained("google/switch-base-8") + >>> model = SwitchTransformersModel.from_pretrained("google/switch-base-8") >>> input_ids = tokenizer( ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" @@ -1517,9 +1514,6 @@ class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedMod r"decoder.embed_tokens.weight", r"lm_head.weight", ] - _keys_to_ignore_on_load_unexpected = [ - r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", - ] def __init__(self, config: SwitchTransformersConfig): super().__init__(config) @@ -1605,8 +1599,8 @@ def forward( ```python >>> from transformers import T5Tokenizer, SwitchTransformersForConditionalGeneration - >>> tokenizer = T5Tokenizer.from_pretrained("HFLAY/switch_base_8") - >>> model = SwitchTransformersForConditionalGeneration.from_pretrained("HFLAY/switch_base_8") + >>> tokenizer = T5Tokenizer.from_pretrained("google/switch-base-8") + >>> model = SwitchTransformersForConditionalGeneration.from_pretrained("google/switch-base-8") >>> # training >>> input_ids = tokenizer("The walks in park", return_tensors="pt").input_ids @@ -1889,8 +1883,8 @@ def forward( ```python >>> from transformers import T5Tokenizer, SwitchTransformersEncoderModel - >>> tokenizer = T5Tokenizer.from_pretrained("HFLAY/switch_base_8") - >>> model = SwitchTransformersEncoderModel.from_pretrained("HFLAY/switch_base_8") + >>> tokenizer = T5Tokenizer.from_pretrained("google/switch-base-8") + >>> model = SwitchTransformersEncoderModel.from_pretrained("google/switch-base-8") >>> input_ids = tokenizer( ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" ... ).input_ids # Batch size 1 From f82f0cf7bb434b5c113bf70f7264875cc289b18e Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Mon, 7 Nov 2022 11:42:45 +0100 Subject: [PATCH 094/102] Update src/transformers/models/switch_transformers/configuration_switch_transformers.py --- .../switch_transformers/configuration_switch_transformers.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/transformers/models/switch_transformers/configuration_switch_transformers.py b/src/transformers/models/switch_transformers/configuration_switch_transformers.py index 20392cea4de54..9c31ca6fb0e8f 100644 --- a/src/transformers/models/switch_transformers/configuration_switch_transformers.py +++ b/src/transformers/models/switch_transformers/configuration_switch_transformers.py @@ -166,8 +166,6 @@ def __init__( raise ValueError(f"`router_dtype` must be one of 'float32', 'float16' or 'bfloat16', got {router_dtype}") self.router_dtype = router_dtype - if router_dtype not in ["float16", "float32", "bfloat16"]: - raise ValueError("""Please select a correct `router_dtype` from ["float16", "float32", "bfloat16"].""") self.router_ignore_padding_tokens = router_ignore_padding_tokens self.relative_attention_num_buckets = relative_attention_num_buckets From cbb4c7747e0f84fc93c668d71f8ac3e4bfe2f6bd Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 7 Nov 2022 18:25:43 +0100 Subject: [PATCH 095/102] add google as authors --- .../switch_transformers/configuration_switch_transformers.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/transformers/models/switch_transformers/configuration_switch_transformers.py b/src/transformers/models/switch_transformers/configuration_switch_transformers.py index 9c31ca6fb0e8f..a0e47ed39245c 100644 --- a/src/transformers/models/switch_transformers/configuration_switch_transformers.py +++ b/src/transformers/models/switch_transformers/configuration_switch_transformers.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2020, The SwitchTransformers Authors and HuggingFace Inc. +# Copyright 2020, Google and HuggingFace Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -166,7 +166,6 @@ def __init__( raise ValueError(f"`router_dtype` must be one of 'float32', 'float16' or 'bfloat16', got {router_dtype}") self.router_dtype = router_dtype - self.router_ignore_padding_tokens = router_ignore_padding_tokens self.relative_attention_num_buckets = relative_attention_num_buckets self.relative_attention_max_distance = relative_attention_max_distance From 2f5806938d95a2ff819e240e18f0b50dd456edcb Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 7 Nov 2022 18:26:36 +0100 Subject: [PATCH 096/102] fix year --- .../switch_transformers/configuration_switch_transformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/switch_transformers/configuration_switch_transformers.py b/src/transformers/models/switch_transformers/configuration_switch_transformers.py index a0e47ed39245c..0d84d7ee33ffa 100644 --- a/src/transformers/models/switch_transformers/configuration_switch_transformers.py +++ b/src/transformers/models/switch_transformers/configuration_switch_transformers.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2020, Google and HuggingFace Inc. +# Copyright 2022, Google and HuggingFace Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 2dd9cff9b8c524788a346bcc0f1259f025cf7312 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 7 Nov 2022 18:31:10 +0100 Subject: [PATCH 097/102] remove last `assert` statements --- .../switch_transformers/modeling_switch_transformers.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 4f083623d61bd..748f33aea441f 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -1005,7 +1005,8 @@ def forward( raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") if inputs_embeds is None: - assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings" + if self.embed_tokens is None: + raise ValueError("You have to initialize the model with valid token embeddings") inputs_embeds = self.embed_tokens(input_ids) batch_size, seq_length = input_shape @@ -1014,7 +1015,8 @@ def forward( mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length if use_cache is True: - assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder" + if not self.is_decoder: + raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") if attention_mask is None: attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) From 16e7ff53971730ba39608c59881697f7c68e4afa Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 7 Nov 2022 18:41:49 +0100 Subject: [PATCH 098/102] standardize vertical spaces --- .../switch_transformers/modeling_switch_transformers.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 748f33aea441f..dbba70e5bfc7a 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -1340,10 +1340,7 @@ def custom_forward(*inputs): SWITCH_TRANSFORMERS_START_DOCSTRING, ) class SwitchTransformersModel(SwitchTransformersPreTrainedModel): - _keys_to_ignore_on_load_missing = [ - r"encoder.embed_tokens.weight", - r"decoder.embed_tokens.weight", - ] + _keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight", r"decoder.embed_tokens.weight"] def __init__(self, config: SwitchTransformersConfig): super().__init__(config) @@ -1827,9 +1824,7 @@ def _reorder_cache(self, past, beam_idx): SWITCH_TRANSFORMERS_START_DOCSTRING, ) class SwitchTransformersEncoderModel(SwitchTransformersPreTrainedModel): - authorized_missing_keys = [ - r"encoder.embed_tokens.weight", - ] + authorized_missing_keys = [r"encoder.embed_tokens.weight"] def __init__(self, config: SwitchTransformersConfig): super().__init__(config) From 1208b86bee9878c9a65c132476499ba4483749a8 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 15 Nov 2022 10:13:16 +0000 Subject: [PATCH 099/102] fix failing import --- .../switch_transformers/test_modeling_switch_transformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/switch_transformers/test_modeling_switch_transformers.py b/tests/models/switch_transformers/test_modeling_switch_transformers.py index 67411663c2cba..45be7a3bc1ea5 100644 --- a/tests/models/switch_transformers/test_modeling_switch_transformers.py +++ b/tests/models/switch_transformers/test_modeling_switch_transformers.py @@ -21,7 +21,7 @@ from transformers import SwitchTransformersConfig, is_torch_available from transformers.testing_utils import require_tokenizers, require_torch, slow, torch_device -from ...generation.test_generation_utils import GenerationTesterMixin +from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, ids_tensor From fc4921dad4e620e3b1e6390abc3d5ed61a10bf36 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 15 Nov 2022 10:19:25 +0000 Subject: [PATCH 100/102] fix another failing test --- .../switch_transformers/test_modeling_switch_transformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/switch_transformers/test_modeling_switch_transformers.py b/tests/models/switch_transformers/test_modeling_switch_transformers.py index 45be7a3bc1ea5..0b90fc371367b 100644 --- a/tests/models/switch_transformers/test_modeling_switch_transformers.py +++ b/tests/models/switch_transformers/test_modeling_switch_transformers.py @@ -36,7 +36,7 @@ SwitchTransformersModel, SwitchTransformersTop1Router, ) - from transformers.generation_utils import BeamSampleDecoderOnlyOutput, BeamSampleEncoderDecoderOutput + from transformers.generation import BeamSampleDecoderOnlyOutput, BeamSampleEncoderDecoderOutput from transformers.models.switch_transformers.modeling_switch_transformers import ( SWITCH_TRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST, load_balancing_loss_func, From 21010f70dfed5f95bbbab4bf162148b22c52273e Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 15 Nov 2022 10:43:57 +0000 Subject: [PATCH 101/102] =?UTF-8?q?Remove=20strange=20=C3=A0uthorized=5Fke?= =?UTF-8?q?ys`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../models/switch_transformers/modeling_switch_transformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index dbba70e5bfc7a..65f53f17c7a0f 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -1824,7 +1824,7 @@ def _reorder_cache(self, past, beam_idx): SWITCH_TRANSFORMERS_START_DOCSTRING, ) class SwitchTransformersEncoderModel(SwitchTransformersPreTrainedModel): - authorized_missing_keys = [r"encoder.embed_tokens.weight"] + _keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"] def __init__(self, config: SwitchTransformersConfig): super().__init__(config) From 0188acd4fa4c2ec1f11f755720e8652b9552054e Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 15 Nov 2022 10:57:59 +0000 Subject: [PATCH 102/102] removing todo and padding that is never used --- .../modeling_switch_transformers.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 65f53f17c7a0f..455f5aef0e1ed 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -218,19 +218,6 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple: """ router_probs, router_logits = self._compute_router_probabilities(hidden_states) - # Flax code for reference TODO check what happens with padded inputs here - if self.ignore_padding_tokens: - # To identify non-padding tokens, we rely on the fact that padding tokens - # in the inputs have already been masked in the default T5 architecture. - # See - # https://github.com/google/flaxformer/blob/9712a16/flaxformer/architectures/t5/t5_architecture.py#L315 - # and - # https://github.com/google/flaxformer/blob/9712a16/flaxformer/architectures/t5/t5_architecture.py#L603. - padding_mask = torch.Tensor((torch.sum(torch.abs(hidden_states), axis=-1) > 0)).to(hidden_states.dtype) - router_logits *= padding_mask.unsqueeze(-1) - else: - padding_mask = None - expert_index = torch.argmax(router_probs, dim=-1) expert_index = torch.nn.functional.one_hot(expert_index, num_classes=self.num_experts)