Skip to content

Commit

Permalink
Add Wav2Vec2 & Hubert ForSequenceClassification (#13153)
Browse files Browse the repository at this point in the history
* Add hubert classifier + tests

* Add hubert classifier + tests

* Dummies for all classification tests

* Wav2Vec2 classifier + ER test

* Fix hubert integration tests

* Add hubert IC

* Pass tests for all classification tasks on Hubert

* Pass all tests + copies

* Move models to the SUPERB org
  • Loading branch information
anton-l committed Aug 27, 2021
1 parent 2bef343 commit b6f332e
Show file tree
Hide file tree
Showing 16 changed files with 823 additions and 36 deletions.
8 changes: 8 additions & 0 deletions docs/source/model_doc/hubert.rst
Expand Up @@ -64,6 +64,14 @@ HubertForCTC
.. autoclass:: transformers.HubertForCTC
:members: forward


HubertForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.HubertForSequenceClassification
:members: forward


TFHubertModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
8 changes: 8 additions & 0 deletions docs/source/model_doc/wav2vec2.rst
Expand Up @@ -96,6 +96,14 @@ Wav2Vec2ForCTC
.. autoclass:: transformers.Wav2Vec2ForCTC
:members: forward


Wav2Vec2ForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.Wav2Vec2ForSequenceClassification
:members: forward


Wav2Vec2ForPreTraining
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
4 changes: 4 additions & 0 deletions src/transformers/__init__.py
Expand Up @@ -818,6 +818,7 @@
[
"HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"HubertForCTC",
"HubertForSequenceClassification",
"HubertModel",
"HubertPreTrainedModel",
]
Expand Down Expand Up @@ -1128,6 +1129,7 @@
"Wav2Vec2ForCTC",
"Wav2Vec2ForMaskedLM",
"Wav2Vec2ForPreTraining",
"Wav2Vec2ForSequenceClassification",
"Wav2Vec2Model",
"Wav2Vec2PreTrainedModel",
]
Expand Down Expand Up @@ -2424,6 +2426,7 @@
from .models.hubert import (
HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
HubertForCTC,
HubertForSequenceClassification,
HubertModel,
HubertPreTrainedModel,
)
Expand Down Expand Up @@ -2681,6 +2684,7 @@
Wav2Vec2ForCTC,
Wav2Vec2ForMaskedLM,
Wav2Vec2ForPreTraining,
Wav2Vec2ForSequenceClassification,
Wav2Vec2Model,
Wav2Vec2PreTrainedModel,
)
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/hubert/__init__.py
Expand Up @@ -28,6 +28,7 @@
_import_structure["modeling_hubert"] = [
"HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"HubertForCTC",
"HubertForSequenceClassification",
"HubertModel",
"HubertPreTrainedModel",
]
Expand All @@ -48,6 +49,7 @@
from .modeling_hubert import (
HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
HubertForCTC,
HubertForSequenceClassification,
HubertModel,
HubertPreTrainedModel,
)
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/models/hubert/configuration_hubert.py
Expand Up @@ -115,6 +115,11 @@ class HubertConfig(PretrainedConfig):
Whether to zero infinite losses and the associated gradients of ``torch.nn.CTCLoss``. Infinite losses
mainly occur when the inputs are too short to be aligned to the targets. Only relevant when training an
instance of :class:`~transformers.HubertForCTC`.
use_weighted_layer_sum (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an
instance of :class:`~transformers.HubertForSequenceClassification`.
classifier_proj_size (:obj:`int`, `optional`, defaults to 256):
Dimensionality of the projection before token mean-pooling for classification.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
Expand Down Expand Up @@ -165,6 +170,8 @@ def __init__(
mask_feature_length=10,
ctc_loss_reduction="sum",
ctc_zero_infinity=False,
use_weighted_layer_sum=False,
classifier_proj_size=256,
gradient_checkpointing=False,
pad_token_id=0,
bos_token_id=1,
Expand Down Expand Up @@ -197,6 +204,8 @@ def __init__(
self.vocab_size = vocab_size
self.do_stable_layer_norm = do_stable_layer_norm
self.gradient_checkpointing = gradient_checkpointing
self.use_weighted_layer_sum = use_weighted_layer_sum
self.classifier_proj_size = classifier_proj_size

if (
(len(self.conv_stride) != self.num_feat_extract_layers)
Expand Down
@@ -0,0 +1,69 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Hubert checkpoint."""


import argparse

import torch

from transformers import HubertConfig, HubertForSequenceClassification, Wav2Vec2FeatureExtractor, logging


logging.set_verbosity_info()
logger = logging.get_logger(__name__)

SUPPORTED_MODELS = ["UtteranceLevel"]


@torch.no_grad()
def convert_s3prl_checkpoint(base_model_name, config_path, checkpoint_path, model_dump_path):
"""
Copy/paste/tweak model's weights to transformers design.
"""
checkpoint = torch.load(checkpoint_path, map_location="cpu")
if checkpoint["Config"]["downstream_expert"]["modelrc"]["select"] not in SUPPORTED_MODELS:
raise NotImplementedError(f"The supported s3prl models are {SUPPORTED_MODELS}")

downstream_dict = checkpoint["Downstream"]

hf_congfig = HubertConfig.from_pretrained(config_path)
hf_model = HubertForSequenceClassification.from_pretrained(base_model_name, config=hf_congfig)
hf_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
base_model_name, return_attention_mask=True, do_normalize=False
)

if hf_congfig.use_weighted_layer_sum:
hf_model.layer_weights.data = checkpoint["Featurizer"]["weights"]

hf_model.projector.weight.data = downstream_dict["projector.weight"]
hf_model.projector.bias.data = downstream_dict["projector.bias"]
hf_model.classifier.weight.data = downstream_dict["model.post_net.linear.weight"]
hf_model.classifier.bias.data = downstream_dict["model.post_net.linear.bias"]

hf_feature_extractor.save_pretrained(model_dump_path)
hf_model.save_pretrained(model_dump_path)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--base_model_name", default=None, type=str, help="Name of the huggingface pretrained base model."
)
parser.add_argument("--config_path", default=None, type=str, help="Path to the huggingface classifier config.")
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to the s3prl checkpoint.")
parser.add_argument("--model_dump_path", default=None, type=str, help="Path to the final converted model.")
args = parser.parse_args()
convert_s3prl_checkpoint(args.base_model_name, args.config_path, args.checkpoint_path, args.model_dump_path)
155 changes: 141 additions & 14 deletions src/transformers/models/hubert/modeling_hubert.py
Expand Up @@ -20,12 +20,13 @@
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss

from transformers.deepspeed import is_deepspeed_zero3_enabled

from ...activations import ACT2FN
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
from ...modeling_outputs import BaseModelOutput, CausalLMOutput
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel
from ...utils import logging
from .configuration_hubert import HubertConfig
Expand Down Expand Up @@ -735,6 +736,18 @@ def _conv_out_length(input_length, kernel_size, stride):

return input_lengths

def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
batch_size = attention_mask.shape[0]

attention_mask = torch.zeros(
(batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
)
# these two operations makes sure that all values before the output lengths idxs are attended to
attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
return attention_mask


HUBERT_START_DOCSTRING = r"""
Hubert was proposed in `HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units
Expand Down Expand Up @@ -904,19 +917,8 @@ def forward(
extract_features = extract_features.transpose(1, 2)

if attention_mask is not None:
# compute real output lengths according to convolution formula
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)

attention_mask = torch.zeros(
extract_features.shape[:2], dtype=extract_features.dtype, device=extract_features.device
)

# these two operations makes sure that all values
# before the output lengths indices are attended to
attention_mask[
(torch.arange(attention_mask.shape[0], device=extract_features.device), output_lengths - 1)
] = 1
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
# compute reduced attention_mask corresponding to feature vectors
attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)

hidden_states = self.feature_projection(extract_features)
hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
Expand Down Expand Up @@ -1070,3 +1072,128 @@ def forward(
return CausalLMOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
)


@add_start_docstrings(
"""
Hubert Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like
SUPERB Keyword Spotting.
""",
HUBERT_START_DOCSTRING,
)
class HubertForSequenceClassification(HubertPreTrainedModel):
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.__init__ with Wav2Vec2->Hubert, wav2vec2->hubert
def __init__(self, config):
super().__init__(config)

self.hubert = HubertModel(config)
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
if config.use_weighted_layer_sum:
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)

self.init_weights()

# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_extractor with wav2vec2->hubert
def freeze_feature_extractor(self):
"""
Calling this function will disable the gradient computation for the feature extractor so that its parameters
will not be updated during training.
"""
self.hubert.feature_extractor._freeze_parameters()

# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_base_model with wav2vec2->hubert
def freeze_base_model(self):
"""
Calling this function will disable the gradient computation for the base model so that its parameters will not
be updated during training. Only the classification head will be updated.
"""
for param in self.hubert.parameters():
param.requires_grad = False

@add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_values,
attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
labels=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
Returns:
Example::
>>> import torch
>>> from transformers import Wav2Vec2FeatureExtractor, HubertForSequenceClassification
>>> from datasets import load_dataset
>>> processor = Wav2Vec2FeatureExtractor.from_pretrained("superb/hubert-base-superb-ks")
>>> model = HubertForSequenceClassification.from_pretrained("superb/hubert-base-superb-ks")
>>> ds = load_dataset("anton-l/superb_dummy", "ks", split="test")
>>> input_values = processor(ds["speech"][4], return_tensors="pt").input_values # Batch size 1
>>> logits = model(input_values).logits
>>> predicted_class_ids = torch.argmax(logits, dim=-1)
>>> # compute loss
>>> target_label = "down"
>>> labels = torch.tensor([model.config.label2id[target_label]])
>>> loss = model(input_values, labels=labels).loss
"""

return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states

outputs = self.hubert(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

if self.config.use_weighted_layer_sum:
hidden_states = outputs[1]
hidden_states = torch.stack(hidden_states, dim=1)
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
else:
hidden_states = outputs[0]

hidden_states = self.projector(hidden_states)
if attention_mask is None:
pooled_output = hidden_states.mean(dim=1)
else:
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
hidden_states[~padding_mask] = 0.0
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)

logits = self.classifier(pooled_output)

loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))

if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output

return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
2 changes: 2 additions & 0 deletions src/transformers/models/wav2vec2/__init__.py
Expand Up @@ -33,6 +33,7 @@
"Wav2Vec2ForCTC",
"Wav2Vec2ForMaskedLM",
"Wav2Vec2ForPreTraining",
"Wav2Vec2ForSequenceClassification",
"Wav2Vec2Model",
"Wav2Vec2PreTrainedModel",
]
Expand Down Expand Up @@ -66,6 +67,7 @@
Wav2Vec2ForCTC,
Wav2Vec2ForMaskedLM,
Wav2Vec2ForPreTraining,
Wav2Vec2ForSequenceClassification,
Wav2Vec2Model,
Wav2Vec2PreTrainedModel,
)
Expand Down

0 comments on commit b6f332e

Please sign in to comment.