Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix whisper for pipeline #19482

Merged
merged 7 commits into from Oct 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/transformers/models/whisper/feature_extraction_whisper.py
Expand Up @@ -218,6 +218,7 @@ def __call__(
return_attention_mask: Optional[bool] = None,
padding: Optional[str] = "max_length",
max_length: Optional[int] = None,
sampling_rate: Optional[int] = None,
**kwargs
) -> BatchFeature:
"""
Expand Down Expand Up @@ -261,6 +262,19 @@ def __call__(
The value that is used to fill the padding values / vectors.
"""

if sampling_rate is not None:
if sampling_rate != self.sampling_rate:
raise ValueError(
f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
f" {self.sampling_rate}. Please make sure that the provided `raw_speech` input was sampled with"
f" {self.sampling_rate} and not {sampling_rate}."
)
else:
logger.warning(
"It is strongly recommended to pass the `sampling_rate` argument to this function. "
"Failing to do so can result in silent errors that might be hard to debug."
)

is_batched = bool(
isinstance(raw_speech, (list, tuple))
and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list)))
Expand Down
40 changes: 18 additions & 22 deletions src/transformers/models/whisper/modeling_whisper.py
Expand Up @@ -31,13 +31,22 @@
Seq2SeqModelOutput,
)
from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_whisper import WhisperConfig


logger = logging.get_logger(__name__)

_CONFIG_FOR_DOC = "WhisperConfig"
_CHECKPOINT_FOR_DOC = "openai/whisper-tiny"
_PROCESSOR_FOR_DOC = "openai/whisper-tiny"
_EXPECTED_OUTPUT_SHAPE = [1, 2, 512]


WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST = [
Expand Down Expand Up @@ -982,7 +991,14 @@ def get_decoder(self):
return self.decoder

@add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
@add_code_sample_docstrings(
processor_class=_PROCESSOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=Seq2SeqModelOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_EXPECTED_OUTPUT_SHAPE,
modality="audio",
)
def forward(
self,
input_features=None,
Expand All @@ -999,26 +1015,6 @@ def forward(
output_hidden_states=None,
return_dict=None,
):
r"""
Returns:

Example:

```python
>>> import torch
>>> from transformers import WhisperModel, WhisperFeatureExtractor
>>> from datasets import load_dataset

>>> model = WhisperModel.from_pretrained("openai/whisper-base")
>>> feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-base")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt")
>>> input_features = inputs.input_features
>>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
>>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
>>> list(last_hidden_state.shape)
[1, 2, 512]
```"""

output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down
48 changes: 48 additions & 0 deletions tests/pipelines/test_pipelines_automatic_speech_recognition.py
Expand Up @@ -26,6 +26,8 @@
AutoTokenizer,
Speech2TextForConditionalGeneration,
Wav2Vec2ForCTC,
WhisperForConditionalGeneration,
WhisperProcessor,
)
from transformers.pipelines import AutomaticSpeechRecognitionPipeline, pipeline
from transformers.pipelines.audio_utils import chunk_bytes_iter
Expand Down Expand Up @@ -308,6 +310,52 @@ def test_simple_s2t(self):
output = asr(data)
self.assertEqual(output, {"text": "Un uomo disse all'universo: \"Signore, io esisto."})

@slow
@require_torch
@require_torchaudio
def test_simple_whisper_asr(self):
speech_recognizer = pipeline(
task="automatic-speech-recognition",
model="openai/whisper-tiny.en",
framework="pt",
)
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
filename = ds[0]["file"]
output = speech_recognizer(filename)
self.assertEqual(output, {"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to"})

@slow
@require_torch
@require_torchaudio
def test_simple_whisper_translation(self):
speech_recognizer = pipeline(
task="automatic-speech-recognition",
model="openai/whisper-large",
framework="pt",
)
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
filename = ds[40]["file"]
output = speech_recognizer(filename)
self.assertEqual(output, {"text": " A man said to the universe, Sir, I exist."})

model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
tokenizer = AutoTokenizer.from_pretrained("openai/whisper-large")
feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-large")

speech_recognizer_2 = AutomaticSpeechRecognitionPipeline(
model=model, tokenizer=tokenizer, feature_extractor=feature_extractor
)
output_2 = speech_recognizer_2(filename)
self.assertEqual(output, output_2)

processor = WhisperProcessor(feature_extractor, tokenizer)
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(task="transcribe", language="it")
speech_translator = AutomaticSpeechRecognitionPipeline(
model=model, tokenizer=tokenizer, feature_extractor=feature_extractor
)
output_3 = speech_translator(filename)
self.assertEqual(output_3, {"text": " Un uomo ha detto allo universo, Sir, esiste."})

@slow
@require_torch
@require_torchaudio
Expand Down
12 changes: 10 additions & 2 deletions tests/pipelines/test_pipelines_common.py
Expand Up @@ -178,8 +178,16 @@ def __repr__(self):
class PipelineTestCaseMeta(type):
def __new__(mcs, name, bases, dct):
def gen_test(ModelClass, checkpoint, tiny_config, tokenizer_class, feature_extractor_class):
@skipIf(tiny_config is None, "TinyConfig does not exist")
@skipIf(checkpoint is None, "checkpoint does not exist")
@skipIf(
tiny_config is None,
"TinyConfig does not exist, make sure that you defined a `_CONFIG_FOR_DOC` variable in the modeling"
" file",
)
@skipIf(
checkpoint is None,
"checkpoint does not exist, make sure that you defined a `_CHECKPOINT_FOR_DOC` variable in the"
" modeling file",
)
def test(self):
if ModelClass.__name__.endswith("ForCausalLM"):
tiny_config.is_encoder_decoder = False
Expand Down