Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TypeError: T5EncoderModel.forward() got an unexpected keyword argument 'token_type_ids' #2588

Open
atasoglu opened this issue Apr 11, 2024 · 2 comments

Comments

@atasoglu
Copy link

Hi,

I am trying to use boun-tabi-LMG/TURNA, a Turkish T5 model, with sentence-transformers as it has been specifically pre-trained for Turkish.

While trying with the code snippet below, I encountered a TypeError as I shared below.

from sentence_transformers import models, SentenceTransformer
t5_model = models.Transformer("boun-tabi-LMG/TURNA")
pooling_model = models.Pooling(t5_model.get_word_embedding_dimension(), pooling_mode="mean")
model = SentenceTransformer(modules=[t5_model, pooling_model])
model.encode(["Merhaba dünya!"])

Out:

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
[<ipython-input-4-e5957cf51002>](https://localhost:8080/#) in <cell line: 5>()
      3 pooling_model = models.Pooling(t5_model.get_word_embedding_dimension(), pooling_mode="mean")
      4 model = SentenceTransformer(modules=[t5_model, pooling_model])
----> 5 model.encode(["Merhaba dünya!"])

6 frames
[/usr/local/lib/python3.10/dist-packages/sentence_transformers/SentenceTransformer.py](https://localhost:8080/#) in encode(self, sentences, prompt_name, prompt, batch_size, show_progress_bar, output_value, precision, convert_to_numpy, convert_to_tensor, device, normalize_embeddings)
    355 
    356             with torch.no_grad():
--> 357                 out_features = self.forward(features)
    358 
    359                 if output_value == "token_embeddings":

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/container.py](https://localhost:8080/#) in forward(self, input)
    215     def forward(self, input):
    216         for module in self:
--> 217             input = module(input)
    218         return input
    219 

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1509             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510         else:
-> 1511             return self._call_impl(*args, **kwargs)
   1512 
   1513     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1518                 or _global_backward_pre_hooks or _global_backward_hooks
   1519                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520             return forward_call(*args, **kwargs)
   1521 
   1522         try:

[/usr/local/lib/python3.10/dist-packages/sentence_transformers/models/Transformer.py](https://localhost:8080/#) in forward(self, features)
     96             trans_features["token_type_ids"] = features["token_type_ids"]
     97 
---> 98         output_states = self.auto_model(**trans_features, return_dict=False)
     99         output_tokens = output_states[0]
    100 

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1509             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510         else:
-> 1511             return self._call_impl(*args, **kwargs)
   1512 
   1513     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1518                 or _global_backward_pre_hooks or _global_backward_hooks
   1519                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520             return forward_call(*args, **kwargs)
   1521 
   1522         try:

TypeError: T5EncoderModel.forward() got an unexpected keyword argument 'token_type_ids'

Thank you in advance for your assistance and guidance!

@tomaarsen
Copy link
Collaborator

Hello!

I believe this is a configuration issue on the side of boun-tabi-LMG/TURNA. Their tokenizer returns a token_type_ids, when it really should not, as the model seems to not use them. Sentence Transformers assumes that if the tokenizer returns token_type_ids, it's because the model requires it, so it's passed to the model.

See e.g. the following script:

from transformers import AutoModel, AutoTokenizer

model = AutoModel.from_pretrained("boun-tabi-LMG/TURNA")
tokenizer = AutoTokenizer.from_pretrained("boun-tabi-LMG/TURNA")

inputs = tokenizer("Merhaba dünya!", return_tensors="pt")
outputs = model(**inputs)
print(outputs.last_hidden_state.shape)

This also returns:

TypeError: T5Model.forward() got an unexpected keyword argument 'token_type_ids'

I suspect this is because the configured tokenizer class here is PreTrainedTokenizerFast, and not e.g. T5TokenizerFast. The former seems to assume that the model has token_type_ids as one of the model inputs: https://github.com/huggingface/transformers/blob/0bd58f1ce0573c0e3269de4215a17d318add49b9/src/transformers/tokenization_utils_base.py#L1561

So, the patch is as follows:

from sentence_transformers import models, SentenceTransformer

t5_model = models.Transformer("boun-tabi-LMG/TURNA")
pooling_model = models.Pooling(t5_model.get_word_embedding_dimension(), pooling_mode="mean")
model = SentenceTransformer(modules=[t5_model, pooling_model])
# Remove token_type_ids from the tokenizer's model input names, as the model does not use it
model.tokenizer.model_input_names.remove("token_type_ids")

embeddings = model.encode(["Merhaba dünya!"])
print(embeddings.shape)
(1, 1024)

And now you can use the model or finetune it as normal. Hope this helps.
You can also open a discussion at https://huggingface.co/boun-tabi-LMG/TURNA that the model_input_names for their tokenizer might not be configured well, or that they might want to change the tokenizer class (e.g. T5TokenizerFast has the correct model_input_names here)

  • Tom Aarsen

@atasoglu
Copy link
Author

It worked! Thank you very much for your detailed answer and thoughtful advice on the tokenizer!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants