New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Wav2Vec2] Fix normalization for non-padded tensors #13512
[Wav2Vec2] Fix normalization for non-padded tensors #13512
Conversation
src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py
Outdated
Show resolved
Hide resolved
src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py
Outdated
Show resolved
Hide resolved
src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py
Outdated
Show resolved
Hide resolved
@@ -79,13 +79,20 @@ def __init__( | |||
self.do_normalize = do_normalize | |||
|
|||
@staticmethod | |||
def zero_mean_unit_var_norm(input_values: List[np.ndarray], input_lengths: List[int]) -> List[np.ndarray]: | |||
def zero_mean_unit_var_norm( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The responsibility of retrieving the correct length from the attention mask should be in this method since input_values
and attention_mask
are the well-known inputs to functions in transformers
@@ -196,19 +195,33 @@ def __call__( | |||
return_attention_mask=return_attention_mask, | |||
) | |||
|
|||
if "attention_mask" in padded_inputs: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part is removed/cleaned-up
@@ -172,14 +179,6 @@ def __call__( | |||
and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list))) | |||
) | |||
|
|||
# make sure input is in list format |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently all the padding is happening in pure python and not in numpy so let's move the numpification further down
@@ -134,7 +134,22 @@ def _check_zero_mean_unit_variance(input_vector): | |||
_check_zero_mean_unit_variance(input_values[1, :1000]) | |||
_check_zero_mean_unit_variance(input_values[2]) | |||
|
|||
def test_zero_mean_unit_variance_normalization_trunc(self): | |||
def test_zero_mean_unit_variance_normalization(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add test to make sure normalization always works as expected
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great catch!
This looks good to me.
src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py
Outdated
Show resolved
Hide resolved
src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems to look good but will delegate to @patil-suraj and @anton-l's w2v2 knowledge.
Let me know once this is merged so that I may release a patch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM other than the small issues already pointed out, thanks for fixing it!
src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All slow tests now pass for Wav2Vec and Hubert, nice!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks for adding all those tests :)
src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py
Show resolved
Hide resolved
* finalize * Apply suggestions from code review * finish cleaner implementation * more tests * small fix * finish * up
What does this PR do?
This PR fixes a problem with normalization when the input is a list of different length that is not numpified - see: #13504
Just noticed that this bug is pretty severe actually as it affects all large-Wav2Vec2 fine-tuning :-/.
It was introduced by me in this PR: https://github.com/huggingface/transformers/pull/12804/files - I should have written more and better tests for this.
=> This means that from transformers 4.9.0 to until this PR is merged the normalization for all large Wav2Vec2 models was way off when fine-tuning the model.
@LysandreJik - do you think it might be possible to do a patched release for this?