Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Wav2Vec2] Fix normalization for non-padded tensors #13512

Conversation

patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Sep 10, 2021

What does this PR do?

This PR fixes a problem with normalization when the input is a list of different length that is not numpified - see: #13504

Just noticed that this bug is pretty severe actually as it affects all large-Wav2Vec2 fine-tuning :-/.
It was introduced by me in this PR: https://github.com/huggingface/transformers/pull/12804/files - I should have written more and better tests for this.

=> This means that from transformers 4.9.0 to until this PR is merged the normalization for all large Wav2Vec2 models was way off when fine-tuning the model.

@LysandreJik - do you think it might be possible to do a patched release for this?

@@ -79,13 +79,20 @@ def __init__(
self.do_normalize = do_normalize

@staticmethod
def zero_mean_unit_var_norm(input_values: List[np.ndarray], input_lengths: List[int]) -> List[np.ndarray]:
def zero_mean_unit_var_norm(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The responsibility of retrieving the correct length from the attention mask should be in this method since input_values and attention_mask are the well-known inputs to functions in transformers

@@ -196,19 +195,33 @@ def __call__(
return_attention_mask=return_attention_mask,
)

if "attention_mask" in padded_inputs:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is removed/cleaned-up

@@ -172,14 +179,6 @@ def __call__(
and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list)))
)

# make sure input is in list format
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently all the padding is happening in pure python and not in numpy so let's move the numpification further down

@@ -134,7 +134,22 @@ def _check_zero_mean_unit_variance(input_vector):
_check_zero_mean_unit_variance(input_values[1, :1000])
_check_zero_mean_unit_variance(input_values[2])

def test_zero_mean_unit_variance_normalization_trunc(self):
def test_zero_mean_unit_variance_normalization(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add test to make sure normalization always works as expected

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch!

This looks good to me.

tests/test_feature_extraction_wav2vec2.py Outdated Show resolved Hide resolved
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to look good but will delegate to @patil-suraj and @anton-l's w2v2 knowledge.

Let me know once this is merged so that I may release a patch.

Copy link
Member

@anton-l anton-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM other than the small issues already pointed out, thanks for fixing it!

Copy link
Member

@anton-l anton-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All slow tests now pass for Wav2Vec and Hubert, nice!

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for adding all those tests :)

@patrickvonplaten patrickvonplaten merged commit d7b3b70 into huggingface:master Sep 10, 2021
@patrickvonplaten patrickvonplaten deleted the fix_normalization_non_padded branch September 10, 2021 13:27
patrickvonplaten added a commit that referenced this pull request Sep 10, 2021
* finalize

* Apply suggestions from code review

* finish cleaner implementation

* more tests

* small fix

* finish

* up
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Wav2vec2Processor normalization issues on transformers 4.10.0
4 participants