From 57857c5aa65aaeae9237063e06bdb91133a4198e Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Mon, 12 Sep 2022 19:39:01 +0200 Subject: [PATCH] fix checkpoint name for wav2vec2 conformer (#18994) * fix checkpoint name for wav2vec2 conformer Co-authored-by: ydshieh --- .../configuration_wav2vec2_conformer.py | 10 +++++----- .../wav2vec2_conformer/modeling_wav2vec2_conformer.py | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py index 9c5e4d205b9af7..11181f5601a197 100644 --- a/src/transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +++ b/src/transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py @@ -24,8 +24,8 @@ logger = logging.get_logger(__name__) WAV2VEC2_CONFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = { - "facebook/wav2vec2-conformer-large-rel-pos": ( - "https://huggingface.co/facebook/wav2vec2-conformer-large-rel-pos/resolve/main/config.json" + "facebook/wav2vec2-conformer-rel-pos-large": ( + "https://huggingface.co/facebook/wav2vec2-conformer-rel-pos-large/resolve/main/config.json" ), } @@ -35,7 +35,7 @@ class Wav2Vec2ConformerConfig(PretrainedConfig): This is the configuration class to store the configuration of a [`Wav2Vec2ConformerModel`]. It is used to instantiate an Wav2Vec2Conformer model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the Wav2Vec2Conformer - [facebook/wav2vec2-conformer-large-rel-pos](https://huggingface.co/facebook/wav2vec2-conformer-large-rel-pos) + [facebook/wav2vec2-conformer-rel-pos-large](https://huggingface.co/facebook/wav2vec2-conformer-rel-pos-large) architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the @@ -195,10 +195,10 @@ class Wav2Vec2ConformerConfig(PretrainedConfig): ```python >>> from transformers import Wav2Vec2ConformerModel, Wav2Vec2ConformerConfig - >>> # Initializing a Wav2Vec2Conformer facebook/wav2vec2-conformer-large-rel-pos style configuration + >>> # Initializing a Wav2Vec2Conformer facebook/wav2vec2-conformer-rel-pos-large style configuration >>> configuration = Wav2Vec2ConformerConfig() - >>> # Initializing a model from the facebook/wav2vec2-conformer-large-rel-pos style configuration + >>> # Initializing a model from the facebook/wav2vec2-conformer-rel-pos-large style configuration >>> model = Wav2Vec2ConformerModel(configuration) >>> # Accessing the model configuration diff --git a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py index 4c4962b155c35c..5bee0d040c8ba4 100644 --- a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +++ b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py @@ -80,7 +80,7 @@ WAV2VEC2_CONFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "facebook/wav2vec2-conformer-large-rel-pos", + "facebook/wav2vec2-conformer-rel-pos-large", # See all Wav2Vec2Conformer models at https://huggingface.co/models?filter=wav2vec2-conformer ] @@ -1226,7 +1226,7 @@ def _set_gradient_checkpointing(self, module, value=False): `attention_mask` should only be passed if the corresponding processor has `config.return_attention_mask == True`. For all models whose processor has `config.return_attention_mask == False`, such as - [wav2vec2_conformer-base](https://huggingface.co/facebook/wav2vec2-conformer-large-rel-pos), + [wav2vec2-conformer-rel-pos-large](https://huggingface.co/facebook/wav2vec2-conformer-rel-pos-large), `attention_mask` should **not** be passed to avoid degraded performance when doing batched inference. For such models `input_values` should simply be padded with 0 and passed without `attention_mask`. Be aware that these models also yield slightly different results depending on whether `input_values` is padded or