Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NLLB tokenizer #18126

Merged
merged 8 commits into from Jul 18, 2022
Merged

NLLB tokenizer #18126

merged 8 commits into from Jul 18, 2022

Conversation

LysandreJik
Copy link
Member

@LysandreJik LysandreJik commented Jul 13, 2022

Adds the NLLB tokenizer. In order to run:

>>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline

>>> model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")

>>> translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang="eng_Latn", tgt_lang='ron_Latn')
>>> translator("UN Chief says there is no military solution in Syria")
[{'translation_text': 'Şeful ONU spune că nu există o soluţie militară în Siria'}]

Closes #18043

LysandreJik and others added 3 commits July 14, 2022 02:07
@LysandreJik LysandreJik marked this pull request as ready for review July 14, 2022 07:46
@LysandreJik
Copy link
Member Author

All models are now public, feel free to try it out @stefan-it. The generation seems good, have not tried fine-tuning yet.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jul 14, 2022

The documentation is not available anymore as the PR was closed or merged.

@vmarsel
Copy link

vmarsel commented Jul 14, 2022

I don't know of a better place to post (issue?), so I'll do it here :)

@LysandreJik Thank you so much for adding support for the NLLB dense models! I pulled out this branch and tried all of them and they work awesome!

There is the following place in the readme
"This implementation contains dense models available in release. Let us know via GitHub if you want to see MoE models as well."

So it would be really great if you could add MoE models! I tried to figure out the original repo, but it turned out to be unexpectedly difficult. I couldn't get MoE to run. So if you add MoE models, I'm sure it will make a lot of people happier, at least me :)

@TonyMas
Copy link

TonyMas commented Jul 14, 2022

@LysandreJik Thanks a lot for your promt work! I tried using NLLB model from HuggingFace and noticed one problem:

max_length does not set in config.json for any of the NLLB models, so it uses default value of max_length (20).

max_length (`int`, *optional*, defaults to 20):

As the result, your example code cannot generate more than 20 tokens. it is possible to set max_length higher when calling translation method, but it will be great to have meaningful default as well.

For comparison, both for M2M and MBart50 models max_length set in config.json file to 200.

@TroyZuroske
Copy link

@LysandreJik Thanks a lot for your promt work! I tried using NLLB model from HuggingFace and noticed one problem:

max_length does not set in config.json for any of the NLLB models, so it uses default value of max_length (20).

max_length (`int`, *optional*, defaults to 20):

As the result, your example code cannot generate more than 20 tokens. it is possible to set max_length higher when calling translation method, but it will be great to have meaningful default as well.
For comparison, both for M2M and MBart50 models max_length set in config.json file to 200.

How is the default max_length determined per model? Or is it documented in their white papers? With this PR, I have started evaluating the extremely large model (facebook/nllb-200-3.3B) against GCP translation and so far it is doing really well despite the length of text I give it but I want to give it the best chance to perform so knowing the ideal max_length would help.

@TonyMas
Copy link

TonyMas commented Jul 15, 2022

@LysandreJik Thanks a lot for your promt work! I tried using NLLB model from HuggingFace and noticed one problem:
max_length does not set in config.json for any of the NLLB models, so it uses default value of max_length (20).

max_length (`int`, *optional*, defaults to 20):

As the result, your example code cannot generate more than 20 tokens. it is possible to set max_length higher when calling translation method, but it will be great to have meaningful default as well.
For comparison, both for M2M and MBart50 models max_length set in config.json file to 200.

How is the default max_length determined per model? Or is it documented in their white papers? With this PR, I have started evaluating the extremely large model (facebook/nllb-200-3.3B) against GCP translation and so far it is doing really well despite the length of text I give it but I want to give it the best chance to perform so knowing the ideal max_length would help.

I think usual default for max_length is to be equal to max input length. Translation pipeline in transformers are checking that max_length at higher than 90% of input length.

def check_inputs(self, input_length: int, min_length: int, max_length: int):
if input_length > 0.9 * max_length:
logger.warning(
f"Your input_length: {input_length} is bigger than 0.9 * max_length: {max_length}. You might consider "
"increasing your max_length manually, e.g. translator('...', max_length=400)"
)
return True

@LysandreJik LysandreJik requested a review from sgugger July 18, 2022 06:41
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this model! Would it make sense to add a default model for NLLB in the auto mappings?

README.md Outdated Show resolved Hide resolved
docs/source/en/model_doc/nllb.mdx Outdated Show resolved Hide resolved
src/transformers/models/nllb/tokenization_nllb.py Outdated Show resolved Hide resolved
src/transformers/models/nllb/tokenization_nllb_fast.py Outdated Show resolved Hide resolved
tests/models/nllb/test_tokenization_nllb.py Outdated Show resolved Hide resolved
tests/models/nllb/test_tokenization_nllb.py Outdated Show resolved Hide resolved
tests/models/nllb/test_tokenization_nllb.py Outdated Show resolved Hide resolved
LysandreJik and others added 4 commits July 18, 2022 05:47
Co-authored-by: Stefan Schweter <stefan@schweter.it>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
@LysandreJik LysandreJik merged commit c1c79b0 into main Jul 18, 2022
@LysandreJik LysandreJik deleted the nllb branch July 18, 2022 12:12
viclzhu pushed a commit to viclzhu/transformers that referenced this pull request Jul 18, 2022
* NLLB tokenizer

* Apply suggestions from code review - Thanks Stefan!

Co-authored-by: Stefan Schweter <stefan@schweter.it>

* Final touches

* Style :)

* Update docs/source/en/model_doc/nllb.mdx

Co-authored-by: Stefan Schweter <stefan@schweter.it>

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* PR reviews

* Auto models

Co-authored-by: Stefan Schweter <stefan@schweter.it>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
@vince62s
Copy link

as mentionned here #19943
where did you guys see that the " Langtoken" is added AFTER the tokens ?
In the NLLB paper, it says only the "Langtoken" is placed BEFORE the tokens. (mBart does the opposite)

@stefan-it
Copy link
Collaborator

stefan-it commented Mar 18, 2023

I've just seen this example - where the lang-token is prepended:

https://github.com/facebookresearch/fairseq/blob/nllb/fairseq/data/multilingual/multilingual_data_manager.py#L78-L101

from original code base 🤔

@vince62s
Copy link

right. Also I am wondering why they use "" which is "eos" as the start token of the source sequence. (in fact same for the target sequence). I would have expected:
SRC = LangTok + tokens
TGT = BOS + LangTok, tokens + EOS

It seems they use EOS instead of BOS and that they put a EOS as the SRC start.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add Support for "No Language Left Behind" (NLLB)
8 participants