Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BloomForSequenceClassification and BloomForTokenClassification classes #17639

Merged
10 changes: 10 additions & 0 deletions docs/source/en/model_doc/bloom.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,14 @@ Several smaller versions of the models have been trained on the same dataset. BL
## BloomForCausalLM

[[autodoc]] BloomForCausalLM
- forward

## BloomForSequenceClassification

[[autodoc]] BloomForSequenceClassification
- forward

## BloomForTokenClassification

[[autodoc]] BloomForTokenClassification
- forward
4 changes: 4 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,8 @@
"BloomForCausalLM",
"BloomModel",
"BloomPreTrainedModel",
"BloomForSequenceClassification",
"BloomForTokenClassification",
]
)
_import_structure["models.blenderbot"].extend(
Expand Down Expand Up @@ -3397,6 +3399,8 @@
from .models.bloom import (
BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST,
BloomForCausalLM,
BloomForSequenceClassification,
BloomForTokenClassification,
BloomModel,
BloomPreTrainedModel,
)
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@
("bert", "BertForSequenceClassification"),
("big_bird", "BigBirdForSequenceClassification"),
("bigbird_pegasus", "BigBirdPegasusForSequenceClassification"),
("bloom", "BloomForSequenceClassification"),
("camembert", "CamembertForSequenceClassification"),
("canine", "CanineForSequenceClassification"),
("convbert", "ConvBertForSequenceClassification"),
Expand Down Expand Up @@ -554,6 +555,7 @@
("albert", "AlbertForTokenClassification"),
("bert", "BertForTokenClassification"),
("big_bird", "BigBirdForTokenClassification"),
("bloom", "BloomForTokenClassification"),
("camembert", "CamembertForTokenClassification"),
("canine", "CanineForTokenClassification"),
("convbert", "ConvBertForTokenClassification"),
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/bloom/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
"BloomForCausalLM",
"BloomModel",
"BloomPreTrainedModel",
"BloomForSequenceClassification",
"BloomForTokenClassification",
]

if TYPE_CHECKING:
Expand All @@ -68,6 +70,8 @@
from .modeling_bloom import (
BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST,
BloomForCausalLM,
BloomForSequenceClassification,
BloomForTokenClassification,
BloomModel,
BloomPreTrainedModel,
)
Expand Down
237 changes: 231 additions & 6 deletions src/transformers/models/bloom/modeling_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,20 @@
"""PyTorch BLOOM model."""

import math
from typing import Tuple
from typing import Tuple, Union

import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss, LayerNorm
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss

from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...utils import logging
from .configuration_bloom import BloomConfig
Expand All @@ -42,7 +47,7 @@
"bigscience/bloom-1b3",
"bigscience/bloom-2b5",
"bigscience/bloom-6b3",
"bigscience/bloom-176b",
"bigscience/bloom",
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
]


Expand Down Expand Up @@ -726,7 +731,7 @@ def forward(
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
Expand Down Expand Up @@ -902,7 +907,7 @@ def forward(
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
Expand Down Expand Up @@ -959,3 +964,223 @@ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) ->
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past
)


@add_start_docstrings(
"""
The Bloom Model transformer with a sequence classification head on top (linear layer).

[`BloomForSequenceClassification`] uses the last token in order to do the classification, as other causal models
(e.g. GPT-1) do.

Since it does classification on the last token, it requires to know the position of the last token. If a
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
each row of the batch).
""",
BLOOM_START_DOCSTRING,
)
class BloomForSequenceClassification(BloomPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]

def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.transformer = BloomModel(config)
self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)

# Initialize weights and apply final processing
self.post_init()

@add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=SequenceClassifierOutputWithPast,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)

if input_ids is not None:
batch_size = input_ids.shape[0]
else:
batch_size = inputs_embeds.shape[0]

if self.config.pad_token_id is not None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is not None:
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
else:
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)

pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]

loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"

if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output

return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)


@add_start_docstrings(
"""
Bloom Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
Named-Entity-Recognition (NER) tasks.
""",
BLOOM_START_DOCSTRING,
)
class BloomForTokenClassification(BloomPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]

def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved

self.transformer = BloomModel(config)
if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
classifier_dropout = config.classifier_dropout
elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
classifier_dropout = config.hidden_dropout
else:
classifier_dropout = 0.1
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)

# Initialize weights and apply final processing
self.post_init()

@add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
haileyschoelkopf marked this conversation as resolved.
Show resolved Hide resolved
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

hidden_states = transformer_outputs[0]
hidden_states = self.dropout(hidden_states)
logits = self.classifier(hidden_states)

loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

if not return_dict:
output = (logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output

return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
14 changes: 14 additions & 0 deletions src/transformers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,20 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class BloomForSequenceClassification(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class BloomForTokenClassification(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class BloomModel(metaclass=DummyObject):
_backends = ["torch"]

Expand Down