From 94e6b558cbd4e1a688ba9770b24fe79fb3fd032e Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 6 Oct 2022 16:35:23 +0000 Subject: [PATCH 1/6] update feature extractor params --- .../whisper/feature_extraction_whisper.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/whisper/feature_extraction_whisper.py b/src/transformers/models/whisper/feature_extraction_whisper.py index ce5de7b65afa8..d19d0ef3afd51 100644 --- a/src/transformers/models/whisper/feature_extraction_whisper.py +++ b/src/transformers/models/whisper/feature_extraction_whisper.py @@ -218,6 +218,7 @@ def __call__( return_attention_mask: Optional[bool] = None, padding: Optional[str] = "max_length", max_length: Optional[int] = None, + sampling_rate: Optional[int] = None, **kwargs ) -> BatchFeature: """ @@ -255,11 +256,25 @@ def __call__( - `'np'`: Return Numpy `np.ndarray` objects. sampling_rate (`int`, *optional*): The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass - `sampling_rate` at the forward call to prevent silent errors. + `sampling_rate` at the forward call to prevent silent errors and allow automatic speech recognition + pipeline. padding_value (`float`, defaults to 0.0): The value that is used to fill the padding values / vectors. """ + if sampling_rate is not None: + if sampling_rate != self.sampling_rate: + raise ValueError( + f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of" + f" {self.sampling_rate}. Please make sure that the provided `raw_speech` input was sampled with" + f" {self.sampling_rate} and not {sampling_rate}." + ) + else: + logger.warning( + "It is strongly recommended to pass the `sampling_rate` argument to this function. " + "Failing to do so can result in silent errors that might be hard to debug." + ) + is_batched = bool( isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list))) From d5cf3eac2d23d900f202df519ac51287940351f7 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 6 Oct 2022 16:41:28 +0000 Subject: [PATCH 2/6] update attention mask handling --- .../pipelines/automatic_speech_recognition.py | 23 ++++++++++++------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index c52b1002cf713..1d5546edb5e6f 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect from collections import defaultdict from typing import TYPE_CHECKING, Dict, Optional, Union @@ -259,9 +260,9 @@ def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None): # Currently chunking is not possible at this level for `seq2seq` so # it's ok. align_to = self.model.config.inputs_to_logits_ratio - chunk_len = int(round(chunk_length_s * self.feature_extractor.sampling_rate / align_to)) * align_to - stride_left = int(round(stride_length_s[0] * self.feature_extractor.sampling_rate / align_to)) * align_to - stride_right = int(round(stride_length_s[1] * self.feature_extractor.sampling_rate / align_to)) * align_to + chunk_len = int(round(chunk_length_s * self.feature_extractor.sampling_rate / align_to) * align_to) + stride_left = int(round(stride_length_s[0] * self.feature_extractor.sampling_rate / align_to) * align_to) + stride_right = int(round(stride_length_s[1] * self.feature_extractor.sampling_rate / align_to) * align_to) if self.type not in {"ctc", "ctc_with_lm"}: raise ValueError( @@ -304,12 +305,18 @@ def _forward(self, model_inputs): f"`input_features` or `input_values` key, but only has {model_inputs.keys()}" ) - attention_mask = model_inputs.pop("attention_mask", None) - tokens = self.model.generate( - encoder_outputs=encoder(inputs, attention_mask=attention_mask), - attention_mask=attention_mask, - ) + accepts_attention_mask = "attention_mask" in set(inspect.signature(encoder.forward).parameters.keys()) + if accepts_attention_mask: + attention_mask = model_inputs.pop("attention_mask", None) + tokens = self.model.generate( + encoder_outputs=encoder(inputs, attention_mask=attention_mask), + attention_mask=attention_mask, + ) + else: + tokens = self.model.generate(inputs) + out = {"tokens": tokens} + else: stride = model_inputs.pop("stride", None) input_values = model_inputs.pop("input_values") From 6a662f57046ae1e0d435e78977d10039d7c71a8c Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 11 Oct 2022 09:19:27 +0000 Subject: [PATCH 3/6] fix doc and pipeline test --- .../models/whisper/modeling_whisper.py | 21 +++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index ef23914b8ce77..adad4abb6c990 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -31,13 +31,22 @@ Seq2SeqModelOutput, ) from ...modeling_utils import PreTrainedModel -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) from .configuration_whisper import WhisperConfig logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "WhisperConfig" +_CHECKPOINT_FOR_DOC = "openai/whisper-tiny" +_PROCESSOR_FOR_DOC = "openai/whisper-tiny" +_EXPECTED_OUTPUT_SHAPE = [1, 2, 512] WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST = [ @@ -982,7 +991,15 @@ def get_decoder(self): return self.decoder @add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + @add_code_sample_docstrings( + processor_class=_PROCESSOR_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=Seq2SeqModelOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_EXPECTED_OUTPUT_SHAPE, + modality="audio", + ) + @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) def forward( self, input_features=None, From d22ba7a86d951f711417e1d761cc4cd48aab2b04 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 11 Oct 2022 09:46:51 +0000 Subject: [PATCH 4/6] add warning when skipping test --- tests/pipelines/test_pipelines_common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 0f03a42440d79..b9ee26a1ac27f 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -178,8 +178,8 @@ def __repr__(self): class PipelineTestCaseMeta(type): def __new__(mcs, name, bases, dct): def gen_test(ModelClass, checkpoint, tiny_config, tokenizer_class, feature_extractor_class): - @skipIf(tiny_config is None, "TinyConfig does not exist") - @skipIf(checkpoint is None, "checkpoint does not exist") + @skipIf(tiny_config is None, "TinyConfig does not exist, make sure that you defined a `_CONFIG_FOR_DOC` variable in the modeling file") + @skipIf(checkpoint is None, "checkpoint does not exist, make sure that you defined a `_CHECKPOINT_FOR_DOC` variable in the modeling file") def test(self): if ModelClass.__name__.endswith("ForCausalLM"): tiny_config.is_encoder_decoder = False From b6e72d561fdda2fd67e5484b4bd0e328619f0c83 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 11 Oct 2022 10:02:50 +0000 Subject: [PATCH 5/6] add whisper translation and transcription test --- ..._pipelines_automatic_speech_recognition.py | 48 +++++++++++++++++++ tests/pipelines/test_pipelines_common.py | 12 ++++- 2 files changed, 58 insertions(+), 2 deletions(-) diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py index d4fcbf5f78146..f73fda39e97b6 100644 --- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py +++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py @@ -26,6 +26,8 @@ AutoTokenizer, Speech2TextForConditionalGeneration, Wav2Vec2ForCTC, + WhisperForConditionalGeneration, + WhisperProcessor, ) from transformers.pipelines import AutomaticSpeechRecognitionPipeline, pipeline from transformers.pipelines.audio_utils import chunk_bytes_iter @@ -308,6 +310,52 @@ def test_simple_s2t(self): output = asr(data) self.assertEqual(output, {"text": "Un uomo disse all'universo: \"Signore, io esisto."}) + @slow + @require_torch + @require_torchaudio + def test_simple_whisper_asr(self): + speech_recognizer = pipeline( + task="automatic-speech-recognition", + model="openai/whisper-tiny.en", + framework="pt", + ) + ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + filename = ds[0]["file"] + output = speech_recognizer(filename) + self.assertEqual(output, {"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to"}) + + @slow + @require_torch + @require_torchaudio + def test_simple_whisper_translation(self): + speech_recognizer = pipeline( + task="automatic-speech-recognition", + model="openai/whisper-large", + framework="pt", + ) + ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id") + filename = ds[40]["file"] + output = speech_recognizer(filename) + self.assertEqual(output, {"text": " A man said to the universe, Sir, I exist."}) + + model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large") + tokenizer = AutoTokenizer.from_pretrained("openai/whisper-large") + feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-large") + + speech_recognizer_2 = AutomaticSpeechRecognitionPipeline( + model=model, tokenizer=tokenizer, feature_extractor=feature_extractor + ) + output_2 = speech_recognizer_2(filename) + self.assertEqual(output, output_2) + + processor = WhisperProcessor(feature_extractor, tokenizer) + model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(task="transcribe", language="it") + speech_translator = AutomaticSpeechRecognitionPipeline( + model=model, tokenizer=tokenizer, feature_extractor=feature_extractor + ) + output_3 = speech_translator(filename) + self.assertEqual(output_3, {"text": " Un uomo ha detto allo universo, Sir, esiste."}) + @slow @require_torch @require_torchaudio diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index b9ee26a1ac27f..9449b38094988 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -178,8 +178,16 @@ def __repr__(self): class PipelineTestCaseMeta(type): def __new__(mcs, name, bases, dct): def gen_test(ModelClass, checkpoint, tiny_config, tokenizer_class, feature_extractor_class): - @skipIf(tiny_config is None, "TinyConfig does not exist, make sure that you defined a `_CONFIG_FOR_DOC` variable in the modeling file") - @skipIf(checkpoint is None, "checkpoint does not exist, make sure that you defined a `_CHECKPOINT_FOR_DOC` variable in the modeling file") + @skipIf( + tiny_config is None, + "TinyConfig does not exist, make sure that you defined a `_CONFIG_FOR_DOC` variable in the modeling" + " file", + ) + @skipIf( + checkpoint is None, + "checkpoint does not exist, make sure that you defined a `_CHECKPOINT_FOR_DOC` variable in the" + " modeling file", + ) def test(self): if ModelClass.__name__.endswith("ForCausalLM"): tiny_config.is_encoder_decoder = False From 3f88d3769e6fd13ae850154466cc537a321f10df Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 11 Oct 2022 10:10:40 +0000 Subject: [PATCH 6/6] fix build doc test --- .../models/whisper/modeling_whisper.py | 21 ------------------- 1 file changed, 21 deletions(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index adad4abb6c990..a6e50153375c2 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -999,7 +999,6 @@ def get_decoder(self): expected_output=_EXPECTED_OUTPUT_SHAPE, modality="audio", ) - @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) def forward( self, input_features=None, @@ -1016,26 +1015,6 @@ def forward( output_hidden_states=None, return_dict=None, ): - r""" - Returns: - - Example: - - ```python - >>> import torch - >>> from transformers import WhisperModel, WhisperFeatureExtractor - >>> from datasets import load_dataset - - >>> model = WhisperModel.from_pretrained("openai/whisper-base") - >>> feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-base") - >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - >>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt") - >>> input_features = inputs.input_features - >>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id - >>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state - >>> list(last_hidden_state.shape) - [1, 2, 512] - ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = (