-
Notifications
You must be signed in to change notification settings - Fork 25.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Wav2Vec2] Fix normalization for non-padded tensors #13512
Changes from 3 commits
c7b2630
ebdd72b
bed55de
85620e8
8cb8554
064b54b
75541d9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -79,13 +79,20 @@ def __init__( | |
self.do_normalize = do_normalize | ||
|
||
@staticmethod | ||
def zero_mean_unit_var_norm(input_values: List[np.ndarray], input_lengths: List[int]) -> List[np.ndarray]: | ||
def zero_mean_unit_var_norm( | ||
input_values: List[np.ndarray], attention_mask: List[np.ndarray], padding_value=0.0 | ||
patrickvonplaten marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) -> List[np.ndarray]: | ||
""" | ||
Every array in the list is normalized to have zero mean and unit variance | ||
""" | ||
normed_input_values = [ | ||
(x - np.mean(x[:i])) / np.sqrt(np.var(x[:i]) + 1e-5) for x, i in zip(input_values, input_lengths) | ||
] | ||
|
||
if attention_mask is not None: | ||
normed_input_values = [ | ||
(x - x[:i].mean()) / np.sqrt(x[:i].var() + 1e-7) for x, i in zip(input_values, attention_mask.sum(-1)) | ||
] | ||
else: | ||
normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values] | ||
patrickvonplaten marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
return normed_input_values | ||
|
||
def __call__( | ||
|
@@ -172,14 +179,6 @@ def __call__( | |
and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list))) | ||
) | ||
|
||
# make sure input is in list format | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently all the padding is happening in pure python and not in numpy so let's move the numpification further down |
||
if is_batched and not isinstance(raw_speech[0], np.ndarray): | ||
raw_speech = [np.asarray(speech, dtype=np.float32) for speech in raw_speech] | ||
elif not is_batched and not isinstance(raw_speech, np.ndarray): | ||
raw_speech = np.asarray(raw_speech, dtype=np.float32) | ||
elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.float64: | ||
raw_speech = raw_speech.astype(np.float32) | ||
|
||
# always return batch | ||
if not is_batched: | ||
raw_speech = [raw_speech] | ||
|
@@ -196,19 +195,33 @@ def __call__( | |
return_attention_mask=return_attention_mask, | ||
) | ||
|
||
if "attention_mask" in padded_inputs: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This part is removed/cleaned-up |
||
input_lengths = padded_inputs["attention_mask"].sum(-1) | ||
else: | ||
padded_input_values = padded_inputs["input_values"] | ||
input_lengths = [padded_input_values.shape[-1] for _ in range(padded_input_values.shape[0])] | ||
|
||
if isinstance(padded_inputs["input_values"][0], np.ndarray): | ||
padded_inputs["input_values"] = [x.astype(np.float32) for x in padded_inputs["input_values"]] | ||
# convert input values to correct format | ||
input_values = padded_inputs["input_values"] | ||
if not isinstance(input_values[0], np.ndarray): | ||
padded_inputs["input_values"] = [np.asarray(array, dtype=np.float32) for array in input_values] | ||
elif ( | ||
not isinstance(input_values, np.ndarray) | ||
and isinstance(input_values[0], np.ndarray) | ||
and input_values[0].dtype is np.float64 | ||
): | ||
padded_inputs["input_values"] = [array.astype(np.float32) for array in input_values] | ||
elif isinstance(input_values, np.ndarray) and input_values.dtype is np.float64: | ||
padded_inputs["input_values"] = input_values.astype(np.float32) | ||
|
||
# convert attention_mask to correct format | ||
attention_mask = padded_inputs.get("attention_mask") | ||
if attention_mask is not None: | ||
padded_inputs["attention_mask"] = [np.asarray(array, dtype=np.bool) for array in attention_mask] | ||
|
||
# zero-mean and unit-variance normalization | ||
if self.do_normalize: | ||
attention_mask = ( | ||
np.array(attention_mask, dtype=np.bool) | ||
if self._get_padding_strategies(padding, max_length=max_length) is not PaddingStrategy.DO_NOT_PAD | ||
patrickvonplaten marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else None | ||
) | ||
padded_inputs["input_values"] = self.zero_mean_unit_var_norm( | ||
padded_inputs["input_values"], input_lengths=input_lengths | ||
padded_inputs["input_values"], attention_mask=attention_mask, padding_value=self.padding_value | ||
patrickvonplaten marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
|
||
if return_tensors is not None: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -120,7 +120,7 @@ def test_call(self): | |
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2): | ||
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3)) | ||
|
||
def test_zero_mean_unit_variance_normalization(self): | ||
def test_zero_mean_unit_variance_normalization_np(self): | ||
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict()) | ||
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)] | ||
processed = feat_extract(speech_inputs, padding="longest", return_tensors="np") | ||
|
@@ -134,7 +134,22 @@ def _check_zero_mean_unit_variance(input_vector): | |
_check_zero_mean_unit_variance(input_values[1, :1000]) | ||
_check_zero_mean_unit_variance(input_values[2]) | ||
|
||
def test_zero_mean_unit_variance_normalization_trunc(self): | ||
def test_zero_mean_unit_variance_normalization(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add test to make sure normalization always works as expected |
||
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict()) | ||
lengths = range(800, 1400, 200) | ||
speech_inputs = [floats_list((1, x))[0] for x in lengths] | ||
processed = feat_extract(speech_inputs) | ||
input_values = processed.input_values | ||
|
||
def _check_zero_mean_unit_variance(input_vector): | ||
self.assertTrue(np.abs(np.mean(input_vector)) < 1e-3) | ||
self.assertTrue(np.abs(np.var(input_vector) - 1) < 1e-3) | ||
patrickvonplaten marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
for normalized_array, length in zip(input_values, lengths): | ||
self.assertEqual(len(normalized_array), length) | ||
_check_zero_mean_unit_variance(normalized_array) | ||
|
||
def test_zero_mean_unit_variance_normalization_trunc_np(self): | ||
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict()) | ||
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)] | ||
processed = feat_extract( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The responsibility of retrieving the correct length from the attention mask should be in this method since
input_values
andattention_mask
are the well-known inputs to functions intransformers