Skip to content

Commit

Permalink
Supporting seq2seq models for bitsandbytes integration (huggingface…
Browse files Browse the repository at this point in the history
…#18579)

* Supporting seq2seq models for `bitsandbytes` integration

- `bitsandbytes` integration supports now seq2seq models
- check if a model has tied weights as an additional check

* small modification

- tie the weights before looking at tied weights!
  • Loading branch information
younesbelkada authored and amyeroberts committed Aug 17, 2022
1 parent b93957a commit b881653
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 3 deletions.
14 changes: 13 additions & 1 deletion src/transformers/utils/bitsandbytes.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from copy import deepcopy

from transformers.utils import is_accelerate_available, is_bitsandbytes_available


Expand All @@ -9,6 +11,7 @@

if is_accelerate_available():
from accelerate import init_empty_weights
from accelerate.utils import find_tied_parameters


def set_module_8bit_tensor_to_device(module, tensor_name, device, value=None):
Expand Down Expand Up @@ -132,8 +135,17 @@ def get_key_to_not_convert(model):
model (`torch.nn.Module`):
Input model
"""
# Create a copy of the model and tie the weights, then
# check if it contains tied weights
tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager`
tied_model.tie_weights()
has_tied_params = len(find_tied_parameters(tied_model)) > 0

# Check if it is a base model
is_base_model = not hasattr(model, model.base_model_prefix)

# Ignore this for base models (BertModel, GPT2Model, etc.)
if not hasattr(model, model.base_model_prefix):
if (not has_tied_params) and is_base_model:
return ""

# otherwise they have an attached head
Expand Down
22 changes: 20 additions & 2 deletions tests/mixed_int8/test_mixed_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,14 @@
import gc
import unittest

from transformers import AutoModel, AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, pipeline
from transformers import (
AutoModel,
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoTokenizer,
pipeline,
)
from transformers.testing_utils import (
is_torch_available,
require_accelerate,
Expand Down Expand Up @@ -106,12 +113,21 @@ def setUp(self):
super().setUp()
# model_name
self.model_name = "bigscience/bloom-560m"
# Models and tokenizer
self.seq_to_seq_name = "t5-small"

# Different types of model

self.base_model = AutoModel.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto")
# Sequence classification model
self.sequence_model = AutoModelForSequenceClassification.from_pretrained(
self.model_name, load_in_8bit=True, device_map="auto"
)
# CausalLM model
self.model_8bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto")
# Seq2seq model
self.seq_to_seq_model = AutoModelForSeq2SeqLM.from_pretrained(
self.seq_to_seq_name, load_in_8bit=True, device_map="auto"
)

def tearDown(self):
r"""
Expand All @@ -121,6 +137,7 @@ def tearDown(self):
del self.base_model
del self.sequence_model
del self.model_8bit
del self.seq_to_seq_model

gc.collect()
torch.cuda.empty_cache()
Expand All @@ -138,6 +155,7 @@ def test_correct_head_class(self):
# Other heads should be nn.Parameter
self.assertTrue(self.model_8bit.lm_head.weight.__class__ == torch.nn.Parameter)
self.assertTrue(self.sequence_model.score.weight.__class__ == torch.nn.Parameter)
self.assertTrue(self.seq_to_seq_model.lm_head.weight.__class__ == torch.nn.Parameter)


class MixedInt8TestPipeline(BaseMixedInt8Test):
Expand Down

0 comments on commit b881653

Please sign in to comment.