New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Doctests] Fix ignore bug and add more doc tests #15911
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -465,22 +465,28 @@ def forward( | |
Examples: | ||
|
||
```python | ||
>>> from transformers import SpeechEncoderDecoderModel, Speech2Text2Processor | ||
>>> from transformers import SpeechEncoderDecoderModel, Wav2Vec2Processor | ||
>>> from datasets import load_dataset | ||
>>> import torch | ||
|
||
>>> processor = Speech2Text2Processor.from_pretrained("facebook/s2t-wav2vec2-large-en-de") | ||
>>> model = SpeechEncoderDecoderModel.from_pretrained("facebook/s2t-wav2vec2-large-en-de") | ||
>>> processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-xls-r-300m-en-to-15") | ||
>>> model = SpeechEncoderDecoderModel.from_pretrained("facebook/wav2vec2-xls-r-300m-en-to-15") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Better example |
||
|
||
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") | ||
|
||
>>> input_values = processor(ds[0]["audio"]["array"], return_tensors="pt").input_values | ||
>>> decoder_input_ids = torch.tensor([[model.config.decoder.decoder_start_token_id]]) | ||
>>> outputs = model(input_values=input_values, decoder_input_ids=decoder_input_ids) | ||
|
||
>>> # inference (generation) | ||
>>> # Inference: Translate English speech to German | ||
>>> generated = model.generate(input_values) | ||
>>> translation = processor.batch_decode(generated) | ||
>>> decoded = processor.batch_decode(generated, skip_special_tokens=True)[0] | ||
>>> decoded | ||
'Mr. Quilter ist der Apostel der Mittelschicht und wir freuen uns, sein Evangelium willkommen heißen zu können.' | ||
|
||
>>> # Training: Train model on English transcription | ||
>>> with processor.as_target_processor(): | ||
... labels = processor(ds[0]["text"], return_tensors="pt").input_ids | ||
|
||
>>> loss = model(input_values, labels=labels).loss | ||
>>> loss.backward() | ||
```""" | ||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,12 +24,7 @@ | |
from torch.nn import CrossEntropyLoss | ||
|
||
from ...activations import ACT2FN | ||
from ...file_utils import ( | ||
add_code_sample_docstrings, | ||
add_start_docstrings, | ||
add_start_docstrings_to_model_forward, | ||
replace_return_docstrings, | ||
) | ||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings | ||
from ...modeling_outputs import ( | ||
BaseModelOutput, | ||
BaseModelOutputWithPastAndCrossAttentions, | ||
|
@@ -44,8 +39,6 @@ | |
logger = logging.get_logger(__name__) | ||
|
||
_CONFIG_FOR_DOC = "Speech2TextConfig" | ||
_TOKENIZER_FOR_DOC = "Speech2TextTokenizer" | ||
_CHECKPOINT_FOR_DOC = "facebook/s2t-small-librispeech-asr" | ||
|
||
|
||
SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST = [ | ||
|
@@ -780,7 +773,7 @@ def forward( | |
attention_mask = self._get_feature_vector_attention_mask(inputs_embeds.shape[1], attention_mask) | ||
padding_mask = attention_mask.ne(1).long() | ||
else: | ||
padding_mask = torch.zeros_like(inputs_embeds, dtype=torch.long) | ||
padding_mask = torch.zeros(inputs_embeds.shape[:2], dtype=torch.long, device=inputs_embeds.device) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a bug actually for which an additional test is added |
||
|
||
embed_pos = self.embed_positions(padding_mask) | ||
|
||
|
@@ -1144,12 +1137,7 @@ def get_decoder(self): | |
return self.decoder | ||
|
||
@add_start_docstrings_to_model_forward(SPEECH_TO_TEXT_INPUTS_DOCSTRING) | ||
@add_code_sample_docstrings( | ||
processor_class=_TOKENIZER_FOR_DOC, | ||
checkpoint=_CHECKPOINT_FOR_DOC, | ||
output_type=Seq2SeqModelOutput, | ||
config_class=_CONFIG_FOR_DOC, | ||
) | ||
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) | ||
def forward( | ||
self, | ||
input_features=None, | ||
|
@@ -1167,6 +1155,21 @@ def forward( | |
output_hidden_states=None, | ||
return_dict=None, | ||
): | ||
r""" | ||
Returns: | ||
|
||
Example: | ||
|
||
```python | ||
patrickvonplaten marked this conversation as resolved.
Show resolved
Hide resolved
|
||
>>> import torch >>> from transformers import Speech2TextModel, Speech2TextFeatureExtractor >>> from datasets | ||
import load_dataset >>> model = Speech2TextModel.from_pretrained("facebook/s2t-small-librispeech-asr") >>> | ||
feature_extractor = Speech2TextFeatureExtractor.from_pretrained("facebook/s2t-small-librispeech-asr") >>> ds = | ||
load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") >>> input_features = | ||
feature_extractor(ds[0]["audio"]["array"], sampling_rate=ds[0]["audio"]["sampling_rate"], | ||
return_tensors="pt").input_features >>> decoder_input_ids = torch.tensor([[1, 1]]) * | ||
model.config.decoder_start_token_id >>> last_hidden_state = model(input_features, | ||
decoder_input_ids=decoder_input_ids).last_hidden_state >>> list(last_hidden_state.shape) [1, 2, 256] """ | ||
|
||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | ||
output_hidden_states = ( | ||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | ||
|
@@ -1300,32 +1303,26 @@ def forward( | |
Returns: | ||
|
||
Example: | ||
|
||
```python | ||
>>> import torch | ||
>>> from transformers import Speech2TextProcessor, Speech2TextForConditionalGeneration | ||
>>> from datasets import load_dataset | ||
>>> import soundfile as sf | ||
|
||
>>> model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-small-librispeech-asr") | ||
>>> processor = Speech2TextProcessor.from_pretrained("facebook/s2t-small-librispeech-asr") | ||
|
||
|
||
>>> def map_to_array(batch): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. old API should be removed |
||
... speech, _ = sf.read(batch["file"]) | ||
... batch["speech"] = speech | ||
... return batch | ||
|
||
|
||
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") | ||
>>> ds = ds.map(map_to_array) | ||
|
||
>>> input_features = processor( | ||
... ds["speech"][0], sampling_rate=16000, return_tensors="pt" | ||
>>> ).input_features # Batch size 1 | ||
... ds[0]["audio"]["array"], sampling_rate=ds[0]["audio"]["sampling_rate"], return_tensors="pt" | ||
>>> ).input_features | ||
|
||
>>> generated_ids = model.generate(inputs=input_features) | ||
|
||
>>> transcription = processor.batch_decode(generated_ids) | ||
>>> transcription = processor.batch_decode(generated_ids)[0] | ||
>>> transcription | ||
'mister quilter is the apostle of the middle classes and we are glad to welcome his gospel' | ||
```""" | ||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,13 +35,12 @@ | |
logger = logging.get_logger(__name__) | ||
|
||
_CONFIG_FOR_DOC = "Speech2Text2Config" | ||
_TOKENIZER_FOR_DOC = "Speech2Text2Tokenizer" | ||
_CHECKPOINT_FOR_DOC = "facebook/s2t-small-librispeech-asr" | ||
_CHECKPOINT_FOR_DOC = "facebook/s2t-wav2vec2-large-en-de" | ||
|
||
|
||
SPEECH_TO_TEXT_2_PRETRAINED_MODEL_ARCHIVE_LIST = [ | ||
"facebook/s2t-small-librispeech-asr", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is not speech2text2 it's speech2text |
||
# See all Speech2Text2 models at https://huggingface.co/models?filter=speech_to_text | ||
"facebook/s2t-wav2vec2-large-en-de", | ||
# See all Speech2Text2 models at https://huggingface.co/models?filter=speech2text2 | ||
] | ||
|
||
|
||
|
@@ -865,13 +864,34 @@ def forward( | |
... Wav2Vec2Model, | ||
... Speech2Text2Config, | ||
... Wav2Vec2Config, | ||
... Wav2Vec2FeatureExtractor, | ||
... Speech2Text2Tokenizer, | ||
... ) | ||
>>> from datasets import load_dataset | ||
|
||
>>> feature_extractor = Wav2Vec2FeatureExtractor() | ||
>>> tokenizer = Speech2Text2Tokenizer.from_pretrained(_CHECKPOINT_FOR_DOC) | ||
|
||
>>> encoder = Wav2Vec2Model(Wav2Vec2Config()) | ||
>>> decoder = Speech2Text2ForCausalLM(Speech2Text2Config()) | ||
# init speech2text model | ||
# init random speech2text model | ||
|
||
>>> model = SpeechEncoderDecoderModel(encoder=encoder, decoder=decoder) | ||
>>> model.config.pad_token_id = tokenizer.pad_token_id | ||
>>> model.config.decoder_start_token_id = tokenizer.bos_token_id | ||
# pre-process inputs and labels | ||
|
||
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") | ||
>>> input_values = feature_extractor( | ||
... ds[0]["audio"]["array"], sampling_rate=ds[0]["audio"]["sampling_rate"], return_tensors="pt" | ||
>>> ).input_values # Batch size 1 | ||
>>> decoder_input_ids = tokenizer(ds[0]["text"], return_tensors="pt").input_ids | ||
# compute loss | ||
|
||
>>> loss = model(inputs=input_values, labels=decoder_input_ids).loss | ||
# backprop loss | ||
|
||
>>> loss.backward() | ||
```""" | ||
|
||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,15 @@ | ||
src/transformers/models/wav2vec2/modeling_wav2vec2.py | ||
src/transformers/models/wav2vec2/tokenization_wav2vec2.py | ||
src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py | ||
src/transformers/models/hubert/modeling_hubert.py | ||
src/transformers/models/wavlm/modeling_wavlm.py | ||
src/transformers/models/unispeech/modeling_unispeech.py | ||
src/transformers/models/unispeech_sat/modeling_unispeech_sat.py | ||
src/transformers/models/sew/modeling_sew.py | ||
src/transformers/models/sew_d/modeling_sew_d.py | ||
src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py | ||
src/transformers/models/speech_to_text/modeling_speech_to_text.py | ||
src/transformers/models/speech_encoder_decoder/modeling_speech_enocder_decoder.py | ||
src/transformers/models/data2vec/modeling_data2vec_audio.py | ||
docs/source/quicktour.mdx | ||
docs/source/task_summary.mdx |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sgugger - I've dug a bit into this.
IGNORE_RESULTS
is actually not a bool but an integer so as it was coded now all doc tests were ignored. We need to use the bitwise and here I think (tested it locally and it works again now).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, thanks for checking and fixing! I had tested it locally but I must have messed something up.