Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WhisperModel to transformers #19166

Merged
merged 167 commits into from Oct 5, 2022

Conversation

ArthurZucker
Copy link
Collaborator

What does this PR do?

Adds Whisper to transformers

forced_decoder_tokens += f"<|{task}|>"

forced_decoder_tokens += "<|notimestamps|>" if no_timestamps else None
forced_decoder_ids = self.tokenizer.encode(forced_decoder_tokens, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

be careful here with how white spaces are handled think the convert_token_to_ids function is better

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

convert_tokens_to_ids didn't work and only outputs the first token, I might be missing something

from .english_normalizer import EnglishTextNormalizer


VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "tokenizer_file": "tokenizer.json", "merges_file": "merges.txt"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker could you add a "normalizer_file": "nomalizer.json" here ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can then "get it back" at init if you have your tokenizer accept normalizer_file=None in the init.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome thanks wasn't really sure where to actually put it

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done ✅ Let me just update the code, style and push everything

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like it a lot this way!

@ArthurZucker
Copy link
Collaborator Author

ArthurZucker commented Oct 4, 2022

Okay, so here is a simple example :

>>> model = WhisperForConditionalGeneration.from_pretrained(f"openai/whisper-large")
>>> processor = WhisperProcessor.from_pretrained(f"openai/whisper-large")

>>> ds = load_dataset("common_voice", "ja", split="test", streaming=True)
>>> ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16_000))
>>> ds_iter = iter(ds)
>>> input_speech = next(ds_iter)["audio"]["array"]
>>> inputs = processor(input_speech, return_tensors = "pt")

>>> predicted_ids = model.generate(**inputs)
>>> processor.tokenizer.batch_decode(predicted_ids, skip_special_tokens=True, normalize = True)[0]
'i borrowed a phone from kimura san'

>>> forced_decoder_ids = processor.get_decoder_prompt_ids(language = "ja", task = "transcribe")
>>> predicted_ids = model.generate(**inputs, forced_decoder_ids=forced_decoder_ids)
>>> processor.tokenizer.batch_decode(predicted_ids, skip_special_tokens=True)[0]
"木村さんに電話を貸してもらいました"

>>> forced_decoder_ids = processor.get_decoder_prompt_ids(language = "en", task = "transcribe")
>>> predicted_ids = model.generate(**inputs, forced_decoder_ids=forced_decoder_ids)
>>> processor.tokenizer.batch_decode(predicted_ids, skip_special_tokens=True)[0]
' Kimura san ni denwa wo kaite moraimashita'

generation_idx = input_ids.shape[-1]
current_token = self.force_token_map.get(generation_idx, None)
if current_token is not None:
scores[:, current_token] = np.inf
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be -inf ?

Suggested change
scores[:, current_token] = np.inf
scores[:, current_token] = -np.inf

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No since we want to be sure that they are sampled with argmax

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, so should the docstring be updated, or am I missing something else?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, sorry I didn't update it

Comment on lines 1116 to 1123
suppress_tokens (`List[int]`, *optional*, defaults to `model.config.suppress_tokens`):
A list of tokens that will be supressed at generation. The `SupressTokens` logit processor will set
their log probs to `-inf` so that they are not sampled.
begin_suppress_tokens (`List[int]`, *optional*, defaults to `model.config.begin_suppress_tokens`):
A list of tokens that will be supressed at the begining of the generation. The `SupressBeginTokens`
logit processor will set their log probs to `-inf` so that they are not sampled.
forced_decoder_ids (`List[int]`, *optional*, defaults to `model.config.forced_decoder_ids`):
A list of tokens that will be forced as beginning tokens.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these might need to be set to their respective model config values below if they're None e.g.
suppress_tokens = suppress_tokens if suppress_tokens is not None else model.config.suppress_tokens

(caveat: generation code is all new to me and I'm quite likely missing something)

Nevermind: I realised it's set in _get_logits_processor :)

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice addition for the normalizer, as said before, I like it a lot that way!
Left a few last nits on docstrings.

src/transformers/models/whisper/configuration_whisper.py Outdated Show resolved Hide resolved
src/transformers/models/whisper/configuration_whisper.py Outdated Show resolved Hide resolved
Path to the vocabulary file.
merges_file (`str`):
Path to the merges file.
errors (`str`, *optional*, defaults to `"replace"`):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're missing arguments here (normalizer_file, task and language)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gonna remove them they are useless here

add_prefix_space (`bool`, *optional*, defaults to `False`):
Whether or not to add an initial space to the input. This allows to treat the leading word just as any
other word. (GPT2 tokenizer detect beginning of words by the preceding space).
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing add_bos_token here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like it is almost never documented, see tokenization_utils_base or tokenization_gpt2

ArthurZucker and others added 4 commits October 5, 2022 14:52
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)

@slow
def test_large_generation_multilingual(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very nice test!

@patrickvonplaten
Copy link
Contributor

2 final things:

  • Add 2 tests for batched generation
  • Make sure the tokenizer has a pad_token_id => it should be identical to the eos_token_id since there is no official one. We don't want to trigger a warning every time we run generation in batch
  • Also make sure that config.pad_token_id is correctly set.

cc @sanchit-gandhi we have to remember this when doing fine-tuning experiments! Whisper has pad_token_id == eos_token_id which means that during training we need to make sure in our general training scripts that we don't replace the eos_token_id with -100 and thus ignore it in the loss. Instead we should only replace the "not-first" pad_token_id with -100 (we have the same for GPT2 BTW)

@ArtyomZemlyak
Copy link

Hello!
I apologize for interrupting the development process. But I'm following this thread, because I'm really looking forward to the Whisper at HF, and here I also see words about fine tuning. It will be very cool if you can make good fine tuning and code examples!

I myself am already trying to finetune in different ways, but so far the model is only being unlearned.

In any case, thanks for your work and good luck! ❤️

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Oct 5, 2022

Hello! I apologize for interrupting the development process. But I'm following this thread, because I'm really looking forward to the Whisper at HF, and here I also see words about fine tuning. It will be very cool if you can make good fine tuning and code examples!

I myself am already trying to finetune in different ways, but so far the model is only being unlearned.

In any case, thanks for your work and good luck! heart

Hey @ArtyomZemlyak,

This is a major focus of our right now! We've already done some experiments - you can check it here:
https://openreview.net/forum?id=9OL2fIfDLK (we've fine-tuned whisper on a bunch of open-source datasets)

We hope to have a well-functioning fine-tuning script by early next week (we plan on doing a blog post + google colab)

EXPECTED_LOGITS = torch.tensor(
[
[50257, 50362, 1770, 13, 2264, 346, 353, 318, 262, 46329, 286, 262, 3504, 6097, 11, 290, 356, 389, 9675, 284],
[50257, 50362, 5414, 318, 1770, 13, 2264, 346, 353, 338, 5642, 1342, 3499, 621, 465, 2300, 13, 50256, 50256, 50256],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker feel free to merge whenever :-)

@patrickvonplaten
Copy link
Contributor

Merging to unblock TF PR

@patrickvonplaten patrickvonplaten merged commit 45e1403 into huggingface:main Oct 5, 2022
@ArthurZucker
Copy link
Collaborator Author

Awesome, sorry for the delay!

ajsanjoaquin pushed a commit to ajsanjoaquin/transformers that referenced this pull request Oct 12, 2022
* simplify loop

* add featur extractor

* add model

* start conversion

* add dropout

* initial commit of test files

* copnversion for all models

* update processor for correct padding

* update feature extraction

* update integration test logits match

* fmnt: off for the logits

* on the fly mel bank

* small nit

* update test

* update tokenizer

* nit feature extraction

* update

* update tokenizer test

* adds logit processor and update tokenizer to get supress tokens

* style

* clean convert

* revert to original modeling tf utils

* Update

* update

* nit

* clean convert file

* update tests and nits

* quality

* slow generation test

* ffn_dim to allow customization

* update readme

* add to toctreee

* start fixing integration tests

* update tests and code

* fix feature extractor

* fix config tests common

* update code to fix tests

* fix feature exctractor

* nit feature extraction

* update test for new feature extractor

* style

* add absrtact

* large logits wioth custom decoder input ids

* wraap around is otrch available

* fix feature extractor

* correct logits for whisper small.en

* nit

* fix encoder_attentino_mask

* some fixes

* remove unnecessary inputs

* nits

* add normalizer file

* update etst tokenization

* fix attention mask not defined

* Add model to README

* Fix doc tests

* fix generate

* remove uncoder attention mask useless

* update test modeling whisper

* update condfig to add second non supress tokens

* nits on feature exrtactor

* nit for test tokenizers

* update etsts

* update tests

* update tokenization test

* fixup

* invalidated hf token. Clean convert openai to whisper

* fix logit tests

* fixup

* clean merge

* revert toc_tree changes

* remove useless LogitProcessor

* Update whisper .mdx

* update config file doc

* update configuration docstring

* update test tokenization

* update test tokenization

* update tokenization whisper
Added copied from where needed

* update feature extraction

* nit test name

* style

* quality

* remove get suppress tokens and update non_speech tokens global variables

* Update src/transformers/models/whisper/feature_extraction_whisper.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* clean modeling whisper and test
Removed the attention mask arguments that are deprecated

* fix large test

* Add multilingual audio test, and translate test

* style

* fix larg multilingual test

* nits

* Update docs/source/en/model_doc/whisper.mdx

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* add copied from for attention layer

* remove attention masks in doc

* add english normalizer

* update tokenization test

* remove copied from in whisper attention : no bias in k_proj only

* wrap around dependencies in english normalizer

* style

* correct import generation logits

* for now, wrap feature extractor with torch

* Update src/transformers/models/whisper/convert_openai_whisper_to_tfms.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/whisper/configuration_whisper.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update docs/source/en/model_doc/whisper.mdx

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* remove torch depencies for feature extraction and style

* fixup

* nit

* update logitds

* style

* nit

* nits and fix final tests

* add `is_more_itertools_available` to utils

* quality

* add begin supress tokens, supress tokens to generate args and config

* clean supressTokensLogitProcessor in generation logits

* Nit naming

* add supressTokensAtBegin

* udpate tests, supress tokens to None or correct values

* nit and style

* update RAG to fit test and generate_logit

* add copy pasted statment on english normalizer

* add arguments to config_common_kwargs

* Update src/transformers/generation_utils.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/generation_logits_process.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/whisper/configuration_whisper.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* revert changes based on reviews

* update doc and nits

* more nits

* last nits

* update test configuration common

* add BART name in decoder attention mask documentation

* Update src/transformers/models/whisper/modeling_whisper.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* style

* nit

* nit

* add english.json file to git

* nits on documentation

* nit

* nits

* last styling

* add main toctree file

* remove sentence piece dependency

* clean init file

* fix tokenizer that has no dependencies on sentencepiece

* update whisper init file, nit

* remove english.json file

* add get decoder prompt id

* revert changes and add forced logit processor

* nit

* clean normalizer

* remove protected

* update

* Update src/transformers/models/whisper/configuration_whisper.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* update based on review

* Update src/transformers/models/whisper/configuration_whisper.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* add batched tests

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: NielsRogge <niels.rogge1@gmail.com>
Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
amyeroberts pushed a commit to amyeroberts/transformers that referenced this pull request Oct 18, 2022
* simplify loop

* add featur extractor

* add model

* start conversion

* add dropout

* initial commit of test files

* copnversion for all models

* update processor for correct padding

* update feature extraction

* update integration test logits match

* fmnt: off for the logits

* on the fly mel bank

* small nit

* update test

* update tokenizer

* nit feature extraction

* update

* update tokenizer test

* adds logit processor and update tokenizer to get supress tokens

* style

* clean convert

* revert to original modeling tf utils

* Update

* update

* nit

* clean convert file

* update tests and nits

* quality

* slow generation test

* ffn_dim to allow customization

* update readme

* add to toctreee

* start fixing integration tests

* update tests and code

* fix feature extractor

* fix config tests common

* update code to fix tests

* fix feature exctractor

* nit feature extraction

* update test for new feature extractor

* style

* add absrtact

* large logits wioth custom decoder input ids

* wraap around is otrch available

* fix feature extractor

* correct logits for whisper small.en

* nit

* fix encoder_attentino_mask

* some fixes

* remove unnecessary inputs

* nits

* add normalizer file

* update etst tokenization

* fix attention mask not defined

* Add model to README

* Fix doc tests

* fix generate

* remove uncoder attention mask useless

* update test modeling whisper

* update condfig to add second non supress tokens

* nits on feature exrtactor

* nit for test tokenizers

* update etsts

* update tests

* update tokenization test

* fixup

* invalidated hf token. Clean convert openai to whisper

* fix logit tests

* fixup

* clean merge

* revert toc_tree changes

* remove useless LogitProcessor

* Update whisper .mdx

* update config file doc

* update configuration docstring

* update test tokenization

* update test tokenization

* update tokenization whisper
Added copied from where needed

* update feature extraction

* nit test name

* style

* quality

* remove get suppress tokens and update non_speech tokens global variables

* Update src/transformers/models/whisper/feature_extraction_whisper.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* clean modeling whisper and test
Removed the attention mask arguments that are deprecated

* fix large test

* Add multilingual audio test, and translate test

* style

* fix larg multilingual test

* nits

* Update docs/source/en/model_doc/whisper.mdx

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* add copied from for attention layer

* remove attention masks in doc

* add english normalizer

* update tokenization test

* remove copied from in whisper attention : no bias in k_proj only

* wrap around dependencies in english normalizer

* style

* correct import generation logits

* for now, wrap feature extractor with torch

* Update src/transformers/models/whisper/convert_openai_whisper_to_tfms.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/whisper/configuration_whisper.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update docs/source/en/model_doc/whisper.mdx

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* remove torch depencies for feature extraction and style

* fixup

* nit

* update logitds

* style

* nit

* nits and fix final tests

* add `is_more_itertools_available` to utils

* quality

* add begin supress tokens, supress tokens to generate args and config

* clean supressTokensLogitProcessor in generation logits

* Nit naming

* add supressTokensAtBegin

* udpate tests, supress tokens to None or correct values

* nit and style

* update RAG to fit test and generate_logit

* add copy pasted statment on english normalizer

* add arguments to config_common_kwargs

* Update src/transformers/generation_utils.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/generation_logits_process.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/whisper/configuration_whisper.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* revert changes based on reviews

* update doc and nits

* more nits

* last nits

* update test configuration common

* add BART name in decoder attention mask documentation

* Update src/transformers/models/whisper/modeling_whisper.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* style

* nit

* nit

* add english.json file to git

* nits on documentation

* nit

* nits

* last styling

* add main toctree file

* remove sentence piece dependency

* clean init file

* fix tokenizer that has no dependencies on sentencepiece

* update whisper init file, nit

* remove english.json file

* add get decoder prompt id

* revert changes and add forced logit processor

* nit

* clean normalizer

* remove protected

* update

* Update src/transformers/models/whisper/configuration_whisper.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* update based on review

* Update src/transformers/models/whisper/configuration_whisper.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* add batched tests

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: NielsRogge <niels.rogge1@gmail.com>
Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants