Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Wav2Vec2] Fix normalization for non-padded tensors #13512

Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/transformers/feature_extraction_sequence_utils.py
Expand Up @@ -341,7 +341,7 @@ def _truncate(

return processed_features

def _get_padding_strategies(self, padding=False, max_length=None, pad_to_multiple_of=None, **kwargs):
def _get_padding_strategies(self, padding=False, max_length=None):
"""
Find the correct padding strategy
"""
Expand Down
53 changes: 32 additions & 21 deletions src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py
Expand Up @@ -79,13 +79,18 @@ def __init__(
self.do_normalize = do_normalize

@staticmethod
def zero_mean_unit_var_norm(input_values: List[np.ndarray], input_lengths: List[int]) -> List[np.ndarray]:
def zero_mean_unit_var_norm(input_values: List[np.ndarray], attention_mask: List[np.ndarray]) -> List[np.ndarray]:
"""
Every array in the list is normalized to have zero mean and unit variance
"""
normed_input_values = [
(x - np.mean(x[:i])) / np.sqrt(np.var(x[:i]) + 1e-5) for x, i in zip(input_values, input_lengths)
]

if attention_mask is not None:
normed_input_values = [
(x - x[:i].mean()) / np.sqrt(x[:i].var() + 1e-7) for x, i in zip(input_values, attention_mask.sum(-1))
]
else:
normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values]
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved

return normed_input_values

def __call__(
Expand Down Expand Up @@ -172,14 +177,6 @@ def __call__(
and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list)))
)

# make sure input is in list format
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently all the padding is happening in pure python and not in numpy so let's move the numpification further down

if is_batched and not isinstance(raw_speech[0], np.ndarray):
raw_speech = [np.asarray(speech, dtype=np.float32) for speech in raw_speech]
elif not is_batched and not isinstance(raw_speech, np.ndarray):
raw_speech = np.asarray(raw_speech, dtype=np.float32)
elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.float64:
raw_speech = raw_speech.astype(np.float32)

# always return batch
if not is_batched:
raw_speech = [raw_speech]
Expand All @@ -196,19 +193,33 @@ def __call__(
return_attention_mask=return_attention_mask,
)

if "attention_mask" in padded_inputs:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is removed/cleaned-up

input_lengths = padded_inputs["attention_mask"].sum(-1)
else:
padded_input_values = padded_inputs["input_values"]
input_lengths = [padded_input_values.shape[-1] for _ in range(padded_input_values.shape[0])]

if isinstance(padded_inputs["input_values"][0], np.ndarray):
padded_inputs["input_values"] = [x.astype(np.float32) for x in padded_inputs["input_values"]]
# convert input values to correct format
input_values = padded_inputs["input_values"]
if not isinstance(input_values[0], np.ndarray):
padded_inputs["input_values"] = [np.asarray(array, dtype=np.float32) for array in input_values]
elif (
not isinstance(input_values, np.ndarray)
and isinstance(input_values[0], np.ndarray)
and input_values[0].dtype is np.float64
):
padded_inputs["input_values"] = [array.astype(np.float32) for array in input_values]
elif isinstance(input_values, np.ndarray) and input_values.dtype is np.float64:
padded_inputs["input_values"] = input_values.astype(np.float32)

# convert attention_mask to correct format
attention_mask = padded_inputs.get("attention_mask")
if attention_mask is not None:
padded_inputs["attention_mask"] = [np.asarray(array, dtype=np.bool) for array in attention_mask]

# zero-mean and unit-variance normalization
if self.do_normalize:
attention_mask = (
np.array(attention_mask, dtype=np.bool)
if self._get_padding_strategies(padding, max_length=max_length) is not PaddingStrategy.DO_NOT_PAD
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
else None
)
padded_inputs["input_values"] = self.zero_mean_unit_var_norm(
padded_inputs["input_values"], input_lengths=input_lengths
padded_inputs["input_values"], attention_mask=attention_mask, padding_value=self.padding_value
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
)

if return_tensors is not None:
Expand Down
44 changes: 34 additions & 10 deletions tests/test_feature_extraction_wav2vec2.py
Expand Up @@ -120,21 +120,45 @@ def test_call(self):
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))

def test_zero_mean_unit_variance_normalization(self):
def test_zero_mean_unit_variance_normalization_np(self):
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
processed = feat_extract(speech_inputs, padding="longest", return_tensors="np")
input_values = processed.input_values

def _check_zero_mean_unit_variance(input_vector):
self.assertTrue(np.abs(np.mean(input_vector)) < 1e-3)
self.assertTrue(np.abs(np.var(input_vector) - 1) < 1e-3)
paddings = ["longest", "do_not_pad", "max_length"]
max_lengths = [None, None, 1600]
for max_length, padding in zip(max_lengths, paddings):
processed = feat_extract(speech_inputs, padding=padding, max_length=max_length, return_tensors="np")
input_values = processed.input_values

_check_zero_mean_unit_variance(input_values[0, :800])
_check_zero_mean_unit_variance(input_values[1, :1000])
_check_zero_mean_unit_variance(input_values[2])
def _check_zero_mean_unit_variance(input_vector):
self.assertTrue(np.abs(np.mean(input_vector)) < 1e-3)
self.assertTrue(np.abs(np.var(input_vector) - 1) < 1e-3)

_check_zero_mean_unit_variance(input_values[0][:800])
_check_zero_mean_unit_variance(input_values[1][:1000])
_check_zero_mean_unit_variance(input_values[2][:1200])

def test_zero_mean_unit_variance_normalization(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add test to make sure normalization always works as expected

feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
lengths = range(800, 1400, 200)
speech_inputs = [floats_list((1, x))[0] for x in lengths]

paddings = ["longest", "do_not_pad", "max_length"]
max_lengths = [None, None, 1600]

for max_length, padding in zip(max_lengths, paddings):
processed = feat_extract(speech_inputs, max_length=max_length, padding=padding)
input_values = processed.input_values

def _check_zero_mean_unit_variance(input_vector):
self.assertTrue(np.abs(np.mean(input_vector)) < 1e-3)
self.assertTrue(np.abs(np.var(input_vector) - 1) < 1e-3)

_check_zero_mean_unit_variance(input_values[0][:800])
_check_zero_mean_unit_variance(input_values[1][:1000])
_check_zero_mean_unit_variance(input_values[2][:1200])

def test_zero_mean_unit_variance_normalization_trunc(self):
def test_zero_mean_unit_variance_normalization_trunc_np(self):
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
processed = feat_extract(
Expand Down