-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TypeError: T5EncoderModel.forward() got an unexpected keyword argument 'token_type_ids' #2588
Comments
Hello! I believe this is a configuration issue on the side of See e.g. the following script: from transformers import AutoModel, AutoTokenizer
model = AutoModel.from_pretrained("boun-tabi-LMG/TURNA")
tokenizer = AutoTokenizer.from_pretrained("boun-tabi-LMG/TURNA")
inputs = tokenizer("Merhaba dünya!", return_tensors="pt")
outputs = model(**inputs)
print(outputs.last_hidden_state.shape) This also returns:
I suspect this is because the configured tokenizer class here is So, the patch is as follows: from sentence_transformers import models, SentenceTransformer
t5_model = models.Transformer("boun-tabi-LMG/TURNA")
pooling_model = models.Pooling(t5_model.get_word_embedding_dimension(), pooling_mode="mean")
model = SentenceTransformer(modules=[t5_model, pooling_model])
# Remove token_type_ids from the tokenizer's model input names, as the model does not use it
model.tokenizer.model_input_names.remove("token_type_ids")
embeddings = model.encode(["Merhaba dünya!"])
print(embeddings.shape)
And now you can use the model or finetune it as normal. Hope this helps.
|
It worked! Thank you very much for your detailed answer and thoughtful advice on the tokenizer! |
Hi,
I am trying to use boun-tabi-LMG/TURNA, a Turkish T5 model, with sentence-transformers as it has been specifically pre-trained for Turkish.
While trying with the code snippet below, I encountered a TypeError as I shared below.
Out:
Thank you in advance for your assistance and guidance!
The text was updated successfully, but these errors were encountered: