Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wav2vec2Processor normalization issues on transformers 4.10.0 #13504

Closed
2 of 4 tasks
dmurillo976s opened this issue Sep 9, 2021 · 3 comments · Fixed by #13512
Closed
2 of 4 tasks

Wav2vec2Processor normalization issues on transformers 4.10.0 #13504

dmurillo976s opened this issue Sep 9, 2021 · 3 comments · Fixed by #13512

Comments

@dmurillo976s
Copy link

dmurillo976s commented Sep 9, 2021

When fine-tuning facebook/wav2vec2-large-robust-ft-swbd-300h I noticed I couldn't reproduce past training results from transformers version 4.9.2 now on 4.10. I noticed that inputs are not being correctly normalized with zero mean and unit variance in this new version. This seems to happen when return_attention_mask=True, audios in a batch input have different lengths and no padding is done.

Environment info

  • transformers version: 4.10.0
  • Platform: Linux-5.4.104+-x86_64-with-Ubuntu-18.04-bionic
  • Python version: 3.7.11
  • PyTorch version (GPU?): 1.8.1+cu102 (True)
  • Tensorflow version (GPU?): 2.6.0 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: No
  • Using distributed or parallel set-up in script?: No

Who can help

@patrickvonplaten
@sgugger

Information

Model I am using (Bert, XLNet ...): Wav2Vec 2.0

The problem arises when using:

  • the official example scripts: (give details below)
  • my own modified scripts: (give details below)

The tasks I am working on is:

  • an official GLUE/SQUaD task: (give the name)
  • my own task or dataset: (give details below)

To reproduce

Steps to reproduce the behavior:

  1. Load Wav2Vec2Processor from facebook/wav2vec2-large-robust-ft-swbd-300h
  2. Call processor with batched inputs of individual different lengths

Sample code to replicate the error:

import numpy as np
from transformers import Wav2Vec2Processor

processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-robust-ft-swbd-300h")

sample_rate = 16000
length_1 = 10
length_2 = 20

# Generate dummy input audios of same sample rate but different lengths
input_1 = np.random.rand((sample_rate * length_1))
input_2 = np.random.rand((sample_rate * length_1))
input_3 = np.random.rand((sample_rate * length_2))
 
same_length_result = processor([input_1, input_2], sampling_rate=sample_rate)
different_length_result = processor([input_1, input_3], sampling_rate=sample_rate)

# Show normalized batched audios when using same length
print(same_length_result)
# Show normalized batched audios when using different length
print(different_length_result)

# Check same audio suffers different transformations according to length of audios in batch
np.testing.assert_array_equal(same_length_result["input_values"][0], different_length_result["input_values"][0])

Expected behavior

A successful assert. Both processed inputs should be equal, with a mean close to 0 and a standard deviation close to 1.

@patrickvonplaten
Copy link
Contributor

Hey @dmurillo976s,

Thanks a lot for the very well explained issue! I can reproduce the problem. I'll open a PR to fix it today

@patrickvonplaten
Copy link
Contributor

@dmurillo976s - this PR: #13512 should fix the problem. Could you give it a try? :-)

@dmurillo976s
Copy link
Author

Hi @patrickvonplaten,
Thank you so very much! Sorry for not responding earlier. I've tried the latest patch release version and everything works as it should!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants