diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 00abb1f35bfe3..c5e12db375358 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -262,6 +262,8 @@ title: Pegasus - local: model_doc/phobert title: PhoBERT + - local: model_doc/plbart + title: PLBart - local: model_doc/poolformer title: PoolFormer - local: model_doc/prophetnet diff --git a/docs/source/index.mdx b/docs/source/index.mdx index 5b7b8bb484400..53e4160b16ea9 100644 --- a/docs/source/index.mdx +++ b/docs/source/index.mdx @@ -215,6 +215,7 @@ Flax), PyTorch, and/or TensorFlow. | OpenAI GPT-2 | ✅ | ✅ | ✅ | ✅ | ✅ | | Pegasus | ✅ | ✅ | ✅ | ✅ | ✅ | | Perceiver | ✅ | ❌ | ✅ | ❌ | ❌ | +| PLBart | ✅ | ❌ | ✅ | ❌ | ❌ | | PoolFormer | ❌ | ❌ | ✅ | ❌ | ❌ | | ProphetNet | ✅ | ❌ | ✅ | ❌ | ❌ | | QDQBert | ❌ | ❌ | ✅ | ❌ | ❌ | diff --git a/docs/source/model_doc/plbart.mdx b/docs/source/model_doc/plbart.mdx new file mode 100644 index 0000000000000..6e3e4a5b7773b --- /dev/null +++ b/docs/source/model_doc/plbart.mdx @@ -0,0 +1,112 @@ + + +# PLBart + +**DISCLAIMER:** If you see something strange, file a [Github Issue](https://github.com/huggingface/transformers/issues/new?assignees=&labels=&template=bug-report.md&title) and assign +[@gchhablani](https://www.github.com/gchhablani). + +## Overview of PLBart + +The PLBART model was proposed in [Unified Pre-training for Program Understanding and Generation](https://arxiv.org/abs/2103.06333) by Wasi Uddin Ahmad, Saikat Chakraborty, Baishakhi Ray, Kai-Wei Chang. +This is a BART-like model which can be used to perform code-summarization, code-generation, and code-translation tasks. The pre-trained model `plbart-base` has been trained using multilingual denoising task +on Java, Python and English. + +According to the abstract + +*Code summarization and generation empower conversion between programming language (PL) and natural language (NL), +while code translation avails the migration of legacy code from one PL to another. This paper introduces PLBART, +a sequence-to-sequence model capable of performing a broad spectrum of program and language understanding and generation tasks. +PLBART is pre-trained on an extensive collection of Java and Python functions and associated NL text via denoising autoencoding. +Experiments on code summarization in the English language, code generation, and code translation in seven programming languages +show that PLBART outperforms or rivals state-of-the-art models. Moreover, experiments on discriminative tasks, e.g., program +repair, clone detection, and vulnerable code detection, demonstrate PLBART's effectiveness in program understanding. +Furthermore, analysis reveals that PLBART learns program syntax, style (e.g., identifier naming convention), logical flow +(e.g., if block inside an else block is equivalent to else if block) that are crucial to program semantics and thus excels +even with limited annotations.* + +This model was contributed by [gchhablani](https://huggingface.co/gchhablani). The Authors' code can be found [here](https://github.com/wasiahmad/PLBART). + +### Training of PLBart + +PLBart is a multilingual encoder-decoder (sequence-to-sequence) model primarily intended for code-to-text, text-to-code, code-to-code tasks. As the +model is multilingual it expects the sequences in a different format. A special language id token is added in both the +source and target text. The source text format is `X [eos, src_lang_code]` where `X` is the source text. The +target text format is `[tgt_lang_code] X [eos]`. `bos` is never used. + +However, for fine-tuning, in some cases no language token is provided in cases where a single language is used. Please refer to [the paper](https://arxiv.org/abs/2103.06333) to learn more about this. + +In cases where the language code is needed, The regular [`~PLBartTokenizer.__call__`] will encode source text format, and it should be wrapped +inside the context manager [`~PLBartTokenizer.as_target_tokenizer`] to encode target text format. + +- Supervised training + +```python +>>> from transformers import PLBartForConditionalGeneration, PLBartTokenizer + +>>> tokenizer = PLBartTokenizer.from_pretrained("uclanlp/plbart-base", src_lang="en_XX", tgt_lang="python") +>>> example_python_phrase = "def maximum(a,b,c):NEW_LINE_INDENTreturn max([a,b,c])" +>>> expected_translation_english = "Returns the maximum value of a b c." +>>> inputs = tokenizer(example_python_phrase, return_tensors="pt") +>>> with tokenizer.as_target_tokenizer(): +... labels = tokenizer(expected_translation_english, return_tensors="pt") +>>> inputs["labels"] = labels["input_ids"] +>>> # forward pass +>>> model(**inputs) +``` + +- Generation + + While generating the target text set the `decoder_start_token_id` to the target language id. The following + example shows how to translate Python to English using the `uclanlp/plbart-python-en_XX` model. + +```python +>>> from transformers import PLBartForConditionalGeneration, PLBartTokenizer + +>>> tokenizer = PLBartTokenizer.from_pretrained("uclanlp/plbart-python-en_XX", src_lang="python", tgt_lang="en_XX") +>>> example_python_phrase = "def maximum(a,b,c):NEW_LINE_INDENTreturn max([a,b,c])" +>>> inputs = tokenizer(example_python_phrase, return_tensors="pt") +>>> model = PLBartForConditionalGeneration.from_pretrained("uclanlp/plbart-python-en_XX") +>>> translated_tokens = model.generate(**inputs, decoder_start_token_id=tokenizer.lang_code_to_id["en_XX"]) +>>> tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] +"Returns the maximum value of a b c." +``` + +## PLBartConfig + +[[autodoc]] PLBartConfig + +## PLBartTokenizer + +[[autodoc]] PLBartTokenizer + - as_target_tokenizer + - build_inputs_with_special_tokens + +## PLBartModel + +[[autodoc]] PLBartModel + - forward + +## PLBartForConditionalGeneration + +[[autodoc]] PLBartForConditionalGeneration + - forward + +## PLBartForSequenceClassification + +[[autodoc]] PLBartForSequenceClassification + - forward + +## PLBartForCausalLM + +[[autodoc]] PLBartForCausalLM + - forward \ No newline at end of file diff --git a/docs/source/serialization.mdx b/docs/source/serialization.mdx index 9a5a2dfe91180..aee21535aca4e 100644 --- a/docs/source/serialization.mdx +++ b/docs/source/serialization.mdx @@ -57,6 +57,7 @@ Ready-made configurations include the following architectures: - Marian - mBART - OpenAI GPT-2 +- PLBart - RoBERTa - T5 - XLM-RoBERTa diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index d0cc27ad9574c..a6249edbc6bbb 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -263,6 +263,7 @@ "models.pegasus": ["PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP", "PegasusConfig", "PegasusTokenizer"], "models.perceiver": ["PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP", "PerceiverConfig", "PerceiverTokenizer"], "models.phobert": ["PhobertTokenizer"], + "models.plbart": ["PLBART_PRETRAINED_CONFIG_ARCHIVE_MAP", "PLBartConfig"], "models.poolformer": ["POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "PoolFormerConfig"], "models.prophetnet": ["PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "ProphetNetConfig", "ProphetNetTokenizer"], "models.qdqbert": ["QDQBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "QDQBertConfig"], @@ -410,6 +411,7 @@ _import_structure["models.mluke"].append("MLukeTokenizer") _import_structure["models.mt5"].append("MT5Tokenizer") _import_structure["models.pegasus"].append("PegasusTokenizer") + _import_structure["models.plbart"].append("PLBartTokenizer") _import_structure["models.reformer"].append("ReformerTokenizer") _import_structure["models.rembert"].append("RemBertTokenizer") _import_structure["models.speech_to_text"].append("Speech2TextTokenizer") @@ -1219,6 +1221,16 @@ "PerceiverPreTrainedModel", ] ) + _import_structure["models.plbart"].extend( + [ + "PLBART_PRETRAINED_MODEL_ARCHIVE_LIST", + "PLBartForCausalLM", + "PLBartForConditionalGeneration", + "PLBartForSequenceClassification", + "PLBartModel", + "PLBartPreTrainedModel", + ] + ) _import_structure["models.poolformer"].extend( [ "POOLFORMER_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -2498,6 +2510,7 @@ from .models.pegasus import PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP, PegasusConfig, PegasusTokenizer from .models.perceiver import PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP, PerceiverConfig, PerceiverTokenizer from .models.phobert import PhobertTokenizer + from .models.plbart import PLBART_PRETRAINED_CONFIG_ARCHIVE_MAP, PLBartConfig from .models.poolformer import POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, PoolFormerConfig from .models.prophetnet import PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ProphetNetConfig, ProphetNetTokenizer from .models.qdqbert import QDQBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, QDQBertConfig @@ -2630,6 +2643,7 @@ from .models.mluke import MLukeTokenizer from .models.mt5 import MT5Tokenizer from .models.pegasus import PegasusTokenizer + from .models.plbart import PLBartTokenizer from .models.reformer import ReformerTokenizer from .models.rembert import RemBertTokenizer from .models.speech_to_text import Speech2TextTokenizer @@ -3292,6 +3306,14 @@ PerceiverModel, PerceiverPreTrainedModel, ) + from .models.plbart import ( + PLBART_PRETRAINED_MODEL_ARCHIVE_LIST, + PLBartForCausalLM, + PLBartForConditionalGeneration, + PLBartForSequenceClassification, + PLBartModel, + PLBartPreTrainedModel, + ) from .models.poolformer import ( POOLFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, PoolFormerForImageClassification, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 50d287c61efcd..14122e5a45647 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -83,6 +83,7 @@ pegasus, perceiver, phobert, + plbart, poolformer, prophetnet, qdqbert, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 1115ffb7a3663..878045158ea57 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -49,6 +49,7 @@ ("perceiver", "PerceiverConfig"), ("gptj", "GPTJConfig"), ("layoutlmv2", "LayoutLMv2Config"), + ("plbart", "PLBartConfig"), ("beit", "BeitConfig"), ("rembert", "RemBertConfig"), ("visual_bert", "VisualBertConfig"), @@ -143,6 +144,7 @@ ("perceiver", "PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("gptj", "GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("layoutlmv2", "LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("plbart", "PLBART_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("beit", "BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("rembert", "REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("visual_bert", "VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), @@ -228,6 +230,7 @@ ("perceiver", "Perceiver"), ("gptj", "GPT-J"), ("beit", "BEiT"), + ("plbart", "PLBart"), ("rembert", "RemBERT"), ("layoutlmv2", "LayoutLMv2"), ("visual_bert", "VisualBert"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 8c3946e62f61c..5c2ec495eabc2 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -44,6 +44,7 @@ ("perceiver", "PerceiverModel"), ("gptj", "GPTJModel"), ("layoutlmv2", "LayoutLMv2Model"), + ("plbart", "PLBartModel"), ("beit", "BeitModel"), ("rembert", "RemBertModel"), ("visual_bert", "VisualBertModel"), @@ -163,6 +164,7 @@ # Model with LM heads mapping ("yoso", "YosoForMaskedLM"), ("nystromformer", "NystromformerForMaskedLM"), + ("plbart", "PLBartForConditionalGeneration"), ("qdqbert", "QDQBertForMaskedLM"), ("fnet", "FNetForMaskedLM"), ("gptj", "GPTJForCausalLM"), @@ -216,6 +218,7 @@ [ # Model for Causal LM mapping ("xglm", "XGLMForCausalLM"), + ("plbart", "PLBartForCausalLM"), ("qdqbert", "QDQBertLMHeadModel"), ("trocr", "TrOCRForCausalLM"), ("gptj", "GPTJForCausalLM"), @@ -361,6 +364,7 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict( [ # Model for Seq2Seq Causal LM mapping + ("plbart", "PLBartForConditionalGeneration"), ("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"), ("m2m_100", "M2M100ForConditionalGeneration"), ("led", "LEDForConditionalGeneration"), @@ -391,6 +395,7 @@ # Model for Sequence Classification mapping ("yoso", "YosoForSequenceClassification"), ("nystromformer", "NystromformerForSequenceClassification"), + ("plbart", "PLBartForSequenceClassification"), ("perceiver", "PerceiverForSequenceClassification"), ("qdqbert", "QDQBertForSequenceClassification"), ("fnet", "FNetForSequenceClassification"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 41d44c641f334..0c953f1636bf5 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -47,6 +47,7 @@ else: TOKENIZER_MAPPING_NAMES = OrderedDict( [ + ("plbart", ("PLBartTokenizer" if is_sentencepiece_available() else None, None)), ("fnet", ("FNetTokenizer", "FNetTokenizerFast" if is_tokenizers_available() else None)), ("retribert", ("RetriBertTokenizer", "RetriBertTokenizerFast" if is_tokenizers_available() else None)), ("roformer", ("RoFormerTokenizer", "RoFormerTokenizerFast" if is_tokenizers_available() else None)), diff --git a/src/transformers/models/plbart/__init__.py b/src/transformers/models/plbart/__init__.py new file mode 100644 index 0000000000000..5d0ff08e12640 --- /dev/null +++ b/src/transformers/models/plbart/__init__.py @@ -0,0 +1,61 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...file_utils import _LazyModule, is_sentencepiece_available, is_tokenizers_available, is_torch_available + + +_import_structure = { + "configuration_plbart": ["PLBART_PRETRAINED_CONFIG_ARCHIVE_MAP", "PLBartConfig"], +} + +if is_sentencepiece_available(): + _import_structure["tokenization_plbart"] = ["PLBartTokenizer"] + +if is_torch_available(): + _import_structure["modeling_plbart"] = [ + "PLBART_PRETRAINED_MODEL_ARCHIVE_LIST", + "PLBartForCausalLM", + "PLBartForConditionalGeneration", + "PLBartForSequenceClassification", + "PLBartModel", + "PLBartPreTrainedModel", + ] + + +if TYPE_CHECKING: + from .configuration_plbart import PLBART_PRETRAINED_CONFIG_ARCHIVE_MAP, PLBartConfig + + if is_sentencepiece_available(): + from .tokenization_plbart import PLBartTokenizer + + if is_torch_available(): + from .modeling_plbart import ( + PLBART_PRETRAINED_MODEL_ARCHIVE_LIST, + PLBartForCausalLM, + PLBartForConditionalGeneration, + PLBartForSequenceClassification, + PLBartModel, + PLBartPreTrainedModel, + ) + + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) diff --git a/src/transformers/models/plbart/configuration_plbart.py b/src/transformers/models/plbart/configuration_plbart.py new file mode 100644 index 0000000000000..75bdd1f5dea5a --- /dev/null +++ b/src/transformers/models/plbart/configuration_plbart.py @@ -0,0 +1,192 @@ +# coding=utf-8 +# Copyright 2022, UCLA NLP, The Facebook AI Research Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PLBART model configuration""" +from collections import OrderedDict +from typing import Mapping + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfigWithPast +from ...utils import logging + + +logger = logging.get_logger(__name__) + +PLBART_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "uclanlp/plbart-base": "https://huggingface.co/uclanlp/plbart-base/resolve/main/config.json", + # See all PLBART models at https://huggingface.co/models?filter=plbart +} + + +class PLBartConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`PLBartModel`]. It is used to instantiate an + PLBART model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the PLBART + [uclanlp/plbart-base](https://huggingface.co/uclanlp/plbart-base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50005): + Vocabulary size of the PLBART model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`PLBartModel`]. + d_model (`int`, *optional*, defaults to 768): + Dimensionality of the layers and the pooler layer. + encoder_layers (`int`, *optional*, defaults to 6): + Number of encoder layers. + decoder_layers (`int`, *optional*, defaults to 6): + Number of decoder layers. + encoder_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + decoder_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer decoder. + decoder_ffn_dim (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. + encoder_ffn_dim (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. + activation_function (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + classifier_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for classifier. + max_position_embeddings (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + init_std (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + encoder_layerdrop: (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) + for more details. + decoder_layerdrop: (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) + for more details. + scale_embedding (`bool`, *optional*, defaults to `True`): + Scale embeddings by diving by sqrt(d_model). + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models) + forced_eos_token_id (`int`, *optional*, defaults to 2): + The id of the token to force as the last generated token when `max_length` is reached. Usually set to + `eos_token_id`. + + Example: + + ```python + >>> from transformers import PLBartModel, PLBartConfig + + >>> # Initializing a PLBART uclanlp/plbart-base style configuration + >>> configuration = PLBartConfig() + >>> # Initializing a model from the uclanlp/plbart-base style configuration + >>> model = PLBartModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "plbart" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"} + + def __init__( + self, + vocab_size=50005, + max_position_embeddings=1024, + encoder_layers=6, + encoder_ffn_dim=3072, + encoder_attention_heads=12, + decoder_layers=6, + decoder_ffn_dim=3072, + decoder_attention_heads=12, + encoder_layerdrop=0.0, + decoder_layerdrop=0.0, + use_cache=True, + is_encoder_decoder=True, + activation_function="gelu", + d_model=768, + dropout=0.1, + attention_dropout=0.1, + activation_dropout=0.0, + init_std=0.02, + classifier_dropout=0.0, + scale_embedding=True, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + forced_eos_token_id=2, + **kwargs + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.d_model = d_model + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.activation_function = activation_function + self.init_std = init_std + self.encoder_layerdrop = encoder_layerdrop + self.decoder_layerdrop = decoder_layerdrop + self.classifier_dropout = classifier_dropout + self.use_cache = use_cache + self.num_hidden_layers = encoder_layers + self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + is_encoder_decoder=is_encoder_decoder, + forced_eos_token_id=forced_eos_token_id, + **kwargs, + ) + + +class PLBartOnnxConfig(OnnxConfigWithPast): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("input_ids", {0: "batch", 1: "sequence"}), + ("attention_mask", {0: "batch", 1: "sequence"}), + ] + ) + + @property + def outputs(self) -> Mapping[str, Mapping[int, str]]: + if self.use_past: + return OrderedDict( + [ + ("last_hidden_state", {0: "batch", 1: "sequence"}), + ("past_keys", {0: "batch", 2: "sequence"}), + ("encoder_last_hidden_state", {0: "batch", 1: "sequence"}), + ] + ) + else: + return OrderedDict( + [ + ("last_hidden_state", {0: "batch", 1: "sequence"}), + ("encoder_last_hidden_state", {0: "batch", 1: "sequence"}), + ] + ) diff --git a/src/transformers/models/plbart/convert_plbart_original_checkpoint_to_torch.py b/src/transformers/models/plbart/convert_plbart_original_checkpoint_to_torch.py new file mode 100644 index 0000000000000..eac4a27d11c5a --- /dev/null +++ b/src/transformers/models/plbart/convert_plbart_original_checkpoint_to_torch.py @@ -0,0 +1,94 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import torch +from torch import nn + +from transformers import PLBartConfig, PLBartForConditionalGeneration, PLBartForSequenceClassification + + +def remove_ignore_keys_(state_dict): + ignore_keys = [ + "encoder.version", + "decoder.version", + "model.encoder.version", + "model.decoder.version", + "_float_tensor", + "decoder.output_projection.weight", + ] + for k in ignore_keys: + state_dict.pop(k, None) + + +def make_linear_from_emb(emb): + vocab_size, emb_size = emb.weight.shape + lin_layer = nn.Linear(vocab_size, emb_size, bias=False) + lin_layer.weight.data = emb.weight.data + return lin_layer + + +def convert_fairseq_plbart_checkpoint_from_disk( + checkpoint_path, hf_config_path="uclanlp/plbart-base", finetuned=False, classification=False +): + state_dict = torch.load(checkpoint_path, map_location="cpu")["model"] + remove_ignore_keys_(state_dict) + vocab_size = state_dict["encoder.embed_tokens.weight"].shape[0] + + plbart_config = PLBartConfig.from_pretrained(hf_config_path, vocab_size=vocab_size) + + state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"] + if not classification: + model = PLBartForConditionalGeneration(plbart_config) + model.model.load_state_dict(state_dict) + if finetuned: + model.lm_head = make_linear_from_emb(model.model.shared) + + else: + classification_head = {} + for key, value in state_dict.copy().items(): + if key.startswith("classification_heads.sentence_classification_head"): + classification_head[key.replace("classification_heads.sentence_classification_head.", "")] = value + state_dict.pop(key) + model = PLBartForSequenceClassification(plbart_config) + model.model.load_state_dict(state_dict) + model.classification_head.load_state_dict(classification_head) + + return model + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument("fairseq_path", type=str, help="model.pt on local filesystem.") + parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") + parser.add_argument( + "--hf_config", + default="uclanlp/plbart-base", + type=str, + help="Which huggingface architecture to use: plbart-base", + ) + parser.add_argument("--finetuned", action="store_true", help="whether the model is a fine-tuned checkpoint") + parser.add_argument( + "--classification", action="store_true", help="whether the model is a classification checkpoint" + ) + args = parser.parse_args() + model = convert_fairseq_plbart_checkpoint_from_disk( + args.fairseq_path, + hf_config_path=args.hf_config, + finetuned=args.finetuned, + classification=args.classification, + ) + model.save_pretrained(args.pytorch_dump_folder_path) diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py new file mode 100755 index 0000000000000..b8db30cc9cb50 --- /dev/null +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -0,0 +1,1717 @@ +# coding=utf-8 +# Copyright 2022, UCLA NLP, The Facebook AI Research Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch PLBART model.""" +import copy +import math +import random +from typing import Optional, Tuple + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...file_utils import ( + add_code_sample_docstrings, + add_end_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, + Seq2SeqSequenceClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...utils import logging +from .configuration_plbart import PLBartConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "uclanlp/plbart-base" +_CONFIG_FOR_DOC = "PLBartConfig" +_TOKENIZER_FOR_DOC = "PLBartTokenizer" + + +PLBART_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "uclanlp/plbart-base", + "uclanlp/plbart-cs-java", + "uclanlp/plbart-multi_task-all", + # See all PLBART models at https://huggingface.co/models?filter=plbart +] + + +# Copied from transformers.models.mbart.modeling_mbart.shift_tokens_right +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int): + """ + Shift input ids one token to the right, and wrap the last non pad token (the token) Note that MBart does not + have a single `decoder_start_token_id` in contrast to other Bart-like models. + """ + prev_output_tokens = input_ids.clone() + + assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." + # replace possible -100 values in labels by `pad_token_id` + prev_output_tokens.masked_fill_(prev_output_tokens == -100, pad_token_id) + + index_of_eos = (prev_output_tokens.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1) + decoder_start_tokens = prev_output_tokens.gather(1, index_of_eos).squeeze() + prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].clone() + prev_output_tokens[:, 0] = decoder_start_tokens + + return prev_output_tokens + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), float("-inf")) + mask_cond = torch.arange(mask.size(-1)) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) + + +# Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->PLBart +class PLBartLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + # PLBart is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + bsz, seq_len = input_ids_shape[:2] + positions = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device + ) + return super().forward(positions + self.offset) + + +# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->PLBart +class PLBartAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned aross GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +# Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->PLBart +class PLBartEncoderLayer(nn.Module): + def __init__(self, config: PLBartConfig): + super().__init__() + self.embed_dim = config.d_model + self.self_attn = PLBartAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + layer_head_mask: torch.Tensor, + output_attentions: bool = False, + ): + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states, attn_weights, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + ): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->PLBart +class PLBartDecoderLayer(nn.Module): + def __init__(self, config: PLBartConfig): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = PLBartAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.encoder_attn = PLBartAttention( + self.embed_dim, + config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + ): + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of + size `(decoder_attention_heads,)`. + past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + # Self Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + # add present self-attn cache to positions 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Cross-Attention Block + cross_attn_present_key_value = None + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +# Copied from transformers.models.bart.modeling_bart.BartClassificationHead with Bart->PLBart +class PLBartClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__( + self, + input_dim: int, + inner_dim: int, + num_classes: int, + pooler_dropout: float, + ): + super().__init__() + self.dense = nn.Linear(input_dim, inner_dim) + self.dropout = nn.Dropout(p=pooler_dropout) + self.out_proj = nn.Linear(inner_dim, num_classes) + + def forward(self, hidden_states: torch.Tensor): + hidden_states = self.dropout(hidden_states) + hidden_states = self.dense(hidden_states) + hidden_states = torch.tanh(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.out_proj(hidden_states) + return hidden_states + + +class PLBartPreTrainedModel(PreTrainedModel): + config_class = PLBartConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (PLBartDecoder, PLBartEncoder)): + module.gradient_checkpointing = value + + +PLBART_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`PLBartConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +PLBART_GENERATION_EXAMPLE = r""" + Token in-filling example: + + >>> from transformers import PLBartTokenizer, PLBartForConditionalGeneration, PLBartConfig + + >>> model = PLBartForConditionalGeneration.from_pretrained('uclanlp/plbart-base') >>> tokenizer = + PLBartTokenizer.from_pretrained('uclanlp/plbart-base', src_lang='java', tgt_lang='java') >>> METHOD_TO_FILL = + "public static main (String args[0]) { data=Date(); System.out. String.format("Current Date : % tc", ));}" >>> + inputs = tokenizer([METHOD_TO_FILL], max_length=1024, return_tensors='pt') >>> # Generate Filled Code >>> + generated_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=5, early_stopping=True) >>> + print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in + generated_ids]) + + Mask-filling example: + + >>> from transformers import PLBartTokenizer, PLBartForConditionalGeneration >>> tokenizer = + PLBartTokenizer.from_pretrained('uclanlp/plbart-base') >>> # en_XX is the language symbol id for English + >>> TXT = " Is 0 the Fibonacci ? en_XX" >>> model = + PLBartForConditionalGeneration.from_pretrained('uclanlp/plbart-base') >>> input_ids = tokenizer([TXT], + add_special_tokens=False, return_tensors='pt')['input_ids'] >>> logits = model(input_ids).logits >>> + masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() >>> probs = logits[0, + masked_index].softmax(dim=0) >>> values, predictions = probs.topk(5) >>> tokenizer.decode(predictions).split() +""" + +PLBART_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`PLBartTokenizer`] or [`PLBartMultiTokenizer`] depending on the checkpoint. + See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`PLBartTokenizer`] or [`PLBartMultiTokenizer`] depending on the checkpoint. + See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + PLBart uses a specific language id token as the starting token for `decoder_input_ids` generation that + varies according to source and target language, *e.g.* 50003 for *en_XX*, and 50001 for *java*. If + `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + For translation and summarization training, `decoder_input_ids` should be provided. If no + `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right + for denoising pre-training following the paper. + decoder_attention_mask (: + obj:*torch.LongTensor* of shape `(batch_size, target_sequence_length)`, *optional*): Default behavior: + generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also be used by default. + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (: + obj:*torch.Tensor* of shape `(decoder_layers, decoder_attention_heads)`, *optional*): Mask to nullify + selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + past_key_values (: + obj:*tuple(tuple(torch.FloatTensor))*, *optional*, returned when `use_cache=True` is passed or when + `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple + having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional + tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + ``decoder_input_ids``` of shape `(batch_size, sequence_length)`. + inputs_embeds (: + obj:*torch.FloatTensor* of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, + instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful + if you want more control over how to convert `input_ids` indices into associated vectors than the model's + internal embedding lookup matrix. + decoder_inputs_embeds (: + obj:*torch.FloatTensor* of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded + representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be + input (see `past_key_values`). This is useful if you want more control over how to convert + `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + + If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value + of `inputs_embeds`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. +""" + + +# Copied from transformers.models.bart.modeling_bart.BartEncoder with Bart->PLBart +class PLBartEncoder(PLBartPreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`PLBartEncoderLayer`]. + + Args: + config: PLBartConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, config: PLBartConfig, embed_tokens: Optional[nn.Embedding] = None): + super().__init__(config) + + self.dropout = config.dropout + self.layerdrop = config.encoder_layerdrop + + embed_dim = config.d_model + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_position_embeddings + self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + + if embed_tokens is not None: + self.embed_tokens = embed_tokens + else: + self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) + + self.embed_positions = PLBartLearnedPositionalEmbedding( + config.max_position_embeddings, + embed_dim, + ) + self.layers = nn.ModuleList([PLBartEncoderLayer(config) for _ in range(config.encoder_layers)]) + self.layernorm_embedding = nn.LayerNorm(embed_dim) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids=None, + attention_mask=None, + head_mask=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`PLBartTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + embed_pos = self.embed_positions(input_shape) + + hidden_states = inputs_embeds + embed_pos + hidden_states = self.layernorm_embedding(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + if head_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + ) + + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = random.uniform(0, 1) + if self.training and (dropout_probability < self.layerdrop): # skip the layer + layer_outputs = (None, None) + else: + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(encoder_layer), + hidden_states, + attention_mask, + (head_mask[idx] if head_mask is not None else None), + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +# Copied from transformers.models.bart.modeling_bart.BartDecoder with Bart->PLBart +class PLBartDecoder(PLBartPreTrainedModel): + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`PLBartDecoderLayer`] + + Args: + config: PLBartConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, config: PLBartConfig, embed_tokens: Optional[nn.Embedding] = None): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_position_embeddings + self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + + if embed_tokens is not None: + self.embed_tokens = embed_tokens + else: + self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) + + self.embed_positions = PLBartLearnedPositionalEmbedding( + config.max_position_embeddings, + config.d_model, + ) + self.layers = nn.ModuleList([PLBartDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layernorm_embedding = nn.LayerNorm(config.d_model) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length + ).to(self.device) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`PLBartTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing + cross-attention on hidden heads. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all ``decoder_input_ids``` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` + of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing + `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more + control over how to convert `input_ids` indices into associated vectors than the model's internal + embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + + # embed positions + positions = self.embed_positions(input_shape, past_key_values_length) + + hidden_states = inputs_embeds + positions + hidden_states = self.layernorm_embedding(hidden_states) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + next_decoder_cache = () if use_cache else None + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != (len(self.layers)): + raise ValueError( + "The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + ) + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + dropout_probability = random.uniform(0, 1) + if self.training and (dropout_probability < self.layerdrop): + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + if use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, use_cache) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + head_mask[idx] if head_mask is not None else None, + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, + None, + ) + else: + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +@add_start_docstrings( + "The bare PLBART Model outputting raw hidden-states without any specific head on top.", + PLBART_START_DOCSTRING, +) +class PLBartModel(PLBartPreTrainedModel): + def __init__(self, config: PLBartConfig): + super().__init__(config) + + padding_idx, vocab_size = config.pad_token_id, config.vocab_size + self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) + + self.encoder = PLBartEncoder(config, self.shared) + self.decoder = PLBartDecoder(config, self.shared) + + self.init_weights() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, value): + self.shared = value + self.encoder.embed_tokens = self.shared + self.decoder.embed_tokens = self.shared + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(PLBART_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=Seq2SeqModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + encoder_outputs=None, + past_key_values=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # different to other models, PLBart automatically creates decoder_input_ids from + # input_ids if no decoder_input_ids are provided + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right(input_ids, self.config.pad_token_id) + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + "The PLBART Model with a language modeling head. Can be used for code-to-text, text-to-code and code-to-code.", + PLBART_START_DOCSTRING, +) +class PLBartForConditionalGeneration(PLBartPreTrainedModel): + base_model_prefix = "model" + _keys_to_ignore_on_load_missing = [ + r"final_logits_bias", + r"encoder\.version", + r"decoder\.version", + r"lm_head\.weight", + ] + + def __init__(self, config: PLBartConfig): + super().__init__(config) + self.model = PLBartModel(config) + self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) + self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) + + self.init_weights() + + def get_encoder(self): + return self.model.get_encoder() + + def get_decoder(self): + return self.model.get_decoder() + + def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: + new_embeddings = super().resize_token_embeddings(new_num_tokens) + self._resize_final_logits_bias(new_num_tokens) + return new_embeddings + + def _resize_final_logits_bias(self, new_num_tokens: int) -> None: + old_num_tokens = self.final_logits_bias.shape[-1] + if new_num_tokens <= old_num_tokens: + new_bias = self.final_logits_bias[:, :new_num_tokens] + else: + extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) + new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) + self.register_buffer("final_logits_bias", new_bias) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + @add_start_docstrings_to_model_forward(PLBART_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + @add_end_docstrings(PLBART_GENERATION_EXAMPLE) + def forward( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + encoder_outputs=None, + past_key_values=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if decoder_input_ids is None: + decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + past=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs # TODO: Check if this is needed. It is unused? + ): + # cut decoder_input_ids if past is used + if past is not None: + decoder_input_ids = decoder_input_ids[:, -1:] + + return { + "input_ids": None, # encoder_outputs is defined. input_ids not needed + "encoder_outputs": encoder_outputs, + "past_key_values": past, + "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, # change this to avoid caching (presumably for debugging) + } + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id) + + @staticmethod + def _reorder_cache(past, beam_idx): + reordered_past = () + for layer_past in past: + # cached cross_attention states don't have to be reordered -> they are always the same + reordered_past += ( + tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + ) + return reordered_past + + +@add_start_docstrings( + """ + PLBart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for code + classification. + """, + PLBART_START_DOCSTRING, +) +class PLBartForSequenceClassification(PLBartPreTrainedModel): + def __init__(self, config: PLBartConfig, **kwargs): + super().__init__(config, **kwargs) + self.model = PLBartModel(config) + self.classification_head = PLBartClassificationHead( + config.d_model, + config.d_model, + config.num_labels, + config.classifier_dropout, + ) + self.model._init_weights(self.classification_head.dense) + self.model._init_weights(self.classification_head.out_proj) + + @add_start_docstrings_to_model_forward(PLBART_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=Seq2SeqSequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + # Copied from transformers.models.bart.modeling_bart.BartForSequenceClassification.forward + def forward( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + encoder_outputs=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + if input_ids is None and inputs_embeds is not None: + raise NotImplementedError( + f"Passing input embeddings is currently not supported for {self.__class__.__name__}" + ) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] # last hidden state + + eos_mask = input_ids.eq(self.config.eos_token_id) + + if len(torch.unique_consecutive(eos_mask.sum(1))) > 1: + raise ValueError("All examples must have the same number of tokens.") + sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[ + :, -1, : + ] + logits = self.classification_head(sentence_representation) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.config.num_labels == 1: + self.config.problem_type = "regression" + elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.config.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return Seq2SeqSequenceClassifierOutput( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +# Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->PLBart +class PLBartDecoderWrapper(PLBartPreTrainedModel): + """ + This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is + used in combination with the [`EncoderDecoderModel`] framework. + """ + + def __init__(self, config): + super().__init__(config) + self.decoder = PLBartDecoder(config) + + def forward(self, *args, **kwargs): + return self.decoder(*args, **kwargs) + + +# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->PLBart +class PLBartForCausalLM(PLBartPreTrainedModel): + def __init__(self, config): + config = copy.deepcopy(config) + config.is_decoder = True + config.is_encoder_decoder = False + super().__init__(config) + self.model = PLBartDecoderWrapper(config) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model.decoder = decoder + + def get_decoder(self): + return self.model.decoder + + @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`PLBartTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + if the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used + in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional + tensors are only required when the model is used as a decoder in a Sequence to Sequence model. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. + + Returns: + + Example: + + ```python + >>> from transformers import PLBartTokenizer, PLBartForCausalLM + + >>> tokenizer = PLBartTokenizer.from_pretrained("facebook/bart-large") + >>> model = PLBartForCausalLM.from_pretrained("facebook/bart-large", add_cross_attention=False) + >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder." + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> logits = outputs.logits + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + logits = self.lm_head(outputs[0]) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs): + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + + if past: + input_ids = input_ids[:, -1:] + # first step, decoder_cached_states are empty + return { + "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed + "attention_mask": attention_mask, + "past_key_values": past, + "use_cache": use_cache, + } + + @staticmethod + def _reorder_cache(past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past diff --git a/src/transformers/models/plbart/tokenization_plbart.py b/src/transformers/models/plbart/tokenization_plbart.py new file mode 100644 index 0000000000000..4c302e8b62cea --- /dev/null +++ b/src/transformers/models/plbart/tokenization_plbart.py @@ -0,0 +1,448 @@ +# coding=utf-8 +# Copyright 2022, UCLA NLP, The Facebook AI Research Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from contextlib import contextmanager +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple + +import sentencepiece as spm + +from ...tokenization_utils import AddedToken, BatchEncoding, PreTrainedTokenizer +from ...utils import logging + + +logger = logging.get_logger(__name__) + +SPIECE_UNDERLINE = "▁" + +VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model", "tokenizer_file": "tokenizer.json"} + +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + "uclanlp/plbart-base": "https://huggingface.co/uclanlp/plbart-base/resolve/main/sentencepiece.bpe.model", + "uclanlp/plbart-c-cpp-defect-detection": "https://huggingface.co/uclanlp/plbart-c-cpp-defect-detection/resolve/main/sentencepiece.bpe.model", + "uclanlp/plbart-cs-java": "https://huggingface.co/uclanlp/plbart-cs-java/resolve/main/sentencepiece.bpe.model", + "uclanlp/plbart-en_XX-java": "https://huggingface.co/uclanlp/plbart-en_XX-java/resolve/main/sentencepiece.bpe.model", + "uclanlp/plbart-go-en_XX": "https://huggingface.co/uclanlp/plbart-go-en_XX/resolve/main/sentencepiece.bpe.model", + "uclanlp/plbart-java-clone-detection": "https://huggingface.co/uclanlp/plbart-java-clone-detection/resolve/main/sentencepiece.bpe.model", + "uclanlp/plbart-java-cs": "https://huggingface.co/uclanlp/plbart-java-cs/resolve/main/sentencepiece.bpe.model", + "uclanlp/plbart-java-en_XX": "https://huggingface.co/uclanlp/plbart-java-en_XX/resolve/main/sentencepiece.bpe.model", + "uclanlp/plbart-javascript-en_XX": "https://huggingface.co/uclanlp/plbart-javascript-en_XX/resolve/main/sentencepiece.bpe.model", + "uclanlp/plbart-php-en_XX": "https://huggingface.co/uclanlp/plbart-php-en_XX/resolve/main/sentencepiece.bpe.model", + "uclanlp/plbart-python-en_XX": "https://huggingface.co/uclanlp/plbart-python-en_XX/resolve/main/sentencepiece.bpe.model", + "uclanlp/plbart-refine-java-medium": "https://huggingface.co/uclanlp/plbart-refine-java-medium/resolve/main/sentencepiece.bpe.model", + "uclanlp/plbart-refine-java-small": "https://huggingface.co/uclanlp/plbart-refine-java-small/resolve/main/sentencepiece.bpe.model", + "uclanlp/plbart-ruby-en_XX": "https://huggingface.co/uclanlp/plbart-ruby-en_XX/resolve/main/sentencepiece.bpe.model", + } +} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "uclanlp/plbart-base": 1024, + "uclanlp/plbart-c-cpp-defect-detection": 1024, + "uclanlp/plbart-cs-java": 1024, + "uclanlp/plbart-en_XX-java": 1024, + "uclanlp/plbart-go-en_XX": 1024, + "uclanlp/plbart-java-clone-detection": 1024, + "uclanlp/plbart-java-cs": 1024, + "uclanlp/plbart-java-en_XX": 1024, + "uclanlp/plbart-javascript-en_XX": 1024, + "uclanlp/plbart-php-en_XX": 1024, + "uclanlp/plbart-python-en_XX": 1024, + "uclanlp/plbart-refine-java-medium": 1024, + "uclanlp/plbart-refine-java-small": 1024, + "uclanlp/plbart-ruby-en_XX": 1024, +} + +FAIRSEQ_LANGUAGE_CODES = { + "base": ["java", "python", "en_XX"], + "multi": ["java", "python", "en_XX", "javascript", "php", "ruby", "go"], +} + + +class PLBartTokenizer(PreTrainedTokenizer): + """ + Construct an PLBART tokenizer. + + Adapted from [`RobertaTokenizer`] and [`XLNetTokenizer`]. Based on + [SentencePiece](https://github.com/google/sentencepiece). + + The tokenization method is ` ` for source language documents, and `` + ``` for target language documents. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + src_lang (`str`, *optional*): + A string representing the source language. + tgt_lang (`str`, *optional*): + A string representing the target language. + bos_token (`str`, *optional*, defaults to `""`): + The start of sequence token. + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (`str`, *optional*, defaults to `""`): + The cls token, which is a special token used as the first token for all tasks. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + mask_token(`str`, *optional*, defaults to `""`): + The token used for masking values. This is the token used when training this model with masking tasks. This + is only used in the `"base"` tokenizer type. For `"multi"` tokenizer, masking is never done for the + downstream tasks. + language_codes (`str`, *optional*, defaults to `"base"`): + What language codes to use. Should be one of `"base"` or `"multi"`. + sp_model_kwargs (`dict`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + + Examples: + + ```python + >>> from transformers import PLBartTokenizer + + >>> tokenizer = PLBartTokenizer.from_pretrained("uclanlp/plbart-python-en_XX", src_lang="python", tgt_lang="en_XX") + >>> example_python_phrase = "def maximum(a,b,c):NEW_LINE_INDENTreturn max([a,b,c])" + >>> expected_translation_english = "Returns the maximum value of a b c." + >>> inputs = tokenizer(example_python_phrase, return_tensors="pt") + >>> with tokenizer.as_target_tokenizer(): + ... labels = tokenizer(expected_translation_english, return_tensors="pt") + >>> inputs["labels"] = labels["input_ids"] + ```""" + + vocab_files_names = VOCAB_FILES_NAMES + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + model_input_names = ["input_ids", "attention_mask"] + + prefix_tokens: List[int] = [] + suffix_tokens: List[int] = [] + + def __init__( + self, + vocab_file, + bos_token="", + eos_token="", + sep_token="", + cls_token="", + unk_token="", + pad_token="", + mask_token="", + language_codes="base", + tokenizer_file=None, + src_lang=None, + tgt_lang=None, + sp_model_kwargs: Optional[Dict[str, Any]] = None, + additional_special_tokens=None, + **kwargs + ): + # Mask token behave like a normal word, i.e. include the space before it + mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token + + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + sep_token=sep_token, + cls_token=cls_token, + pad_token=pad_token, + mask_token=mask_token, + language_codes=language_codes, + tokenizer_file=tokenizer_file, + src_lang=src_lang, + tgt_lang=tgt_lang, + additional_special_tokens=additional_special_tokens, + sp_model_kwargs=self.sp_model_kwargs, + **kwargs, + ) + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(str(vocab_file)) + self.vocab_file = vocab_file + self.language_codes = language_codes + + fairseq_language_codes = FAIRSEQ_LANGUAGE_CODES[self.language_codes] + + # Original fairseq vocab and spm vocab must be "aligned": + # Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 + # -------- | ------- | ------- | ------ | ------- | --- | --- | --- | ----- | ----- | ---- + # fairseq | '' | '' | '' | '' | ',' | '.' | '▁' | 's' | '▁de' | '-' + # spm | '' | '' | '' | ',' | '.' | '▁' | 's' | '▁de' | '-' | '▁a' + + # Mimic fairseq token-to-id alignment for the first 4 token + self.fairseq_tokens_to_ids = {"": 0, "": 1, "": 2, "": 3} + + # The first "real" token "," has position 4 in the original fairseq vocab and position 3 in the spm vocab + self.fairseq_offset = 1 + + self.sp_model_size = len(self.sp_model) + self.lang_code_to_id = { + code: self.sp_model_size + i + self.fairseq_offset for i, code in enumerate(fairseq_language_codes) + } + self.id_to_lang_code = {v: k for k, v in self.lang_code_to_id.items()} + + if self.language_codes == "base": + self.fairseq_tokens_to_ids[""] = len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset + + self.fairseq_tokens_to_ids.update(self.lang_code_to_id) + self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()} + self._additional_special_tokens = list(self.lang_code_to_id.keys()) + + if additional_special_tokens is not None: + # Only add those special tokens if they are not already there. + self._additional_special_tokens.extend( + [t for t in additional_special_tokens if t not in self._additional_special_tokens] + ) + + if self.language_codes == "base": + self._src_lang = src_lang + self.cur_lang_code_id = ( + self.lang_code_to_id[self._src_lang] if self._src_lang is not None else self._src_lang + ) + else: + self._src_lang = src_lang if src_lang is not None else "en_XX" + self.cur_lang_code_id = self.lang_code_to_id[self._src_lang] + + self.tgt_lang = tgt_lang + self.set_src_lang_special_tokens(self._src_lang) + + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + state["sp_model_proto"] = self.sp_model.serialized_model_proto() + return state + + def __setstate__(self, d): + self.__dict__ = d + + # for backward compatibility + if not hasattr(self, "sp_model_kwargs"): + self.sp_model_kwargs = {} + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.LoadFromSerializedProto(self.sp_model_proto) + + @property + def vocab_size(self): + if self.language_codes == "base": + return ( + len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset + 1 + ) # Plus 1 for the mask token + else: + return len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset + + @property + def src_lang(self) -> str: + return self._src_lang + + @src_lang.setter + def src_lang(self, new_src_lang: str) -> None: + self._src_lang = new_src_lang + self.set_src_lang_special_tokens(self._src_lang) + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + prefix_ones = [1] * len(self.prefix_tokens) + suffix_ones = [1] * len(self.suffix_tokens) + if token_ids_1 is None: + return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones + return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. An PLBART sequence has the following format, where `X` represents the sequence: + + - `input_ids` (for encoder) `X [eos, src_lang_code]` + - `decoder_input_ids`: (for decoder) `X [eos, tgt_lang_code]` + + BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a + separator. + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return self.prefix_tokens + token_ids_0 + self.suffix_tokens + # We don't expect to process pairs, but leave the pair logic for API consistency + return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. PLBart does not + make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of zeros. + """ + + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + def _build_translation_inputs( + self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs + ): + """Used by translation pipeline, to prepare inputs for the generate function""" + if src_lang is None or tgt_lang is None: + raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model") + self.src_lang = src_lang + inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs) + tgt_lang_id = self.convert_tokens_to_ids(tgt_lang) + inputs["forced_bos_token_id"] = tgt_lang_id + return inputs + + def get_vocab(self): + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def _tokenize(self, text: str) -> List[str]: + return self.sp_model.encode(text, out_type=str) + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + if token in self.fairseq_tokens_to_ids: + return self.fairseq_tokens_to_ids[token] + spm_id = self.sp_model.PieceToId(token) + + # Need to return unknown token if the SP model returned 0 + return spm_id + self.fairseq_offset if spm_id else self.unk_token_id + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + if index in self.fairseq_ids_to_tokens: + return self.fairseq_ids_to_tokens[index] + return self.sp_model.IdToPiece(index - self.fairseq_offset) + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (strings for sub-words) in a single string.""" + out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip() + return out_string + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + def prepare_seq2seq_batch( + self, + src_texts: List[str], + src_lang: str = "en_XX", + tgt_texts: Optional[List[str]] = None, + tgt_lang: str = "python", + **kwargs, + ) -> BatchEncoding: + self.src_lang = src_lang + self.tgt_lang = tgt_lang + return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs) + + @contextmanager + def as_target_tokenizer(self): + """ + Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to + sequence-to-sequence models that need a slightly different processing for the labels. + """ + self.set_tgt_lang_special_tokens(self.tgt_lang) + yield + self.set_src_lang_special_tokens(self.src_lang) + + def set_src_lang_special_tokens(self, src_lang) -> None: + """Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code].""" + self.cur_lang_code = self.lang_code_to_id[src_lang] if src_lang is not None else None + self.prefix_tokens = [] + if self.cur_lang_code is not None: + self.suffix_tokens = [self.eos_token_id, self.cur_lang_code] + else: + self.suffix_tokens = [self.eos_token_id] + + def set_tgt_lang_special_tokens(self, lang: str) -> None: + """Reset the special tokens to the target language setting. No prefix and suffix=[eos, tgt_lang_code].""" + self.cur_lang_code = self.lang_code_to_id[lang] if lang is not None else None + self.prefix_tokens = [] + if self.cur_lang_code is not None: + self.suffix_tokens = [self.eos_token_id, self.cur_lang_code] + else: + self.suffix_tokens = [self.eos_token_id] diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 4facca5740473..4f8049e23e017 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -2786,6 +2786,44 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +PLBART_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class PLBartForCausalLM(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class PLBartForConditionalGeneration(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class PLBartForSequenceClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class PLBartModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class PLBartPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + POOLFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/src/transformers/utils/dummy_sentencepiece_objects.py b/src/transformers/utils/dummy_sentencepiece_objects.py index b94480df603c9..b358e6d26937b 100644 --- a/src/transformers/utils/dummy_sentencepiece_objects.py +++ b/src/transformers/utils/dummy_sentencepiece_objects.py @@ -108,6 +108,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["sentencepiece"]) +class PLBartTokenizer(metaclass=DummyObject): + _backends = ["sentencepiece"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["sentencepiece"]) + + class ReformerTokenizer(metaclass=DummyObject): _backends = ["sentencepiece"] diff --git a/tests/test_modeling_plbart.py b/tests/test_modeling_plbart.py new file mode 100644 index 0000000000000..df652a48e267c --- /dev/null +++ b/tests/test_modeling_plbart.py @@ -0,0 +1,632 @@ +# coding=utf-8 +# Copyright 2022, The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch PLBART model. """ + + +import copy +import tempfile +import unittest + +from transformers import PLBartConfig, is_torch_available +from transformers.file_utils import cached_property +from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device + +from .test_configuration_common import ConfigTester +from .test_generation_utils import GenerationTesterMixin +from .test_modeling_common import ModelTesterMixin, ids_tensor + + +if is_torch_available(): + import torch + + from transformers import ( + AutoTokenizer, + PLBartForCausalLM, + PLBartForConditionalGeneration, + PLBartForSequenceClassification, + PLBartModel, + ) + from transformers.models.plbart.modeling_plbart import PLBartDecoder, PLBartEncoder + + +def prepare_plbart_inputs_dict( + config, + input_ids, + decoder_input_ids, + attention_mask=None, + decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, +): + if attention_mask is None: + attention_mask = input_ids.ne(config.pad_token_id) + if decoder_attention_mask is None: + decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id) + if head_mask is None: + head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device) + if decoder_head_mask is None: + decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device) + if cross_attn_head_mask is None: + cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device) + return { + "input_ids": input_ids, + "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "decoder_attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + } + + +class PLBartModelTester: + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_labels=False, + vocab_size=99, + hidden_size=16, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=4, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=100, + eos_token_id=2, + pad_token_id=1, + bos_token_id=0, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + input_ids = input_ids.clamp(3) + input_ids[:, -1] = self.eos_token_id # Eos Token + + decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + config = self.get_config() + inputs_dict = prepare_plbart_inputs_dict(config, input_ids, decoder_input_ids) + return config, inputs_dict + + def get_config(self): + return PLBartConfig( + vocab_size=self.vocab_size, + d_model=self.hidden_size, + encoder_layers=self.num_hidden_layers, + decoder_layers=self.num_hidden_layers, + encoder_attention_heads=self.num_attention_heads, + decoder_attention_heads=self.num_attention_heads, + encoder_ffn_dim=self.intermediate_size, + decoder_ffn_dim=self.intermediate_size, + dropout=self.hidden_dropout_prob, + attention_dropout=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + eos_token_id=self.eos_token_id, + bos_token_id=self.bos_token_id, + pad_token_id=self.pad_token_id, + ) + + def prepare_config_and_inputs_for_common(self): + config, inputs_dict = self.prepare_config_and_inputs() + return config, inputs_dict + + def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict): + model = PLBartModel(config=config).get_decoder().to(torch_device).eval() + input_ids = inputs_dict["input_ids"] + attention_mask = inputs_dict["attention_mask"] + head_mask = inputs_dict["head_mask"] + + # first forward pass + outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True) + + output, past_key_values = outputs.to_tuple() + + # create hypothetical multiple next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + next_attn_mask = ids_tensor((self.batch_size, 3), 2) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + next_attention_mask = torch.cat([attention_mask, next_attn_mask], dim=-1) + + output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"] + output_with_past_key_values = model( + next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values + ) + output_from_past = output_with_past_key_values["last_hidden_state"] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() + + self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def check_encoder_decoder_model_standalone(self, config, inputs_dict): + model = PLBartModel(config=config).to(torch_device).eval() + outputs = model(**inputs_dict) + + encoder_last_hidden_state = outputs.encoder_last_hidden_state + last_hidden_state = outputs.last_hidden_state + + with tempfile.TemporaryDirectory() as tmpdirname: + encoder = model.get_encoder() + encoder.save_pretrained(tmpdirname) + encoder = PLBartEncoder.from_pretrained(tmpdirname).to(torch_device) + + encoder_last_hidden_state_2 = encoder(inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"])[ + 0 + ] + + self.parent.assertTrue((encoder_last_hidden_state_2 - encoder_last_hidden_state).abs().max().item() < 1e-3) + + with tempfile.TemporaryDirectory() as tmpdirname: + decoder = model.get_decoder() + decoder.save_pretrained(tmpdirname) + decoder = PLBartDecoder.from_pretrained(tmpdirname).to(torch_device) + + last_hidden_state_2 = decoder( + input_ids=inputs_dict["decoder_input_ids"], + attention_mask=inputs_dict["decoder_attention_mask"], + encoder_hidden_states=encoder_last_hidden_state, + encoder_attention_mask=inputs_dict["attention_mask"], + )[0] + + self.parent.assertTrue((last_hidden_state_2 - last_hidden_state).abs().max().item() < 1e-3) + + +@require_torch +class PLBartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + all_model_classes = ( + (PLBartModel, PLBartForConditionalGeneration, PLBartForSequenceClassification) if is_torch_available() else () + ) + all_generative_model_classes = (PLBartForConditionalGeneration,) if is_torch_available() else () + is_encoder_decoder = True + test_pruning = False + test_missing_keys = False + + def setUp(self): + self.model_tester = PLBartModelTester(self) + self.config_tester = ConfigTester(self, config_class=PLBartConfig) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_save_load_strict(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs() + for model_class in self.all_model_classes: + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) + self.assertEqual(info["missing_keys"], []) + + def test_decoder_model_past_with_large_inputs(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) + + def test_encoder_decoder_model_standalone(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common() + self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs) + + # PLBartForSequenceClassification does not support inputs_embeds + def test_inputs_embeds(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in (PLBartModel, PLBartForConditionalGeneration): + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) + + if not self.is_encoder_decoder: + input_ids = inputs["input_ids"] + del inputs["input_ids"] + else: + encoder_input_ids = inputs["input_ids"] + decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids) + del inputs["input_ids"] + inputs.pop("decoder_input_ids", None) + + wte = model.get_input_embeddings() + if not self.is_encoder_decoder: + inputs["inputs_embeds"] = wte(input_ids) + else: + inputs["inputs_embeds"] = wte(encoder_input_ids) + inputs["decoder_inputs_embeds"] = wte(decoder_input_ids) + + with torch.no_grad(): + model(**inputs)[0] + + def test_generate_fp16(self): + config, input_dict = self.model_tester.prepare_config_and_inputs() + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + model = PLBartForConditionalGeneration(config).eval().to(torch_device) + if torch_device == "cuda": + model.half() + model.generate(input_ids, attention_mask=attention_mask) + model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3) + + +def assert_tensors_close(a, b, atol=1e-12, prefix=""): + """If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error.""" + if a is None and b is None: + return True + try: + if torch.allclose(a, b, atol=atol): + return True + raise + except Exception: + pct_different = (torch.gt((a - b).abs(), atol)).float().mean().item() + if a.numel() > 100: + msg = f"tensor values are {pct_different:.1%} percent different." + else: + msg = f"{a} != {b}" + if prefix: + msg = prefix + ": " + msg + raise AssertionError(msg) + + +def _long_tensor(tok_lst): + return torch.tensor(tok_lst, dtype=torch.long, device=torch_device) + + +@require_torch +@require_sentencepiece +@require_tokenizers +class AbstractSeq2SeqIntegrationTest(unittest.TestCase): + maxDiff = 1000 # longer string compare tracebacks + checkpoint_name = None + + @classmethod + def setUpClass(cls): + cls.tokenizer = AutoTokenizer.from_pretrained(cls.checkpoint_name, use_fast=False) + return cls + + @cached_property + def model(self): + """Only load the model if needed.""" + model = PLBartForConditionalGeneration.from_pretrained(self.checkpoint_name).to(torch_device) + if "cuda" in torch_device: + model = model.half() + return model + + +@require_torch +@require_sentencepiece +@require_tokenizers +class PLBartJavaCsIntegrationTest(AbstractSeq2SeqIntegrationTest): + checkpoint_name = "uclanlp/plbart-java-cs" + src_text = [ + "public int maximum(int a, int b, int c){return Math.max(a, Math.max(b, c));}", + "public int product(int a, int b, int c){return a*b*c;}", + ] + tgt_text = [ + "public int maximum(int a, int b, int c){return Math.Max(", + "public int Product(int a, int b, int c){return a * b *", + ] + + @slow + def test_java_cs_generate_one(self): + batch = self.tokenizer( + ["public int maximum(int a, int b, int c){return Math.max(a, Math.max(b, c));}"], return_tensors="pt" + ) + batch = batch.to(torch_device) + translated_tokens = self.model.generate(**batch) + decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True) + self.assertEqual(self.tgt_text[0], decoded[0]) + # self.assertEqual(self.tgt_text[1], decoded[1]) + + @slow + def test_java_cs_generate_batch(self): + batch = self.tokenizer(self.src_text, return_tensors="pt", padding=True, truncation=True) + batch = batch.to(torch_device) + translated_tokens = self.model.generate(**batch) + decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True) + assert self.tgt_text == decoded + + def test_plbart_java_cs_config(self): + plbart_models = ["uclanlp/plbart-java-cs"] + expected = {"scale_embedding": True} + for name in plbart_models: + config = PLBartConfig.from_pretrained(name) + for k, v in expected.items(): + try: + self.assertEqual(v, getattr(config, k)) + except AssertionError as e: + e.args += (name, k) + raise + + def test_plbart_fast_forward(self): + config = PLBartConfig( + vocab_size=99, + d_model=24, + encoder_layers=2, + decoder_layers=2, + encoder_attention_heads=2, + decoder_attention_heads=2, + encoder_ffn_dim=32, + decoder_ffn_dim=32, + max_position_embeddings=48, + add_final_layer_norm=True, + ) + lm_model = PLBartForConditionalGeneration(config).to(torch_device) + context = torch.tensor( + [[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]], device=torch_device, dtype=torch.long + ) + summary = torch.tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]], device=torch_device, dtype=torch.long) + result = lm_model(input_ids=context, decoder_input_ids=summary, labels=summary) + expected_shape = (*summary.shape, config.vocab_size) + self.assertEqual(result.logits.shape, expected_shape) + + +@require_torch +@require_sentencepiece +@require_tokenizers +class PLBartBaseIntegrationTest(AbstractSeq2SeqIntegrationTest): + checkpoint_name = "uclanlp/plbart-base" + src_text = ["Is 0 the first Fibonacci number ?", "Find the sum of all prime numbers ."] + tgt_text = ["0 the first Fibonacci number?", "the sum of all prime numbers.......... the the"] + + # @unittest.skip("This test is broken, still generates english") + def test_base_generate(self): + inputs = self.tokenizer([self.src_text[0]], return_tensors="pt").to(torch_device) + translated_tokens = self.model.generate( + input_ids=inputs["input_ids"].to(torch_device), + decoder_start_token_id=self.tokenizer.lang_code_to_id["en_XX"], + ) + decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True) + self.assertEqual(self.tgt_text[0], decoded[0]) + + @slow + def test_fill_mask(self): + inputs = self.tokenizer(["Is 0 the Fibonacci ?"], return_tensors="pt").to(torch_device) + outputs = self.model.generate( + inputs["input_ids"], decoder_start_token_id=self.tokenizer.lang_code_to_id["en_XX"], num_beams=1 + ) + prediction: str = self.tokenizer.batch_decode( + outputs, clean_up_tokenization_spaces=True, skip_special_tokens=True + )[0] + self.assertEqual(prediction, "0 0 the 0 the 0 the 0 the 0 the 0 the 0 the 0 the") + + +class PLBartStandaloneDecoderModelTester: + def __init__( + self, + parent, + vocab_size=99, + batch_size=13, + d_model=16, + decoder_seq_length=7, + is_training=True, + is_decoder=True, + use_attention_mask=True, + use_cache=False, + use_labels=True, + decoder_start_token_id=2, + decoder_ffn_dim=32, + decoder_layers=4, + encoder_attention_heads=4, + decoder_attention_heads=4, + max_position_embeddings=30, + is_encoder_decoder=False, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.decoder_seq_length = decoder_seq_length + # For common tests + self.seq_length = self.decoder_seq_length + self.is_training = is_training + self.use_attention_mask = use_attention_mask + self.use_labels = use_labels + + self.vocab_size = vocab_size + self.d_model = d_model + self.hidden_size = d_model + self.num_hidden_layers = decoder_layers + self.decoder_layers = decoder_layers + self.decoder_ffn_dim = decoder_ffn_dim + self.encoder_attention_heads = encoder_attention_heads + self.decoder_attention_heads = decoder_attention_heads + self.num_attention_heads = decoder_attention_heads + self.eos_token_id = eos_token_id + self.bos_token_id = bos_token_id + self.pad_token_id = pad_token_id + self.decoder_start_token_id = decoder_start_token_id + self.use_cache = use_cache + self.max_position_embeddings = max_position_embeddings + self.is_encoder_decoder = is_encoder_decoder + + self.scope = None + self.decoder_key_length = decoder_seq_length + self.base_model_out_len = 2 + self.decoder_attention_idx = 1 + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) + + attention_mask = None + if self.use_attention_mask: + attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2) + + lm_labels = None + if self.use_labels: + lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) + + config = PLBartConfig( + vocab_size=self.vocab_size, + d_model=self.d_model, + decoder_layers=self.decoder_layers, + decoder_ffn_dim=self.decoder_ffn_dim, + encoder_attention_heads=self.encoder_attention_heads, + decoder_attention_heads=self.decoder_attention_heads, + eos_token_id=self.eos_token_id, + bos_token_id=self.bos_token_id, + use_cache=self.use_cache, + pad_token_id=self.pad_token_id, + decoder_start_token_id=self.decoder_start_token_id, + max_position_embeddings=self.max_position_embeddings, + is_encoder_decoder=self.is_encoder_decoder, + ) + + return (config, input_ids, attention_mask, lm_labels) + + def create_and_check_decoder_model_past( + self, + config, + input_ids, + attention_mask, + lm_labels, + ): + config.use_cache = True + model = PLBartDecoder(config=config).to(torch_device).eval() + # first forward pass + outputs = model(input_ids, use_cache=True) + outputs_use_cache_conf = model(input_ids) + outputs_no_past = model(input_ids, use_cache=False) + + self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf)) + self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1) + + past_key_values = outputs["past_key_values"] + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + + output_from_no_past = model(next_input_ids)["last_hidden_state"] + output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def create_and_check_decoder_model_attention_mask_past( + self, + config, + input_ids, + attention_mask, + lm_labels, + ): + model = PLBartDecoder(config=config).to(torch_device).eval() + + # create attention mask + attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device) + + half_seq_length = input_ids.shape[-1] // 2 + attn_mask[:, half_seq_length:] = 0 + + # first forward pass + past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True)["past_key_values"] + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # change a random masked slice from input_ids + random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1 + random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1) + input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens + + # append to next input_ids and attn_mask + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + attn_mask = torch.cat( + [attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)], + dim=1, + ) + + # get two different outputs + output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"] + output_from_past = model(next_tokens, attention_mask=attn_mask, past_key_values=past_key_values)[ + "last_hidden_state" + ] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + (config, input_ids, attention_mask, lm_labels) = config_and_inputs + inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask} + return config, inputs_dict + + +@require_torch +class PLBartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + all_model_classes = (PLBartDecoder, PLBartForCausalLM) if is_torch_available() else () + all_generative_model_classes = (PLBartForCausalLM,) if is_torch_available() else () + test_pruning = False + is_encoder_decoder = False + + def setUp(self): + self.model_tester = PLBartStandaloneDecoderModelTester(self, is_training=False) + self.config_tester = ConfigTester(self, config_class=PLBartConfig) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_decoder_model_past(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_past(*config_and_inputs) + + def test_decoder_model_attn_mask_past(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs) + + def test_retain_grad_hidden_states_attentions(self): + # decoder cannot keep gradients + return diff --git a/tests/test_tokenization_plbart.py b/tests/test_tokenization_plbart.py new file mode 100644 index 0000000000000..cc16770ecd7e5 --- /dev/null +++ b/tests/test_tokenization_plbart.py @@ -0,0 +1,361 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +from transformers import SPIECE_UNDERLINE, BatchEncoding, PLBartTokenizer, is_torch_available +from transformers.testing_utils import nested_simplify, require_sentencepiece, require_tokenizers, require_torch + +from .test_tokenization_common import TokenizerTesterMixin + + +SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model") + + +if is_torch_available(): + from transformers.models.plbart.modeling_plbart import shift_tokens_right + +EN_CODE = 50003 +PYTHON_CODE = 50002 + + +@require_sentencepiece +@require_tokenizers +class PLBartTokenizationTest(TokenizerTesterMixin, unittest.TestCase): + tokenizer_class = PLBartTokenizer + rust_tokenizer_class = None + test_rust_tokenizer = False + + def setUp(self): + super().setUp() + + # We have a SentencePiece fixture for testing + tokenizer = PLBartTokenizer(SAMPLE_VOCAB, language_codes="base", keep_accents=True) + tokenizer.save_pretrained(self.tmpdirname) + + def test_full_base_tokenizer(self): + tokenizer = PLBartTokenizer(SAMPLE_VOCAB, language_codes="base", keep_accents=True) + + tokens = tokenizer.tokenize("This is a test") + self.assertListEqual(tokens, ["▁This", "▁is", "▁a", "▁t", "est"]) + + self.assertListEqual( + tokenizer.convert_tokens_to_ids(tokens), + [value + tokenizer.fairseq_offset for value in [285, 46, 10, 170, 382]], + ) + + tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.") + self.assertListEqual( + tokens, + [ + SPIECE_UNDERLINE + "I", + SPIECE_UNDERLINE + "was", + SPIECE_UNDERLINE + "b", + "or", + "n", + SPIECE_UNDERLINE + "in", + SPIECE_UNDERLINE + "", + "9", + "2", + "0", + "0", + "0", + ",", + SPIECE_UNDERLINE + "and", + SPIECE_UNDERLINE + "this", + SPIECE_UNDERLINE + "is", + SPIECE_UNDERLINE + "f", + "al", + "s", + "é", + ".", + ], + ) + ids = tokenizer.convert_tokens_to_ids(tokens) + self.assertListEqual( + ids, + [ + value + tokenizer.fairseq_offset + for value in [8, 21, 84, 55, 24, 19, 7, 2, 602, 347, 347, 347, 3, 12, 66, 46, 72, 80, 6, 2, 4] + ], + ) + + back_tokens = tokenizer.convert_ids_to_tokens(ids) + self.assertListEqual( + back_tokens, + [ + SPIECE_UNDERLINE + "I", + SPIECE_UNDERLINE + "was", + SPIECE_UNDERLINE + "b", + "or", + "n", + SPIECE_UNDERLINE + "in", + SPIECE_UNDERLINE + "", + "", + "2", + "0", + "0", + "0", + ",", + SPIECE_UNDERLINE + "and", + SPIECE_UNDERLINE + "this", + SPIECE_UNDERLINE + "is", + SPIECE_UNDERLINE + "f", + "al", + "s", + "", + ".", + ], + ) + + end = tokenizer.vocab_size + language_tokens = [tokenizer.convert_ids_to_tokens(x) for x in range(end - 4, end)] + + self.assertListEqual(language_tokens, ["java", "python", "en_XX", ""]) + + def test_full_multi_tokenizer(self): + tokenizer = PLBartTokenizer(SAMPLE_VOCAB, language_codes="multi", keep_accents=True) + + tokens = tokenizer.tokenize("This is a test") + self.assertListEqual(tokens, ["▁This", "▁is", "▁a", "▁t", "est"]) + + self.assertListEqual( + tokenizer.convert_tokens_to_ids(tokens), + [value + tokenizer.fairseq_offset for value in [285, 46, 10, 170, 382]], + ) + + tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.") + self.assertListEqual( + tokens, + [ + SPIECE_UNDERLINE + "I", + SPIECE_UNDERLINE + "was", + SPIECE_UNDERLINE + "b", + "or", + "n", + SPIECE_UNDERLINE + "in", + SPIECE_UNDERLINE + "", + "9", + "2", + "0", + "0", + "0", + ",", + SPIECE_UNDERLINE + "and", + SPIECE_UNDERLINE + "this", + SPIECE_UNDERLINE + "is", + SPIECE_UNDERLINE + "f", + "al", + "s", + "é", + ".", + ], + ) + ids = tokenizer.convert_tokens_to_ids(tokens) + self.assertListEqual( + ids, + [ + value + tokenizer.fairseq_offset + for value in [8, 21, 84, 55, 24, 19, 7, 2, 602, 347, 347, 347, 3, 12, 66, 46, 72, 80, 6, 2, 4] + ], + ) + + back_tokens = tokenizer.convert_ids_to_tokens(ids) + self.assertListEqual( + back_tokens, + [ + SPIECE_UNDERLINE + "I", + SPIECE_UNDERLINE + "was", + SPIECE_UNDERLINE + "b", + "or", + "n", + SPIECE_UNDERLINE + "in", + SPIECE_UNDERLINE + "", + "", + "2", + "0", + "0", + "0", + ",", + SPIECE_UNDERLINE + "and", + SPIECE_UNDERLINE + "this", + SPIECE_UNDERLINE + "is", + SPIECE_UNDERLINE + "f", + "al", + "s", + "", + ".", + ], + ) + end = tokenizer.vocab_size + language_tokens = [tokenizer.convert_ids_to_tokens(x) for x in range(end - 7, end)] + + self.assertListEqual(language_tokens, ["java", "python", "en_XX", "javascript", "php", "ruby", "go"]) + + +@require_torch +@require_sentencepiece +@require_tokenizers +class PLBartPythonEnIntegrationTest(unittest.TestCase): + checkpoint_name = "uclanlp/plbart-python-en_XX" + src_text = [ + "def maximum(a,b,c):NEW_LINE_INDENTreturn max([a,b,c])", + "def sum(a,b,c):NEW_LINE_INDENTreturn sum([a,b,c])", + ] + tgt_text = [ + "Returns the maximum value of a b c.", + "Sums the values of a b c.", + ] + expected_src_tokens = [ + 134, + 5452, + 33460, + 33441, + 33463, + 33465, + 33463, + 33449, + 988, + 20, + 33456, + 19, + 33456, + 771, + 39, + 4258, + 889, + 3318, + 33441, + 33463, + 33465, + 33463, + 33449, + 2471, + 2, + PYTHON_CODE, + ] + + @classmethod + def setUpClass(cls): + cls.tokenizer: PLBartTokenizer = PLBartTokenizer.from_pretrained( + cls.checkpoint_name, language_codes="base", src_lang="python", tgt_lang="en_XX" + ) + cls.pad_token_id = 1 + return cls + + def check_language_codes(self): + self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["java"], 50001) + self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["python"], 50002) + self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["en_XX"], 50003) + + def test_python_en_tokenizer_batch_encode_plus(self): + ids = self.tokenizer.batch_encode_plus(self.src_text).input_ids[0] + self.assertListEqual(self.expected_src_tokens, ids) + + def test_python_en_tokenizer_decode_ignores_language_codes(self): + self.assertIn(PYTHON_CODE, self.tokenizer.all_special_ids) + generated_ids = [EN_CODE, 9037, 33442, 57, 752, 153, 14, 56, 18, 9, 2] + result = self.tokenizer.decode(generated_ids, skip_special_tokens=True) + expected_english = self.tokenizer.decode(generated_ids[1:], skip_special_tokens=True) + self.assertEqual(result, expected_english) + self.assertNotIn(self.tokenizer.eos_token, result) + + def test_python_en_tokenizer_truncation(self): + src_text = ["def sum(a,b,c):NEW_LINE_INDENTreturn sum([a,b,c])" * 20] + self.assertIsInstance(src_text[0], str) + desired_max_length = 10 + ids = self.tokenizer(src_text, max_length=desired_max_length, truncation=True).input_ids[0] + self.assertEqual(ids[-2], 2) + self.assertEqual(ids[-1], PYTHON_CODE) + self.assertEqual(len(ids), desired_max_length) + + def test_mask_token(self): + self.assertListEqual(self.tokenizer.convert_tokens_to_ids(["", "java"]), [50004, 50001]) + + def test_special_tokens_unaffacted_by_save_load(self): + tmpdirname = tempfile.mkdtemp() + original_special_tokens = self.tokenizer.fairseq_tokens_to_ids + self.tokenizer.save_pretrained(tmpdirname) + new_tok = PLBartTokenizer.from_pretrained(tmpdirname) + self.assertDictEqual(new_tok.fairseq_tokens_to_ids, original_special_tokens) + + @require_torch + def test_batch_fairseq_parity(self): + batch = self.tokenizer(self.src_text, padding=True) + with self.tokenizer.as_target_tokenizer(): + targets = self.tokenizer(self.tgt_text, padding=True, return_tensors="pt") + labels = targets["input_ids"] + batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id).tolist() + + # fairseq batch: https://gist.github.com/sshleifer/cba08bc2109361a74ac3760a7e30e4f4 + self.assertEqual(batch.input_ids[1][-2:], [2, PYTHON_CODE]) + self.assertEqual(batch.decoder_input_ids[1][0], EN_CODE) + self.assertEqual(batch.decoder_input_ids[1][-1], 2) + self.assertEqual(labels[1][-2:].tolist(), [2, EN_CODE]) + + @require_torch + def test_python_en_tokenizer_prepare_batch(self): + batch = self.tokenizer( + self.src_text, padding=True, truncation=True, max_length=len(self.expected_src_tokens), return_tensors="pt" + ) + with self.tokenizer.as_target_tokenizer(): + targets = self.tokenizer( + self.tgt_text, + padding=True, + truncation=True, + max_length=len(self.expected_src_tokens), + return_tensors="pt", + ) + labels = targets["input_ids"] + batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id) + + self.assertIsInstance(batch, BatchEncoding) + + self.assertEqual((2, 26), batch.input_ids.shape) + self.assertEqual((2, 26), batch.attention_mask.shape) + result = batch.input_ids.tolist()[0] + self.assertListEqual(self.expected_src_tokens, result) + self.assertEqual(2, batch.decoder_input_ids[0, -1]) # EOS + # Test that special tokens are reset + self.assertEqual(self.tokenizer.prefix_tokens, []) + self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id, PYTHON_CODE]) + + def test_seq2seq_max_length(self): + batch = self.tokenizer(self.src_text, padding=True, truncation=True, max_length=3, return_tensors="pt") + with self.tokenizer.as_target_tokenizer(): + targets = self.tokenizer(self.tgt_text, padding=True, truncation=True, max_length=10, return_tensors="pt") + labels = targets["input_ids"] + batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id) + + self.assertEqual(batch.input_ids.shape[1], 3) + self.assertEqual(batch.decoder_input_ids.shape[1], 10) + + @require_torch + def test_tokenizer_translation(self): + inputs = self.tokenizer._build_translation_inputs( + "A test", return_tensors="pt", src_lang="en_XX", tgt_lang="java" + ) + + self.assertEqual( + nested_simplify(inputs), + { + # A, test, EOS, en_XX + "input_ids": [[150, 242, 2, 50003]], + "attention_mask": [[1, 1, 1, 1]], + # java + "forced_bos_token_id": 50001, + }, + ) diff --git a/utils/check_repo.py b/utils/check_repo.py index 9ee2266ca7366..dd990913ed1db 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -45,6 +45,9 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [ # models to ignore for not tested "SegformerDecodeHead", # Building part of bigger (tested) model. + "PLBartEncoder", # Building part of bigger (tested) model. + "PLBartDecoder", # Building part of bigger (tested) model. + "PLBartDecoderWrapper", # Building part of bigger (tested) model. "BigBirdPegasusEncoder", # Building part of bigger (tested) model. "BigBirdPegasusDecoder", # Building part of bigger (tested) model. "BigBirdPegasusDecoderWrapper", # Building part of bigger (tested) model. @@ -119,6 +122,9 @@ "PerceiverForOpticalFlow", "SegformerDecodeHead", "FlaxBeitForMaskedImageModeling", + "PLBartEncoder", + "PLBartDecoder", + "PLBartDecoderWrapper", "BeitForMaskedImageModeling", "CLIPTextModel", "CLIPVisionModel",