diff --git a/src/transformers/models/whisper/feature_extraction_whisper.py b/src/transformers/models/whisper/feature_extraction_whisper.py index dda53dffaafd0..0d6bbd9ed18bb 100644 --- a/src/transformers/models/whisper/feature_extraction_whisper.py +++ b/src/transformers/models/whisper/feature_extraction_whisper.py @@ -218,6 +218,7 @@ def __call__( return_attention_mask: Optional[bool] = None, padding: Optional[str] = "max_length", max_length: Optional[int] = None, + sampling_rate: Optional[int] = None, **kwargs ) -> BatchFeature: """ @@ -261,6 +262,19 @@ def __call__( The value that is used to fill the padding values / vectors. """ + if sampling_rate is not None: + if sampling_rate != self.sampling_rate: + raise ValueError( + f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of" + f" {self.sampling_rate}. Please make sure that the provided `raw_speech` input was sampled with" + f" {self.sampling_rate} and not {sampling_rate}." + ) + else: + logger.warning( + "It is strongly recommended to pass the `sampling_rate` argument to this function. " + "Failing to do so can result in silent errors that might be hard to debug." + ) + is_batched = bool( isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list))) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index ef23914b8ce77..a6e50153375c2 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -31,13 +31,22 @@ Seq2SeqModelOutput, ) from ...modeling_utils import PreTrainedModel -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) from .configuration_whisper import WhisperConfig logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "WhisperConfig" +_CHECKPOINT_FOR_DOC = "openai/whisper-tiny" +_PROCESSOR_FOR_DOC = "openai/whisper-tiny" +_EXPECTED_OUTPUT_SHAPE = [1, 2, 512] WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST = [ @@ -982,7 +991,14 @@ def get_decoder(self): return self.decoder @add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + @add_code_sample_docstrings( + processor_class=_PROCESSOR_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=Seq2SeqModelOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_EXPECTED_OUTPUT_SHAPE, + modality="audio", + ) def forward( self, input_features=None, @@ -999,26 +1015,6 @@ def forward( output_hidden_states=None, return_dict=None, ): - r""" - Returns: - - Example: - - ```python - >>> import torch - >>> from transformers import WhisperModel, WhisperFeatureExtractor - >>> from datasets import load_dataset - - >>> model = WhisperModel.from_pretrained("openai/whisper-base") - >>> feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-base") - >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - >>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt") - >>> input_features = inputs.input_features - >>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id - >>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state - >>> list(last_hidden_state.shape) - [1, 2, 512] - ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py index d4fcbf5f78146..f73fda39e97b6 100644 --- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py +++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py @@ -26,6 +26,8 @@ AutoTokenizer, Speech2TextForConditionalGeneration, Wav2Vec2ForCTC, + WhisperForConditionalGeneration, + WhisperProcessor, ) from transformers.pipelines import AutomaticSpeechRecognitionPipeline, pipeline from transformers.pipelines.audio_utils import chunk_bytes_iter @@ -308,6 +310,52 @@ def test_simple_s2t(self): output = asr(data) self.assertEqual(output, {"text": "Un uomo disse all'universo: \"Signore, io esisto."}) + @slow + @require_torch + @require_torchaudio + def test_simple_whisper_asr(self): + speech_recognizer = pipeline( + task="automatic-speech-recognition", + model="openai/whisper-tiny.en", + framework="pt", + ) + ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + filename = ds[0]["file"] + output = speech_recognizer(filename) + self.assertEqual(output, {"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to"}) + + @slow + @require_torch + @require_torchaudio + def test_simple_whisper_translation(self): + speech_recognizer = pipeline( + task="automatic-speech-recognition", + model="openai/whisper-large", + framework="pt", + ) + ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id") + filename = ds[40]["file"] + output = speech_recognizer(filename) + self.assertEqual(output, {"text": " A man said to the universe, Sir, I exist."}) + + model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large") + tokenizer = AutoTokenizer.from_pretrained("openai/whisper-large") + feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-large") + + speech_recognizer_2 = AutomaticSpeechRecognitionPipeline( + model=model, tokenizer=tokenizer, feature_extractor=feature_extractor + ) + output_2 = speech_recognizer_2(filename) + self.assertEqual(output, output_2) + + processor = WhisperProcessor(feature_extractor, tokenizer) + model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(task="transcribe", language="it") + speech_translator = AutomaticSpeechRecognitionPipeline( + model=model, tokenizer=tokenizer, feature_extractor=feature_extractor + ) + output_3 = speech_translator(filename) + self.assertEqual(output_3, {"text": " Un uomo ha detto allo universo, Sir, esiste."}) + @slow @require_torch @require_torchaudio diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 1853f8ebe4117..0449b15396e80 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -178,8 +178,16 @@ def __repr__(self): class PipelineTestCaseMeta(type): def __new__(mcs, name, bases, dct): def gen_test(ModelClass, checkpoint, tiny_config, tokenizer_class, feature_extractor_class): - @skipIf(tiny_config is None, "TinyConfig does not exist") - @skipIf(checkpoint is None, "checkpoint does not exist") + @skipIf( + tiny_config is None, + "TinyConfig does not exist, make sure that you defined a `_CONFIG_FOR_DOC` variable in the modeling" + " file", + ) + @skipIf( + checkpoint is None, + "checkpoint does not exist, make sure that you defined a `_CHECKPOINT_FOR_DOC` variable in the" + " modeling file", + ) def test(self): if ModelClass.__name__.endswith("ForCausalLM"): tiny_config.is_encoder_decoder = False