From 4784d0e2b52c1424eb318cb5d30c1cb5fa5c32fb Mon Sep 17 00:00:00 2001 From: Jason Phang Date: Sat, 23 Jul 2022 14:38:52 -0700 Subject: [PATCH 01/25] PegasusX Initial commit --- docs/source/en/model_doc/pegasus_x.mdx | 101 + src/transformers/__init__.py | 44 + src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 3 + src/transformers/models/auto/modeling_auto.py | 8 + .../models/auto/modeling_flax_auto.py | 9 + src/transformers/models/pegasus_x/__init__.py | 128 ++ .../pegasus_x/configuration_pegasus_x.py | 165 ++ .../pegasus_x/modeling_flax_pegasus_x.py | 1749 ++++++++++++++++ .../models/pegasus_x/modeling_pegasus_x.py | 1796 +++++++++++++++++ .../pegasus_x/tokenization_pegasus_x.py | 54 + .../pegasus_x/tokenization_pegasus_x_fast.py | 56 + tests/models/pegasus_x/__init__.py | 0 .../pegasus_x/test_modeling_flax_pegasus_x.py | 346 ++++ .../pegasus_x/test_modeling_pegasus_x.py | 603 ++++++ utils/check_repo.py | 6 + 16 files changed, 5069 insertions(+) create mode 100644 docs/source/en/model_doc/pegasus_x.mdx create mode 100644 src/transformers/models/pegasus_x/__init__.py create mode 100644 src/transformers/models/pegasus_x/configuration_pegasus_x.py create mode 100644 src/transformers/models/pegasus_x/modeling_flax_pegasus_x.py create mode 100755 src/transformers/models/pegasus_x/modeling_pegasus_x.py create mode 100644 src/transformers/models/pegasus_x/tokenization_pegasus_x.py create mode 100644 src/transformers/models/pegasus_x/tokenization_pegasus_x_fast.py create mode 100644 tests/models/pegasus_x/__init__.py create mode 100644 tests/models/pegasus_x/test_modeling_flax_pegasus_x.py create mode 100644 tests/models/pegasus_x/test_modeling_pegasus_x.py diff --git a/docs/source/en/model_doc/pegasus_x.mdx b/docs/source/en/model_doc/pegasus_x.mdx new file mode 100644 index 0000000000000..853e3b1ae15c7 --- /dev/null +++ b/docs/source/en/model_doc/pegasus_x.mdx @@ -0,0 +1,101 @@ + + +# BrandNewBERT + +## Overview + +The BrandNewBERT model was proposed in []() by . + +The abstract from the paper is the following: + +** + +Tips: + + + +This model was contributed by [INSERT YOUR HF USERNAME HERE](). The original code can be found [here](). + +## PegasusXConfig + +[[autodoc]] PegasusXConfig + + +## PegasusXTokenizer + +[[autodoc]] PegasusXTokenizer + - build_inputs_with_special_tokens + - get_special_tokens_mask + - create_token_type_ids_from_sequences + - save_vocabulary + + +## PegasusXTokenizerFast + +[[autodoc]] PegasusXTokenizerFast + + +## PegasusXModel + +[[autodoc]] PegasusXModel + - forward + + +## PegasusXForConditionalGeneration + +[[autodoc]] PegasusXForConditionalGeneration + - forward + + +## PegasusXForSequenceClassification + +[[autodoc]] PegasusXForSequenceClassification + - forward + + +## PegasusXForQuestionAnswering + +[[autodoc]] PegasusXForQuestionAnswering + - forward + + +## PegasusXForCausalLM + +[[autodoc]] PegasusXForCausalLM + - forward + + +## FlaxPegasusXModel + +[[autodoc]] FlaxPegasusXModel + - call + + +## FlaxPegasusXForSequenceClassification + +[[autodoc]] FlaxPegasusXForSequenceClassification + - call + + +## FlaxPegasusXForQuestionAnswering + +[[autodoc]] FlaxPegasusXForQuestionAnswering + - call + + +## FlaxPegasusXForConditionalGeneration + +[[autodoc]] FlaxPegasusXForConditionalGeneration + - call + + diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 3281d266a2f3c..3c766301f1e59 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -119,6 +119,7 @@ ], "models": [], # Models + "models.pegasus_x": ["PEGASUS_X_PRETRAINED_CONFIG_ARCHIVE_MAP", "PegasusXConfig", "PegasusXTokenizer"], "models.albert": ["ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "AlbertConfig"], "models.auto": [ "ALL_PRETRAINED_CONFIG_ARCHIVE_MAP", @@ -517,6 +518,7 @@ ] else: # Fast tokenizers structure + _import_structure["models.pegasus_x"].append("PegasusXTokenizerFast") _import_structure["models.albert"].append("AlbertTokenizerFast") _import_structure["models.bart"].append("BartTokenizerFast") _import_structure["models.barthez"].append("BarthezTokenizerFast") @@ -766,6 +768,18 @@ _import_structure["modeling_utils"] = ["PreTrainedModel"] # PyTorch models structure + + _import_structure["models.pegasus_x"].extend( + [ + "PEGASUS_X_PRETRAINED_MODEL_ARCHIVE_LIST", + "PegasusXForCausalLM", + "PegasusXForConditionalGeneration", + "PegasusXForQuestionAnswering", + "PegasusXForSequenceClassification", + "PegasusXModel", + "PegasusXPreTrainedModel", + ] + ) _import_structure["models.albert"].extend( [ "ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -2685,6 +2699,16 @@ # Flax models structure + _import_structure["models.pegasus_x"].extend( + [ + "FlaxPegasusXForConditionalGeneration", + "FlaxPegasusXForQuestionAnswering", + "FlaxPegasusXForSequenceClassification", + "FlaxPegasusXModel", + "FlaxPegasusXPreTrainedModel", + ] + ) + _import_structure["models.bart"].extend( [ "FlaxBartDecoderPreTrainedModel", @@ -2937,6 +2961,7 @@ load_tf2_weights_in_pytorch_model, ) from .models.albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig + from .models.pegasus_x import PEGASUS_X_PRETRAINED_CONFIG_ARCHIVE_MAP, PegasusXConfig, PegasusXTokenizer from .models.auto import ( ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, @@ -3295,6 +3320,7 @@ from .utils.dummy_tokenizers_objects import * else: # Fast tokenizers imports + from .models.pegasus_x import PegasusXTokenizerFast from .models.albert import AlbertTokenizerFast from .models.bart import BartTokenizerFast from .models.barthez import BarthezTokenizerFast @@ -3498,6 +3524,16 @@ from .modeling_utils import PreTrainedModel # PyTorch model imports + + from .models.pegasus_x import ( + PEGASUS_X_PRETRAINED_MODEL_ARCHIVE_LIST, + PegasusXForConditionalGeneration, + PegasusXForCausalLM, + PegasusXForQuestionAnswering, + PegasusXForSequenceClassification, + PegasusXModel, + PegasusXPreTrainedModel, + ) from .models.albert import ( ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST, AlbertForMaskedLM, @@ -5028,6 +5064,14 @@ from .modeling_flax_utils import FlaxPreTrainedModel # Flax model imports + + from .models.pegasus_x import ( + FlaxPegasusXForConditionalGeneration, + FlaxPegasusXForQuestionAnswering, + FlaxPegasusXForSequenceClassification, + FlaxPegasusXModel, + FlaxPegasusXPreTrainedModel, + ) from .models.albert import ( FlaxAlbertForMaskedLM, FlaxAlbertForMultipleChoice, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index fdf315b2257d8..05a62c38a44c5 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -17,6 +17,7 @@ # limitations under the License. from . import ( + pegasus_x, albert, auto, bart, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index c9e6156a3843d..7c79442a75cad 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -29,6 +29,7 @@ CONFIG_MAPPING_NAMES = OrderedDict( [ # Add configs here + ("pegasus_x", "PegasusXConfig"), ("albert", "AlbertConfig"), ("bart", "BartConfig"), ("beit", "BeitConfig"), @@ -157,6 +158,7 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict( [ # Add archive maps here) + ("pegasus_x", "PEGASUS_X_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("albert", "ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("bart", "BART_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("beit", "BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"), @@ -270,6 +272,7 @@ MODEL_NAMES_MAPPING = OrderedDict( [ # Add full (and cased) model names here + ("pegasus_x", "PegasusX"), ("albert", "ALBERT"), ("bart", "BART"), ("barthez", "BARThez"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 0e026cb48d0c0..f30195f8b8694 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -28,6 +28,7 @@ MODEL_MAPPING_NAMES = OrderedDict( [ # Base model mapping + ("pegasus_x", "PegasusXModel"), ("albert", "AlbertModel"), ("bart", "BartModel"), ("beit", "BeitModel"), @@ -204,6 +205,8 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict( [ # Model with LM heads mapping + + ("pegasus_x", "PegasusXForConditionalGeneration"), ("albert", "AlbertForMaskedLM"), ("bart", "BartForConditionalGeneration"), ("bert", "BertForMaskedLM"), @@ -268,6 +271,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( [ # Model for Causal LM mapping + ("pegasus_x", "PegasusXForCausalLM"), ("bart", "BartForCausalLM"), ("bert", "BertLMHeadModel"), ("bert-generation", "BertGenerationDecoder"), @@ -451,6 +455,8 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict( [ # Model for Seq2Seq Causal LM mapping + + ("pegasus_x", "PegasusXForConditionalGeneration"), ("bart", "BartForConditionalGeneration"), ("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"), ("blenderbot", "BlenderbotForConditionalGeneration"), @@ -483,6 +489,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ # Model for Sequence Classification mapping + ("pegasus_x", "PegasusXForSequenceClassification"), ("albert", "AlbertForSequenceClassification"), ("bart", "BartForSequenceClassification"), ("bert", "BertForSequenceClassification"), @@ -541,6 +548,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( [ # Model for Question Answering mapping + ("pegasus_x", "PegasusXForQuestionAnswering"), ("albert", "AlbertForQuestionAnswering"), ("bart", "BartForQuestionAnswering"), ("bert", "BertForQuestionAnswering"), diff --git a/src/transformers/models/auto/modeling_flax_auto.py b/src/transformers/models/auto/modeling_flax_auto.py index 98c5d6fb5a104..624147a329aab 100644 --- a/src/transformers/models/auto/modeling_flax_auto.py +++ b/src/transformers/models/auto/modeling_flax_auto.py @@ -28,6 +28,7 @@ FLAX_MODEL_MAPPING_NAMES = OrderedDict( [ # Base model mapping + ("pegasus_x", "FlaxPegasusXModel"), ("albert", "FlaxAlbertModel"), ("bart", "FlaxBartModel"), ("beit", "FlaxBeitModel"), @@ -80,6 +81,8 @@ FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( [ # Model for Masked LM mapping + + ("pegasus_x", "FlaxPegasusXForConditionalGeneration"), ("albert", "FlaxAlbertForMaskedLM"), ("bart", "FlaxBartForConditionalGeneration"), ("bert", "FlaxBertForMaskedLM"), @@ -96,6 +99,8 @@ FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict( [ # Model for Seq2Seq Causal LM mapping + + ("pegasus_x", "FlaxPegasusXForConditionalGeneration"), ("bart", "FlaxBartForConditionalGeneration"), ("blenderbot", "FlaxBlenderbotForConditionalGeneration"), ("blenderbot-small", "FlaxBlenderbotSmallForConditionalGeneration"), @@ -142,6 +147,8 @@ FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ # Model for Sequence Classification mapping + + ("pegasus_x", "FlaxPegasusXForSequenceClassification"), ("albert", "FlaxAlbertForSequenceClassification"), ("bart", "FlaxBartForSequenceClassification"), ("bert", "FlaxBertForSequenceClassification"), @@ -158,6 +165,8 @@ FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( [ # Model for Question Answering mapping + + ("pegasus_x", "FlaxPegasusXForQuestionAnswering"), ("albert", "FlaxAlbertForQuestionAnswering"), ("bart", "FlaxBartForQuestionAnswering"), ("bert", "FlaxBertForQuestionAnswering"), diff --git a/src/transformers/models/pegasus_x/__init__.py b/src/transformers/models/pegasus_x/__init__.py new file mode 100644 index 0000000000000..bbb812d77790b --- /dev/null +++ b/src/transformers/models/pegasus_x/__init__.py @@ -0,0 +1,128 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +# rely on isort to merge the imports +from ...utils import _LazyModule, OptionalDependencyNotAvailable, is_tokenizers_available +from ...utils import is_torch_available + + + +from ...utils import is_flax_available + + + + +_import_structure = { + "configuration_pegasus_x": ["PEGASUS_X_PRETRAINED_CONFIG_ARCHIVE_MAP", "PegasusXConfig"], + "tokenization_pegasus_x": ["PegasusXTokenizer"], +} + +try: + if not is_tokenizers_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["tokenization_pegasus_x_fast"] = ["PegasusXTokenizerFast"] + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_pegasus_x"] = [ + "PEGASUS_X_PRETRAINED_MODEL_ARCHIVE_LIST", + "PegasusXForConditionalGeneration", + "PegasusXForQuestionAnswering", + "PegasusXForSequenceClassification", + "PegasusXForCausalLM", + "PegasusXModel", + "PegasusXPreTrainedModel", + ] + + + +try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_flax_pegasus_x"] = [ + "FlaxPegasusXForConditionalGeneration", + "FlaxPegasusXForQuestionAnswering", + "FlaxPegasusXForSequenceClassification", + "FlaxPegasusXModel", + "FlaxPegasusXPreTrainedModel", + ] + + + + +if TYPE_CHECKING: + from .configuration_pegasus_x import PEGASUS_X_PRETRAINED_CONFIG_ARCHIVE_MAP, PegasusXConfig + from .tokenization_pegasus_x import PegasusXTokenizer + + try: + if not is_tokenizers_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .tokenization_pegasus_x_fast import PegasusXTokenizerFast + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_pegasus_x import ( + PEGASUS_X_PRETRAINED_MODEL_ARCHIVE_LIST, + PegasusXForConditionalGeneration, + PegasusXForCausalLM, + PegasusXForQuestionAnswering, + PegasusXForSequenceClassification, + PegasusXModel, + PegasusXPreTrainedModel, + ) + + + + try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_pegasus_x import ( + FlaxPegasusXForConditionalGeneration, + FlaxPegasusXForQuestionAnswering, + FlaxPegasusXForSequenceClassification, + FlaxPegasusXModel, + FlaxPegasusXPreTrainedModel, + ) + + + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/pegasus_x/configuration_pegasus_x.py b/src/transformers/models/pegasus_x/configuration_pegasus_x.py new file mode 100644 index 0000000000000..b6554127389b7 --- /dev/null +++ b/src/transformers/models/pegasus_x/configuration_pegasus_x.py @@ -0,0 +1,165 @@ +# coding=utf-8 +# Copyright 2022 Google and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" BrandNewBERT model configuration """ + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + +PEGASUS_X_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "pegasus-x-base": "https://huggingface.co/pegasus-x-base/resolve/main/config.json", + # See all BrandNewBERT models at https://huggingface.co/models?filter=pegasus_x +} + + +class PegasusXConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`~PegasusXModel`]. + It is used to instantiate an BrandNewBERT model according to the specified arguments, defining the model + architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of + the BrandNewBERT [pegasus-x-base](https://huggingface.co/pegasus-x-base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used + to control the model outputs. Read the documentation from [`PretrainedConfig`] + for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50265): + Vocabulary size of the BrandNewBERT model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`~PegasusXModel`] or + [`~TFPegasusXModel`]. + d_model (`int`, *optional*, defaults to 1024): + Dimension of the layers and the pooler layer. + encoder_layers (`int`, *optional*, defaults to 12): + Number of encoder layers. + decoder_layers (`int`, *optional*, defaults to 12): + Number of decoder layers. + encoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + decoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer decoder. + decoder_ffn_dim (`int`, *optional*, defaults to 4096): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + encoder_ffn_dim (`int`, *optional*, defaults to 4096): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + activation_function (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, + `"gelu"`, `"relu"`, `"silu"` and `"gelu_new"` are supported. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + classifier_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for classifier. + max_position_embeddings (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + init_std (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + encoder_layerdrop: (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the encoder. See the [LayerDrop paper](see + https://arxiv.org/abs/1909.11556) for more details. + decoder_layerdrop: (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the decoder. See the [LayerDrop paper](see + https://arxiv.org/abs/1909.11556) for more details. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + Example: + + ```python + >>> from transformers import PegasusXModel, PegasusXConfig + + >>> # Initializing a BrandNewBERT pegasus-x-base style configuration + >>> configuration = PegasusXConfig() + + >>> # Initializing a model from the pegasus-x-base style configuration + >>> model = PegasusXModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` +""" + model_type = "pegasus_x" + keys_to_ignore_at_inference = ["past_key_values"] + + attribute_map = { + "num_attention_heads": "encoder_attention_heads", + "hidden_size": "d_model" + } + + def __init__( + self, + vocab_size=50265, + max_position_embeddings=1024, + encoder_layers=12, + encoder_ffn_dim=4096, + encoder_attention_heads=16, + decoder_layers=12, + decoder_ffn_dim=4096, + decoder_attention_heads=16, + encoder_layerdrop=0.0, + decoder_layerdrop=0.0, + use_cache=True, + is_encoder_decoder=True, + activation_function="gelu", + d_model=1024, + dropout=0.1, + attention_dropout=0.0, + activation_dropout=0.0, + init_std=0.02, + decoder_start_token_id=2, + classifier_dropout=0.0, + scale_embedding=False, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + **kwargs + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.d_model = d_model + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.activation_function = activation_function + self.init_std = init_std + self.encoder_layerdrop = encoder_layerdrop + self.decoder_layerdrop = decoder_layerdrop + self.classifier_dropout = classifier_dropout + self.use_cache = use_cache + self.num_hidden_layers = encoder_layers + self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + is_encoder_decoder=is_encoder_decoder, + decoder_start_token_id=decoder_start_token_id, + **kwargs + ) + + \ No newline at end of file diff --git a/src/transformers/models/pegasus_x/modeling_flax_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_flax_pegasus_x.py new file mode 100644 index 0000000000000..048fd382a9c73 --- /dev/null +++ b/src/transformers/models/pegasus_x/modeling_flax_pegasus_x.py @@ -0,0 +1,1749 @@ +# coding=utf-8 +# Copyright 2022 Google and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Flax BrandNewBERT model. """ + + +import math +import random +from functools import partial +from typing import Callable, Optional, Tuple + +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict, unfreeze, freeze +from flax.linen import combine_masks, make_causal_mask +from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict +from jax import lax +from jax.random import PRNGKey + +from ...utils import add_start_docstrings, replace_return_docstrings +from ...modeling_flax_outputs import ( + FlaxBaseModelOutput, + FlaxBaseModelOutputWithPastAndCrossAttentions, + FlaxCausalLMOutputWithCrossAttentions, + FlaxSeq2SeqLMOutput, + FlaxSeq2SeqModelOutput, + FlaxSeq2SeqQuestionAnsweringModelOutput, + FlaxSeq2SeqSequenceClassifierOutput, +) +from ...modeling_flax_utils import ( + ACT2FN, + FlaxPreTrainedModel, + append_call_sample_docstring, + append_replace_return_docstrings, + overwrite_call_docstring, +) +from ...utils import logging +from .configuration_pegasus_x import PegasusXConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "pegasus-x-base" +_CONFIG_FOR_DOC = "PegasusXConfig" +_TOKENIZER_FOR_DOC = "PegasusXTokenizer" + +PEGASUS_X_START_DOCSTRING = r""" + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the + generic methods the library implements for all its model (such as downloading or saving, resizing the input + embeddings, pruning heads etc.) + + This model is also a Flax Linen [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a regular Flax + Module and refer to the Flax documentation for all matter related to general usage and behavior. + + Finally, this model supports inherent JAX features such as: + + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + config ([`~PegasusXConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the + model weights. + dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): + The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on + GPUs) and `jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given `dtype`. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see + [`~FlaxPreTrainedModel.to_fp16`] and [`~FlaxPreTrainedModel.to_bf16`]. +""" + +PEGASUS_X_INPUTS_DOCSTRING = r""" + Args: + input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`~PegasusXTokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for + details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`~PegasusXTokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for + details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + For translation and summarization training, `decoder_input_ids` should be provided. If no + `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to + the right for denoising pre-training following the paper. + decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will + also be used by default. + + If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. + position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.max_position_embeddings - 1]`. + decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the + range `[0, config.max_position_embeddings - 1]`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +PEGASUS_X_ENCODE_INPUTS_DOCSTRING = r""" + Args: + input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`~PegasusXTokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for + details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.max_position_embeddings - 1]`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +PEGASUS_X_DECODE_INPUTS_DOCSTRING = r""" + Args: + decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`~PegasusXTokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for + details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + For translation and summarization training, `decoder_input_ids` should be provided. If no + `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to + the right for denoising pre-training following the paper. + encoder_outputs (`tuple(tuple(jnp.ndarray)`): + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: + `attentions`) `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, + *optional*) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the + cross-attention of the decoder. + encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will + also be used by default. + + If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. + decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the + range `[0, config.max_position_embeddings - 1]`. + past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): + Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast + auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: + """ + Shift input ids one token to the right. + """ + shifted_input_ids = jnp.roll(input_ids, 1, axis=-1) + shifted_input_ids = shifted_input_ids.at[(..., 0)].set(decoder_start_token_id) + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) + + return shifted_input_ids + + + +class FlaxPegasusXAttention(nn.Module): + config: PegasusXConfig + embed_dim: int + num_heads: int + dropout: float = 0.0 + causal: bool = False + bias: bool = True + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self) -> None: + self.head_dim = self.embed_dim // self.num_heads + assert ( + self.head_dim * self.num_heads == self.embed_dim + ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})." + + dense = partial( + nn.Dense, + self.embed_dim, + use_bias=self.bias, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + ) + + self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense() + self.out_proj = dense() + + self.dropout_layer = nn.Dropout(rate=self.dropout) + + if self.causal: + self.causal_mask = make_causal_mask( + jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool" + ) + + def _split_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim)) + + def _merge_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) + + @nn.compact + def _concatenate_to_cache(self, key, value, query, attention_mask): + """ + This function takes projected key, value states from a single input token and concatenates the states to cached + states from previous steps. This function is slighly adapted from the official Flax repository: + https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 + """ + # detect if we're initializing by absence of existing cache data. + is_initialized = self.has_variable("cache", "cached_key") + cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) + cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) + cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) + + if is_initialized: + *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape + # update key, value caches with our new 1d spatial slices + cur_index = cache_index.value + indices = (0,) * len(batch_dims) + (cur_index, 0, 0) + key = lax.dynamic_update_slice(cached_key.value, key, indices) + value = lax.dynamic_update_slice(cached_value.value, value, indices) + cached_key.value = key + cached_value.value = value + num_updated_cache_vectors = query.shape[1] + cache_index.value = cache_index.value + num_updated_cache_vectors + # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. + pad_mask = jnp.broadcast_to( + jnp.arange(max_length) < cur_index + num_updated_cache_vectors, + tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), + ) + attention_mask = combine_masks(pad_mask, attention_mask) + return key, value, attention_mask + + def __call__( + self, + hidden_states: jnp.ndarray, + key_value_states: Optional[jnp.ndarray] = None, + attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + ) -> Tuple[jnp.ndarray]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + batch_size = hidden_states.shape[0] + + # get query proj + query_states = self.q_proj(hidden_states) + # get key, value proj + if is_cross_attention: + # cross_attentions + key_states = self.k_proj(key_value_states) + value_states = self.v_proj(key_value_states) + else: + # self_attention + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = self._split_heads(query_states) + key_states = self._split_heads(key_states) + value_states = self._split_heads(value_states) + + # handle cache prepare causal attention mask + if self.causal: + query_length, key_length = query_states.shape[1], key_states.shape[1] + if self.has_variable("cache", "cached_key"): + mask_shift = self.variables["cache"]["cache_index"] + max_decoder_length = self.variables["cache"]["cached_key"].shape[1] + causal_mask = lax.dynamic_slice( + self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) + ) + else: + causal_mask = self.causal_mask[:, :, :query_length, :key_length] + causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) + + # combine masks if needed + if attention_mask is not None and self.causal: + attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) + attention_mask = combine_masks(attention_mask, causal_mask) + elif self.causal: + attention_mask = causal_mask + elif attention_mask is not None: + attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) + + # During fast autoregressive decoding, we feed one position at a time, + # and cache the keys and values step by step. + if self.causal and (self.has_variable("cache", "cached_key") or init_cache): + key_states, value_states, attention_mask = self._concatenate_to_cache( + key_states, value_states, query_states, attention_mask + ) + + # Convert the boolean attention mask to an attention bias. + if attention_mask is not None: + # attention mask in the form of attention bias + attention_bias = lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype), + ) + else: + attention_bias = None + + dropout_rng = None + if not deterministic and self.dropout > 0.0: + dropout_rng = self.make_rng("dropout") + + attn_weights = dot_product_attention_weights( + query_states, + key_states, + bias=attention_bias, + dropout_rng=dropout_rng, + dropout_rate=self.dropout, + broadcast_dropout=True, + deterministic=deterministic, + dtype=self.dtype, + precision=None, + ) + + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) + attn_output = self._merge_heads(attn_output) + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +class FlaxPegasusXEncoderLayer(nn.Module): + config: PegasusXConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self) -> None: + self.embed_dim = self.config.d_model + self.self_attn = FlaxPegasusXAttention( + config=self.config, + embed_dim=self.embed_dim, + num_heads=self.config.encoder_attention_heads, + dropout=self.config.attention_dropout, + dtype=self.dtype + ) + self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype) + self.dropout_layer = nn.Dropout(rate=self.config.dropout) + self.activation_fn = ACT2FN[self.config.activation_function] + self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) + self.fc1 = nn.Dense( + self.config.encoder_ffn_dim, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + ) + self.fc2 = nn.Dense( + self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) + ) + self.final_layer_norm = nn.LayerNorm(dtype=self.dtype) + + def __call__( + self, + hidden_states: jnp.ndarray, + attention_mask: jnp.ndarray, + output_attentions: bool = True, + deterministic: bool = True, + ) -> Tuple[jnp.ndarray]: + residual = hidden_states + hidden_states, attn_weights = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask) + + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class FlaxPegasusXEncoderLayerCollection(nn.Module): + config: PegasusXConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.layers = [ + FlaxPegasusXEncoderLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.encoder_layers) + ] + self.layerdrop = self.config.encoder_layerdrop + + def __call__( + self, + hidden_states, + attention_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + for encoder_layer in self.layers: + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = random.uniform(0, 1) + if not deterministic and (dropout_probability < self.layerdrop): # skip the layer + layer_outputs = (None, None) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions, + deterministic, + ) + hidden_states = layer_outputs[0] + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + outputs = (hidden_states, all_hidden_states, all_attentions) + + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutput( + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + ) + + +class FlaxPegasusXDecoderLayer(nn.Module): + config: PegasusXConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self) -> None: + self.embed_dim = self.config.d_model + self.self_attn = FlaxPegasusXAttention( + config=self.config, + embed_dim=self.embed_dim, + num_heads=self.config.decoder_attention_heads, + dropout=self.config.attention_dropout, + causal=True, + dtype=self.dtype, + ) + self.dropout_layer = nn.Dropout(rate=self.config.dropout) + self.activation_fn = ACT2FN[self.config.activation_function] + self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) + + self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype) + self.encoder_attn = FlaxPegasusXAttention( + config=self.config, + embed_dim=self.embed_dim, + num_heads=self.config.decoder_attention_heads, + dropout=self.config.attention_dropout, + dtype=self.dtype, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype) + self.fc1 = nn.Dense( + self.config.decoder_ffn_dim, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + ) + self.fc2 = nn.Dense( + self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) + ) + self.final_layer_norm = nn.LayerNorm(dtype=self.dtype) + + def __call__( + self, + hidden_states: jnp.ndarray, + attention_mask: jnp.ndarray, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + output_attentions: bool = True, + deterministic: bool = True, + ) -> Tuple[jnp.ndarray]: + residual = hidden_states + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache + ) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Cross-Attention Block + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + + hidden_states, cross_attn_weights = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + ) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = residual + hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # Fully Connected + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + return outputs + + +class FlaxPegasusXDecoderLayerCollection(nn.Module): + config: PegasusXConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.layers = [ + FlaxPegasusXDecoderLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.decoder_layers) + ] + self.layerdrop = self.config.decoder_layerdrop + + def __call__( + self, + hidden_states, + attention_mask, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = random.uniform(0, 1) + if not deterministic and (dropout_probability < self.layerdrop): + layer_outputs = (None, None, None) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, + output_attentions=output_attentions, + deterministic=deterministic, + ) + + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions] + + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +class FlaxPegasusXClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + config: PegasusXConfig + inner_dim: int + num_classes: int + pooler_dropout: float + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.dense = nn.Dense( + self.inner_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) + ) + self.dropout = nn.Dropout(rate=self.pooler_dropout) + self.out_proj = nn.Dense( + self.num_classes, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + ) + + def __call__(self, hidden_states: jnp.ndarray, deterministic: bool): + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.dense(hidden_states) + hidden_states = jnp.tanh(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.out_proj(hidden_states) + return hidden_states + + +class FlaxPegasusXEncoder(nn.Module): + config: PegasusXConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + embed_tokens: Optional[nn.Embed] = None + + def setup(self): + self.dropout_layer = nn.Dropout(rate=self.config.dropout) + + embed_dim = self.config.d_model + self.padding_idx = self.config.pad_token_id + self.max_source_positions = self.config.max_position_embeddings + self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0 + + if self.embed_tokens is None: + self.embed_tokens = nn.Embed( + self.config.vocab_size, + embed_dim, + embedding_init=jax.nn.initializers.normal(self.config.init_std), + ) + + # PegasusX is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + self.embed_positions = nn.Embed( + self.config.max_position_embeddings + self.offset, + embed_dim, + embedding_init=jax.nn.initializers.normal(self.config.init_std), + ) + self.layers = FlaxPegasusXEncoderLayerCollection(self.config, self.dtype) + self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + position_ids, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ): + input_shape = input_ids.shape + input_ids = input_ids.reshape(-1, input_shape[-1]) + + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + embed_pos = self.embed_positions(position_ids + self.offset) + + hidden_states = inputs_embeds + embed_pos + hidden_states = self.layernorm_embedding(hidden_states) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + + outputs = self.layers( + hidden_states, + attention_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return outputs + + return FlaxBaseModelOutput( + last_hidden_state=outputs.last_hidden_state, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class FlaxPegasusXDecoder(nn.Module): + config: PegasusXConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + embed_tokens: Optional[nn.Embed] = None + + def setup(self): + self.dropout_layer = nn.Dropout(rate=self.config.dropout) + + embed_dim = self.config.d_model + self.padding_idx = self.config.pad_token_id + self.max_target_positions = self.config.max_position_embeddings + self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0 + + if self.embed_tokens is None: + self.embed_tokens = nn.Embed( + self.config.vocab_size, + embed_dim, + embedding_init=jax.nn.initializers.normal(self.config.init_std), + ) + + # PegasusX is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + self.embed_positions = nn.Embed( + self.config.max_position_embeddings + self.offset, + embed_dim, + embedding_init=jax.nn.initializers.normal(self.config.init_std), + ) + + self.layers = FlaxPegasusXDecoderLayerCollection(self.config, self.dtype) + self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + position_ids, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ): + input_shape = input_ids.shape + input_ids = input_ids.reshape(-1, input_shape[-1]) + + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + # embed positions + positions = self.embed_positions(position_ids + self.offset) + + hidden_states = inputs_embeds + positions + hidden_states = self.layernorm_embedding(hidden_states) + + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + + outputs = self.layers( + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return outputs + + return FlaxBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=outputs.last_hidden_state, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +class FlaxPegasusXModule(nn.Module): + config: PegasusXConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.shared = nn.Embed( + self.config.vocab_size, + self.config.d_model, + embedding_init=jax.nn.initializers.normal(self.config.init_std), + ) + + self.encoder = FlaxPegasusXEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared) + self.decoder = FlaxPegasusXDecoder(self.config, dtype=self.dtype, embed_tokens=self.shared) + + def _get_encoder_module(self): + return self.encoder + + def _get_decoder_module(self): + return self.decoder + + def __call__( + self, + input_ids, + attention_mask, + decoder_input_ids, + decoder_attention_mask, + position_ids, + decoder_position_ids, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ): + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + position_ids=decoder_position_ids, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return FlaxSeq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +class FlaxPegasusXPreTrainedModel(FlaxPreTrainedModel): + config_class = PegasusXConfig + base_model_prefix: str = "model" + module_class: nn.Module = None + + def __init__( + self, + config: PegasusXConfig, + input_shape: Tuple[int] = (1, 1), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs + ): + module = self.module_class(config=config, dtype=dtype, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensors + input_ids = jnp.zeros(input_shape, dtype="i4") + # make sure initialization pass will work for FlaxPegasusXForSequenceClassificationModule + input_ids = input_ids.at[(..., -1)].set(self.config.eos_token_id) + attention_mask = jnp.ones_like(input_ids) + decoder_input_ids = input_ids + decoder_attention_mask = jnp.ones_like(input_ids) + + batch_size, sequence_length = input_ids.shape + position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + random_params = self.module.init( + rngs, + input_ids, + attention_mask, + decoder_input_ids, + decoder_attention_mask, + position_ids, + decoder_position_ids, + )["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + def init_cache(self, batch_size, max_length, encoder_outputs): + r""" + Args: + batch_size (`int`): + batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. + max_length (`int`): + maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized + cache. + encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`): + `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, + *optional*: `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of hidden-states at the output of the last layer of the + encoder. Used in the cross-attention of the decoder. + """ + # init input variables to retrieve cache + decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") + decoder_attention_mask = jnp.ones_like(decoder_input_ids) + decoder_position_ids = jnp.broadcast_to( + jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape + ) + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): + decoder_module = module._get_decoder_module() + return decoder_module( + decoder_input_ids, + decoder_attention_mask, + decoder_position_ids, + **kwargs, + ) + + init_variables = self.module.init( + jax.random.PRNGKey(0), + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + encoder_hidden_states=encoder_outputs[0], + init_cache=True, + method=_decoder_forward, # we only need to call the decoder to init the cache + ) + return unfreeze(init_variables["cache"]) + + @add_start_docstrings(PEGASUS_X_ENCODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=PegasusXConfig) + def encode( + self, + input_ids: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + position_ids: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example: + + ```python + >>> from transformers import PegasusXTokenizer, FlaxPegasusXForConditionalGeneration + + >>> model = FlaxPegasusXForConditionalGeneration.from_pretrained('pegasus-x-base') + >>> tokenizer = PegasusXTokenizer.from_pretrained('pegasus-x-base') + + >>> text = "My friends are cool but they eat too many carbs." + >>> inputs = tokenizer(text, max_length=1024, return_tensors='np') + >>> encoder_outputs = model.encode(**inputs) + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + if position_ids is None: + batch_size, sequence_length = input_ids.shape + position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + def _encoder_forward(module, input_ids, attention_mask, position_ids, **kwargs): + encode_module = module._get_encoder_module() + return encode_module(input_ids, attention_mask, position_ids, **kwargs) + + return self.module.apply( + {"params": params or self.params}, + input_ids=jnp.array(input_ids, dtype="i4"), + attention_mask=jnp.array(attention_mask, dtype="i4"), + position_ids=jnp.array(position_ids, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + method=_encoder_forward, + ) + + @add_start_docstrings(PEGASUS_X_DECODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=PegasusXConfig) + def decode( + self, + decoder_input_ids, + encoder_outputs, + encoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_position_ids: Optional[jnp.ndarray] = None, + past_key_values: dict = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example: + + ```python + >>> import jax.numpy as jnp + >>> from transformers import PegasusXTokenizer, FlaxPegasusXForConditionalGeneration + + >>> model = FlaxPegasusXForConditionalGeneration.from_pretrained('pegasus-x-base') + >>> tokenizer = PegasusXTokenizer.from_pretrained('pegasus-x-base') + + >>> text = "My friends are cool but they eat too many carbs." + >>> inputs = tokenizer(text, max_length=1024, return_tensors='np') + >>> encoder_outputs = model.encode(**inputs) + + >>> decoder_start_token_id = model.config.decoder_start_token_id + >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id + + >>> outputs = model.decode(decoder_input_ids, encoder_outputs) + >>> last_decoder_hidden_states = outputs.last_hidden_state + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + encoder_hidden_states = encoder_outputs[0] + if encoder_attention_mask is None: + batch_size, sequence_length = encoder_hidden_states.shape[:2] + encoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + batch_size, sequence_length = decoder_input_ids.shape + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + if decoder_position_ids is None: + if past_key_values is not None: + raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") + + decoder_position_ids = jnp.broadcast_to( + jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + ) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be + # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that + # it can be changed by FlaxPegasusXAttention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): + decoder_module = module._get_decoder_module() + return decoder_module( + decoder_input_ids, + decoder_attention_mask, + decoder_position_ids, + **kwargs, + ) + + outputs = self.module.apply( + inputs, + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + mutable=mutable, + method=_decoder_forward, + ) + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs, past = outputs + outputs["past_key_values"] = unfreeze(past["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs, past = outputs + outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] + + return outputs + + def __call__( + self, + input_ids: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + decoder_input_ids: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + position_ids: Optional[jnp.ndarray] = None, + decoder_position_ids: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # prepare encoder inputs + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + if position_ids is None: + batch_size, sequence_length = input_ids.shape + position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + + # prepare decoder inputs + if decoder_input_ids is None: + decoder_input_ids = shift_tokens_right( + input_ids, self.config.pad_token_id, decoder_start_token_id=self.config.decoder_start_token_id + ) + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones_like(decoder_input_ids) + if decoder_position_ids is None: + batch_size, sequence_length = decoder_input_ids.shape + decoder_position_ids = jnp.broadcast_to( + jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + ) + + # Handle any PRNG if needed + rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} + + return self.module.apply( + {"params": params or self.params}, + input_ids=jnp.array(input_ids, dtype="i4"), + attention_mask=jnp.array(attention_mask, dtype="i4"), + position_ids=jnp.array(position_ids, dtype="i4"), + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + ) + + +@add_start_docstrings( + "The bare PegasusX Model transformer outputting raw hidden-states without any specific head on top.", + PEGASUS_X_START_DOCSTRING, +) +class FlaxPegasusXModel(FlaxPegasusXPreTrainedModel): + config: PegasusXConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + module_class = FlaxPegasusXModule + + +append_call_sample_docstring( + FlaxPegasusXModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC +) + + +class FlaxPegasusXForConditionalGenerationModule(nn.Module): + config: PegasusXConfig + dtype: jnp.dtype = jnp.float32 + bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros + + def setup(self): + self.model = FlaxPegasusXModule(config=self.config, dtype=self.dtype) + self.lm_head = nn.Dense( + self.model.shared.num_embeddings, + use_bias=False, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + ) + self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings)) + + def _get_encoder_module(self): + return self.model.encoder + + def _get_decoder_module(self): + return self.model.decoder + + def __call__( + self, + input_ids, + attention_mask, + decoder_input_ids, + decoder_attention_mask, + position_ids, + decoder_position_ids, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ): + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + position_ids=position_ids, + decoder_position_ids=decoder_position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + hidden_states = outputs[0] + + if self.config.tie_word_embeddings: + shared_embedding = self.model.variables["params"]["shared"]["embedding"] + lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) + else: + lm_logits = self.lm_head(hidden_states) + + lm_logits += self.final_logits_bias.astype(self.dtype) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return output + + return FlaxSeq2SeqLMOutput( + logits=lm_logits, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +@add_start_docstrings( + "The PEGASUS_X Model with a language modeling head. Can be used for summarization.", PEGASUS_X_START_DOCSTRING +) +class FlaxPegasusXForConditionalGeneration(FlaxPegasusXPreTrainedModel): + module_class = FlaxPegasusXForConditionalGenerationModule + dtype: jnp.dtype = jnp.float32 + + @add_start_docstrings(PEGASUS_X_DECODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=PegasusXConfig) + def decode( + self, + decoder_input_ids, + encoder_outputs, + encoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_position_ids: Optional[jnp.ndarray] = None, + past_key_values: dict = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + deterministic: bool = True, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example: + + ```python + >>> import jax.numpy as jnp + >>> from transformers import PegasusXTokenizer, FlaxPegasusXForConditionalGeneration + + >>> model = FlaxPegasusXForConditionalGeneration.from_pretrained('pegasus-x-base') + >>> tokenizer = PegasusXTokenizer.from_pretrained('pegasus-x-base') + + >>> text = "My friends are cool but they eat too many carbs." + >>> inputs = tokenizer(text, max_length=1024, return_tensors='np') + >>> encoder_outputs = model.encode(**inputs) + + >>> decoder_start_token_id = model.config.decoder_start_token_id + >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id + + >>> outputs = model.decode(decoder_input_ids, encoder_outputs) + >>> logits = outputs.logits + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + encoder_hidden_states = encoder_outputs[0] + if encoder_attention_mask is None: + batch_size, sequence_length = encoder_hidden_states.shape[:2] + encoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + batch_size, sequence_length = decoder_input_ids.shape + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + if decoder_position_ids is None: + if past_key_values is not None: + raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") + + decoder_position_ids = jnp.broadcast_to( + jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + ) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be + # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that + # it can be changed by FlaxPegasusXAttention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): + decoder_module = module._get_decoder_module() + outputs = decoder_module( + decoder_input_ids, + decoder_attention_mask, + decoder_position_ids, + **kwargs, + ) + hidden_states = outputs[0] + + if self.config.tie_word_embeddings: + shared_embedding = module.model.variables["params"]["shared"]["embedding"] + lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) + else: + lm_logits = module.lm_head(hidden_states) + + lm_logits += module.final_logits_bias.astype(self.dtype) + return lm_logits, outputs + + outputs = self.module.apply( + inputs, + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + rngs=rngs, + mutable=mutable, + method=_decoder_forward, + ) + + if past_key_values is None: + lm_logits, decoder_outputs = outputs + else: + (lm_logits, decoder_outputs), past = outputs + + if return_dict: + outputs = FlaxCausalLMOutputWithCrossAttentions( + logits=lm_logits, + hidden_states=decoder_outputs.hidden_states, + attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + ) + else: + outputs = (lm_logits,) + decoder_outputs[1:] + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs["past_key_values"] = unfreeze(past["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] + + return outputs + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + max_length, + attention_mask: Optional[jnp.DeviceArray] = None, + decoder_attention_mask: Optional[jnp.DeviceArray] = None, + encoder_outputs=None, + **kwargs + ): + # initializing the cache + batch_size, seq_length = decoder_input_ids.shape + + past_key_values = self.init_cache(batch_size, max_length, encoder_outputs) + # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. + # But since the decoder uses a causal mask, those positions are masked anyways. + # Thus we can create a single static attention_mask here, which is more efficient for compilation + extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") + if decoder_attention_mask is not None: + position_ids = decoder_attention_mask.cumsum(axis=-1) - 1 + extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0)) + else: + position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) + + return { + "past_key_values": past_key_values, + "encoder_outputs": encoder_outputs, + "encoder_attention_mask": attention_mask, + "decoder_attention_mask": extended_attention_mask, + "decoder_position_ids": position_ids, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1 + return model_kwargs + + +FLAX_PEGASUS_X_CONDITIONAL_GENERATION_DOCSTRING = """ + Returns: + + Summarization example: + + ```python + >>> from transformers import PegasusXTokenizer, FlaxPegasusXForConditionalGeneration + + >>> model = FlaxPegasusXForConditionalGeneration.from_pretrained('pegasus-x-base') + >>> tokenizer = PegasusXTokenizer.from_pretrained('pegasus-x-base') + + >>> ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs." + >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='np') + + >>> # Generate Summary + >>> summary_ids = model.generate(inputs['input_ids']).sequences + >>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)) + ``` + + Mask filling example: + + ```python + >>> import jax + >>> from transformers import PegasusXTokenizer, FlaxPegasusXForConditionalGeneration + + >>> model = FlaxPegasusXForConditionalGeneration.from_pretrained('pegasus-x-base') + >>> tokenizer = PegasusXTokenizer.from_pretrained('pegasus-x-base') + + >>> TXT = "My friends are but they eat too many carbs." + >>> input_ids = tokenizer([TXT], return_tensors='np')['input_ids'] + + >>> logits = model(input_ids).logits + >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() + >>> probs = jax.nn.softmax(logits[0, masked_index], axis=0) + >>> values, predictions = jax.lax.top_k(probs, k=1) + + >>> tokenizer.decode(predictions).split() + ``` +""" + +overwrite_call_docstring( + FlaxPegasusXForConditionalGeneration, PEGASUS_X_INPUTS_DOCSTRING + FLAX_PEGASUS_X_CONDITIONAL_GENERATION_DOCSTRING +) +append_replace_return_docstrings( + FlaxPegasusXForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC +) + + +class FlaxPegasusXForSequenceClassificationModule(nn.Module): + config: PegasusXConfig + dtype: jnp.dtype = jnp.float32 + num_labels: Optional[int] = None + + def setup(self): + self.model = FlaxPegasusXModule(config=self.config, dtype=self.dtype) + self.classification_head = FlaxPegasusXClassificationHead( + config=self.config, + inner_dim=self.config.d_model, + num_classes=self.num_labels if self.num_labels is not None else self.config.num_labels, + pooler_dropout=self.config.classifier_dropout, + ) + + def _get_encoder_module(self): + return self.model.encoder + + def _get_decoder_module(self): + return self.model.decoder + + def __call__( + self, + input_ids, + attention_mask, + decoder_input_ids, + decoder_attention_mask, + position_ids, + decoder_position_ids, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ): + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + position_ids=position_ids, + decoder_position_ids=decoder_position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + hidden_states = outputs[0] # last hidden state + + eos_mask = jnp.where(input_ids == self.config.eos_token_id, 1, 0) + + # The first condition is necessary to overcome jax._src.errors.ConcretizationTypeError during JIT compilation + if type(eos_mask) != jax.interpreters.partial_eval.DynamicJaxprTracer: + if len(jnp.unique(eos_mask.sum(1))) > 1: + raise ValueError("All examples must have the same number of tokens.") + + if any(eos_mask.sum(1) == 0): + raise ValueError("There are missing tokens in input_ids") + + # Ensure to keep 1 only for the last token for each example + eos_mask_noised = eos_mask + jnp.arange(eos_mask.shape[1]) * 1e-6 + eos_mask = jnp.where(eos_mask_noised == eos_mask_noised.max(1).reshape(-1, 1), 1, 0) + + sentence_representation = jnp.einsum("ijk, ij -> ijk", hidden_states, eos_mask).sum(1) + logits = self.classification_head(sentence_representation, deterministic=deterministic) + + if not return_dict: + output = (logits,) + outputs[1:] + return output + + return FlaxSeq2SeqSequenceClassifierOutput( + logits=logits, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +@add_start_docstrings( + """ + PegasusX model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE + tasks. + """, + PEGASUS_X_START_DOCSTRING, +) +class FlaxPegasusXForSequenceClassification(FlaxPegasusXPreTrainedModel): + module_class = FlaxPegasusXForSequenceClassificationModule + dtype = jnp.float32 + + +append_call_sample_docstring( + FlaxPegasusXForSequenceClassification, + _TOKENIZER_FOR_DOC, + _CHECKPOINT_FOR_DOC, + FlaxSeq2SeqSequenceClassifierOutput, + _CONFIG_FOR_DOC, +) + + +class FlaxPegasusXForQuestionAnsweringModule(nn.Module): + config: PegasusXConfig + dtype: jnp.dtype = jnp.float32 + num_labels = 2 + + def setup(self): + self.model = FlaxPegasusXModule(config=self.config, dtype=self.dtype) + self.qa_outputs = nn.Dense( + self.num_labels, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) + ) + + def _get_encoder_module(self): + return self.model.encoder + + def _get_decoder_module(self): + return self.model.decoder + + def __call__( + self, + input_ids, + attention_mask, + decoder_input_ids, + decoder_attention_mask, + position_ids, + decoder_position_ids, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ): + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + position_ids=position_ids, + decoder_position_ids=decoder_position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = jnp.split(logits, logits.shape[-1], axis=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + if not return_dict: + output = (start_logits, end_logits) + outputs[1:] + return output + + return FlaxSeq2SeqQuestionAnsweringModelOutput( + start_logits=start_logits, + end_logits=end_logits, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +@add_start_docstrings( + """ + PEGASUS_X Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + PEGASUS_X_START_DOCSTRING, +) +class FlaxPegasusXForQuestionAnswering(FlaxPegasusXPreTrainedModel): + module_class = FlaxPegasusXForQuestionAnsweringModule + dtype = jnp.float32 + + +append_call_sample_docstring( + FlaxPegasusXForQuestionAnswering, + _TOKENIZER_FOR_DOC, + _CHECKPOINT_FOR_DOC, + FlaxSeq2SeqQuestionAnsweringModelOutput, + _CONFIG_FOR_DOC, +) + diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py new file mode 100755 index 0000000000000..8d3d8b98b56d7 --- /dev/null +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -0,0 +1,1796 @@ +# coding=utf-8 +# Copyright 2022 Google The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch BrandNewBERT model. """ + + +import math +import copy +import random +from typing import Optional, Tuple, List, Union + +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...utils import ( + add_code_sample_docstrings, + add_end_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, + Seq2SeqQuestionAnsweringModelOutput, + Seq2SeqSequenceClassifierOutput, + CausalLMOutputWithCrossAttentions +) +from ...modeling_utils import PreTrainedModel +from ...utils import logging +from .configuration_pegasus_x import PegasusXConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "pegasus-x-base" +_CONFIG_FOR_DOC = "PegasusXConfig" +_TOKENIZER_FOR_DOC = "PegasusXTokenizer" + + +PEGASUS_X_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "pegasus-x-base", + # See all BrandNewBERT models at https://huggingface.co/models?filter=pegasus_x +] + + +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) + mask_cond = torch.arange(mask.size(-1)) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +def _expand_mask( + mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None +): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) + + +class PegasusXLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + super().__init__(num_embeddings, embedding_dim) + + def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + bsz, seq_len = input_ids_shape[:2] + positions = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device + ) + return super().forward(positions) + + +class PegasusXAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {num_heads})." + self.scaling = self.head_dim ** -0.5 + self.is_decoder = is_decoder + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + bsz, tgt_len, embed_dim = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit akward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +class PegasusXEncoderLayer(nn.Module): + def __init__(self, config: PegasusXConfig): + super().__init__() + self.embed_dim = config.d_model + self.self_attn = PegasusXAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + layer_head_mask: torch.Tensor, + output_attentions: bool = False, + ): + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape *(seq_len, batch, embed_dim)* + attention_mask (`torch.FloatTensor`): attention mask of size + *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + *(config.encoder_attention_heads,)*. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states, attn_weights, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + if hidden_states.dtype == torch.float16 and (torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class PegasusXDecoderLayer(nn.Module): + def __init__(self, config: PegasusXConfig): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = PegasusXAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.encoder_attn = PegasusXAttention( + self.embed_dim, + config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + cross_layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + ): + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape *(seq_len, batch, embed_dim)* + attention_mask (`torch.FloatTensor`): attention mask of size + *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values. + encoder_hidden_states (`torch.FloatTensor`): cross attention input to the layer of shape *(seq_len, batch, embed_dim)* + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + *(encoder_attention_heads,)*. + cross_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of + size *(decoder_attention_heads,)*. + past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + # Self Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + # add present self-attn cache to positions 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Cross-Attention Block + cross_attn_present_key_value = None + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_layer_head_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class PegasusXClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__( + self, + input_dim: int, + inner_dim: int, + num_classes: int, + pooler_dropout: float, + ): + super().__init__() + self.dense = nn.Linear(input_dim, inner_dim) + self.dropout = nn.Dropout(p=pooler_dropout) + self.out_proj = nn.Linear(inner_dim, num_classes) + + def forward(self, hidden_states: torch.Tensor): + hidden_states = self.dropout(hidden_states) + hidden_states = self.dense(hidden_states) + hidden_states = torch.tanh(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.out_proj(hidden_states) + return hidden_states + + +class PegasusXPreTrainedModel(PreTrainedModel): + config_class = PegasusXConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (PegasusXDecoder, PegasusXEncoder)): + module.gradient_checkpointing = value + + +PEGASUS_X_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + + Parameters: + config ([`~PegasusXConfig`]): + Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model + weights. +""" + +PEGASUS_X_GENERATION_EXAMPLE = r""" + Summarization example: + + ```python + >>> from transformers import PegasusXTokenizer, PegasusXForConditionalGeneration + + >>> model = PegasusXForConditionalGeneration.from_pretrained('pegasus-x-base') + >>> tokenizer = PegasusXTokenizer.from_pretrained('pegasus-x-base') + + >>> ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs." + >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt') + + >>> # Generate Summary + >>> summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=5) + >>> print(tokenizer.decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)) + ``` +""" + +PEGASUS_X_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`~PegasusXTokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for + details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Provide for translation and summarization training. By default, the model will create this tensor by + shifting the `input_ids` to the right, following the paper. + decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will + also be used by default. + + If you want to change padding behavior, you should read [`modeling_pegasus_x._prepare_decoder_attention_mask`] and + modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: + `attentions`) `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, + *optional*) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the + cross-attention of the decoder. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors + of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` + (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` + instead of all ``decoder_input_ids``` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded + representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` + have to be input (see `past_key_values`). This is useful if you want more control over how to convert + `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + + If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` + takes the value of `inputs_embeds`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up + decoding (see `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +PEGASUS_X_STANDALONE_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`ProphetNetTokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for + details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class PegasusXEncoder(PegasusXPreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`PegasusXEncoderLayer`]. + + Args: + config: PegasusXConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, config: PegasusXConfig, embed_tokens: Optional[nn.Embedding] = None): + super().__init__(config) + + self.dropout = config.dropout + self.layerdrop = config.encoder_layerdrop + + embed_dim = config.d_model + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_position_embeddings + self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + + if embed_tokens is not None: + self.embed_tokens = embed_tokens + else: + self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) + + self.embed_positions = PegasusXLearnedPositionalEmbedding( + config.max_position_embeddings, + embed_dim, + ) + self.layers = nn.ModuleList([PegasusXEncoderLayer(config) for _ in range(config.encoder_layers)]) + self.layernorm_embedding = nn.LayerNorm(embed_dim) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids=None, + attention_mask=None, + head_mask=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`~PegasusXTokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] + for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded + representation. This is useful if you want more control over how to convert `input_ids` indices + into associated vectors than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + embed_pos = self.embed_positions(input_shape) + + hidden_states = inputs_embeds + embed_pos + hidden_states = self.layernorm_embedding(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + assert head_mask.size()[0] == ( + len(self.layers) + ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = random.uniform(0, 1) + if self.training and (dropout_probability < self.layerdrop): # skip the layer + layer_outputs = (None, None) + else: + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(encoder_layer), + hidden_states, + attention_mask, + (head_mask[idx] if head_mask is not None else None), + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class PegasusXDecoder(PegasusXPreTrainedModel): + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`PegasusXDecoderLayer`] + + Args: + config: PegasusXConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, config: PegasusXConfig, embed_tokens: Optional[nn.Embedding] = None): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_position_embeddings + self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + + if embed_tokens is not None: + self.embed_tokens = embed_tokens + else: + self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) + + self.embed_positions = PegasusXLearnedPositionalEmbedding( + config.max_position_embeddings, + config.d_model, + ) + self.layers = nn.ModuleList([PegasusXDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layernorm_embedding = nn.LayerNorm(config.d_model) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length + ).to(self.device) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`~PegasusXTokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] + for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 + tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional + tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential + decoding. + + If `past_key_values` are used, the user can optionally input only the last + `decoder_input_ids` (those that don't have their past key value states given to this model) of + shape `(batch_size, 1)` instead of all ``decoder_input_ids``` of shape `(batch_size, + sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices + into associated vectors than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + attention_mask = self._prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds, past_key_values_length) + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + + # embed positions + positions = self.embed_positions(input_shape, past_key_values_length) + + hidden_states = inputs_embeds + positions + hidden_states = self.layernorm_embedding(hidden_states) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + next_decoder_cache = () if use_cache else None + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + assert attn_mask.size()[0] == ( + len(self.layers) + ), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + dropout_probability = random.uniform(0, 1) + if self.training and (dropout_probability < self.layerdrop): + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + if use_cache: + logger.warning("`use_cache = True` is incompatible with gradient checkpointing`. Setting `use_cache = False`...") + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, use_cache) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + head_mask[idx] if head_mask is not None else None, + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, + None, + ) + else: + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +@add_start_docstrings( + "The bare BrandNewBERT Model outputting raw hidden-states without any specific head on top.", + PEGASUS_X_START_DOCSTRING, +) +class PegasusXModel(PegasusXPreTrainedModel): + def __init__(self, config: PegasusXConfig): + super().__init__(config) + + padding_idx, vocab_size = config.pad_token_id, config.vocab_size + self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) + + self.encoder = PegasusXEncoder(config, self.shared) + self.decoder = PegasusXDecoder(config, self.shared) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, value): + self.shared = value + self.encoder.embed_tokens = self.shared + self.decoder.embed_tokens = self.shared + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(PEGASUS_X_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=Seq2SeqModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + encoder_outputs=None, + past_key_values=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + "The BrandNewBERT Model with a language modeling head. Can be used for summarization.", PEGASUS_X_START_DOCSTRING +) +class PegasusXForConditionalGeneration(PegasusXPreTrainedModel): + base_model_prefix = "model" + _keys_to_ignore_on_load_missing = [ + r"final_logits_bias", + r"encoder\.version", + r"decoder\.version", + r"lm_head\.weight", + ] + + def __init__(self, config: PegasusXConfig): + super().__init__(config) + self.model = PegasusXModel(config) + self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) + self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_encoder(self): + return self.model.get_encoder() + + def get_decoder(self): + return self.model.get_decoder() + + def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: + new_embeddings = super().resize_token_embeddings(new_num_tokens) + self._resize_final_logits_bias(new_num_tokens) + return new_embeddings + + def _resize_final_logits_bias(self, new_num_tokens: int) -> None: + old_num_tokens = self.final_logits_bias.shape[-1] + if new_num_tokens <= old_num_tokens: + new_bias = self.final_logits_bias[:, :new_num_tokens] + else: + extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) + new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) + self.register_buffer("final_logits_bias", new_bias) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + @add_start_docstrings_to_model_forward(PEGASUS_X_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + @add_end_docstrings(PEGASUS_X_GENERATION_EXAMPLE) + def forward( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + encoder_outputs=None, + past_key_values=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Conditional generation example: + + ```python + >>> from transformers import PegasusXTokenizer, PegasusXForConditionalGeneration + >>> tokenizer = PegasusXTokenizer.from_pretrained('pegasus-x-base') + >>> TXT = "My friends are but they eat too many carbs." + + >>> model = PegasusXForConditionalGeneration.from_pretrained('pegasus-x-base') + >>> input_ids = tokenizer([TXT], return_tensors='pt')['input_ids'] + >>> logits = model(input_ids).logits + + >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() + >>> probs = logits[0, masked_index].softmax(dim=0) + >>> values, predictions = probs.topk(5) + + >>> tokenizer.decode(predictions).split() + ``` +""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + if decoder_input_ids is None: + decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + past=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs + ): + # cut decoder_input_ids if past is used + if past is not None: + decoder_input_ids = decoder_input_ids[:, -1:] + + return { + "input_ids": None, # encoder_outputs is defined. input_ids not needed + "encoder_outputs": encoder_outputs, + "past_key_values": past, + "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, # change this to avoid caching (presumably for debugging) + } + + @staticmethod + def _reorder_cache(past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past + + +@add_start_docstrings( + """ + PegasusX model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE + tasks. + """, + PEGASUS_X_START_DOCSTRING, +) +class PegasusXForSequenceClassification(PegasusXPreTrainedModel): + def __init__(self, config: PegasusXConfig, **kwargs): + super().__init__(config, **kwargs) + self.model = PegasusXModel(config) + self.classification_head = PegasusXClassificationHead( + config.d_model, + config.d_model, + config.num_labels, + config.classifier_dropout, + ) + self.model._init_weights(self.classification_head.dense) + self.model._init_weights(self.classification_head.out_proj) + + @add_start_docstrings_to_model_forward(PEGASUS_X_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=Seq2SeqSequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + encoder_outputs=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + if input_ids is None and inputs_embeds is not None: + raise NotImplementedError( + f"Passing input embeddings is currently not supported for {self.__class__.__name__}" + ) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] # last hidden state + + eos_mask = input_ids.eq(self.config.eos_token_id) + + if len(torch.unique_consecutive(eos_mask.sum(1))) > 1: + raise ValueError("All examples must have the same number of tokens.") + sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[ + :, -1, : + ] + logits = self.classification_head(sentence_representation) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.config.num_labels == 1: + self.config.problem_type = "regression" + elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.config.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return Seq2SeqSequenceClassifierOutput( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +@add_start_docstrings( + """ + BrandNewBERT Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + PEGASUS_X_START_DOCSTRING, +) +class PegasusXForQuestionAnswering(PegasusXPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + config.num_labels = 2 + self.num_labels = config.num_labels + + self.model = PegasusXModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + self.model._init_weights(self.qa_outputs) + + @add_start_docstrings_to_model_forward(PEGASUS_X_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=Seq2SeqQuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + encoder_outputs=None, + start_positions=None, + end_positions=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if start_positions is not None and end_positions is not None: + use_cache = False + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = ( + start_logits, + end_logits, + ) + outputs[1:] + return ((total_loss,) + output) if total_loss is not None else output + + return Seq2SeqQuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + +class PegasusXDecoderWrapper(PegasusXPreTrainedModel): + """ + This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is + used in combination with the [`EncoderDecoderModel`] framework. + """ + + def __init__(self, config): + super().__init__(config) + self.decoder = PegasusXDecoder(config) + + def forward(self, *args, **kwargs): + return self.decoder(*args, **kwargs) + + +class PegasusXForCausalLM(PegasusXPreTrainedModel): + def __init__(self, config): + config = copy.deepcopy(config) + config.is_decoder = True + config.is_encoder_decoder = False + super().__init__(config) + self.model = PegasusXDecoderWrapper(config) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model.decoder = decoder + + def get_decoder(self): + return self.model.decoder + + @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`~PegasusXTokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] + for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + if the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used + in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up + decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` + (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` + instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up + decoding (see `past_key_values`). + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + Returns: + + Example: + + ```python + >>> from transformers import PegasusXTokenizer, PegasusXForCausalLM + + >>> tokenizer = PegasusXTokenizer.from_pretrained('facebook/bart-large') + >>> model = PegasusXForCausalLM.from_pretrained('facebook/bart-large', add_cross_attention=False) + >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder." + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> logits = outputs.logits + ``` +""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + logits = self.lm_head(outputs[0]) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs): + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + + if past: + input_ids = input_ids[:, -1:] + # first step, decoder_cached_states are empty + return { + "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed + "attention_mask": attention_mask, + "past_key_values": past, + "use_cache": use_cache, + } + + @staticmethod + def _reorder_cache(past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past diff --git a/src/transformers/models/pegasus_x/tokenization_pegasus_x.py b/src/transformers/models/pegasus_x/tokenization_pegasus_x.py new file mode 100644 index 0000000000000..09a7cfdf7bf93 --- /dev/null +++ b/src/transformers/models/pegasus_x/tokenization_pegasus_x.py @@ -0,0 +1,54 @@ +# coding=utf-8 +# Copyright 2022 Google and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for BrandNewBERT.""" +from ...utils import logging +from ..bert.tokenization_bert import BertTokenizer + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} + +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + "pegasus-x-base": "https://huggingface.co/pegasus-x-base/resolve/main/vocab.txt", + } +} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "pegasus-x-base": 512, +} + + +PRETRAINED_INIT_CONFIGURATION = { + "pegasus-x-base": {"do_lower_case": False}, +} + + +class PegasusXTokenizer(BertTokenizer): + r""" + Construct a BrandNewBERT tokenizer. + + [`~PegasusXTokenizer`] is identical to [`BertTokenizer`] and runs end-to-end + tokenization: punctuation splitting and wordpiece. + + Refer to superclass [`BertTokenizer`] for usage examples and documentation concerning + parameters. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION diff --git a/src/transformers/models/pegasus_x/tokenization_pegasus_x_fast.py b/src/transformers/models/pegasus_x/tokenization_pegasus_x_fast.py new file mode 100644 index 0000000000000..099149ed0f6c5 --- /dev/null +++ b/src/transformers/models/pegasus_x/tokenization_pegasus_x_fast.py @@ -0,0 +1,56 @@ +# coding=utf-8 +# Copyright 2022 Google and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for BrandNewBERT.""" +from ...utils import logging +from ..bert.tokenization_bert_fast import BertTokenizerFast +from .tokenization_pegasus_x import PegasusXTokenizer + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} + +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + "pegasus-x-base": "https://huggingface.co/pegasus-x-base/resolve/main/vocab.txt", + } +} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "pegasus-x-base": 512, +} + + +PRETRAINED_INIT_CONFIGURATION = { + "pegasus-x-base": {"do_lower_case": False}, +} + + +class PegasusXTokenizerFast(BertTokenizerFast): + r""" + Construct a "fast" BrandNewBERT tokenizer (backed by HuggingFace's *tokenizers* library). + + [`~PegasusXTokenizerFast`] is identical to [`BertTokenizerFast`] and runs + end-to-end tokenization: punctuation splitting and wordpiece. + + Refer to superclass [`BertTokenizerFast`] for usage examples and documentation concerning + parameters. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION + slow_tokenizer_class = PegasusXTokenizer diff --git a/tests/models/pegasus_x/__init__.py b/tests/models/pegasus_x/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/models/pegasus_x/test_modeling_flax_pegasus_x.py b/tests/models/pegasus_x/test_modeling_flax_pegasus_x.py new file mode 100644 index 0000000000000..e187716b56fa0 --- /dev/null +++ b/tests/models/pegasus_x/test_modeling_flax_pegasus_x.py @@ -0,0 +1,346 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +from transformers import ( + is_flax_available, + PegasusXConfig, + PegasusXTokenizer, +) +from transformers.testing_utils import require_sentencepiece, require_flax, require_tokenizers, slow + +from ...test_configuration_common import ConfigTester +from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor + + +if is_flax_available(): + import numpy as np + import jax.numpy as jnp + from transformers import ( + FlaxPegasusXForConditionalGeneration, + FlaxPegasusXForQuestionAnswering, + FlaxPegasusXForSequenceClassification, + FlaxPegasusXModel, + ) + + +@require_flax +class FlaxPegasusXModelTester: + config_cls = PegasusXConfig + config_updates = {} + hidden_act = "gelu" + + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_labels=False, + vocab_size=99, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + intermediate_size=37, + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=20, + eos_token_id=2, + pad_token_id=1, + bos_token_id=0, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + + def prepare_config_and_inputs_for_common(self): + input_ids = ids_tensor([self.batch_size, self.seq_length - 1], self.vocab_size).clip(3, self.vocab_size) + eos_tensor = np.expand_dims(np.array([self.eos_token_id] * self.batch_size), 1) + input_ids = np.concatenate([input_ids, eos_tensor], axis=1) + + decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + config = self.config_cls( + vocab_size=self.vocab_size, + d_model=self.hidden_size, + encoder_layers=self.num_hidden_layers, + decoder_layers=self.num_hidden_layers, + encoder_attention_heads=self.num_attention_heads, + decoder_attention_heads=self.num_attention_heads, + encoder_ffn_dim=self.intermediate_size, + decoder_ffn_dim=self.intermediate_size, + dropout=self.hidden_dropout_prob, + attention_dropout=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + eos_token_ids=[2], + bos_token_id=self.bos_token_id, + pad_token_id=self.pad_token_id, + decoder_start_token_id=self.pad_token_id, + **self.config_updates, + ) + inputs_dict = prepare_pegasus_x_inputs_dict(config, input_ids, decoder_input_ids) + return config, inputs_dict + + def check_use_cache_forward(self, model_class_name, config, inputs_dict): + max_decoder_length = 20 + model = model_class_name(config) + + encoder_outputs = model.encode(inputs_dict["input_ids"]) + + decoder_input_ids, decoder_attention_mask = ( + inputs_dict["decoder_input_ids"], + inputs_dict["decoder_attention_mask"], + ) + + past_key_values = model.init_cache(decoder_input_ids.shape[0], max_decoder_length, encoder_outputs) + decoder_attention_mask = jnp.ones((decoder_input_ids.shape[0], max_decoder_length), dtype="i4") + + decoder_position_ids = jnp.broadcast_to( + jnp.arange(decoder_input_ids.shape[-1] - 1)[None, :], + (decoder_input_ids.shape[0], decoder_input_ids.shape[-1] - 1), + ) + outputs_cache = model.decode( + decoder_input_ids[:, :-1], + encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + past_key_values=past_key_values, + decoder_position_ids=decoder_position_ids, + ) + + decoder_position_ids = jnp.array(decoder_input_ids.shape[0] * [[decoder_input_ids.shape[-1] - 1]], dtype="i4") + outputs_cache_next = model.decode( + decoder_input_ids[:, -1:], + encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + past_key_values=outputs_cache.past_key_values, + decoder_position_ids=decoder_position_ids, + ) + + outputs = model.decode(decoder_input_ids, encoder_outputs) + + diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5]))) + self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}") + + def check_use_cache_forward_with_attn_mask(self, model_class_name, config, inputs_dict): + max_decoder_length = 20 + model = model_class_name(config) + + encoder_outputs = model.encode(inputs_dict["input_ids"]) + + decoder_input_ids, decoder_attention_mask = ( + inputs_dict["decoder_input_ids"], + inputs_dict["decoder_attention_mask"], + ) + + decoder_attention_mask_cache = jnp.concatenate( + [ + decoder_attention_mask, + jnp.zeros((decoder_attention_mask.shape[0], max_decoder_length - decoder_attention_mask.shape[1])), + ], + axis=-1, + ) + + past_key_values = model.init_cache(decoder_input_ids.shape[0], max_decoder_length, encoder_outputs) + decoder_position_ids = jnp.broadcast_to( + jnp.arange(decoder_input_ids.shape[-1] - 1)[None, :], + (decoder_input_ids.shape[0], decoder_input_ids.shape[-1] - 1), + ) + + outputs_cache = model.decode( + decoder_input_ids[:, :-1], + encoder_outputs, + decoder_attention_mask=decoder_attention_mask_cache, + past_key_values=past_key_values, + decoder_position_ids=decoder_position_ids, + ) + decoder_position_ids = jnp.array(decoder_input_ids.shape[0] * [[decoder_input_ids.shape[-1] - 1]], dtype="i4") + outputs_cache_next = model.decode( + decoder_input_ids[:, -1:], + encoder_outputs, + past_key_values=outputs_cache.past_key_values, + decoder_attention_mask=decoder_attention_mask_cache, + decoder_position_ids=decoder_position_ids, + ) + + outputs = model.decode(decoder_input_ids, encoder_outputs, decoder_attention_mask=decoder_attention_mask) + + diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5]))) + self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}") + + +def prepare_pegasus_x_inputs_dict( + config, + input_ids, + decoder_input_ids, + attention_mask=None, + decoder_attention_mask=None, +): + if attention_mask is None: + attention_mask = np.not_equal(input_ids, config.pad_token_id).astype(np.int8) + if decoder_attention_mask is None: + decoder_attention_mask = np.concatenate([np.ones(decoder_input_ids[:, :1].shape, dtype=np.int8), np.not_equal(decoder_input_ids[:, 1:], config.pad_token_id).astype(np.int8)], axis=-1) + return { + "input_ids": input_ids, + "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "decoder_attention_mask": decoder_attention_mask, + } + + +@require_flax +class FlaxPegasusXModelTest(FlaxModelTesterMixin, unittest.TestCase): + all_model_classes = ( + ( + FlaxPegasusXForConditionalGeneration, + FlaxPegasusXForQuestionAnswering, + FlaxPegasusXForSequenceClassification, + FlaxPegasusXModel, + ) if is_flax_available() + else () + ) + all_generative_model_classes = (FlaxPegasusXForConditionalGeneration,) if is_flax_available() else () + is_encoder_decoder = True + test_pruning = False + test_head_masking = False + test_onnx = False + + def setUp(self): + self.model_tester = FlaxPegasusXModelTester(self) + self.config_tester = ConfigTester(self, config_class=PegasusXConfig) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_use_cache_forward(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + for model_class in self.all_model_classes: + self.model_tester.check_use_cache_forward(model_class, config, inputs_dict) + + def test_use_cache_forward_with_attn_mask(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + for model_class in self.all_model_classes: + self.model_tester.check_use_cache_forward_with_attn_mask(model_class, config, inputs_dict) + + +def _assert_tensors_equal(a, b, atol=1e-12, prefix=""): + """If tensors not close, or a and b arent both tensors, raise a nice Assertion error.""" + if a is None and b is None: + return True + try: + if _assert_tensors_equal(a, b, atol=atol): + return True + raise + except Exception: + if len(prefix) > 0: + prefix = f"{prefix}: " + raise AssertionError(f"{prefix}{a} != {b}") + + +def _long_tensor(tok_lst): + return np.array(tok_lst, dtype=np.int32) + + +TOLERANCE = 1e-4 + + +@slow +@require_sentencepiece +@require_tokenizers +@require_flax +class FlaxPegasusXModelIntegrationTest(unittest.TestCase): + def test_inference_no_head(self): + model = FlaxPegasusXModel.from_pretrained('pegasus-x-base') + # change to intended input here + input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]]) + decoder_input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]]) + inputs_dict = prepare_pegasus_x_inputs_dict(model.config, input_ids, decoder_input_ids) + output = model(**inputs_dict)[0] + expected_shape = (1, 11, 1024) + self.assertEqual(output.shape, expected_shape) + # change to expected output here + expected_slice = np.array( + [[0.7144, 0.8143, -1.2813], [0.7144, 0.8143, -1.2813], [-0.0467, 2.5911, -2.1845]], + ) + _assert_tensors_equal(output[:, :3, :3], expected_slice, atol=TOLERANCE) + + def test_inference_with_head(self): + model = FlaxPegasusXForConditionalGeneration.from_pretrained('pegasus-x-base') + # change to intended input here + input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]]) + decoder_input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]]) + inputs_dict = prepare_pegasus_x_inputs_dict(model.config, input_ids, decoder_input_ids) + output = model(**inputs_dict)[0] + expected_shape = (1, 11, 1024) + self.assertEqual(output.shape, expected_shape) + # change to expected output here + expected_slice = np.array( + [[0.7144, 0.8143, -1.2813], [0.7144, 0.8143, -1.2813], [-0.0467, 2.5911, -2.1845]], + ) + _assert_tensors_equal(output[:, :3, :3], expected_slice, atol=TOLERANCE) + + def test_seq_to_seq_generation(self): + hf = FlaxPegasusXForConditionalGeneration.from_pretrained('pegasus-x-base') + tok = PegasusXTokenizer.from_pretrained('pegasus-x-base') + + batch_input = [ + # string 1, + # string 2, + # string 3, + # string 4, + ] + + # The below article tests that we don't add any hypotheses outside of the top n_beams + dct = tok.batch_encode_plus( + batch_input, + max_length=512, + padding="max_length", + truncation_strategy="only_first", + truncation=True, + return_tensors="np", + ) + + hypotheses_batch = hf.generate( + input_ids=dct["input_ids"], + attention_mask=dct["attention_mask"], + num_beams=2, + ) + + EXPECTED = [ + # here expected 1, + # here expected 2, + # here expected 3, + # here expected 4, + ] + + generated = tok.batch_decode( + hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True + ) + assert generated == EXPECTED diff --git a/tests/models/pegasus_x/test_modeling_pegasus_x.py b/tests/models/pegasus_x/test_modeling_pegasus_x.py new file mode 100644 index 0000000000000..1666db5449e44 --- /dev/null +++ b/tests/models/pegasus_x/test_modeling_pegasus_x.py @@ -0,0 +1,603 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch BrandNewBERT model. """ + + +import copy +import tempfile +import unittest + +from transformers import is_torch_available +from transformers.utils import cached_property +from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device + +from ...test_configuration_common import ConfigTester +from ...generation.test_generation_utils import GenerationTesterMixin +from ...test_modeling_common import ModelTesterMixin, ids_tensor + + +if is_torch_available(): + import torch + + from transformers import ( + PegasusXConfig, + PegasusXForConditionalGeneration, + PegasusXForQuestionAnswering, + PegasusXForCausalLM, + PegasusXForSequenceClassification, + PegasusXModel, + PegasusXTokenizer, + ) + from transformers.models.pegasus_x.modeling_pegasus_x import ( + PegasusXDecoder, + PegasusXEncoder, + ) + + +def prepare_pegasus_x_inputs_dict( + config, + input_ids, + decoder_input_ids, + attention_mask=None, + decoder_attention_mask=None, +): + if attention_mask is None: + attention_mask = input_ids.ne(config.pad_token_id) + if decoder_attention_mask is None: + decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id) + return { + "input_ids": input_ids, + "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "decoder_attention_mask": attention_mask, + } + + +@require_torch +class PegasusXModelTester: + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_labels=False, + vocab_size=99, + hidden_size=16, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=4, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=20, + eos_token_id=2, + pad_token_id=1, + bos_token_id=0, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp( + 3, + ) + input_ids[:, -1] = self.eos_token_id # Eos Token + + decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + config = PegasusXConfig( + vocab_size=self.vocab_size, + d_model=self.hidden_size, + encoder_layers=self.num_hidden_layers, + decoder_layers=self.num_hidden_layers, + encoder_attention_heads=self.num_attention_heads, + decoder_attention_heads=self.num_attention_heads, + encoder_ffn_dim=self.intermediate_size, + decoder_ffn_dim=self.intermediate_size, + dropout=self.hidden_dropout_prob, + attention_dropout=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + eos_token_id=self.eos_token_id, + bos_token_id=self.bos_token_id, + pad_token_id=self.pad_token_id, + ) + inputs_dict = prepare_pegasus_x_inputs_dict(config, input_ids, decoder_input_ids) + return config, inputs_dict + + def prepare_config_and_inputs_for_common(self): + config, inputs_dict = self.prepare_config_and_inputs() + return config, inputs_dict + + def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict): + model = PegasusXModel(config=config).get_decoder().to(torch_device).eval() + input_ids = inputs_dict["input_ids"] + attention_mask = inputs_dict["attention_mask"] + + # first forward pass + outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) + + output, past_key_values = outputs.to_tuple() + + # create hypothetical multiple next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + next_attn_mask = ids_tensor((self.batch_size, 3), 2) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + next_attention_mask = torch.cat([attention_mask, next_attn_mask], dim=-1) + + output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"] + output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)["last_hidden_state"] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() + + self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-2)) + + def check_encoder_decoder_model_standalone(self, config, inputs_dict): + model = PegasusXModel(config=config).to(torch_device).eval() + outputs = model(**inputs_dict) + + encoder_last_hidden_state = outputs.encoder_last_hidden_state + last_hidden_state = outputs.last_hidden_state + + with tempfile.TemporaryDirectory() as tmpdirname: + encoder = model.get_encoder() + encoder.save_pretrained(tmpdirname) + encoder = PegasusXEncoder.from_pretrained(tmpdirname).to(torch_device) + + encoder_last_hidden_state_2 = encoder(inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"])[ + 0 + ] + + self.parent.assertTrue((encoder_last_hidden_state_2 - encoder_last_hidden_state).abs().max().item() < 1e-3) + + with tempfile.TemporaryDirectory() as tmpdirname: + decoder = model.get_decoder() + decoder.save_pretrained(tmpdirname) + decoder = PegasusXDecoder.from_pretrained(tmpdirname).to(torch_device) + + last_hidden_state_2 = decoder( + input_ids=inputs_dict["decoder_input_ids"], + attention_mask=inputs_dict["decoder_attention_mask"], + encoder_hidden_states=encoder_last_hidden_state, + encoder_attention_mask=inputs_dict["attention_mask"], + )[0] + + self.parent.assertTrue((last_hidden_state_2 - last_hidden_state).abs().max().item() < 1e-3) + + +@require_torch +class PegasusXModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + all_model_classes = ( + (PegasusXModel, PegasusXForConditionalGeneration, PegasusXForSequenceClassification, PegasusXForQuestionAnswering) + if is_torch_available() + else () + ) + all_generative_model_classes = (PegasusXForConditionalGeneration,) if is_torch_available() else () + is_encoder_decoder = True + test_pruning = False + test_head_masking = False + test_missing_keys = False + + def setUp(self): + self.model_tester = PegasusXModelTester(self) + self.config_tester = ConfigTester(self, config_class=PegasusXConfig) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_save_load_strict(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs() + for model_class in self.all_model_classes: + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) + self.assertEqual(info["missing_keys"], []) + + def test_decoder_model_past_with_large_inputs(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) + + def test_encoder_decoder_model_standalone(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common() + self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs) + + # PegasusXForSequenceClassification does not support inputs_embeds + def test_inputs_embeds(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in (PegasusXModel, PegasusXForConditionalGeneration, PegasusXForQuestionAnswering): + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) + + if not self.is_encoder_decoder: + input_ids = inputs["input_ids"] + del inputs["input_ids"] + else: + encoder_input_ids = inputs["input_ids"] + decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids) + del inputs["input_ids"] + inputs.pop("decoder_input_ids", None) + + wte = model.get_input_embeddings() + if not self.is_encoder_decoder: + inputs["inputs_embeds"] = wte(input_ids) + else: + inputs["inputs_embeds"] = wte(encoder_input_ids) + inputs["decoder_inputs_embeds"] = wte(decoder_input_ids) + + with torch.no_grad(): + model(**inputs)[0] + + def test_generate_fp16(self): + config, input_dict = self.model_tester.prepare_config_and_inputs() + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + model = PegasusXForConditionalGeneration(config).eval().to(torch_device) + if torch_device == "cuda": + model.half() + model.generate(input_ids, attention_mask=attention_mask) + model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3) + + +def assert_tensors_close(a, b, atol=1e-12, prefix=""): + """If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error.""" + if a is None and b is None: + return True + try: + if torch.allclose(a, b, atol=atol): + return True + raise + except Exception: + pct_different = (torch.gt((a - b).abs(), atol)).float().mean().item() + if a.numel() > 100: + msg = f"tensor values are {pct_different:.1%} percent different." + else: + msg = f"{a} != {b}" + if prefix: + msg = prefix + ": " + msg + raise AssertionError(msg) + + +def _long_tensor(tok_lst): + return torch.tensor(tok_lst, dtype=torch.long, device=torch_device) + + +TOLERANCE = 1e-4 + + +@require_torch +@require_sentencepiece +@require_tokenizers +@slow +class PegasusXModelIntegrationTests(unittest.TestCase): + @cached_property + def default_tokenizer(self): + return PegasusXTokenizer.from_pretrained('pegasus-x-base') + + def test_inference_no_head(self): + model = PegasusXModel.from_pretrained('pegasus-x-base').to(torch_device) + input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]]) + decoder_input_ids = _long_tensor([[2, 0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588]]) + inputs_dict = prepare_pegasus_x_inputs_dict(model.config, input_ids, decoder_input_ids) + with torch.no_grad(): + output = model(**inputs_dict)[0] + expected_shape = torch.Size((1, 11, 1024)) + self.assertEqual(output.shape, expected_shape) + # change to expected output here + expected_slice = torch.tensor( + [[0.7144, 0.8143, -1.2813], [0.7144, 0.8143, -1.2813], [-0.0467, 2.5911, -2.1845]], device=torch_device + ) + self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=TOLERANCE)) + + def test_inference_head(self): + model = PegasusXForConditionalGeneration.from_pretrained('pegasus-x-base').to(torch_device) + + # change to intended input + input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]]) + decoder_input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]]) + inputs_dict = prepare_pegasus_x_inputs_dict(model.config, input_ids, decoder_input_ids) + with torch.no_grad(): + output = model(**inputs_dict)[0] + expected_shape = torch.Size((1, 11, model.config.vocab_size)) + self.assertEqual(output.shape, expected_shape) + # change to expected output here + expected_slice = torch.tensor( + [[0.7144, 0.8143, -1.2813], [0.7144, 0.8143, -1.2813], [-0.0467, 2.5911, -2.1845]], device=torch_device + ) + self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=TOLERANCE)) + + def test_seq_to_seq_generation(self): + hf = PegasusXForConditionalGeneration.from_pretrained('pegasus-x-base').to(torch_device) + tok = PegasusXTokenizer.from_pretrained('pegasus-x-base') + + batch_input = [ + # string 1, + # string 2, + # string 3, + # string 4, + ] + + # The below article tests that we don't add any hypotheses outside of the top n_beams + dct = tok.batch_encode_plus( + batch_input, + max_length=512, + padding="max_length", + truncation_strategy="only_first", + truncation=True, + return_tensors="pt", + ) + + hypotheses_batch = hf.generate( + input_ids=dct["input_ids"].to(torch_device), + attention_mask=dct["attention_mask"].to(torch_device), + num_beams=2, + ) + + EXPECTED = [ + # here expected 1, + # here expected 2, + # here expected 3, + # here expected 4, + ] + + generated = tok.batch_decode( + hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True + ) + assert generated == EXPECTED + + +class PegasusXStandaloneDecoderModelTester: + def __init__( + self, + parent, + vocab_size=99, + batch_size=13, + d_model=16, + decoder_seq_length=7, + is_training=True, + is_decoder=True, + use_attention_mask=True, + use_cache=False, + use_labels=True, + decoder_start_token_id=2, + decoder_ffn_dim=32, + decoder_layers=4, + encoder_attention_heads=4, + decoder_attention_heads=4, + max_position_embeddings=30, + is_encoder_decoder=False, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.decoder_seq_length = decoder_seq_length + # For common tests + self.seq_length = self.decoder_seq_length + self.is_training = is_training + self.use_attention_mask = use_attention_mask + self.use_labels = use_labels + + self.vocab_size = vocab_size + self.d_model = d_model + self.hidden_size = d_model + self.num_hidden_layers = decoder_layers + self.decoder_layers = decoder_layers + self.decoder_ffn_dim = decoder_ffn_dim + self.encoder_attention_heads = encoder_attention_heads + self.decoder_attention_heads = decoder_attention_heads + self.num_attention_heads = decoder_attention_heads + self.eos_token_id = eos_token_id + self.bos_token_id = bos_token_id + self.pad_token_id = pad_token_id + self.decoder_start_token_id = decoder_start_token_id + self.use_cache = use_cache + self.max_position_embeddings = max_position_embeddings + self.is_encoder_decoder = is_encoder_decoder + + self.scope = None + self.decoder_key_length = decoder_seq_length + self.base_model_out_len = 2 + self.decoder_attention_idx = 1 + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) + + attention_mask = None + if self.use_attention_mask: + attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2) + + lm_labels = None + if self.use_labels: + lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) + + config = PegasusXConfig( + vocab_size=self.vocab_size, + d_model=self.d_model, + decoder_layers=self.decoder_layers, + decoder_ffn_dim=self.decoder_ffn_dim, + encoder_attention_heads=self.encoder_attention_heads, + decoder_attention_heads=self.decoder_attention_heads, + eos_token_id=self.eos_token_id, + bos_token_id=self.bos_token_id, + use_cache=self.use_cache, + pad_token_id=self.pad_token_id, + decoder_start_token_id=self.decoder_start_token_id, + max_position_embeddings=self.max_position_embeddings, + is_encoder_decoder=self.is_encoder_decoder, + ) + + return ( + config, + input_ids, + attention_mask, + lm_labels, + ) + + def create_and_check_decoder_model_past( + self, + config, + input_ids, + attention_mask, + lm_labels, + ): + config.use_cache = True + model = PegasusXDecoder(config=config).to(torch_device).eval() + # first forward pass + outputs = model(input_ids, use_cache=True) + outputs_use_cache_conf = model(input_ids) + outputs_no_past = model(input_ids, use_cache=False) + + self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf)) + self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1) + + past_key_values = outputs["past_key_values"] + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + + output_from_no_past = model(next_input_ids)["last_hidden_state"] + output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() + + # test that outputs are equal for slice + assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3) + + def create_and_check_decoder_model_attention_mask_past( + self, + config, + input_ids, + attention_mask, + lm_labels, + ): + model = PegasusXDecoder(config=config).to(torch_device).eval() + + # create attention mask + attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device) + + half_seq_length = input_ids.shape[-1] // 2 + attn_mask[:, half_seq_length:] = 0 + + # first forward pass + past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True)["past_key_values"] + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # change a random masked slice from input_ids + random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1 + random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1) + input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens + + # append to next input_ids and attn_mask + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + attn_mask = torch.cat( + [attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)], + dim=1, + ) + + # get two different outputs + output_from_no_past = model(next_input_ids)["last_hidden_state"] + output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() + + # test that outputs are equal for slice + assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-2) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + attention_mask, + lm_labels, + ) = config_and_inputs + + inputs_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + return config, inputs_dict + + +@require_torch +class PegasusXStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + all_model_classes = (PegasusXDecoder, PegasusXForCausalLM) if is_torch_available() else () + all_generative_model_classes = (PegasusXForCausalLM,) if is_torch_available() else () + test_pruning = False + is_encoder_decoder = False + + def setUp( + self, + ): + self.model_tester = PegasusXStandaloneDecoderModelTester(self, is_training=False) + self.config_tester = ConfigTester(self, config_class=PegasusXConfig) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_decoder_model_past(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_past(*config_and_inputs) + + def test_decoder_model_attn_mask_past(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs) + + def test_retain_grad_hidden_states_attentions(self): + # decoder cannot keep gradients + return diff --git a/utils/check_repo.py b/utils/check_repo.py index 254467113d6cb..acea3eefb2d96 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -46,6 +46,9 @@ # Being in this list is an exception and should **not** be the rule. IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [ # models to ignore for not tested +"PegasusXEncoder", # Building part of bigger (tested) model. + "PegasusXDecoder", # Building part of bigger (tested) model. + "PegasusXDecoderWrapper", # Building part of bigger (tested) model. "OPTDecoder", # Building part of bigger (tested) model. "DecisionTransformerGPT2Model", # Building part of bigger (tested) model. "SegformerDecodeHead", # Building part of bigger (tested) model. @@ -125,6 +128,9 @@ # should **not** be the rule. IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [ # models to ignore for model xxx mapping +"PegasusXEncoder", + "PegasusXDecoder", + "PegasusXDecoderWrapper", "DPTForDepthEstimation", "DecisionTransformerGPT2Model", "GLPNForDepthEstimation", From 5e54b096f0f0a98b71fd53464346a2ea33d24a5b Mon Sep 17 00:00:00 2001 From: Jason Phang Date: Sat, 23 Jul 2022 14:47:49 -0700 Subject: [PATCH 02/25] rename --- docs/source/en/model_doc/pegasus_x.mdx | 4 +- src/transformers/__init__.py | 44 ++++ src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 3 + src/transformers/models/auto/modeling_auto.py | 8 + .../models/auto/modeling_flax_auto.py | 9 + .../pegasus_x/configuration_pegasus_x.py | 12 +- .../pegasus_x/modeling_flax_pegasus_x.py | 2 +- .../models/pegasus_x/modeling_pegasus_x.py | 10 +- .../pegasus_x/tokenization_pegasus_x.py | 227 ++++++++++++++++-- .../pegasus_x/tokenization_pegasus_x_fast.py | 97 ++++++-- .../pegasus_x/test_modeling_pegasus_x.py | 2 +- utils/check_repo.py | 6 + 13 files changed, 375 insertions(+), 50 deletions(-) diff --git a/docs/source/en/model_doc/pegasus_x.mdx b/docs/source/en/model_doc/pegasus_x.mdx index 853e3b1ae15c7..73969d3d13d45 100644 --- a/docs/source/en/model_doc/pegasus_x.mdx +++ b/docs/source/en/model_doc/pegasus_x.mdx @@ -10,11 +10,11 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> -# BrandNewBERT +# PEGASUSX ## Overview -The BrandNewBERT model was proposed in []() by . +The PEGASUSX model was proposed in []() by . The abstract from the paper is the following: diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 3c766301f1e59..c4c2dffac463b 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -120,6 +120,7 @@ "models": [], # Models "models.pegasus_x": ["PEGASUS_X_PRETRAINED_CONFIG_ARCHIVE_MAP", "PegasusXConfig", "PegasusXTokenizer"], + "models.pegasus_x": ["PEGASUS_X_PRETRAINED_CONFIG_ARCHIVE_MAP", "PegasusXConfig", "PegasusXTokenizer"], "models.albert": ["ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "AlbertConfig"], "models.auto": [ "ALL_PRETRAINED_CONFIG_ARCHIVE_MAP", @@ -519,6 +520,7 @@ else: # Fast tokenizers structure _import_structure["models.pegasus_x"].append("PegasusXTokenizerFast") + _import_structure["models.pegasus_x"].append("PegasusXTokenizerFast") _import_structure["models.albert"].append("AlbertTokenizerFast") _import_structure["models.bart"].append("BartTokenizerFast") _import_structure["models.barthez"].append("BarthezTokenizerFast") @@ -769,6 +771,18 @@ # PyTorch models structure + _import_structure["models.pegasus_x"].extend( + [ + "PEGASUS_X_PRETRAINED_MODEL_ARCHIVE_LIST", + "PegasusXForCausalLM", + "PegasusXForConditionalGeneration", + "PegasusXForQuestionAnswering", + "PegasusXForSequenceClassification", + "PegasusXModel", + "PegasusXPreTrainedModel", + ] + ) + _import_structure["models.pegasus_x"].extend( [ "PEGASUS_X_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -2709,6 +2723,16 @@ ] ) + _import_structure["models.pegasus_x"].extend( + [ + "FlaxPegasusXForConditionalGeneration", + "FlaxPegasusXForQuestionAnswering", + "FlaxPegasusXForSequenceClassification", + "FlaxPegasusXModel", + "FlaxPegasusXPreTrainedModel", + ] + ) + _import_structure["models.bart"].extend( [ "FlaxBartDecoderPreTrainedModel", @@ -2962,6 +2986,7 @@ ) from .models.albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig from .models.pegasus_x import PEGASUS_X_PRETRAINED_CONFIG_ARCHIVE_MAP, PegasusXConfig, PegasusXTokenizer + from .models.pegasus_x import PEGASUS_X_PRETRAINED_CONFIG_ARCHIVE_MAP, PegasusXConfig, PegasusXTokenizer from .models.auto import ( ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, @@ -3321,6 +3346,7 @@ else: # Fast tokenizers imports from .models.pegasus_x import PegasusXTokenizerFast + from .models.pegasus_x import PegasusXTokenizerFast from .models.albert import AlbertTokenizerFast from .models.bart import BartTokenizerFast from .models.barthez import BarthezTokenizerFast @@ -3525,6 +3551,16 @@ # PyTorch model imports + from .models.pegasus_x import ( + PEGASUS_X_PRETRAINED_MODEL_ARCHIVE_LIST, + PegasusXForConditionalGeneration, + PegasusXForCausalLM, + PegasusXForQuestionAnswering, + PegasusXForSequenceClassification, + PegasusXModel, + PegasusXPreTrainedModel, + ) + from .models.pegasus_x import ( PEGASUS_X_PRETRAINED_MODEL_ARCHIVE_LIST, PegasusXForConditionalGeneration, @@ -5065,6 +5101,14 @@ # Flax model imports + from .models.pegasus_x import ( + FlaxPegasusXForConditionalGeneration, + FlaxPegasusXForQuestionAnswering, + FlaxPegasusXForSequenceClassification, + FlaxPegasusXModel, + FlaxPegasusXPreTrainedModel, + ) + from .models.pegasus_x import ( FlaxPegasusXForConditionalGeneration, FlaxPegasusXForQuestionAnswering, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 05a62c38a44c5..0598b2ed828d5 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -17,6 +17,7 @@ # limitations under the License. from . import ( + pegasus_x, pegasus_x, albert, auto, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 7c79442a75cad..f7173ac7cb515 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -30,6 +30,7 @@ [ # Add configs here ("pegasus_x", "PegasusXConfig"), + ("pegasus_x", "PegasusXConfig"), ("albert", "AlbertConfig"), ("bart", "BartConfig"), ("beit", "BeitConfig"), @@ -159,6 +160,7 @@ [ # Add archive maps here) ("pegasus_x", "PEGASUS_X_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("pegasus_x", "PEGASUS_X_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("albert", "ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("bart", "BART_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("beit", "BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"), @@ -273,6 +275,7 @@ [ # Add full (and cased) model names here ("pegasus_x", "PegasusX"), + ("pegasus_x", "PegasusX"), ("albert", "ALBERT"), ("bart", "BART"), ("barthez", "BARThez"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index f30195f8b8694..87ccd40ebdbc8 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -29,6 +29,7 @@ [ # Base model mapping ("pegasus_x", "PegasusXModel"), + ("pegasus_x", "PegasusXModel"), ("albert", "AlbertModel"), ("bart", "BartModel"), ("beit", "BeitModel"), @@ -206,6 +207,8 @@ [ # Model with LM heads mapping + ("pegasus_x", "PegasusXForConditionalGeneration"), + ("pegasus_x", "PegasusXForConditionalGeneration"), ("albert", "AlbertForMaskedLM"), ("bart", "BartForConditionalGeneration"), @@ -272,6 +275,7 @@ [ # Model for Causal LM mapping ("pegasus_x", "PegasusXForCausalLM"), + ("pegasus_x", "PegasusXForCausalLM"), ("bart", "BartForCausalLM"), ("bert", "BertLMHeadModel"), ("bert-generation", "BertGenerationDecoder"), @@ -456,6 +460,8 @@ [ # Model for Seq2Seq Causal LM mapping + ("pegasus_x", "PegasusXForConditionalGeneration"), + ("pegasus_x", "PegasusXForConditionalGeneration"), ("bart", "BartForConditionalGeneration"), ("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"), @@ -490,6 +496,7 @@ [ # Model for Sequence Classification mapping ("pegasus_x", "PegasusXForSequenceClassification"), + ("pegasus_x", "PegasusXForSequenceClassification"), ("albert", "AlbertForSequenceClassification"), ("bart", "BartForSequenceClassification"), ("bert", "BertForSequenceClassification"), @@ -549,6 +556,7 @@ [ # Model for Question Answering mapping ("pegasus_x", "PegasusXForQuestionAnswering"), + ("pegasus_x", "PegasusXForQuestionAnswering"), ("albert", "AlbertForQuestionAnswering"), ("bart", "BartForQuestionAnswering"), ("bert", "BertForQuestionAnswering"), diff --git a/src/transformers/models/auto/modeling_flax_auto.py b/src/transformers/models/auto/modeling_flax_auto.py index 624147a329aab..9db6b59b90b98 100644 --- a/src/transformers/models/auto/modeling_flax_auto.py +++ b/src/transformers/models/auto/modeling_flax_auto.py @@ -29,6 +29,7 @@ [ # Base model mapping ("pegasus_x", "FlaxPegasusXModel"), + ("pegasus_x", "FlaxPegasusXModel"), ("albert", "FlaxAlbertModel"), ("bart", "FlaxBartModel"), ("beit", "FlaxBeitModel"), @@ -82,6 +83,8 @@ [ # Model for Masked LM mapping + ("pegasus_x", "FlaxPegasusXForConditionalGeneration"), + ("pegasus_x", "FlaxPegasusXForConditionalGeneration"), ("albert", "FlaxAlbertForMaskedLM"), ("bart", "FlaxBartForConditionalGeneration"), @@ -100,6 +103,8 @@ [ # Model for Seq2Seq Causal LM mapping + ("pegasus_x", "FlaxPegasusXForConditionalGeneration"), + ("pegasus_x", "FlaxPegasusXForConditionalGeneration"), ("bart", "FlaxBartForConditionalGeneration"), ("blenderbot", "FlaxBlenderbotForConditionalGeneration"), @@ -148,6 +153,8 @@ [ # Model for Sequence Classification mapping + ("pegasus_x", "FlaxPegasusXForSequenceClassification"), + ("pegasus_x", "FlaxPegasusXForSequenceClassification"), ("albert", "FlaxAlbertForSequenceClassification"), ("bart", "FlaxBartForSequenceClassification"), @@ -166,6 +173,8 @@ [ # Model for Question Answering mapping + ("pegasus_x", "FlaxPegasusXForQuestionAnswering"), + ("pegasus_x", "FlaxPegasusXForQuestionAnswering"), ("albert", "FlaxAlbertForQuestionAnswering"), ("bart", "FlaxBartForQuestionAnswering"), diff --git a/src/transformers/models/pegasus_x/configuration_pegasus_x.py b/src/transformers/models/pegasus_x/configuration_pegasus_x.py index b6554127389b7..6087789db5f0d 100644 --- a/src/transformers/models/pegasus_x/configuration_pegasus_x.py +++ b/src/transformers/models/pegasus_x/configuration_pegasus_x.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" BrandNewBERT model configuration """ +""" PEGASUSX model configuration """ from ...configuration_utils import PretrainedConfig from ...utils import logging @@ -22,16 +22,16 @@ PEGASUS_X_PRETRAINED_CONFIG_ARCHIVE_MAP = { "pegasus-x-base": "https://huggingface.co/pegasus-x-base/resolve/main/config.json", - # See all BrandNewBERT models at https://huggingface.co/models?filter=pegasus_x + # See all PEGASUSX models at https://huggingface.co/models?filter=pegasus_x } class PegasusXConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`~PegasusXModel`]. - It is used to instantiate an BrandNewBERT model according to the specified arguments, defining the model + It is used to instantiate an PEGASUSX model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of - the BrandNewBERT [pegasus-x-base](https://huggingface.co/pegasus-x-base) architecture. + the PEGASUSX [pegasus-x-base](https://huggingface.co/pegasus-x-base) architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] @@ -40,7 +40,7 @@ class PegasusXConfig(PretrainedConfig): Args: vocab_size (`int`, *optional*, defaults to 50265): - Vocabulary size of the BrandNewBERT model. Defines the number of different tokens that can be represented by the + Vocabulary size of the PEGASUSX model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`~PegasusXModel`] or [`~TFPegasusXModel`]. d_model (`int`, *optional*, defaults to 1024): @@ -86,7 +86,7 @@ class PegasusXConfig(PretrainedConfig): ```python >>> from transformers import PegasusXModel, PegasusXConfig - >>> # Initializing a BrandNewBERT pegasus-x-base style configuration + >>> # Initializing a PEGASUSX pegasus-x-base style configuration >>> configuration = PegasusXConfig() >>> # Initializing a model from the pegasus-x-base style configuration diff --git a/src/transformers/models/pegasus_x/modeling_flax_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_flax_pegasus_x.py index 048fd382a9c73..f632801675618 100644 --- a/src/transformers/models/pegasus_x/modeling_flax_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_flax_pegasus_x.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" Flax BrandNewBERT model. """ +""" Flax PEGASUSX model. """ import math diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index 8d3d8b98b56d7..a78f4e3200e4d 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch BrandNewBERT model. """ +""" PyTorch PEGASUSX model. """ import math @@ -55,7 +55,7 @@ PEGASUS_X_PRETRAINED_MODEL_ARCHIVE_LIST = [ "pegasus-x-base", - # See all BrandNewBERT models at https://huggingface.co/models?filter=pegasus_x + # See all PEGASUSX models at https://huggingface.co/models?filter=pegasus_x ] @@ -1069,7 +1069,7 @@ def custom_forward(*inputs): @add_start_docstrings( - "The bare BrandNewBERT Model outputting raw hidden-states without any specific head on top.", + "The bare PEGASUSX Model outputting raw hidden-states without any specific head on top.", PEGASUS_X_START_DOCSTRING, ) class PegasusXModel(PegasusXPreTrainedModel): @@ -1181,7 +1181,7 @@ def forward( @add_start_docstrings( - "The BrandNewBERT Model with a language modeling head. Can be used for summarization.", PEGASUS_X_START_DOCSTRING + "The PEGASUSX Model with a language modeling head. Can be used for summarization.", PEGASUS_X_START_DOCSTRING ) class PegasusXForConditionalGeneration(PegasusXPreTrainedModel): base_model_prefix = "model" @@ -1479,7 +1479,7 @@ def forward( @add_start_docstrings( """ - BrandNewBERT Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + PEGASUSX Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). """, PEGASUS_X_START_DOCSTRING, diff --git a/src/transformers/models/pegasus_x/tokenization_pegasus_x.py b/src/transformers/models/pegasus_x/tokenization_pegasus_x.py index 09a7cfdf7bf93..296cd6cf83cd9 100644 --- a/src/transformers/models/pegasus_x/tokenization_pegasus_x.py +++ b/src/transformers/models/pegasus_x/tokenization_pegasus_x.py @@ -12,9 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tokenization classes for BrandNewBERT.""" +"""Tokenization classes for PEGASUSX.""" +from typing import List, Optional + +from tokenizers import ByteLevelBPETokenizer + +from ...tokenization_utils import AddedToken, PreTrainedTokenizer +from ...tokenization_utils_fast import PreTrainedTokenizerFast from ...utils import logging -from ..bert.tokenization_bert import BertTokenizer logger = logging.get_logger(__name__) @@ -24,31 +29,223 @@ PRETRAINED_VOCAB_FILES_MAP = { "vocab_file": { "pegasus-x-base": "https://huggingface.co/pegasus-x-base/resolve/main/vocab.txt", - } + }, } PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { - "pegasus-x-base": 512, + "pegasus-x-base": 1024, } +class PegasusXTokenizer(PreTrainedTokenizer): + """ + Construct a PEGASUSX tokenizer. Based on byte-level Byte-Pair-Encoding. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + """ -PRETRAINED_INIT_CONFIGURATION = { - "pegasus-x-base": {"do_lower_case": False}, -} + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + unk_token="<|endoftext|>", + bos_token="<|endoftext|>", + eos_token="<|endoftext|>", + **kwargs + ): + bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token + eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token + unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token + super().__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs) + + """ Initialisation """ + + @property + def vocab_size(self): + """ Returns vocab size """ + + def get_vocab(self): + """ Returns vocab as a dict """ + + def _tokenize(self, text): + """ Returns a tokenized string. """ + + def _convert_token_to_id(self, token): + """ Converts a token (str) in an id using the vocab. """ + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + + def convert_tokens_to_string(self, tokens): + """ Converts a sequence of tokens (string) in a single string. """ + + def save_vocabulary(self, save_directory): + """ + Save the vocabulary and special tokens file to a directory. + + Args: + save_directory (`str`): + The directory in which to save the vocabulary. + + Returns: + `Tuple(str)`: Paths to the files saved. + """ + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks + by concatenating and adding special tokens. + A PEGASUSX sequence has the following format: + - single sequence: ` X ` + - pair of sequences: ` A B ` -class PegasusXTokenizer(BertTokenizer): - r""" - Construct a BrandNewBERT tokenizer. + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. - [`~PegasusXTokenizer`] is identical to [`BertTokenizer`] and runs end-to-end - tokenization: punctuation splitting and wordpiece. + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + sep + token_ids_1 + sep - Refer to superclass [`BertTokenizer`] for usage examples and documentation concerning - parameters. + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + if token_ids_1 is None: + return [1] + ([0] * len(token_ids_0)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1] + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. + PEGASUSX does not make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of zeros. + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs): + add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space) + if (is_split_into_words or add_prefix_space) and (len(text) > 0 and not text[0].isspace()): + text = " " + text + return (text, kwargs) + +class PegasusXTokenizerFast(PreTrainedTokenizerFast): + """ + Construct a "fast" PEGASUSX tokenizer (backed by HuggingFace's *tokenizers* library). + + Args: + vocab_file (`str`): + Path to the vocabulary file. """ vocab_files_names = VOCAB_FILES_NAMES pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES - pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + merges_file, + unk_token="<|endoftext|>", + bos_token="<|endoftext|>", + eos_token="<|endoftext|>", + add_prefix_space=False, + trim_offsets=True, + **kwargs + ): + super().__init__( + ByteLevelBPETokenizer( + vocab_file=vocab_file, + merges_file=merges_file, + add_prefix_space=add_prefix_space, + trim_offsets=trim_offsets, + ), + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + **kwargs, + ) + self.add_prefix_space = add_prefix_space + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id] + if token_ids_1 is None: + return output + + return output + [self.eos_token_id] + token_ids_1 + [self.eos_token_id] + + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. + PEGASUSX does not make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of zeros. + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + diff --git a/src/transformers/models/pegasus_x/tokenization_pegasus_x_fast.py b/src/transformers/models/pegasus_x/tokenization_pegasus_x_fast.py index 099149ed0f6c5..b84388b75d8df 100644 --- a/src/transformers/models/pegasus_x/tokenization_pegasus_x_fast.py +++ b/src/transformers/models/pegasus_x/tokenization_pegasus_x_fast.py @@ -12,45 +12,102 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tokenization classes for BrandNewBERT.""" +"""Tokenization classes for PEGASUSX.""" +from typing import List, Optional + +from tokenizers import ByteLevelBPETokenizer + +from ...tokenization_utils_fast import PreTrainedTokenizerFast from ...utils import logging -from ..bert.tokenization_bert_fast import BertTokenizerFast from .tokenization_pegasus_x import PegasusXTokenizer logger = logging.get_logger(__name__) -VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} +VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"} PRETRAINED_VOCAB_FILES_MAP = { "vocab_file": { "pegasus-x-base": "https://huggingface.co/pegasus-x-base/resolve/main/vocab.txt", - } + }, + "tokenizer_file": { + "pegasus-x-base": "https://huggingface.co/pegasus-x-base/resolve/main/tokenizer.json", + }, } PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { - "pegasus-x-base": 512, -} - - -PRETRAINED_INIT_CONFIGURATION = { - "pegasus-x-base": {"do_lower_case": False}, + "pegasus-x-base": 1024, } +class PegasusXTokenizerFast(PreTrainedTokenizerFast): + """ + Construct a "fast" PEGASUSX tokenizer (backed by HuggingFace's *tokenizers* library). -class PegasusXTokenizerFast(BertTokenizerFast): - r""" - Construct a "fast" BrandNewBERT tokenizer (backed by HuggingFace's *tokenizers* library). - - [`~PegasusXTokenizerFast`] is identical to [`BertTokenizerFast`] and runs - end-to-end tokenization: punctuation splitting and wordpiece. - - Refer to superclass [`BertTokenizerFast`] for usage examples and documentation concerning - parameters. + Args: + vocab_file (`str`): + Path to the vocabulary file. """ vocab_files_names = VOCAB_FILES_NAMES pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES - pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION slow_tokenizer_class = PegasusXTokenizer + + def __init__( + self, + vocab_file, + merges_file, + unk_token="<|endoftext|>", + bos_token="<|endoftext|>", + eos_token="<|endoftext|>", + add_prefix_space=False, + trim_offsets=True, + **kwargs + ): + super().__init__( + ByteLevelBPETokenizer( + vocab_file=vocab_file, + merges_file=merges_file, + add_prefix_space=add_prefix_space, + trim_offsets=trim_offsets, + ), + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + **kwargs, + ) + self.add_prefix_space = add_prefix_space + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id] + if token_ids_1 is None: + return output + + return output + [self.eos_token_id] + token_ids_1 + [self.eos_token_id] + + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. + PEGASUSX does not make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of zeros. + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + + diff --git a/tests/models/pegasus_x/test_modeling_pegasus_x.py b/tests/models/pegasus_x/test_modeling_pegasus_x.py index 1666db5449e44..6cf276679c4e7 100644 --- a/tests/models/pegasus_x/test_modeling_pegasus_x.py +++ b/tests/models/pegasus_x/test_modeling_pegasus_x.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" Testing suite for the PyTorch BrandNewBERT model. """ +""" Testing suite for the PyTorch PEGASUSX model. """ import copy diff --git a/utils/check_repo.py b/utils/check_repo.py index acea3eefb2d96..92d6cf18edd50 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -46,6 +46,9 @@ # Being in this list is an exception and should **not** be the rule. IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [ # models to ignore for not tested +"PegasusXEncoder", # Building part of bigger (tested) model. + "PegasusXDecoder", # Building part of bigger (tested) model. + "PegasusXDecoderWrapper", # Building part of bigger (tested) model. "PegasusXEncoder", # Building part of bigger (tested) model. "PegasusXDecoder", # Building part of bigger (tested) model. "PegasusXDecoderWrapper", # Building part of bigger (tested) model. @@ -128,6 +131,9 @@ # should **not** be the rule. IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [ # models to ignore for model xxx mapping +"PegasusXEncoder", + "PegasusXDecoder", + "PegasusXDecoderWrapper", "PegasusXEncoder", "PegasusXDecoder", "PegasusXDecoderWrapper", From 417688a945b65ead20309e584f8312e875ae86de Mon Sep 17 00:00:00 2001 From: Jason Phang Date: Sun, 24 Jul 2022 20:40:10 -0700 Subject: [PATCH 03/25] pegasus X implementation --- src/transformers/__init__.py | 52 +- .../pegasus_x/configuration_pegasus_x.py | 82 +- .../models/pegasus_x/modeling_pegasus_x.py | 1447 ++++++++--------- .../pegasus_x/tokenization_pegasus_x.py | 416 ++--- .../pegasus_x/tokenization_pegasus_x_fast.py | 205 ++- 5 files changed, 1115 insertions(+), 1087 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index c4c2dffac463b..30ab238bcdb61 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -120,7 +120,6 @@ "models": [], # Models "models.pegasus_x": ["PEGASUS_X_PRETRAINED_CONFIG_ARCHIVE_MAP", "PegasusXConfig", "PegasusXTokenizer"], - "models.pegasus_x": ["PEGASUS_X_PRETRAINED_CONFIG_ARCHIVE_MAP", "PegasusXConfig", "PegasusXTokenizer"], "models.albert": ["ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "AlbertConfig"], "models.auto": [ "ALL_PRETRAINED_CONFIG_ARCHIVE_MAP", @@ -497,6 +496,7 @@ _import_structure["models.mluke"].append("MLukeTokenizer") _import_structure["models.mt5"].append("MT5Tokenizer") _import_structure["models.pegasus"].append("PegasusTokenizer") + _import_structure["models.pegasus_x"].append("PegasusXTokenizer") _import_structure["models.plbart"].append("PLBartTokenizer") _import_structure["models.reformer"].append("ReformerTokenizer") _import_structure["models.rembert"].append("RemBertTokenizer") @@ -520,7 +520,6 @@ else: # Fast tokenizers structure _import_structure["models.pegasus_x"].append("PegasusXTokenizerFast") - _import_structure["models.pegasus_x"].append("PegasusXTokenizerFast") _import_structure["models.albert"].append("AlbertTokenizerFast") _import_structure["models.bart"].append("BartTokenizerFast") _import_structure["models.barthez"].append("BarthezTokenizerFast") @@ -774,22 +773,7 @@ _import_structure["models.pegasus_x"].extend( [ "PEGASUS_X_PRETRAINED_MODEL_ARCHIVE_LIST", - "PegasusXForCausalLM", "PegasusXForConditionalGeneration", - "PegasusXForQuestionAnswering", - "PegasusXForSequenceClassification", - "PegasusXModel", - "PegasusXPreTrainedModel", - ] - ) - - _import_structure["models.pegasus_x"].extend( - [ - "PEGASUS_X_PRETRAINED_MODEL_ARCHIVE_LIST", - "PegasusXForCausalLM", - "PegasusXForConditionalGeneration", - "PegasusXForQuestionAnswering", - "PegasusXForSequenceClassification", "PegasusXModel", "PegasusXPreTrainedModel", ] @@ -2716,18 +2700,6 @@ _import_structure["models.pegasus_x"].extend( [ "FlaxPegasusXForConditionalGeneration", - "FlaxPegasusXForQuestionAnswering", - "FlaxPegasusXForSequenceClassification", - "FlaxPegasusXModel", - "FlaxPegasusXPreTrainedModel", - ] - ) - - _import_structure["models.pegasus_x"].extend( - [ - "FlaxPegasusXForConditionalGeneration", - "FlaxPegasusXForQuestionAnswering", - "FlaxPegasusXForSequenceClassification", "FlaxPegasusXModel", "FlaxPegasusXPreTrainedModel", ] @@ -2986,7 +2958,6 @@ ) from .models.albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig from .models.pegasus_x import PEGASUS_X_PRETRAINED_CONFIG_ARCHIVE_MAP, PegasusXConfig, PegasusXTokenizer - from .models.pegasus_x import PEGASUS_X_PRETRAINED_CONFIG_ARCHIVE_MAP, PegasusXConfig, PegasusXTokenizer from .models.auto import ( ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, @@ -3346,7 +3317,6 @@ else: # Fast tokenizers imports from .models.pegasus_x import PegasusXTokenizerFast - from .models.pegasus_x import PegasusXTokenizerFast from .models.albert import AlbertTokenizerFast from .models.bart import BartTokenizerFast from .models.barthez import BarthezTokenizerFast @@ -3555,21 +3525,10 @@ PEGASUS_X_PRETRAINED_MODEL_ARCHIVE_LIST, PegasusXForConditionalGeneration, PegasusXForCausalLM, - PegasusXForQuestionAnswering, - PegasusXForSequenceClassification, PegasusXModel, PegasusXPreTrainedModel, ) - from .models.pegasus_x import ( - PEGASUS_X_PRETRAINED_MODEL_ARCHIVE_LIST, - PegasusXForConditionalGeneration, - PegasusXForCausalLM, - PegasusXForQuestionAnswering, - PegasusXForSequenceClassification, - PegasusXModel, - PegasusXPreTrainedModel, - ) from .models.albert import ( ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST, AlbertForMaskedLM, @@ -5103,19 +5062,10 @@ from .models.pegasus_x import ( FlaxPegasusXForConditionalGeneration, - FlaxPegasusXForQuestionAnswering, - FlaxPegasusXForSequenceClassification, FlaxPegasusXModel, FlaxPegasusXPreTrainedModel, ) - from .models.pegasus_x import ( - FlaxPegasusXForConditionalGeneration, - FlaxPegasusXForQuestionAnswering, - FlaxPegasusXForSequenceClassification, - FlaxPegasusXModel, - FlaxPegasusXPreTrainedModel, - ) from .models.albert import ( FlaxAlbertForMaskedLM, FlaxAlbertForMultipleChoice, diff --git a/src/transformers/models/pegasus_x/configuration_pegasus_x.py b/src/transformers/models/pegasus_x/configuration_pegasus_x.py index 6087789db5f0d..e40fa2c25f808 100644 --- a/src/transformers/models/pegasus_x/configuration_pegasus_x.py +++ b/src/transformers/models/pegasus_x/configuration_pegasus_x.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2022 Google and The HuggingFace Inc. team. All rights reserved. +# Copyright 2021, Google and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PEGASUSX model configuration """ +""" PEGASUS-X model configuration""" from ...configuration_utils import PretrainedConfig from ...utils import logging @@ -21,28 +21,26 @@ logger = logging.get_logger(__name__) PEGASUS_X_PRETRAINED_CONFIG_ARCHIVE_MAP = { - "pegasus-x-base": "https://huggingface.co/pegasus-x-base/resolve/main/config.json", - # See all PEGASUSX models at https://huggingface.co/models?filter=pegasus_x + "google/pegasus-x-base": "https://huggingface.co/google/pegasus-x-base/resolve/main/config.json", + # See all PEGASUS-X models at https://huggingface.co/models?filter=pegasus-x } class PegasusXConfig(PretrainedConfig): r""" - This is the configuration class to store the configuration of a [`~PegasusXModel`]. - It is used to instantiate an PEGASUSX model according to the specified arguments, defining the model - architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of - the PEGASUSX [pegasus-x-base](https://huggingface.co/pegasus-x-base) architecture. + This is the configuration class to store the configuration of a [`PegasusXModel`]. It is used to instantiate an + PEGASUS model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the PEGASUS-X + [google/pegasus-large](https://huggingface.co/google/pegasus-large) architecture. - Configuration objects inherit from [`PretrainedConfig`] and can be used - to control the model outputs. Read the documentation from [`PretrainedConfig`] - for more information. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. Args: vocab_size (`int`, *optional*, defaults to 50265): - Vocabulary size of the PEGASUSX model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`~PegasusXModel`] or - [`~TFPegasusXModel`]. + Vocabulary size of the PEGASUS-X model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`PegasusXModel`]. d_model (`int`, *optional*, defaults to 1024): Dimension of the layers and the pooler layer. encoder_layers (`int`, *optional*, defaults to 12): @@ -80,29 +78,34 @@ class PegasusXConfig(PretrainedConfig): The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) for more details. use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). - Example: + Whether or not the model should return the last key/values attentions (not used by all models) + forced_eos_token_id (`int`, *optional*, defaults to 1): + The id of the token to force as the last generated token when `max_length` is reached. Usually set to + `eos_token_id`. + num_global_tokens (`int`, *optional*, defaults to 32): + Number of global tokens to use for the encoder + block_size=256 (`int`, *optional*, defaults to 256): + Block size for encoder localattention + stagger_local_block (`bool`, *optional*, defaults to `True`): + Whether to stagger every other local attention by half a block + + Example: ```python >>> from transformers import PegasusXModel, PegasusXConfig - >>> # Initializing a PEGASUSX pegasus-x-base style configuration + >>> # Initializing a PEGASUS google/pegasus-large style configuration >>> configuration = PegasusXConfig() - >>> # Initializing a model from the pegasus-x-base style configuration + >>> # Initializing a model from the google/pegasus-large style configuration >>> model = PegasusXModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config - ``` -""" + ```""" model_type = "pegasus_x" keys_to_ignore_at_inference = ["past_key_values"] - - attribute_map = { - "num_attention_heads": "encoder_attention_heads", - "hidden_size": "d_model" - } + attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"} def __init__( self, @@ -124,12 +127,15 @@ def __init__( attention_dropout=0.0, activation_dropout=0.0, init_std=0.02, - decoder_start_token_id=2, + decoder_start_token_id=0, classifier_dropout=0.0, scale_embedding=False, - pad_token_id=1, - bos_token_id=0, - eos_token_id=2, + pad_token_id=0, + eos_token_id=1, + forced_eos_token_id=1, + num_global_tokens=32, + block_size=256, + stagger_local_blocks=True, **kwargs ): self.vocab_size = vocab_size @@ -152,14 +158,24 @@ def __init__( self.use_cache = use_cache self.num_hidden_layers = encoder_layers self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True - + + self.num_global_tokens = num_global_tokens + self.block_size = block_size + self.stagger_local_blocks = stagger_local_blocks + super().__init__( pad_token_id=pad_token_id, - bos_token_id=bos_token_id, eos_token_id=eos_token_id, is_encoder_decoder=is_encoder_decoder, decoder_start_token_id=decoder_start_token_id, - **kwargs + forced_eos_token_id=forced_eos_token_id, + **kwargs, ) - \ No newline at end of file + @property + def num_attention_heads(self) -> int: + return self.encoder_attention_heads + + @property + def hidden_size(self) -> int: + return self.d_model diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index a78f4e3200e4d..0521f6b3dc7e2 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2022 Google The HuggingFace Inc. team. All rights reserved. +# Copyright 2021, Google and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,53 +12,70 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch PEGASUSX model. """ +""" PyTorch PEGASUS model.""" - -import math import copy +import math import random -from typing import Optional, Tuple, List, Union +import dataclasses +from typing import List, Optional, Tuple, Union +import numpy as np import torch +import torch.utils.checkpoint from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...utils import ( - add_code_sample_docstrings, - add_end_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - replace_return_docstrings, -) from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput, - Seq2SeqQuestionAnsweringModelOutput, - Seq2SeqSequenceClassifierOutput, - CausalLMOutputWithCrossAttentions ) from ...modeling_utils import PreTrainedModel -from ...utils import logging +from ...utils import ( + add_end_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) from .configuration_pegasus_x import PegasusXConfig logger = logging.get_logger(__name__) -_CHECKPOINT_FOR_DOC = "pegasus-x-base" -_CONFIG_FOR_DOC = "PegasusXConfig" -_TOKENIZER_FOR_DOC = "PegasusXTokenizer" +_CHECKPOINT_FOR_DOC = "google/pegasus-large" +_CONFIG_FOR_DOC = "PegasusConfig" +_TOKENIZER_FOR_DOC = "PegasusTokenizer" PEGASUS_X_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "pegasus-x-base", - # See all PEGASUSX models at https://huggingface.co/models?filter=pegasus_x + "google/pegasus-x-large", + # See all PEGASUS models at https://huggingface.co/models?filter=pegasus-x ] +@dataclasses.dataclass +class DimensionInfo: + """Wrapper for dimension info.""" + B: int # batch size + T: int # token length + K: int # block size + H: int # num heads + D: int # hidden dim + F: int # dim per head + N: int # num blocks + G: int # global length + P: int # padded token seq length + + # Note: Compared to the original Flax implementation, we will pad the token representations to + # a multiple of block size at the start of the encoder layers, so T=P always. + + +# Copied from transformers.models.bart.modeling_bart.shift_tokens_right def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): """ Shift input ids one token to the right. @@ -67,13 +84,15 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() shifted_input_ids[:, 0] = decoder_start_token_id - assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") # replace possible -100 values in labels by `pad_token_id` shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) return shifted_input_ids +# Copied from transformers.models.bart.modeling_bart._make_causal_mask def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): """ Make causal mask used for bi-directional self-attention. @@ -89,9 +108,8 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) -def _expand_mask( - mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None -): +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ @@ -102,18 +120,36 @@ def _expand_mask( inverted_mask = 1.0 - expanded_mask - return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) -class PegasusXLearnedPositionalEmbedding(nn.Embedding): - """ - This module learns positional embeddings up to a fixed maximum size. - """ +# Copied from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding with Marian->PegasusX +class PegasusXSinusoidalPositionalEmbedding(nn.Embedding): + """This module produces sinusoidal positional embeddings of any length.""" - def __init__(self, num_embeddings: int, embedding_dim: int): - super().__init__(num_embeddings, embedding_dim) + def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None: + super().__init__(num_positions, embedding_dim) + self.weight = self._init_weight(self.weight) - def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0): + @staticmethod + def _init_weight(out: nn.Parameter) -> nn.Parameter: + """ + Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in + the 2nd half of the vector. [dim // 2:] + """ + n_pos, dim = out.shape + position_enc = np.array( + [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)] + ) + out.requires_grad = False # set early to avoid an error in pytorch-1.8+ + sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1 + out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) + out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) + out.detach_() + return out + + @torch.no_grad() + def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor: """`input_ids_shape` is expected to be [bsz x seqlen].""" bsz, seq_len = input_ids_shape[:2] positions = torch.arange( @@ -122,6 +158,7 @@ def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0): return super().forward(positions) +# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->PegasusX class PegasusXAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -138,10 +175,13 @@ def __init__( self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads - assert ( - self.head_dim * num_heads == self.embed_dim - ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {num_heads})." - self.scaling = self.head_dim ** -0.5 + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @@ -158,7 +198,6 @@ def forward( key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -166,7 +205,8 @@ def forward( # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, embed_dim = hidden_states.size() + + bsz, tgt_len, _ = hidden_states.size() # get query proj query_states = self.q_proj(hidden_states) * self.scaling @@ -210,7 +250,8 @@ def forward( if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" ) if attention_mask is not None: @@ -223,18 +264,10 @@ def forward( attn_weights = nn.functional.softmax(attn_weights, dim=-1) - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" - ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - if output_attentions: - # this operation is a bit akward, but it's required to + # this operation is a bit awkward, but it's required to # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to reshaped + # In order to do so, attn_weights have to be reshaped # twice and have to be reused in the following attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) @@ -247,84 +280,354 @@ def forward( if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" ) attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned aross GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) attn_output = self.out_proj(attn_output) return attn_output, attn_weights_reshaped, past_key_value +class PegasusXGlobalLocalAttention(nn.Module): + """Global + Local attention. For use with Encoder only.""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + block_size: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.block_size = block_size + self.dropout = dropout + self.head_dim = embed_dim // num_heads + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + token_hidden_states: torch.Tensor, + global_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """Input shape: Batch x Time x Channel""" + assert token_hidden_states.shape[1] % self.block_size == 0, "Sequence length should be multiple of block size" + dim = DimensionInfo( + B=token_hidden_states.shape[0], + T=token_hidden_states.shape[1], + K=self.block_size, + H=self.num_heads, + D=token_hidden_states.shape[2], + F=self.head_dim, + N=token_hidden_states.shape[1] // self.block_size, + G=global_hidden_states.shape[1], + P=token_hidden_states.shape[1], + ) + + # [B, H, P, F] + local_q = self._shape(self.q_proj(token_hidden_states) * self.scaling, seq_len=dim.P, bsz=dim.B) + local_k = self._shape(self.k_proj(token_hidden_states), seq_len=dim.P, bsz=dim.B) + local_v = self._shape(self.v_proj(token_hidden_states), seq_len=dim.P, bsz=dim.B) + + # [B, H, G, F] + global_q = self._shape(self.q_proj(global_hidden_states) * self.scaling, seq_len=dim.G, bsz=dim.B) + global_k = self._shape(self.k_proj(global_hidden_states), seq_len=dim.G, bsz=dim.B) + global_v = self._shape(self.v_proj(global_hidden_states), seq_len=dim.G, bsz=dim.B) + + global_attn_output, global_attn_probs = self.compute_global_attention_representations( + global_q=global_q, global_k=global_k, global_v=global_v, + local_k=local_k, local_v=local_v, + mask=attention_mask, + dim=dim, + ) + local_attn_output, local_attn_probs = self.compute_local_attention_representations( + global_k=global_k, global_v=global_v, + local_q=local_q, local_k=local_k, local_v=local_v, + mask=attention_mask, + dim=dim, + ) + + # [B, G, D] + global_attn_output = global_attn_output.transpose(1, 2).contiguous().view(dim.B, dim.G, dim.D) + # [B, G, D] + global_attn_output = self.out_proj(global_attn_output) + # [B, N, K, H, F] + local_attn_output = local_attn_output.permute(0, 2, 3, 1, 4).contiguous() + # [B, P, D] + local_attn_output = local_attn_output.view(dim.B, dim.P, dim.D) + # [B, P, D] + local_attn_output = self.out_proj(local_attn_output) + + if output_attentions: + attn_probs = {"global": global_attn_probs, "local": local_attn_probs} + else: + attn_probs = None + + return local_attn_output, global_attn_output, attn_probs + + def compute_global_attention_representations( + self, + global_q, + global_k, + global_v, + local_k, + local_v, + mask, + dim: DimensionInfo): + """Compute attention representations for global tokens. + + Global tokens will attend to both global tokens as well as all input + sequence tokens. Because the input sequence tokens are arranged in blocks + for local attention, we unblock them and compute attention. + + Args: + global_q: [B, H, G, F] query vectors from global tokens + global_k: [B, H, G, F] key vectors from global tokens + global_v: [B, H, G, F] value vectors from global tokens + local_k: [B, H, P, F] key vectors from local tokens + local_v: [B, H, P, F] value vectors from local tokens + mask: [B, P] attention mask + dim: DimensionInfo wrapper for dimensions + + Returns: + output of shape `[batch_sizes, length, features]`. + where length will be padded to a multiple of block_size + """ + # [B, H, G+P, F] + global_and_local_k = torch.cat([global_k, local_k], dim=2) + # [B, H, G+P, F] + global_and_local_v = torch.cat([global_v, local_v], dim=2) + + # [B, G+P] + extended_mask = nn.functional.pad(mask, pad=(dim.G, 0), value=0) + + # [B, H, G, G+P] + attn_weights = torch.einsum("BHGF,BHXF->BHGX", global_q, global_and_local_k) + attn_weights += attn_weights + extended_mask[:, None, None, :] + attn_probs = nn.functional.softmax(attn_weights, dim=-1) + attn_probs = nn.functional.dropout(attn_probs, p=self.dropout, training=self.training) + + # [B, H, G, F] + attn_output = torch.einsum("BHGX,BHXF->BHGF", attn_probs, global_and_local_v) + return attn_output, attn_probs + + def compute_local_attention_representations( + self, + global_k, + global_v, + local_q, + local_k, + local_v, + mask, + dim: DimensionInfo): + """Compute attention representations for local tokens. + + Local tokens will attend to both global tokens as well as all other tokens + within the same local block. Hence, we need to tile and + concatenate the global tokens to every local block + + Args: + global_k: [B, H, G, F] key vectors from global tokens + global_v: [B, H, G, F] value vectors from global tokens + local_q: [B, H, P, F] query vectors from local tokens + local_k: [B, H, P, F] key vectors from local tokens + local_v: [B, H, P, F] value vectors from local tokens + mask: [B, P] attention mask + dim: DimensionInfo wrapper for dimensions + + Returns: + output of shape `[batch_sizes, length, features]`. + where length will be padded to a multiple of block_size + """ + # [B, H, N, K, F] + blocked_local_q = local_q.view(dim.B, dim.H, dim.N, dim.K, dim.F) + # [B, H, N, K, F] + blocked_local_k = local_k.view(dim.B, dim.H, dim.N, dim.K, dim.F) + # [B, H, N, K, F] + blocked_local_v = local_v.view(dim.B, dim.H, dim.N, dim.K, dim.F) + + # [B, N, G+K] + extended_mask = nn.functional.pad(mask.view(dim.B, dim.N, dim.K), pad=(dim.G, 0), value=0) + + # [B, H, N, K, G] + blocked_local2global = torch.einsum("BHNKF,BHGF->BHNKG", blocked_local_q, global_k) + # [B, H, N, K, K] + blocked_local2local = torch.einsum("BHNKF,BHNXF->BHNKX", blocked_local_q, blocked_local_k) + + # [B, H, N, K, G+K] + attn_weights = torch.cat([blocked_local2global, blocked_local2local], dim=4) + attn_weights = attn_weights + extended_mask[:, None, :, None, :] + attn_probs = nn.functional.softmax(attn_weights, dim=-1) + attn_probs = nn.functional.dropout(attn_probs, p=self.dropout, training=self.training) + + # [B, H, N, K, G] + local2global_attn_probs = attn_probs[:, :, :, :, :dim.G] + # [B, H, N, K, K] + local2local_attn_probs = attn_probs[:, :, :, :, dim.G:] + + # [B, H, N, K, F] + local2global_attn_output = torch.einsum("BHNKG,BHGF->BHNKF", local2global_attn_probs, global_v) + # [B, H, N, K, F] + local2local_attn_output = torch.einsum("BHNKX,BHNXF->BHNKF", local2local_attn_probs, blocked_local_v) + # [B, H, N, K, F] + attn_output = local2global_attn_output + local2local_attn_output + return attn_output, attn_probs + + class PegasusXEncoderLayer(nn.Module): - def __init__(self, config: PegasusXConfig): + def __init__(self, stagger_blocks_this_layer: bool, config: PegasusXConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = PegasusXAttention( + self.self_attn = PegasusXGlobalLocalAttention( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, + block_size=config.block_size, dropout=config.attention_dropout, ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.global_self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) + self.stagger_blocks_this_layer = stagger_blocks_this_layer + self.block_size = config.block_size + if stagger_blocks_this_layer: + assert config.block_size % 2 == 0, "Block size must be an even number" def forward( self, hidden_states: torch.Tensor, + global_hidden_states: torch.Tensor, attention_mask: torch.Tensor, - layer_head_mask: torch.Tensor, output_attentions: bool = False, - ): + ) -> torch.Tensor: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape *(seq_len, batch, embed_dim)* + global_hidden_states (`torch.FloatTensor`): global token hidden states + *(seq_len, num_global_tokens, embed_dim)* attention_mask (`torch.FloatTensor`): attention mask of size *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values. - layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size - *(config.encoder_attention_heads,)*. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. """ residual = hidden_states - hidden_states, attn_weights, _ = self.self_attn( - hidden_states=hidden_states, + global_residual = global_hidden_states + + hidden_states = self.self_attn_layer_norm(hidden_states) + global_hidden_states = self.global_self_attn_layer_norm(global_hidden_states) + + if self.stagger_blocks_this_layer: + # Pad the blocks to simulate staggering + hidden_states, attention_mask = self.pad_local_tokens( + hidden_states=hidden_states, + attention_mask=attention_mask, + block_size=self.block_size + ) + + hidden_states, global_hidden_states, attn_weights = self.self_attn( + token_hidden_states=hidden_states, + global_hidden_states=global_hidden_states, attention_mask=attention_mask, - layer_head_mask=layer_head_mask, output_attentions=output_attentions, ) + + if self.stagger_blocks_this_layer: + # Undo the padding + hidden_states = self.unpad_local_tokens( + padded_hidden_states=hidden_states, + block_size=self.block_size + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) + + global_hidden_states = nn.functional.dropout(global_hidden_states, p=self.dropout, training=self.training) + global_hidden_states = global_residual + global_hidden_states residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = self.fc2(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states - hidden_states = self.final_layer_norm(hidden_states) - if hidden_states.dtype == torch.float16 and (torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()): - clamp_value = torch.finfo(hidden_states.dtype).max - 1000 - hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + global_residual = global_hidden_states + global_hidden_states = self.final_layer_norm(global_hidden_states) + global_hidden_states = self.activation_fn(self.fc1(global_hidden_states)) + global_hidden_states = nn.functional.dropout(global_hidden_states, p=self.activation_dropout, training=self.training) + global_hidden_states = self.fc2(global_hidden_states) + global_hidden_states = nn.functional.dropout(global_hidden_states, p=self.dropout, training=self.training) + global_hidden_states = global_residual + global_hidden_states - outputs = (hidden_states,) + if hidden_states.dtype == torch.float16: + if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + if torch.isinf(global_hidden_states).any() or torch.isnan(global_hidden_states).any(): + clamp_value = torch.finfo(global_hidden_states.dtype).max - 1000 + global_hidden_states = torch.clamp(global_hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states, global_hidden_states) if output_attentions: outputs += (attn_weights,) return outputs + + @classmethod + def pad_local_tokens(cls, hidden_states, attention_mask, block_size): + assert hidden_states.dim() == 3 + pad_size = block_size // 2 + padded_hidden_states = torch.nn.functional.pad( + hidden_states, pad=(0, 0, pad_size, pad_size), + ) + padded_mask = torch.nn.functional.pad( + attention_mask, pad=(pad_size, pad_size), + ) + return padded_hidden_states, padded_mask + + @classmethod + def unpad_local_tokens(cls, padded_hidden_states, block_size): + assert padded_hidden_states.dim() == 3 + pad_size = block_size // 2 + return padded_hidden_states[:, pad_size:-pad_size, :] +# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Pegasus class PegasusXDecoderLayer(nn.Module): def __init__(self, config: PegasusXConfig): super().__init__() @@ -358,30 +661,26 @@ def forward( attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - cross_layer_head_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, - ): + ) -> torch.Tensor: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape *(seq_len, batch, embed_dim)* attention_mask (`torch.FloatTensor`): attention mask of size *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values. - encoder_hidden_states (`torch.FloatTensor`): cross attention input to the layer of shape *(seq_len, batch, embed_dim)* + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape *(seq_len, batch, embed_dim)* encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values. - layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size - *(encoder_attention_heads,)*. - cross_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of - size *(decoder_attention_heads,)*. past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. """ residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 @@ -391,18 +690,17 @@ def forward( hidden_states=hidden_states, past_key_value=self_attn_past_key_value, attention_mask=attention_mask, - layer_head_mask=layer_head_mask, output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) # Cross-Attention Block cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None @@ -410,25 +708,23 @@ def forward( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, - layer_head_mask=cross_layer_head_mask, past_key_value=cross_attn_past_key_value, output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states - hidden_states = self.encoder_attn_layer_norm(hidden_states) # add cross-attn to positions 3,4 of present_key_value tuple present_key_value = present_key_value + cross_attn_present_key_value # Fully Connected residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = self.fc2(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states - hidden_states = self.final_layer_norm(hidden_states) outputs = (hidden_states,) @@ -441,30 +737,6 @@ def forward( return outputs -class PegasusXClassificationHead(nn.Module): - """Head for sentence-level classification tasks.""" - - def __init__( - self, - input_dim: int, - inner_dim: int, - num_classes: int, - pooler_dropout: float, - ): - super().__init__() - self.dense = nn.Linear(input_dim, inner_dim) - self.dropout = nn.Dropout(p=pooler_dropout) - self.out_proj = nn.Linear(inner_dim, num_classes) - - def forward(self, hidden_states: torch.Tensor): - hidden_states = self.dropout(hidden_states) - hidden_states = self.dense(hidden_states) - hidden_states = torch.tanh(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.out_proj(hidden_states) - return hidden_states - - class PegasusXPreTrainedModel(PreTrainedModel): config_class = PegasusXConfig base_model_prefix = "model" @@ -476,31 +748,32 @@ def _init_weights(self, module): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() + elif isinstance(module, PegasusXSinusoidalPositionalEmbedding): + pass elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - + def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (PegasusXDecoder, PegasusXEncoder)): module.gradient_checkpointing = value PEGASUS_X_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic - methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, - pruning heads etc.) + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) - subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to - general usage and behavior. + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. Parameters: - config ([`~PegasusXConfig`]): - Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model - weights. + config ([`PegasusXConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ PEGASUS_X_GENERATION_EXAMPLE = r""" @@ -509,15 +782,20 @@ def _set_gradient_checkpointing(self, module, value=False): ```python >>> from transformers import PegasusXTokenizer, PegasusXForConditionalGeneration - >>> model = PegasusXForConditionalGeneration.from_pretrained('pegasus-x-base') - >>> tokenizer = PegasusXTokenizer.from_pretrained('pegasus-x-base') + >>> model = PegasusXForConditionalGeneration.from_pretrained("google/pegasus-x-base") + >>> tokenizer = PegasusXTokenizer.from_pretrained("google/pegasus-x-base") - >>> ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs." - >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt') + >>> ARTICLE_TO_SUMMARIZE = ( + ... "PG&E stated it scheduled the blackouts in response to forecasts for high winds " + ... "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were " + ... "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow." + ... ) + >>> inputs = tokenizer(ARTICLE_TO_SUMMARIZE, max_length=1024, return_tensors="pt") >>> # Generate Summary - >>> summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=5) - >>> print(tokenizer.decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)) + >>> summary_ids = model.generate(inputs["input_ids"]) + >>> tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "California's largest electricity provider has turned off power to hundreds of thousands of customers." ``` """ @@ -527,9 +805,8 @@ def _set_gradient_checkpointing(self, module, value=False): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. - Indices can be obtained using [`~PegasusXTokenizer`]. See - [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for - details. + Indices can be obtained using [`PegasusXTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -540,90 +817,49 @@ def _set_gradient_checkpointing(self, module, value=False): [What are attention masks?](../glossary#attention-mask) decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): - Provide for translation and summarization training. By default, the model will create this tensor by - shifting the `input_ids` to the right, following the paper. - decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): - Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will - also be used by default. - - If you want to change padding behavior, you should read [`modeling_pegasus_x._prepare_decoder_attention_mask`] and - modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. + Indices of decoder input sequence tokens in the vocabulary. - decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`: + Indices can be obtained using [`PegasusXTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. + [What are decoder input IDs?](../glossary#decoder-input-ids) - cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. + PEGASUS-X uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If + `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): - Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: - `attentions`) `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, - *optional*) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the - cross-attention of the decoder. + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors - of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of - shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` - (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` - instead of all ``decoder_input_ids``` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated - vectors than the model's internal embedding lookup matrix. + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape + `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you + can choose to directly pass an embedded representation. This is useful if you want more control over how to + convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded - representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` - have to be input (see `past_key_values`). This is useful if you want more control over how to convert + representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be + input (see `past_key_values`). This is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. - If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` - takes the value of `inputs_embeds`. + If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value + of `inputs_embeds`. use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up - decoding (see `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -PEGASUS_X_STANDALONE_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`ProphetNetTokenizer`]. See - [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for - details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -661,22 +897,57 @@ def __init__(self, config: PegasusXConfig, embed_tokens: Optional[nn.Embedding] else: self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) - self.embed_positions = PegasusXLearnedPositionalEmbedding( + self.embed_global = nn.Embedding(config.num_global_tokens, embed_dim) + self.embed_positions = PegasusXSinusoidalPositionalEmbedding( config.max_position_embeddings, embed_dim, + self.padding_idx, ) - self.layers = nn.ModuleList([PegasusXEncoderLayer(config) for _ in range(config.encoder_layers)]) - self.layernorm_embedding = nn.LayerNorm(embed_dim) + self.layers = nn.ModuleList([ + PegasusXEncoderLayer( + stagger_blocks_this_layer=i%2 == 1 and config.stagger_local_blocks, + config=config) + for i in range(config.encoder_layers) + ]) + self.layer_norm = nn.LayerNorm(config.d_model) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() + def resize_position_embeddings(self, new_num_position_embeddings: int): + """ + Resizes position embeddings matrix of the model if `new_num_position_embeddings != + config.max_position_embeddings`. + + Arguments: + new_num_position_embeddings (`int`): + The number of new position embeddings. If position embeddings are learned, increasing the size will add + newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If + position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will + add correct vectors at the end following the position encoding algorithm, whereas reducing the size + will remove vectors from the end. + """ + logger.info(f"Setting `config.max_position_embeddings={new_num_position_embeddings}`...") + self.config.max_position_embeddings = new_num_position_embeddings + + self.embed_positions = PegasusXSinusoidalPositionalEmbedding( + self.config.max_position_embeddings, + self.config.d_model, + self.padding_idx, + ) + self.embed_positions.to(self.device) + + def get_position_embeddings(self) -> nn.Embedding: + """ + Returns the position embeddings matrix + """ + return self.embed_positions + def forward( self, input_ids=None, attention_mask=None, - head_mask=None, inputs_embeds=None, output_attentions=None, output_hidden_states=None, @@ -688,9 +959,8 @@ def forward( Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. - Indices can be obtained using [`~PegasusXTokenizer`]. See - [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] - for details. + Indices can be obtained using [`PegasusTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -700,16 +970,11 @@ def forward( - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) - head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded - representation. This is useful if you want more control over how to convert `input_ids` indices - into associated vectors than the model's internal embedding lookup matrix. + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -742,23 +1007,34 @@ def forward( embed_pos = self.embed_positions(input_shape) hidden_states = inputs_embeds + embed_pos - hidden_states = self.layernorm_embedding(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - # expand attention_mask - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) + batch_size, seq_len, _ = hidden_states.shape + + # Setup mask + if attention_mask is None: + attention_mask = torch.ones(*input_shape, dtype=inputs_embeds.dtype, device=inputs_embeds.device) + mask_min_value = torch.finfo(hidden_states.dtype).min + inverted_mask = 1.0 - attention_mask + attention_mask = inverted_mask.masked_fill( + inverted_mask.to(torch.bool), + mask_min_value, + ) + + # padding to block_size + if seq_len % self.config.block_size != 0: + pad_len = self.config.block_size - seq_len % self.config.block_size + hidden_states = nn.functional.pad(hidden_states, pad=(0, 0, 0, pad_len), value=0) + attention_mask = nn.functional.pad(attention_mask, pad=(0, pad_len), value=mask_min_value) + + # Global tokens + global_hidden_states = self.embed_global( + torch.arange(self.config.num_global_tokens, device=hidden_states.device)[None].expand(batch_size, -1) + ) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None - # check if head_mask has a correct number of layers specified if desired - if head_mask is not None: - assert head_mask.size()[0] == ( - len(self.layers) - ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." - for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) @@ -778,24 +1054,30 @@ def custom_forward(*inputs): layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(encoder_layer), hidden_states, + global_hidden_states, attention_mask, - (head_mask[idx] if head_mask is not None else None), ) else: layer_outputs = encoder_layer( hidden_states, + global_hidden_states, attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), output_attentions=output_attentions, ) hidden_states = layer_outputs[0] + global_hidden_states = layer_outputs[1] if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) + all_attentions = all_attentions + (layer_outputs[2],) + + # Undo padding-to-block-size + hidden_states = hidden_states[:, :seq_len] + + hidden_states = self.layer_norm(hidden_states) if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) + encoder_states = encoder_states + ((hidden_states, global_hidden_states),) if not return_dict: return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) @@ -806,7 +1088,7 @@ def custom_forward(*inputs): class PegasusXDecoder(PegasusXPreTrainedModel): """ - Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`PegasusXDecoderLayer`] + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`PegasusDecoderLayer`] Args: config: PegasusXConfig @@ -826,12 +1108,13 @@ def __init__(self, config: PegasusXConfig, embed_tokens: Optional[nn.Embedding] else: self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) - self.embed_positions = PegasusXLearnedPositionalEmbedding( + self.embed_positions = PegasusXSinusoidalPositionalEmbedding( config.max_position_embeddings, config.d_model, + self.padding_idx, ) self.layers = nn.ModuleList([PegasusXDecoderLayer(config) for _ in range(config.decoder_layers)]) - self.layernorm_embedding = nn.LayerNorm(config.d_model) + self.layer_norm = nn.LayerNorm(config.d_model) self.gradient_checkpointing = False # Initialize weights and apply final processing @@ -843,6 +1126,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.embed_tokens = value + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] @@ -850,7 +1134,7 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(self.device) + ).to(inputs_embeds.device) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] @@ -861,14 +1145,41 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em return combined_attention_mask + def resize_position_embeddings(self, new_num_position_embeddings: int): + """ + Resizes position embeddings matrix of the model if `new_num_position_embeddings != + config.max_position_embeddings`. + + Arguments: + new_num_position_embeddings (`int`): + The number of new position embeddings. If position embeddings are learned, increasing the size will add + newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If + position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will + add correct vectors at the end following the position encoding algorithm, whereas reducing the size + will remove vectors from the end. + """ + logger.info(f"Setting `config.max_position_embeddings={new_num_position_embeddings}`...") + self.config.max_position_embeddings = new_num_position_embeddings + + self.embed_positions = PegasusXSinusoidalPositionalEmbedding( + self.config.max_position_embeddings, + self.config.d_model, + self.padding_idx, + ) + self.embed_positions.to(self.device) + + def get_position_embeddings(self) -> nn.Embedding: + """ + Returns the position embeddings matrix + """ + return self.embed_positions + def forward( self, input_ids=None, attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, - head_mask=None, - cross_attn_head_mask=None, past_key_values=None, inputs_embeds=None, use_cache=None, @@ -882,9 +1193,8 @@ def forward( Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. - Indices can be obtained using [`~PegasusXTokenizer`]. See - [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] - for details. + Indices can be obtained using [`PegasusTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -905,32 +1215,22 @@ def forward( - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) - head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 - tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional - tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the - cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential - decoding. - - If `past_key_values` are used, the user can optionally input only the last - `decoder_input_ids` (those that don't have their past key value states given to this model) of - shape `(batch_size, 1)` instead of all ``decoder_input_ids``` of shape `(batch_size, - sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices - into associated vectors than the model's internal embedding lookup matrix. + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of + shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing + `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more + control over how to convert `input_ids` indices into associated vectors than the model's internal + embedding lookup matrix. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -964,7 +1264,9 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - attention_mask = self._prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds, past_key_values_length) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: @@ -975,7 +1277,6 @@ def forward( positions = self.embed_positions(input_shape, past_key_values_length) hidden_states = inputs_embeds + positions - hidden_states = self.layernorm_embedding(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -985,12 +1286,6 @@ def forward( all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None next_decoder_cache = () if use_cache else None - # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired - for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): - if attn_mask is not None: - assert attn_mask.size()[0] == ( - len(self.layers) - ), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) if output_hidden_states: @@ -1004,7 +1299,9 @@ def forward( if self.gradient_checkpointing and self.training: if use_cache: - logger.warning("`use_cache = True` is incompatible with gradient checkpointing`. Setting `use_cache = False`...") + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) use_cache = False def create_custom_forward(module): @@ -1020,8 +1317,6 @@ def custom_forward(*inputs): attention_mask, encoder_hidden_states, encoder_attention_mask, - head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, ) else: @@ -1031,8 +1326,6 @@ def custom_forward(*inputs): attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, @@ -1048,6 +1341,8 @@ def custom_forward(*inputs): if encoder_hidden_states is not None: all_cross_attentions += (layer_outputs[2],) + hidden_states = self.layer_norm(hidden_states) + # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) @@ -1069,7 +1364,7 @@ def custom_forward(*inputs): @add_start_docstrings( - "The bare PEGASUSX Model outputting raw hidden-states without any specific head on top.", + "The bare PEGASUS-X Model outputting raw hidden-states without any specific head on top.", PEGASUS_X_START_DOCSTRING, ) class PegasusXModel(PegasusXPreTrainedModel): @@ -1099,31 +1394,66 @@ def get_encoder(self): def get_decoder(self): return self.decoder + def resize_position_embeddings(self, new_num_position_embeddings: int): + """ + Resizes position embeddings matrix of the model if `new_num_position_embeddings != + config.max_position_embeddings`. + + Arguments: + new_num_position_embeddings (`int`): + The number of new position embeddings. If position embeddings are learned, increasing the size will add + newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If + position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will + add correct vectors at the end following the position encoding algorithm, whereas reducing the size + will remove vectors from the end. + """ + self.config.max_position_embeddings = new_num_position_embeddings + self.encoder.resize_position_embeddings(new_num_position_embeddings) + self.decoder.resize_position_embeddings(new_num_position_embeddings) + + def get_position_embeddings(self) -> Tuple[nn.Embedding]: + """ + Returns the position embeddings matrix + """ + return (self.encoder.get_position_embeddings(), self.decoder.get_position_embeddings()) + @add_start_docstrings_to_model_forward(PEGASUS_X_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - processor_class=_TOKENIZER_FOR_DOC, - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=Seq2SeqModelOutput, - config_class=_CONFIG_FOR_DOC, - ) + @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) def forward( self, - input_ids=None, - attention_mask=None, - decoder_input_ids=None, - decoder_attention_mask=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - encoder_outputs=None, - past_key_values=None, - inputs_embeds=None, - decoder_inputs_embeds=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.Tensor] = None, + decoder_attention_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + decoder_inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Seq2SeqModelOutput]: + r""" + Returns: + + Example: + + ```python + >>> from transformers import PegasusTokenizer, PegasusModel + + >>> tokenizer = PegasusTokenizer.from_pretrained("google/pegasus-large") + >>> model = PegasusModel.from_pretrained("google/pegasus-large") + + >>> inputs = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt") + >>> decoder_inputs = tokenizer("Studies show that", return_tensors="pt") + >>> outputs = model(input_ids=inputs.input_ids, decoder_input_ids=decoder_inputs.input_ids) + + >>> last_hidden_states = outputs.last_hidden_state + >>> list(last_hidden_states.shape) + [1, 4, 1024] + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -1135,7 +1465,6 @@ def forward( encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, - head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -1155,8 +1484,6 @@ def forward( attention_mask=decoder_attention_mask, encoder_hidden_states=encoder_outputs[0], encoder_attention_mask=attention_mask, - head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, @@ -1181,15 +1508,16 @@ def forward( @add_start_docstrings( - "The PEGASUSX Model with a language modeling head. Can be used for summarization.", PEGASUS_X_START_DOCSTRING + "The PEGASUS-X for conditional generation (e.g. summarization).", PEGASUS_X_START_DOCSTRING ) class PegasusXForConditionalGeneration(PegasusXPreTrainedModel): base_model_prefix = "model" _keys_to_ignore_on_load_missing = [ r"final_logits_bias", - r"encoder\.version", - r"decoder\.version", - r"lm_head\.weight", + r"encoder.version", + r"decoder.version", + r"lm_head.weight", + r"embed_positions.weight", ] def __init__(self, config: PegasusXConfig): @@ -1227,53 +1555,57 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings + def resize_position_embeddings(self, new_num_position_embeddings: int): + """ + Resizes position embeddings matrix of the model if `new_num_position_embeddings != + config.max_position_embeddings`. + + Arguments: + new_num_position_embeddings (`int`): + The number of new position embeddings. If position embeddings are learned, increasing the size will add + newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If + position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will + add correct vectors at the end following the position encoding algorithm, whereas reducing the size + will remove vectors from the end. + """ + self.config.max_position_embeddings = new_num_position_embeddings + self.model.encoder.resize_position_embeddings(new_num_position_embeddings) + self.model.decoder.resize_position_embeddings(new_num_position_embeddings) + + def get_position_embeddings(self) -> Tuple[nn.Embedding]: + """ + Returns the position embeddings matrix + """ + return (self.model.encoder.get_position_embeddings(), self.model.decoder.get_position_embeddings()) + @add_start_docstrings_to_model_forward(PEGASUS_X_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) @add_end_docstrings(PEGASUS_X_GENERATION_EXAMPLE) def forward( self, - input_ids=None, - attention_mask=None, - decoder_input_ids=None, - decoder_attention_mask=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - encoder_outputs=None, - past_key_values=None, - inputs_embeds=None, - decoder_inputs_embeds=None, - labels=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.Tensor] = None, + decoder_attention_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + decoder_inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Seq2SeqLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Returns: - Conditional generation example: - - ```python - >>> from transformers import PegasusXTokenizer, PegasusXForConditionalGeneration - >>> tokenizer = PegasusXTokenizer.from_pretrained('pegasus-x-base') - >>> TXT = "My friends are but they eat too many carbs." - - >>> model = PegasusXForConditionalGeneration.from_pretrained('pegasus-x-base') - >>> input_ids = tokenizer([TXT], return_tensors='pt')['input_ids'] - >>> logits = model(input_ids).logits - - >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() - >>> probs = logits[0, masked_index].softmax(dim=0) - >>> values, predictions = probs.topk(5) - - >>> tokenizer.decode(predictions).split() - ``` -""" + """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: @@ -1281,7 +1613,9 @@ def forward( logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") use_cache = False if decoder_input_ids is None: - decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) outputs = self.model( input_ids, @@ -1289,9 +1623,6 @@ def forward( decoder_input_ids=decoder_input_ids, encoder_outputs=encoder_outputs, decoder_attention_mask=decoder_attention_mask, - head_mask=head_mask, - decoder_head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, @@ -1328,9 +1659,6 @@ def prepare_inputs_for_generation( decoder_input_ids, past=None, attention_mask=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, use_cache=None, encoder_outputs=None, **kwargs @@ -1345,252 +1673,24 @@ def prepare_inputs_for_generation( "past_key_values": past, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) + @staticmethod def _reorder_cache(past, beam_idx): reordered_past = () for layer_past in past: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) - return reordered_past - - -@add_start_docstrings( - """ - PegasusX model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE - tasks. - """, - PEGASUS_X_START_DOCSTRING, -) -class PegasusXForSequenceClassification(PegasusXPreTrainedModel): - def __init__(self, config: PegasusXConfig, **kwargs): - super().__init__(config, **kwargs) - self.model = PegasusXModel(config) - self.classification_head = PegasusXClassificationHead( - config.d_model, - config.d_model, - config.num_labels, - config.classifier_dropout, - ) - self.model._init_weights(self.classification_head.dense) - self.model._init_weights(self.classification_head.out_proj) - - @add_start_docstrings_to_model_forward(PEGASUS_X_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - processor_class=_TOKENIZER_FOR_DOC, - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=Seq2SeqSequenceClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) - def forward( - self, - input_ids=None, - attention_mask=None, - decoder_input_ids=None, - decoder_attention_mask=None, - encoder_outputs=None, - inputs_embeds=None, - decoder_inputs_embeds=None, - labels=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if labels is not None: - use_cache = False - - if input_ids is None and inputs_embeds is not None: - raise NotImplementedError( - f"Passing input embeddings is currently not supported for {self.__class__.__name__}" + # cached cross_attention states don't have to be reordered -> they are always the same + reordered_past += ( + tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], ) + return reordered_past - outputs = self.model( - input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - encoder_outputs=encoder_outputs, - inputs_embeds=inputs_embeds, - decoder_inputs_embeds=decoder_inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = outputs[0] # last hidden state - - eos_mask = input_ids.eq(self.config.eos_token_id) - - if len(torch.unique_consecutive(eos_mask.sum(1))) > 1: - raise ValueError("All examples must have the same number of tokens.") - sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[ - :, -1, : - ] - logits = self.classification_head(sentence_representation) - - loss = None - if labels is not None: - if self.config.problem_type is None: - if self.config.num_labels == 1: - self.config.problem_type = "regression" - elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.config.num_labels == 1: - loss = loss_fct(logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(logits, labels) - if not return_dict: - output = (logits,) + outputs[1:] - return ((loss,) + output) if loss is not None else output - - return Seq2SeqSequenceClassifierOutput( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - decoder_hidden_states=outputs.decoder_hidden_states, - decoder_attentions=outputs.decoder_attentions, - cross_attentions=outputs.cross_attentions, - encoder_last_hidden_state=outputs.encoder_last_hidden_state, - encoder_hidden_states=outputs.encoder_hidden_states, - encoder_attentions=outputs.encoder_attentions, - ) - - -@add_start_docstrings( - """ - PEGASUSX Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear - layer on top of the hidden-states output to compute `span start logits` and `span end logits`). - """, - PEGASUS_X_START_DOCSTRING, -) -class PegasusXForQuestionAnswering(PegasusXPreTrainedModel): - def __init__(self, config): - super().__init__(config) - - config.num_labels = 2 - self.num_labels = config.num_labels - - self.model = PegasusXModel(config) - self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) - - self.model._init_weights(self.qa_outputs) - - @add_start_docstrings_to_model_forward(PEGASUS_X_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - processor_class=_TOKENIZER_FOR_DOC, - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=Seq2SeqQuestionAnsweringModelOutput, - config_class=_CONFIG_FOR_DOC, - ) - def forward( - self, - input_ids=None, - attention_mask=None, - decoder_input_ids=None, - decoder_attention_mask=None, - encoder_outputs=None, - start_positions=None, - end_positions=None, - inputs_embeds=None, - decoder_inputs_embeds=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - r""" - start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the start of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence - are not taken into account for computing the loss. - end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the end of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence - are not taken into account for computing the loss. - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if start_positions is not None and end_positions is not None: - use_cache = False - - outputs = self.model( - input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - encoder_outputs=encoder_outputs, - inputs_embeds=inputs_embeds, - decoder_inputs_embeds=decoder_inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output = outputs[0] - - logits = self.qa_outputs(sequence_output) - start_logits, end_logits = logits.split(1, dim=-1) - start_logits = start_logits.squeeze(-1) - end_logits = end_logits.squeeze(-1) - - total_loss = None - if start_positions is not None and end_positions is not None: - # If we are on multi-GPU, split add a dimension - if len(start_positions.size()) > 1: - start_positions = start_positions.squeeze(-1) - if len(end_positions.size()) > 1: - end_positions = end_positions.squeeze(-1) - # sometimes the start/end positions are outside our model inputs, we ignore these terms - ignored_index = start_logits.size(1) - start_positions = start_positions.clamp(0, ignored_index) - end_positions = end_positions.clamp(0, ignored_index) - - loss_fct = CrossEntropyLoss(ignore_index=ignored_index) - start_loss = loss_fct(start_logits, start_positions) - end_loss = loss_fct(end_logits, end_positions) - total_loss = (start_loss + end_loss) / 2 - - if not return_dict: - output = ( - start_logits, - end_logits, - ) + outputs[1:] - return ((total_loss,) + output) if total_loss is not None else output - - return Seq2SeqQuestionAnsweringModelOutput( - loss=total_loss, - start_logits=start_logits, - end_logits=end_logits, - past_key_values=outputs.past_key_values, - decoder_hidden_states=outputs.decoder_hidden_states, - decoder_attentions=outputs.decoder_attentions, - cross_attentions=outputs.cross_attentions, - encoder_last_hidden_state=outputs.encoder_last_hidden_state, - encoder_hidden_states=outputs.encoder_hidden_states, - encoder_attentions=outputs.encoder_attentions, - ) +# Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->Pegasus class PegasusXDecoderWrapper(PegasusXPreTrainedModel): """ This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is @@ -1603,194 +1703,3 @@ def __init__(self, config): def forward(self, *args, **kwargs): return self.decoder(*args, **kwargs) - - -class PegasusXForCausalLM(PegasusXPreTrainedModel): - def __init__(self, config): - config = copy.deepcopy(config) - config.is_decoder = True - config.is_encoder_decoder = False - super().__init__(config) - self.model = PegasusXDecoderWrapper(config) - - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.decoder.embed_tokens - - def set_input_embeddings(self, value): - self.model.decoder.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model.decoder = decoder - - def get_decoder(self): - return self.model.decoder - - @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids=None, - attention_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - head_mask=None, - cross_attn_head_mask=None, - past_key_values=None, - inputs_embeds=None, - labels=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you - provide it. - - Indices can be obtained using [`~PegasusXTokenizer`]. See - [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] - for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention - if the model is configured as a decoder. - encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used - in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: - head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up - decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` - (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` - instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`. - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are - ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up - decoding (see `past_key_values`). - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - - Returns: - - Example: - - ```python - >>> from transformers import PegasusXTokenizer, PegasusXForCausalLM - - >>> tokenizer = PegasusXTokenizer.from_pretrained('facebook/bart-large') - >>> model = PegasusXForCausalLM.from_pretrained('facebook/bart-large', add_cross_attention=False) - >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder." - >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") - >>> outputs = model(**inputs) - - >>> logits = outputs.logits - ``` -""" - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model.decoder( - input_ids=input_ids, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - head_mask=head_mask, - cross_attn_head_mask=cross_attn_head_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - logits = self.lm_head(outputs[0]) - - loss = None - if labels is not None: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithCrossAttentions( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - - def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs): - # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly - if attention_mask is None: - attention_mask = input_ids.new_ones(input_ids.shape) - - if past: - input_ids = input_ids[:, -1:] - # first step, decoder_cached_states are empty - return { - "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed - "attention_mask": attention_mask, - "past_key_values": past, - "use_cache": use_cache, - } - - @staticmethod - def _reorder_cache(past, beam_idx): - reordered_past = () - for layer_past in past: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) - return reordered_past diff --git a/src/transformers/models/pegasus_x/tokenization_pegasus_x.py b/src/transformers/models/pegasus_x/tokenization_pegasus_x.py index 296cd6cf83cd9..d735e4035331d 100644 --- a/src/transformers/models/pegasus_x/tokenization_pegasus_x.py +++ b/src/transformers/models/pegasus_x/tokenization_pegasus_x.py @@ -12,38 +12,91 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tokenization classes for PEGASUSX.""" -from typing import List, Optional +"""Tokenization classes for PEGASUS-X.""" +import os +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple -from tokenizers import ByteLevelBPETokenizer +import sentencepiece as spm -from ...tokenization_utils import AddedToken, PreTrainedTokenizer -from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...tokenization_utils import PreTrainedTokenizer from ...utils import logging -logger = logging.get_logger(__name__) +SPIECE_UNDERLINE = "▁" -VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} +VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"} PRETRAINED_VOCAB_FILES_MAP = { - "vocab_file": { - "pegasus-x-base": "https://huggingface.co/pegasus-x-base/resolve/main/vocab.txt", - }, + "vocab_file": {"google/pegasus-x-base": "https://huggingface.co/google/pegasus-x-base/resolve/main/spiece.model"} } PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { - "pegasus-x-base": 1024, + "google/pegasus-x-base": 512, } + +logger = logging.get_logger(__name__) + + class PegasusXTokenizer(PreTrainedTokenizer): - """ - Construct a PEGASUSX tokenizer. Based on byte-level Byte-Pair-Encoding. + r""" + Construct a PEGASUS-X tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece). + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. Args: vocab_file (`str`): - Path to the vocabulary file. + [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that + contains the vocabulary necessary to instantiate a tokenizer. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + mask_token (`str`, *optional*, defaults to `""`): + The token used for masking single token values. This is the token used when training this model with masked + language modeling (MLM). This is the token that the PEGASUS-X encoder will try to predict during pretraining. + It corresponds to *[MASK2]* in [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive + Summarization](https://arxiv.org/pdf/1912.08777.pdf). + mask_token_sent (`str`, *optional*, defaults to `""`): + The token used for masking whole target sentences. This is the token used when training this model with gap + sentences generation (GSG). This is the sentence that the PEGASUS-X decoder will try to predict during + pretraining. It corresponds to *[MASK1]* in [PEGASUS: Pre-training with Extracted Gap-sentences for + Abstractive Summarization](https://arxiv.org/pdf/1912.08777.pdf). + additional_special_tokens (`List[str]`, *optional*): + Additional special tokens used by the tokenizer. If no additional_special_tokens are provided and + are used as additional special tokens corresponding to the [original PEGASUS + tokenizer](https://github.com/google-research/pegasus/blob/939830367bcf411193d2b5eca2f2f90f3f9260ca/pegasus/ops/pretrain_parsing_ops.cc#L66) + that uses the tokens 2 - 104 only for pretraining + sp_model_kwargs (`dict`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. """ + vocab_files_names = VOCAB_FILES_NAMES vocab_files_names = VOCAB_FILES_NAMES pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP @@ -51,201 +104,196 @@ class PegasusXTokenizer(PreTrainedTokenizer): model_input_names = ["input_ids", "attention_mask"] def __init__( - self, - vocab_file, - unk_token="<|endoftext|>", - bos_token="<|endoftext|>", - eos_token="<|endoftext|>", - **kwargs - ): - bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token - eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token - unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token - super().__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs) - - """ Initialisation """ - - @property - def vocab_size(self): - """ Returns vocab size """ + self, + vocab_file, + pad_token="", + eos_token="", + unk_token="", + mask_token="", + mask_token_sent="", + additional_special_tokens=None, + offset=103, # entries 2 - 104 are only used for pretraining + sp_model_kwargs: Optional[Dict[str, Any]] = None, + **kwargs + ) -> None: + self.offset = offset + if additional_special_tokens is not None: + if not isinstance(additional_special_tokens, list): + raise TypeError( + f"additional_special_tokens should be of type {type(list)}, but is" + f" {type(additional_special_tokens)}" + ) + + additional_special_tokens_extended = ( + ([mask_token_sent] + additional_special_tokens) + if mask_token_sent not in additional_special_tokens and mask_token_sent is not None + else additional_special_tokens + ) + # fill additional tokens with ..., in case not all additional tokens are already taken + additional_special_tokens_extended += [ + f"" for i in range(len(additional_special_tokens_extended), self.offset - 1) + ] + + if len(set(additional_special_tokens_extended)) != len(additional_special_tokens_extended): + raise ValueError( + "Please make sure that the provided additional_special_tokens do not contain an incorrectly" + f" shifted list of tokens. Found {additional_special_tokens_extended}." + ) + additional_special_tokens = additional_special_tokens_extended + else: + additional_special_tokens = [mask_token_sent] if mask_token_sent is not None else [] + additional_special_tokens += [f"" for i in range(2, self.offset)] + + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs - def get_vocab(self): - """ Returns vocab as a dict """ + super().__init__( + eos_token=eos_token, + unk_token=unk_token, + mask_token=mask_token, + pad_token=pad_token, + mask_token_sent=mask_token_sent, + offset=offset, + additional_special_tokens=additional_special_tokens, + sp_model_kwargs=self.sp_model_kwargs, + **kwargs, + ) + self.mask_token_sent = mask_token_sent + self.vocab_file = vocab_file + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(vocab_file) + + # add special tokens to encoder dict + self.encoder: Dict[int, str] = { + 0: self.pad_token, + 1: self.eos_token, + } + + if self.mask_token_sent is not None: + self.encoder.update( + { + 2: self.mask_token_sent, + 3: self.mask_token, + } + ) - def _tokenize(self, text): - """ Returns a tokenized string. """ + if self.offset > 0: + # entries 2-104 are only used for pretraining and called , , unk_2, ...unk_102 + # mask_token_sent is already added to list -> so start at 1 + self.encoder.update({i + 3: additional_special_tokens[i] for i in range(1, self.offset - 1)}) - def _convert_token_to_id(self, token): - """ Converts a token (str) in an id using the vocab. """ + self.decoder: Dict[str, int] = {v: k for k, v in self.encoder.items()} - def _convert_id_to_token(self, index): - """Converts an index (integer) in a token (str) using the vocab.""" + @property + def vocab_size(self) -> int: + return len(self.sp_model) + self.offset + + def get_vocab(self) -> Dict[str, int]: + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + return state + + def __setstate__(self, d): + self.__dict__ = d + + # for backward compatibility + if not hasattr(self, "sp_model_kwargs"): + self.sp_model_kwargs = {} + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(self.vocab_file) + + def _tokenize(self, text: str) -> List[str]: + """Take as input a string and return a list of strings (tokens) for words/sub-words""" + return self.sp_model.encode(text, out_type=str) + + def _convert_token_to_id(self, token: str) -> int: + """Converts a token (str) to an id using the vocab.""" + if token in self.decoder: + return self.decoder[token] + elif token in self.added_tokens_decoder: + return self.added_tokens_decoder[token] + sp_id = self.sp_model.piece_to_id(token) + return sp_id + self.offset + + def _convert_id_to_token(self, index: int) -> str: + """Converts an index (integer) to a token (str) using the vocab.""" + if index in self.encoder: + return self.encoder[index] + elif index in self.added_tokens_encoder: + return self.added_tokens_encoder[index] + else: + token = self.sp_model.IdToPiece(index - self.offset) + return token def convert_tokens_to_string(self, tokens): - """ Converts a sequence of tokens (string) in a single string. """ - - def save_vocabulary(self, save_directory): - """ - Save the vocabulary and special tokens file to a directory. - - Args: - save_directory (`str`): - The directory in which to save the vocabulary. + """Converts a sequence of tokens (string) in a single string.""" + out_string = self.sp_model.decode_pieces(tokens) + return out_string - Returns: - `Tuple(str)`: Paths to the files saved. - """ - - def build_inputs_with_special_tokens( - self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None - ) -> List[int]: - """ - Build model inputs from a sequence or a pair of sequence for sequence classification tasks - by concatenating and adding special tokens. - A PEGASUSX sequence has the following format: - - - single sequence: ` X ` - - pair of sequences: ` A B ` + def num_special_tokens_to_add(self, pair=False): + """Just EOS""" + return 1 - Args: - token_ids_0 (`List[int]`): - List of IDs to which the special tokens will be added. - token_ids_1 (`List[int]`, *optional*): - Optional second list of IDs for sequence pairs. + def _special_token_mask(self, seq): + all_special_ids = set(self.all_special_ids) # call it once instead of inside list comp + all_special_ids.remove(self.unk_token_id) # is only sometimes special - Returns: - `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. - """ - if token_ids_1 is None: - return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] - cls = [self.cls_token_id] - sep = [self.sep_token_id] - return cls + token_ids_0 + sep + sep + token_ids_1 + sep + return [1 if x in all_special_ids else 0 for x in seq] def get_special_tokens_mask( - self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False ) -> List[int]: - """ - Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding - special tokens using the tokenizer `prepare_for_model` method. - - Args: - token_ids_0 (`List[int]`): - List of IDs. - token_ids_1 (`List[int]`, *optional*): - Optional second list of IDs for sequence pairs. - already_has_special_tokens (`bool`, *optional*, defaults to `False`): - Whether or not the token list is already formatted with special tokens for the model. - - Returns: - `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. - """ + """Get list where entries are [1] if a token is [eos] or [pad] else 0.""" if already_has_special_tokens: - return super().get_special_tokens_mask( - token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True - ) - - if token_ids_1 is None: - return [1] + ([0] * len(token_ids_0)) + [1] - return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1] - - def create_token_type_ids_from_sequences( - self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None - ) -> List[int]: - """ - Create a mask from the two sequences passed to be used in a sequence-pair classification task. - PEGASUSX does not make use of token type ids, therefore a list of zeros is returned. - - Args: - token_ids_0 (`List[int]`): - List of IDs. - token_ids_1 (`List[int]`, *optional*): - Optional second list of IDs for sequence pairs. + return self._special_token_mask(token_ids_0) + elif token_ids_1 is None: + return self._special_token_mask(token_ids_0) + [1] + else: + return self._special_token_mask(token_ids_0 + token_ids_1) + [1] - Returns: - `List[int]`: List of zeros. + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]: """ - sep = [self.sep_token_id] - cls = [self.cls_token_id] - - if token_ids_1 is None: - return len(cls + token_ids_0 + sep) * [0] - return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] - - def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs): - add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space) - if (is_split_into_words or add_prefix_space) and (len(text) > 0 and not text[0].isspace()): - text = " " + text - return (text, kwargs) - -class PegasusXTokenizerFast(PreTrainedTokenizerFast): - """ - Construct a "fast" PEGASUSX tokenizer (backed by HuggingFace's *tokenizers* library). - - Args: - vocab_file (`str`): - Path to the vocabulary file. - """ - - vocab_files_names = VOCAB_FILES_NAMES - pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP - max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES - model_input_names = ["input_ids", "attention_mask"] - - def __init__( - self, - vocab_file, - merges_file, - unk_token="<|endoftext|>", - bos_token="<|endoftext|>", - eos_token="<|endoftext|>", - add_prefix_space=False, - trim_offsets=True, - **kwargs - ): - super().__init__( - ByteLevelBPETokenizer( - vocab_file=vocab_file, - merges_file=merges_file, - add_prefix_space=add_prefix_space, - trim_offsets=trim_offsets, - ), - bos_token=bos_token, - eos_token=eos_token, - unk_token=unk_token, - **kwargs, - ) - self.add_prefix_space = add_prefix_space - - def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): - output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id] - if token_ids_1 is None: - return output + Build model inputs from a sequence or a pair of sequences for sequence classification tasks by concatenating + and adding special tokens. A PEGASUS-X sequence has the following format, where `X` represents the sequence: - return output + [self.eos_token_id] + token_ids_1 + [self.eos_token_id] + - single sequence: `X ` + - pair of sequences: `A B ` (not intended use) - - def create_token_type_ids_from_sequences( - self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None - ) -> List[int]: - """ - Create a mask from the two sequences passed to be used in a sequence-pair classification task. - PEGASUSX does not make use of token type ids, therefore a list of zeros is returned. + BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a + separator. Args: token_ids_0 (`List[int]`): - List of IDs. + List of IDs to which the special tokens will be added. token_ids_1 (`List[int]`, *optional*): Optional second list of IDs for sequence pairs. Returns: - `List[int]`: List of zeros. + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. """ - sep = [self.sep_token_id] - cls = [self.cls_token_id] - if token_ids_1 is None: - return len(cls + token_ids_0 + sep) * [0] - return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + return token_ids_0 + [self.eos_token_id] + # We don't expect to process pairs, but leave the pair logic for API consistency + return token_ids_0 + token_ids_1 + [self.eos_token_id] + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + return (out_vocab_file,) diff --git a/src/transformers/models/pegasus_x/tokenization_pegasus_x_fast.py b/src/transformers/models/pegasus_x/tokenization_pegasus_x_fast.py index b84388b75d8df..4827561b976e9 100644 --- a/src/transformers/models/pegasus_x/tokenization_pegasus_x_fast.py +++ b/src/transformers/models/pegasus_x/tokenization_pegasus_x_fast.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2022 Google and The HuggingFace Inc. team. All rights reserved. +# Copyright 2020 Google and The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,102 +12,207 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tokenization classes for PEGASUSX.""" -from typing import List, Optional +""" Tokenization class for model PEGASUS-X.""" -from tokenizers import ByteLevelBPETokenizer + +import os +from shutil import copyfile +from typing import List, Optional, Tuple from ...tokenization_utils_fast import PreTrainedTokenizerFast -from ...utils import logging -from .tokenization_pegasus_x import PegasusXTokenizer +from ...utils import is_sentencepiece_available, logging + + +if is_sentencepiece_available(): + from .tokenization_pegasus_x import PegasusXTokenizer +else: + PegasusTokenizer = None logger = logging.get_logger(__name__) -VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"} + +SPIECE_UNDERLINE = "▁" + +VOCAB_FILES_NAMES = {"vocab_file": "spiece.model", "tokenizer_file": "tokenizer.json"} PRETRAINED_VOCAB_FILES_MAP = { - "vocab_file": { - "pegasus-x-base": "https://huggingface.co/pegasus-x-base/resolve/main/vocab.txt", - }, + "vocab_file": {"google/pegasus-x-base": "https://huggingface.co/google/pegasus-x-base/resolve/main/spiece.model"}, "tokenizer_file": { - "pegasus-x-base": "https://huggingface.co/pegasus-x-base/resolve/main/tokenizer.json", + "google/pegasus-x-base": "https://huggingface.co/google/pegasus-x-base/resolve/main/tokenizer.json" }, } PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { - "pegasus-x-base": 1024, + "google/pegasus-x-base": 512, } + class PegasusXTokenizerFast(PreTrainedTokenizerFast): - """ - Construct a "fast" PEGASUSX tokenizer (backed by HuggingFace's *tokenizers* library). + r""" + Construct a "fast" PEGASUS tokenizer (backed by HuggingFace's *tokenizers* library). Based on + [Unigram](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=unigram#models). + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. Args: vocab_file (`str`): - Path to the vocabulary file. + [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that + contains the vocabulary necessary to instantiate a tokenizer. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + mask_token (`str`, *optional*, defaults to `""`): + The token used for masking single token values. This is the token used when training this model with masked + language modeling (MLM). This is the token that the PEGASUS-X encoder will try to predict during pretraining. + It corresponds to *[MASK2]* in [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive + Summarization](https://arxiv.org/pdf/1912.08777.pdf). + mask_token_sent (`str`, *optional*, defaults to `""`): + The token used for masking whole target sentences. This is the token used when training this model with gap + sentences generation (GSG). This is the sentence that the PEGASUS-X decoder will try to predict during + pretraining. It corresponds to *[MASK1]* in [PEGASUS: Pre-training with Extracted Gap-sentences for + Abstractive Summarization](https://arxiv.org/pdf/1912.08777.pdf). + additional_special_tokens (`List[str]`, *optional*): + Additional special tokens used by the tokenizer. If no additional_special_tokens are provided and + are used as additional special tokens corresponding to the [original PEGASUS + tokenizer](https://github.com/google-research/pegasus/blob/939830367bcf411193d2b5eca2f2f90f3f9260ca/pegasus/ops/pretrain_parsing_ops.cc#L66) + that uses the tokens 2 - 104 only for pretraining """ - vocab_files_names = VOCAB_FILES_NAMES pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES slow_tokenizer_class = PegasusXTokenizer + model_input_names = ["input_ids", "attention_mask"] def __init__( - self, - vocab_file, - merges_file, - unk_token="<|endoftext|>", - bos_token="<|endoftext|>", - eos_token="<|endoftext|>", - add_prefix_space=False, - trim_offsets=True, - **kwargs + self, + vocab_file=None, + tokenizer_file=None, + pad_token="", + eos_token="", + unk_token="", + mask_token="", + mask_token_sent="", + additional_special_tokens=None, + offset=103, # entries 2 - 104 are only used for pretraining + **kwargs ): + self.offset = offset + + if additional_special_tokens is not None: + if not isinstance(additional_special_tokens, list): + raise TypeError( + f"additional_special_tokens should be of type {type(list)}, but is" + f" {type(additional_special_tokens)}" + ) + + additional_special_tokens_extended = ( + ([mask_token_sent] + additional_special_tokens) + if mask_token_sent not in additional_special_tokens and mask_token_sent is not None + else additional_special_tokens + ) + # fill additional tokens with ..., in case not all additional tokens are already taken + additional_special_tokens_extended += [ + f"" for i in range(len(additional_special_tokens_extended), self.offset - 1) + ] + + if len(set(additional_special_tokens_extended)) != len(additional_special_tokens_extended): + raise ValueError( + "Please make sure that the provided additional_special_tokens do not contain an incorrectly" + f" shifted list of tokens. Found {additional_special_tokens_extended}." + ) + additional_special_tokens = additional_special_tokens_extended + else: + additional_special_tokens = [mask_token_sent] if mask_token_sent is not None else [] + additional_special_tokens += [f"" for i in range(2, self.offset)] + super().__init__( - ByteLevelBPETokenizer( - vocab_file=vocab_file, - merges_file=merges_file, - add_prefix_space=add_prefix_space, - trim_offsets=trim_offsets, - ), - bos_token=bos_token, + vocab_file, + tokenizer_file=tokenizer_file, + pad_token=pad_token, eos_token=eos_token, unk_token=unk_token, + mask_token=mask_token, + mask_token_sent=mask_token_sent, + offset=offset, + additional_special_tokens=additional_special_tokens, **kwargs, ) - self.add_prefix_space = add_prefix_space + self.vocab_file = vocab_file + self.can_save_slow_tokenizer = False if not self.vocab_file else True - def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): - output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id] - if token_ids_1 is None: - return output + def _special_token_mask(self, seq): + all_special_ids = set(self.all_special_ids) # call it once instead of inside list comp + all_special_ids.remove(self.unk_token_id) # is only sometimes special - return output + [self.eos_token_id] + token_ids_1 + [self.eos_token_id] + if all_special_ids != set(range(len(self.additional_special_tokens) + 3)): + raise ValueError( + "There should be 3 special tokens: mask_token, pad_token, and eos_token +" + f" {len(self.additional_special_tokens)} additional_special_tokens, but got {all_special_ids}" + ) + return [1 if x in all_special_ids else 0 for x in seq] - def create_token_type_ids_from_sequences( - self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + def get_special_tokens_mask( + self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False ) -> List[int]: + """Get list where entries are [1] if a token is [eos] or [pad] else 0.""" + if already_has_special_tokens: + return self._special_token_mask(token_ids_0) + elif token_ids_1 is None: + return self._special_token_mask(token_ids_0) + [1] + else: + return self._special_token_mask(token_ids_0 + token_ids_1) + [1] + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]: """ - Create a mask from the two sequences passed to be used in a sequence-pair classification task. - PEGASUSX does not make use of token type ids, therefore a list of zeros is returned. + Build model inputs from a sequence by adding eos to the end. no bos token is added to the front. + + - single sequence: `X ` + - pair of sequences: `A B ` (not intended use) Args: token_ids_0 (`List[int]`): - List of IDs. + List of IDs to which the special tokens will be added token_ids_1 (`List[int]`, *optional*): Optional second list of IDs for sequence pairs. Returns: - `List[int]`: List of zeros. + `List[int]`: list of [input IDs](../glossary#input-ids) with the appropriate special tokens. """ - sep = [self.sep_token_id] - cls = [self.cls_token_id] - if token_ids_1 is None: - return len(cls + token_ids_0 + sep) * [0] - return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] - + return token_ids_0 + [self.eos_token_id] + # We don't expect to process pairs, but leave the pair logic for API consistency + return token_ids_0 + token_ids_1 + [self.eos_token_id] + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not self.can_save_slow_tokenizer: + raise ValueError( + "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow " + "tokenizer." + ) + + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + copyfile(self.vocab_file, out_vocab_file) + return (out_vocab_file,) From 4f54ee9bce7d146579a507cfff8d04a56fe09e22 Mon Sep 17 00:00:00 2001 From: Jason Phang Date: Tue, 26 Jul 2022 11:08:58 -0700 Subject: [PATCH 04/25] pegx update --- .../models/pegasus_x/modeling_pegasus_x.py | 86 +++++++------------ 1 file changed, 30 insertions(+), 56 deletions(-) diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index 0521f6b3dc7e2..b7c1acfc476a9 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -123,39 +123,31 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) -# Copied from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding with Marian->PegasusX -class PegasusXSinusoidalPositionalEmbedding(nn.Embedding): +class PegasusXSinusoidalPositionalEmbedding(nn.Module): """This module produces sinusoidal positional embeddings of any length.""" - def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None: - super().__init__(num_positions, embedding_dim) - self.weight = self._init_weight(self.weight) - - @staticmethod - def _init_weight(out: nn.Parameter) -> nn.Parameter: - """ - Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in - the 2nd half of the vector. [dim // 2:] - """ - n_pos, dim = out.shape - position_enc = np.array( - [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)] - ) - out.requires_grad = False # set early to avoid an error in pytorch-1.8+ - sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1 - out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) - out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) - out.detach_() - return out + def __init__(self, embed_dim, padding_idx, max_scale: int = 10000.0): + super().__init__() + self.embed_dim = embed_dim + self.padding_idx = padding_idx + self.max_scale = max_scale @torch.no_grad() - def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor: + def forward(self, input_embeds: torch.Tensor, past_key_values_length: int = 0) -> torch.Tensor: """`input_ids_shape` is expected to be [bsz x seqlen].""" - bsz, seq_len = input_ids_shape[:2] + batch_size, seq_len = input_embeds.shape[:2] positions = torch.arange( - past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device - ) - return super().forward(positions) + past_key_values_length, past_key_values_length + seq_len, + dtype=torch.long, device=input_embeds.device + )[:, None] + pe = torch.zeros((seq_len, self.embed_dim), device=input_embeds.device) + half_d_feature = self.embed_dim // 2 + div_term = torch.exp( + torch.arange(half_d_feature, device=input_embeds.device, dtype=input_embeds.dtype) + * -(np.log(float(self.max_scale)) / (half_d_feature - 1))) + pe[:, :half_d_feature] = torch.sin(positions * div_term) + pe[:, half_d_feature:] = torch.cos(positions * div_term) + return pe[None].expand(batch_size, -1, -1) # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->PegasusX @@ -168,7 +160,6 @@ def __init__( num_heads: int, dropout: float = 0.0, is_decoder: bool = False, - bias: bool = True, ): super().__init__() self.embed_dim = embed_dim @@ -184,10 +175,10 @@ def __init__( self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder - self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() @@ -306,7 +297,6 @@ def __init__( block_size: int, dropout: float = 0.0, is_decoder: bool = False, - bias: bool = True, ): super().__init__() self.embed_dim = embed_dim @@ -323,10 +313,10 @@ def __init__( self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder - self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() @@ -899,7 +889,6 @@ def __init__(self, config: PegasusXConfig, embed_tokens: Optional[nn.Embedding] self.embed_global = nn.Embedding(config.num_global_tokens, embed_dim) self.embed_positions = PegasusXSinusoidalPositionalEmbedding( - config.max_position_embeddings, embed_dim, self.padding_idx, ) @@ -932,7 +921,6 @@ def resize_position_embeddings(self, new_num_position_embeddings: int): self.config.max_position_embeddings = new_num_position_embeddings self.embed_positions = PegasusXSinusoidalPositionalEmbedding( - self.config.max_position_embeddings, self.config.d_model, self.padding_idx, ) @@ -1004,7 +992,7 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - embed_pos = self.embed_positions(input_shape) + embed_pos = self.embed_positions(inputs_embeds) hidden_states = inputs_embeds + embed_pos hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -1109,7 +1097,6 @@ def __init__(self, config: PegasusXConfig, embed_tokens: Optional[nn.Embedding] self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) self.embed_positions = PegasusXSinusoidalPositionalEmbedding( - config.max_position_embeddings, config.d_model, self.padding_idx, ) @@ -1162,7 +1149,6 @@ def resize_position_embeddings(self, new_num_position_embeddings: int): self.config.max_position_embeddings = new_num_position_embeddings self.embed_positions = PegasusXSinusoidalPositionalEmbedding( - self.config.max_position_embeddings, self.config.d_model, self.padding_idx, ) @@ -1274,7 +1260,7 @@ def forward( encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) # embed positions - positions = self.embed_positions(input_shape, past_key_values_length) + positions = self.embed_positions(inputs_embeds, past_key_values_length) hidden_states = inputs_embeds + positions @@ -1513,7 +1499,6 @@ def forward( class PegasusXForConditionalGeneration(PegasusXPreTrainedModel): base_model_prefix = "model" _keys_to_ignore_on_load_missing = [ - r"final_logits_bias", r"encoder.version", r"decoder.version", r"lm_head.weight", @@ -1523,7 +1508,6 @@ class PegasusXForConditionalGeneration(PegasusXPreTrainedModel): def __init__(self, config: PegasusXConfig): super().__init__(config) self.model = PegasusXModel(config) - self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) # Initialize weights and apply final processing @@ -1537,18 +1521,8 @@ def get_decoder(self): def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: new_embeddings = super().resize_token_embeddings(new_num_tokens) - self._resize_final_logits_bias(new_num_tokens) return new_embeddings - def _resize_final_logits_bias(self, new_num_tokens: int) -> None: - old_num_tokens = self.final_logits_bias.shape[-1] - if new_num_tokens <= old_num_tokens: - new_bias = self.final_logits_bias[:, :new_num_tokens] - else: - extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) - new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) - self.register_buffer("final_logits_bias", new_bias) - def get_output_embeddings(self): return self.lm_head @@ -1631,7 +1605,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, ) - lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias + lm_logits = self.lm_head(outputs[0]) masked_lm_loss = None if labels is not None: From 3b72f499c17c797d372dbbe31ed41fe645c93e3e Mon Sep 17 00:00:00 2001 From: Jason Phang Date: Tue, 26 Jul 2022 12:24:16 -0700 Subject: [PATCH 05/25] pegx fix --- src/transformers/models/pegasus_x/modeling_pegasus_x.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index b7c1acfc476a9..ccfca06e3a870 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -421,7 +421,7 @@ def compute_global_attention_representations( # [B, H, G, G+P] attn_weights = torch.einsum("BHGF,BHXF->BHGX", global_q, global_and_local_k) - attn_weights += attn_weights + extended_mask[:, None, None, :] + attn_weights = attn_weights + extended_mask[:, None, None, :] attn_probs = nn.functional.softmax(attn_weights, dim=-1) attn_probs = nn.functional.dropout(attn_probs, p=self.dropout, training=self.training) From 7407103bb75b39d19d729fb1e53168ea52c4dd67 Mon Sep 17 00:00:00 2001 From: Jason Phang Date: Tue, 26 Jul 2022 17:30:36 -0700 Subject: [PATCH 06/25] pegasus-x fixes --- src/transformers/models/pegasus_x/modeling_pegasus_x.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index ccfca06e3a870..66d65d2bf1a83 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -602,11 +602,13 @@ def forward( def pad_local_tokens(cls, hidden_states, attention_mask, block_size): assert hidden_states.dim() == 3 pad_size = block_size // 2 + mask_min_value = torch.finfo(hidden_states.dtype).min padded_hidden_states = torch.nn.functional.pad( hidden_states, pad=(0, 0, pad_size, pad_size), ) padded_mask = torch.nn.functional.pad( attention_mask, pad=(pad_size, pad_size), + value=mask_min_value, ) return padded_hidden_states, padded_mask @@ -894,7 +896,7 @@ def __init__(self, config: PegasusXConfig, embed_tokens: Optional[nn.Embedding] ) self.layers = nn.ModuleList([ PegasusXEncoderLayer( - stagger_blocks_this_layer=i%2 == 1 and config.stagger_local_blocks, + stagger_blocks_this_layer=i % 2 == 1 and config.stagger_local_blocks, config=config) for i in range(config.encoder_layers) ]) From 0f9396b56dfc65300d24dea275e8583502df1b15 Mon Sep 17 00:00:00 2001 From: Jason Phang Date: Sat, 6 Aug 2022 21:41:40 -0700 Subject: [PATCH 07/25] pegx updates --- .../models/pegasus_x/modeling_pegasus_x.py | 35 +++++-------------- 1 file changed, 9 insertions(+), 26 deletions(-) diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index 66d65d2bf1a83..beec2fdde2ce6 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -126,10 +126,9 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] class PegasusXSinusoidalPositionalEmbedding(nn.Module): """This module produces sinusoidal positional embeddings of any length.""" - def __init__(self, embed_dim, padding_idx, max_scale: int = 10000.0): + def __init__(self, embed_dim, max_scale: int = 10000.0): super().__init__() self.embed_dim = embed_dim - self.padding_idx = padding_idx self.max_scale = max_scale @torch.no_grad() @@ -744,8 +743,6 @@ def _init_weights(self, module): pass elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (PegasusXDecoder, PegasusXEncoder)): @@ -880,20 +877,16 @@ def __init__(self, config: PegasusXConfig, embed_tokens: Optional[nn.Embedding] self.layerdrop = config.encoder_layerdrop embed_dim = config.d_model - self.padding_idx = config.pad_token_id self.max_source_positions = config.max_position_embeddings self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 if embed_tokens is not None: self.embed_tokens = embed_tokens else: - self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) + self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim) self.embed_global = nn.Embedding(config.num_global_tokens, embed_dim) - self.embed_positions = PegasusXSinusoidalPositionalEmbedding( - embed_dim, - self.padding_idx, - ) + self.embed_positions = PegasusXSinusoidalPositionalEmbedding(embed_dim) self.layers = nn.ModuleList([ PegasusXEncoderLayer( stagger_blocks_this_layer=i % 2 == 1 and config.stagger_local_blocks, @@ -922,10 +915,7 @@ def resize_position_embeddings(self, new_num_position_embeddings: int): logger.info(f"Setting `config.max_position_embeddings={new_num_position_embeddings}`...") self.config.max_position_embeddings = new_num_position_embeddings - self.embed_positions = PegasusXSinusoidalPositionalEmbedding( - self.config.d_model, - self.padding_idx, - ) + self.embed_positions = PegasusXSinusoidalPositionalEmbedding(self.config.d_model) self.embed_positions.to(self.device) def get_position_embeddings(self) -> nn.Embedding: @@ -1089,19 +1079,15 @@ def __init__(self, config: PegasusXConfig, embed_tokens: Optional[nn.Embedding] super().__init__(config) self.dropout = config.dropout self.layerdrop = config.decoder_layerdrop - self.padding_idx = config.pad_token_id self.max_target_positions = config.max_position_embeddings self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 if embed_tokens is not None: self.embed_tokens = embed_tokens else: - self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) + self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model) - self.embed_positions = PegasusXSinusoidalPositionalEmbedding( - config.d_model, - self.padding_idx, - ) + self.embed_positions = PegasusXSinusoidalPositionalEmbedding(config.d_model) self.layers = nn.ModuleList([PegasusXDecoderLayer(config) for _ in range(config.decoder_layers)]) self.layer_norm = nn.LayerNorm(config.d_model) @@ -1150,10 +1136,7 @@ def resize_position_embeddings(self, new_num_position_embeddings: int): logger.info(f"Setting `config.max_position_embeddings={new_num_position_embeddings}`...") self.config.max_position_embeddings = new_num_position_embeddings - self.embed_positions = PegasusXSinusoidalPositionalEmbedding( - self.config.d_model, - self.padding_idx, - ) + self.embed_positions = PegasusXSinusoidalPositionalEmbedding(self.config.d_model) self.embed_positions.to(self.device) def get_position_embeddings(self) -> nn.Embedding: @@ -1359,8 +1342,8 @@ class PegasusXModel(PegasusXPreTrainedModel): def __init__(self, config: PegasusXConfig): super().__init__(config) - padding_idx, vocab_size = config.pad_token_id, config.vocab_size - self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) + vocab_size = config.vocab_size + self.shared = nn.Embedding(vocab_size, config.d_model) self.encoder = PegasusXEncoder(config, self.shared) self.decoder = PegasusXDecoder(config, self.shared) From f4e8956e11f610e10d72691a41033f01b179bcc7 Mon Sep 17 00:00:00 2001 From: Jason Phang Date: Sat, 6 Aug 2022 22:11:30 -0700 Subject: [PATCH 08/25] cleanup --- docs/source/en/model_doc/pegasus_x.mdx | 68 +- src/transformers/__init__.py | 9 - src/transformers/models/auto/modeling_auto.py | 19 +- .../models/auto/modeling_flax_auto.py | 18 - src/transformers/models/pegasus_x/__init__.py | 67 +- .../pegasus_x/modeling_flax_pegasus_x.py | 1749 ----------------- .../models/pegasus_x/modeling_pegasus_x.py | 4 +- .../pegasus_x/tokenization_pegasus_x.py | 299 --- .../pegasus_x/tokenization_pegasus_x_fast.py | 218 -- .../pegasus_x/test_modeling_flax_pegasus_x.py | 346 ---- 10 files changed, 12 insertions(+), 2785 deletions(-) delete mode 100644 src/transformers/models/pegasus_x/modeling_flax_pegasus_x.py delete mode 100644 src/transformers/models/pegasus_x/tokenization_pegasus_x.py delete mode 100644 src/transformers/models/pegasus_x/tokenization_pegasus_x_fast.py delete mode 100644 tests/models/pegasus_x/test_modeling_flax_pegasus_x.py diff --git a/docs/source/en/model_doc/pegasus_x.mdx b/docs/source/en/model_doc/pegasus_x.mdx index 73969d3d13d45..609083a424839 100644 --- a/docs/source/en/model_doc/pegasus_x.mdx +++ b/docs/source/en/model_doc/pegasus_x.mdx @@ -14,37 +14,25 @@ specific language governing permissions and limitations under the License. ## Overview -The PEGASUSX model was proposed in []() by . +The PEGASUS-X model was proposed in [Investigating Efficiently Extending Transformers for Long Input Summarization]() by Jason Phang, Yao Zhao and Peter J. Liu. + +PEGASUS-X (PEGASUS eXtended) extends the PEGASUS models for long input summarization through additional long input pretraining and using staggered block-local attention with glboal tokens in the encoder. The abstract from the paper is the following: -** +*While large pretrained Transformer models have proven highly capable at tackling natural language tasks, handling long sequence inputs continues to be a significant challenge. One such task is long input summarization, where inputs are longer than the maximum input context of most pretrained models. Through an extensive set of experiments, we investigate what model architectural changes and pretraining paradigms can most efficiently adapt a pretrained Transformer for long input summarization. We find that a staggered, block-local Transformer with global encoder tokens strikes a good balance of performance and efficiency, and that an additional pretraining phase on long sequences meaningfully improves downstream summarization performance. Based on our findings, we introduce PEGASUS-X, an extension of the PEGASUS model with additional long input pretraining to handle inputs of up to 16K tokens. PEGASUS-X achieves strong performance on long input summarization tasks comparable with much larger models while adding few additional parameters and not requiring model parallelism to train.* Tips: - +* PEGASUS-X uses the same tokenizer as PEGASUS. -This model was contributed by [INSERT YOUR HF USERNAME HERE](). The original code can be found [here](). +This model was contributed by [zphang](