Skip to content

Commit

Permalink
Add BloomForSequenceClassification and `BloomForTokenClassification…
Browse files Browse the repository at this point in the history
…` classes (#17639)

* add new bloom classes

* (feat) add bloom classification tests; make style

* style: change import in test

* add some typehints to bloom classes

* merge main into branch

* fix: input checking in bloom seq classification

* fix tests

* change model class tests

* fix few tests

- more tests should pass
- one test left

* make token classifier return hidden states

* style: make BLOOM typehints consistent

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

Co-authored-by: younesbelkada <younesbelkada@gmail.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
  • Loading branch information
3 people committed Jun 14, 2022
1 parent bd43151 commit edb672a
Show file tree
Hide file tree
Showing 8 changed files with 328 additions and 10 deletions.
10 changes: 10 additions & 0 deletions docs/source/en/model_doc/bloom.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,14 @@ Several smaller versions of the models have been trained on the same dataset. BL
## BloomForCausalLM

[[autodoc]] BloomForCausalLM
- forward

## BloomForSequenceClassification

[[autodoc]] BloomForSequenceClassification
- forward

## BloomForTokenClassification

[[autodoc]] BloomForTokenClassification
- forward
4 changes: 4 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,6 +870,8 @@
"BloomForCausalLM",
"BloomModel",
"BloomPreTrainedModel",
"BloomForSequenceClassification",
"BloomForTokenClassification",
]
)
_import_structure["models.blenderbot"].extend(
Expand Down Expand Up @@ -3417,6 +3419,8 @@
from .models.bloom import (
BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST,
BloomForCausalLM,
BloomForSequenceClassification,
BloomForTokenClassification,
BloomModel,
BloomPreTrainedModel,
)
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,7 @@
("bert", "BertForSequenceClassification"),
("big_bird", "BigBirdForSequenceClassification"),
("bigbird_pegasus", "BigBirdPegasusForSequenceClassification"),
("bloom", "BloomForSequenceClassification"),
("camembert", "CamembertForSequenceClassification"),
("canine", "CanineForSequenceClassification"),
("convbert", "ConvBertForSequenceClassification"),
Expand Down Expand Up @@ -563,6 +564,7 @@
("albert", "AlbertForTokenClassification"),
("bert", "BertForTokenClassification"),
("big_bird", "BigBirdForTokenClassification"),
("bloom", "BloomForTokenClassification"),
("camembert", "CamembertForTokenClassification"),
("canine", "CanineForTokenClassification"),
("convbert", "ConvBertForTokenClassification"),
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/bloom/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
"BloomForCausalLM",
"BloomModel",
"BloomPreTrainedModel",
"BloomForSequenceClassification",
"BloomForTokenClassification",
]

if TYPE_CHECKING:
Expand All @@ -68,6 +70,8 @@
from .modeling_bloom import (
BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST,
BloomForCausalLM,
BloomForSequenceClassification,
BloomForTokenClassification,
BloomModel,
BloomPreTrainedModel,
)
Expand Down
237 changes: 231 additions & 6 deletions src/transformers/models/bloom/modeling_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,20 @@
"""PyTorch BLOOM model."""

import math
from typing import Tuple
from typing import Tuple, Union

import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss, LayerNorm
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss

from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...utils import logging
from .configuration_bloom import BloomConfig
Expand All @@ -42,7 +47,7 @@
"bigscience/bloom-1b3",
"bigscience/bloom-2b5",
"bigscience/bloom-6b3",
"bigscience/bloom-176b",
"bigscience/bloom",
]


Expand Down Expand Up @@ -726,7 +731,7 @@ def forward(
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
Expand Down Expand Up @@ -902,7 +907,7 @@ def forward(
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
Expand Down Expand Up @@ -959,3 +964,223 @@ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) ->
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past
)


@add_start_docstrings(
"""
The Bloom Model transformer with a sequence classification head on top (linear layer).
[`BloomForSequenceClassification`] uses the last token in order to do the classification, as other causal models
(e.g. GPT-1) do.
Since it does classification on the last token, it requires to know the position of the last token. If a
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
each row of the batch).
""",
BLOOM_START_DOCSTRING,
)
class BloomForSequenceClassification(BloomPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]

def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.transformer = BloomModel(config)
self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)

# Initialize weights and apply final processing
self.post_init()

@add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=SequenceClassifierOutputWithPast,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)

if input_ids is not None:
batch_size = input_ids.shape[0]
else:
batch_size = inputs_embeds.shape[0]

if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
else:
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)

pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]

loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"

if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output

return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)


@add_start_docstrings(
"""
Bloom Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
Named-Entity-Recognition (NER) tasks.
""",
BLOOM_START_DOCSTRING,
)
class BloomForTokenClassification(BloomPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]

def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels

self.transformer = BloomModel(config)
if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
classifier_dropout = config.classifier_dropout
elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
classifier_dropout = config.hidden_dropout
else:
classifier_dropout = 0.1
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)

# Initialize weights and apply final processing
self.post_init()

@add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

hidden_states = transformer_outputs[0]
hidden_states = self.dropout(hidden_states)
logits = self.classifier(hidden_states)

loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

if not return_dict:
output = (logits,) + transformer_outputs[2:]
return ((loss,) + output) if loss is not None else output

return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
12 changes: 12 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,18 @@ def __init__(
)
self.use_apex = True

# FP16 + model parallelism in SageMaker: gradient clipping does not work for now so we raise a helpful error.
if (
is_sagemaker_mp_enabled()
and self.use_cuda_amp
and args.max_grad_norm is not None
and args.max_grad_norm > 0
):
raise ValueError(
"SageMaker Model Parallelism in mixed precision mode does not support gradient clipping yet. Pass "
"along 'max_grad_norm': 0 in your hyperparameters."
)

# Label smoothing
if self.args.label_smoothing_factor != 0:
self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)
Expand Down
14 changes: 14 additions & 0 deletions src/transformers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,6 +986,20 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class BloomForSequenceClassification(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class BloomForTokenClassification(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class BloomModel(metaclass=DummyObject):
_backends = ["torch"]

Expand Down

0 comments on commit edb672a

Please sign in to comment.