Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Wav2Vec2 & Hubert ForSequenceClassification #13153

Merged
merged 11 commits into from Aug 27, 2021
8 changes: 8 additions & 0 deletions docs/source/model_doc/hubert.rst
Expand Up @@ -64,6 +64,14 @@ HubertForCTC
.. autoclass:: transformers.HubertForCTC
:members: forward


HubertForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.HubertForSequenceClassification
:members: forward


TFHubertModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Expand Up @@ -818,6 +818,7 @@
[
"HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"HubertForCTC",
"HubertForSequenceClassification",
"HubertModel",
"HubertPreTrainedModel",
]
Expand Down Expand Up @@ -2423,6 +2424,7 @@
from .models.hubert import (
HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
HubertForCTC,
HubertForSequenceClassification,
HubertModel,
HubertPreTrainedModel,
)
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/hubert/__init__.py
Expand Up @@ -28,6 +28,7 @@
_import_structure["modeling_hubert"] = [
"HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"HubertForCTC",
"HubertForSequenceClassification",
"HubertModel",
"HubertPreTrainedModel",
]
Expand All @@ -48,6 +49,7 @@
from .modeling_hubert import (
HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
HubertForCTC,
HubertForSequenceClassification,
HubertModel,
HubertPreTrainedModel,
)
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/models/hubert/configuration_hubert.py
Expand Up @@ -115,6 +115,11 @@ class HubertConfig(PretrainedConfig):
Whether to zero infinite losses and the associated gradients of ``torch.nn.CTCLoss``. Infinite losses
mainly occur when the inputs are too short to be aligned to the targets. Only relevant when training an
instance of :class:`~transformers.HubertForCTC`.
use_weighted_layer_sum (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to use a weighted average of layer outputs with learned weights. Only relevant when training an
anton-l marked this conversation as resolved.
Show resolved Hide resolved
instance of :class:`~transformers.HubertForSequenceClassification`.
classifier_proj_size (:obj:`int`, `optional`, defaults to 256):
Dimensionality of the projection before token mean-pooling for classification.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.

Expand Down Expand Up @@ -165,6 +170,8 @@ def __init__(
mask_feature_length=10,
ctc_loss_reduction="sum",
ctc_zero_infinity=False,
use_weighted_layer_sum=False,
classifier_proj_size=256,
gradient_checkpointing=False,
pad_token_id=0,
bos_token_id=1,
Expand Down Expand Up @@ -197,6 +204,8 @@ def __init__(
self.vocab_size = vocab_size
self.do_stable_layer_norm = do_stable_layer_norm
self.gradient_checkpointing = gradient_checkpointing
self.use_weighted_layer_sum = use_weighted_layer_sum
self.classifier_proj_size = classifier_proj_size

if (
(len(self.conv_stride) != self.num_feat_extract_layers)
Expand Down
@@ -0,0 +1,77 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Hubert checkpoint."""


import argparse

import torch

from transformers import (
HubertConfig,
HubertForSequenceClassification,
Wav2Vec2CTCTokenizer,
Wav2Vec2FeatureExtractor,
Wav2Vec2Processor,
logging,
)


logging.set_verbosity_info()
logger = logging.get_logger(__name__)

SUPPORTED_MODELS = ["UtteranceLevel"]


@torch.no_grad()
def convert_s3prl_checkpoint(base_model_name, config_path, checkpoint_path, model_dump_path):
"""
Copy/paste/tweak model's weights to transformers design.
"""
checkpoint = torch.load(checkpoint_path, map_location="cpu")
if checkpoint["Config"]["downstream_expert"]["modelrc"]["select"] not in SUPPORTED_MODELS:
raise NotImplementedError(f"The supported s3prl models are {SUPPORTED_MODELS}")

downstream_dict = checkpoint["Downstream"]

hf_congfig = HubertConfig.from_pretrained(config_path)
hf_model = HubertForSequenceClassification.from_pretrained(base_model_name, config=hf_congfig)
# TODO: remove the need for a tokenizer
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("facebook/wav2vec2-base-960h")
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
hf_processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

if hf_congfig.use_weighted_layer_sum:
hf_model.layer_weights.data = checkpoint["Featurizer"]["weights"]

hf_model.projector.weight.data = downstream_dict["projector.weight"]
hf_model.projector.bias.data = downstream_dict["projector.bias"]
hf_model.classifier.weight.data = downstream_dict["model.post_net.linear.weight"]
hf_model.classifier.bias.data = downstream_dict["model.post_net.linear.bias"]

hf_processor.save_pretrained(model_dump_path)
hf_model.save_pretrained(model_dump_path)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--base_model_name", default=None, type=str, help="Name of the huggingface pretrained base model."
)
parser.add_argument("--config_path", default=None, type=str, help="Path to the huggingface classifier config.")
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to the s3prl checkpoint.")
parser.add_argument("--model_dump_path", default=None, type=str, help="Path to the final converted model.")
args = parser.parse_args()
convert_s3prl_checkpoint(args.base_model_name, args.config_path, args.checkpoint_path, args.model_dump_path)
123 changes: 122 additions & 1 deletion src/transformers/models/hubert/modeling_hubert.py
Expand Up @@ -20,12 +20,13 @@
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from transformers.deepspeed import is_deepspeed_zero3_enabled

from ...activations import ACT2FN
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
from ...modeling_outputs import BaseModelOutput, CausalLMOutput
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
anton-l marked this conversation as resolved.
Show resolved Hide resolved
from ...modeling_utils import PreTrainedModel
from ...utils import logging
from .configuration_hubert import HubertConfig
Expand Down Expand Up @@ -1070,3 +1071,123 @@ def forward(
return CausalLMOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
)


@add_start_docstrings(
"""
Hubert Model with a sequence classification/regression head on top (a linear layer over the pooled output) for
tasks like SUPERB.
""",
HUBERT_START_DOCSTRING,
)
class HubertForSequenceClassification(HubertPreTrainedModel):
anton-l marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, config):
super().__init__(config)

self.hubert = HubertModel(config)
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
if config.use_weighted_layer_sum:
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)

self.init_weights()

def freeze_feature_extractor(self):
"""
Calling this function will disable the gradient computation for the feature extractor so that its parameter
will not be updated during training.
"""
self.hubert.feature_extractor._freeze_parameters()

def freeze_base_model(self):
"""
Calling this function will disable the gradient computation for the base Hubert model so that its parameter
will not be updated during training. Only the classification head will be updated.
"""
for param in self.hubert.parameters():
param.requires_grad = False

@add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_values,
attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
labels=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).

Returns:

TODO: Usage example with Keyword Spotting
"""

return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states

outputs = self.hubert(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

if self.config.use_weighted_layer_sum:
hidden_states = outputs[1]
hidden_states = torch.stack(hidden_states, dim=1)
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
else:
hidden_states = outputs[0]

hidden_states = self.projector(hidden_states)
if attention_mask is None:
pooled_output = hidden_states.mean(dim=1)
else:
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
pooled_output = hidden_states.sum(dim=1) / output_lengths.view(-1, 1)

logits = self.classifier(pooled_output)

loss = None
if labels is not None:
anton-l marked this conversation as resolved.
Show resolved Hide resolved
if self.config.problem_type is None:
if self.config.num_labels == 1:
self.config.problem_type = "regression"
elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"

if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.config.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)

if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output

return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
9 changes: 9 additions & 0 deletions src/transformers/utils/dummy_pt_objects.py
Expand Up @@ -1863,6 +1863,15 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class HubertForSequenceClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])


class HubertModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
Expand Down