Skip to content

Commit

Permalink
[Wav2Vec2] Fix normalization for non-padded tensors (#13512)
Browse files Browse the repository at this point in the history
* finalize

* Apply suggestions from code review

* finish cleaner implementation

* more tests

* small fix

* finish

* up
  • Loading branch information
patrickvonplaten committed Sep 10, 2021
1 parent d12bbe4 commit 60eb416
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 60 deletions.
2 changes: 1 addition & 1 deletion src/transformers/feature_extraction_sequence_utils.py
Expand Up @@ -341,7 +341,7 @@ def _truncate(

return processed_features

def _get_padding_strategies(self, padding=False, max_length=None, pad_to_multiple_of=None, **kwargs):
def _get_padding_strategies(self, padding=False, max_length=None):
"""
Find the correct padding strategy
"""
Expand Down
Expand Up @@ -93,10 +93,13 @@ def _extract_fbank_features(

@staticmethod
def utterance_cmvn(
x: np.ndarray, input_length: int, normalize_means: Optional[bool] = True, normalize_vars: Optional[bool] = True
x: np.ndarray,
input_length: int,
normalize_means: Optional[bool] = True,
normalize_vars: Optional[bool] = True,
padding_value: float = 0.0,
) -> np.ndarray:
# make sure we normalie float32 arrays

mean = x[:input_length].mean(axis=0)
square_sums = (x[:input_length] ** 2).sum(axis=0)

Expand All @@ -107,15 +110,21 @@ def utterance_cmvn(
std = np.sqrt(np.maximum(var, 1e-10))
x = np.divide(x, std)

if x.shape[0] > input_length:
x[input_length:] = padding_value

# make sure array is in float32
x = x.astype(np.float32)

return x

def normalize(self, input_values: List[np.ndarray], input_lengths: List[int]) -> List[np.ndarray]:
def normalize(
self, input_features: List[np.ndarray], attention_mask: Optional[np.ndarray] = None
) -> List[np.ndarray]:
lengths = attention_mask.sum(-1) if attention_mask is not None else [x.shape[0] for x in input_features]
return [
self.utterance_cmvn(x, n, self.normalize_means, self.normalize_vars)
for x, n in zip(input_values, input_lengths)
self.utterance_cmvn(x, n, self.normalize_means, self.normalize_vars, self.padding_value)
for x, n in zip(input_features, lengths)
]

def __call__(
Expand Down Expand Up @@ -197,7 +206,6 @@ def __call__(
and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list)))
)

# make sure input is in list format
if is_batched and not isinstance(raw_speech[0], np.ndarray):
raw_speech = [np.asarray(speech) for speech in raw_speech]
elif not is_batched and not isinstance(raw_speech, np.ndarray):
Expand Down Expand Up @@ -225,21 +233,25 @@ def __call__(
**kwargs,
)

if "attention_mask" in padded_inputs:
input_lengths = padded_inputs["attention_mask"].sum(-1)
else:
padded_input_values = padded_inputs["input_features"]
input_lengths = [padded_input_values.shape[-1] for _ in range(padded_input_values.shape[0])]
# make sure list is in array format
input_features = padded_inputs.get("input_features")
if isinstance(input_features[0], list):
padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features]

attention_mask = padded_inputs.get("attention_mask")
if attention_mask is not None:
padded_inputs["attention_mask"] = [np.asarray(array, dtype=np.bool) for array in attention_mask]

# Utterance-level cepstral mean and variance normalization
if self.do_ceptral_normalize:
input_features = padded_inputs["input_features"]

# make sure list is in array format
if isinstance(input_features[0], list):
input_features = [np.asarray(feature, dtype=np.float32) for feature in input_features]

padded_inputs["input_features"] = self.normalize(input_features, input_lengths=input_lengths)
attention_mask = (
np.array(attention_mask, dtype=np.bool)
if self._get_padding_strategies(padding, max_length=max_length) is not PaddingStrategy.DO_NOT_PAD
else None
)
padded_inputs["input_features"] = self.normalize(
padded_inputs["input_features"], attention_mask=attention_mask
)

if return_tensors is not None:
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
Expand Down
60 changes: 39 additions & 21 deletions src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py
Expand Up @@ -79,13 +79,25 @@ def __init__(
self.do_normalize = do_normalize

@staticmethod
def zero_mean_unit_var_norm(input_values: List[np.ndarray], input_lengths: List[int]) -> List[np.ndarray]:
def zero_mean_unit_var_norm(
input_values: List[np.ndarray], attention_mask: List[np.ndarray], padding_value: float = 0.0
) -> List[np.ndarray]:
"""
Every array in the list is normalized to have zero mean and unit variance
"""
normed_input_values = [
(x - np.mean(x[:i])) / np.sqrt(np.var(x[:i]) + 1e-5) for x, i in zip(input_values, input_lengths)
]
if attention_mask is not None:
attention_mask = np.array(attention_mask, np.bool)
normed_input_values = []

for vector, length in zip(input_values, attention_mask.sum(-1)):
normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7)
if length > normed_slice.shape[0]:
normed_slice[length:] = padding_value

normed_input_values.append(normed_slice)
else:
normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values]

return normed_input_values

def __call__(
Expand Down Expand Up @@ -172,14 +184,6 @@ def __call__(
and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list)))
)

# make sure input is in list format
if is_batched and not isinstance(raw_speech[0], np.ndarray):
raw_speech = [np.asarray(speech, dtype=np.float32) for speech in raw_speech]
elif not is_batched and not isinstance(raw_speech, np.ndarray):
raw_speech = np.asarray(raw_speech, dtype=np.float32)
elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.float64:
raw_speech = raw_speech.astype(np.float32)

# always return batch
if not is_batched:
raw_speech = [raw_speech]
Expand All @@ -196,19 +200,33 @@ def __call__(
return_attention_mask=return_attention_mask,
)

if "attention_mask" in padded_inputs:
input_lengths = padded_inputs["attention_mask"].sum(-1)
else:
padded_input_values = padded_inputs["input_values"]
input_lengths = [padded_input_values.shape[-1] for _ in range(padded_input_values.shape[0])]

if isinstance(padded_inputs["input_values"][0], np.ndarray):
padded_inputs["input_values"] = [x.astype(np.float32) for x in padded_inputs["input_values"]]
# convert input values to correct format
input_values = padded_inputs["input_values"]
if not isinstance(input_values[0], np.ndarray):
padded_inputs["input_values"] = [np.asarray(array, dtype=np.float32) for array in input_values]
elif (
not isinstance(input_values, np.ndarray)
and isinstance(input_values[0], np.ndarray)
and input_values[0].dtype is np.float64
):
padded_inputs["input_values"] = [array.astype(np.float32) for array in input_values]
elif isinstance(input_values, np.ndarray) and input_values.dtype is np.float64:
padded_inputs["input_values"] = input_values.astype(np.float32)

# convert attention_mask to correct format
attention_mask = padded_inputs.get("attention_mask")
if attention_mask is not None:
padded_inputs["attention_mask"] = [np.asarray(array, dtype=np.bool) for array in attention_mask]

# zero-mean and unit-variance normalization
if self.do_normalize:
attention_mask = (
attention_mask
if self._get_padding_strategies(padding, max_length=max_length) is not PaddingStrategy.DO_NOT_PAD
else None
)
padded_inputs["input_values"] = self.zero_mean_unit_var_norm(
padded_inputs["input_values"], input_lengths=input_lengths
padded_inputs["input_values"], attention_mask=attention_mask, padding_value=self.padding_value
)

if return_tensors is not None:
Expand Down
51 changes: 41 additions & 10 deletions tests/test_feature_extraction_speech_to_text.py
Expand Up @@ -136,18 +136,49 @@ def test_call(self):
def test_cepstral_mean_and_variance_normalization(self):
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
inputs = feature_extractor(speech_inputs, padding=True, return_tensors="np", return_attention_mask=True)
input_features = inputs.input_features
attention_mask = inputs.attention_mask
fbank_feat_lengths = np.sum(attention_mask == 1, axis=1)

def _check_zero_mean_unit_variance(input_vector):
self.assertTrue(np.all(np.mean(input_vector, axis=0) < 1e-3))
self.assertTrue(np.all(np.abs(np.var(input_vector, axis=0) - 1) < 1e-3))
paddings = ["longest", "max_length", "do_not_pad"]
max_lengths = [None, 16, None]
var_tolerances = [1e-3, 1e-3, 1e-1]
for max_length, padding, var_tol in zip(max_lengths, paddings, var_tolerances):

_check_zero_mean_unit_variance(input_features[0, : fbank_feat_lengths[0]])
_check_zero_mean_unit_variance(input_features[1, : fbank_feat_lengths[1]])
_check_zero_mean_unit_variance(input_features[2, : fbank_feat_lengths[2]])
inputs = feature_extractor(
speech_inputs, padding=padding, max_length=max_length, return_attention_mask=True
)
input_features = inputs.input_features
attention_mask = inputs.attention_mask
fbank_feat_lengths = [np.sum(x) for x in attention_mask]

def _check_zero_mean_unit_variance(input_vector, var_tol=1e-3):
self.assertTrue(np.all(np.mean(input_vector, axis=0) < 1e-3))
self.assertTrue(np.all(np.abs(np.var(input_vector, axis=0) - 1) < var_tol))

_check_zero_mean_unit_variance(input_features[0][: fbank_feat_lengths[0]], var_tol)
_check_zero_mean_unit_variance(input_features[1][: fbank_feat_lengths[1]], var_tol)
_check_zero_mean_unit_variance(input_features[2][: fbank_feat_lengths[2]], var_tol)

def test_cepstral_mean_and_variance_normalization_np(self):
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]

paddings = ["longest", "max_length", "do_not_pad"]
max_lengths = [None, 16, None]
var_tolerances = [1e-3, 1e-3, 1e-1]
for max_length, padding, var_tol in zip(max_lengths, paddings, var_tolerances):
inputs = feature_extractor(
speech_inputs, max_length=max_length, padding=padding, return_tensors="np", return_attention_mask=True
)
input_features = inputs.input_features
attention_mask = inputs.attention_mask
fbank_feat_lengths = [np.sum(x) for x in attention_mask]

def _check_zero_mean_unit_variance(input_vector, var_tol=1e-3):
self.assertTrue(np.all(np.mean(input_vector, axis=0) < 1e-3))
self.assertTrue(np.all(np.abs(np.var(input_vector, axis=0) - 1) < var_tol))

_check_zero_mean_unit_variance(input_features[0][: fbank_feat_lengths[0]], var_tol)
_check_zero_mean_unit_variance(input_features[1][: fbank_feat_lengths[1]], var_tol)
_check_zero_mean_unit_variance(input_features[2][: fbank_feat_lengths[2]], var_tol)

def test_cepstral_mean_and_variance_normalization_trunc(self):
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
Expand Down
44 changes: 34 additions & 10 deletions tests/test_feature_extraction_wav2vec2.py
Expand Up @@ -120,21 +120,45 @@ def test_call(self):
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))

def test_zero_mean_unit_variance_normalization(self):
def test_zero_mean_unit_variance_normalization_np(self):
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
processed = feat_extract(speech_inputs, padding="longest", return_tensors="np")
input_values = processed.input_values

def _check_zero_mean_unit_variance(input_vector):
self.assertTrue(np.abs(np.mean(input_vector)) < 1e-3)
self.assertTrue(np.abs(np.var(input_vector) - 1) < 1e-3)
paddings = ["longest", "max_length", "do_not_pad"]
max_lengths = [None, 1600, None]
for max_length, padding in zip(max_lengths, paddings):
processed = feat_extract(speech_inputs, padding=padding, max_length=max_length, return_tensors="np")
input_values = processed.input_values

_check_zero_mean_unit_variance(input_values[0, :800])
_check_zero_mean_unit_variance(input_values[1, :1000])
_check_zero_mean_unit_variance(input_values[2])
def _check_zero_mean_unit_variance(input_vector):
self.assertTrue(np.abs(np.mean(input_vector)) < 1e-3)
self.assertTrue(np.abs(np.var(input_vector) - 1) < 1e-3)

_check_zero_mean_unit_variance(input_values[0][:800])
_check_zero_mean_unit_variance(input_values[1][:1000])
_check_zero_mean_unit_variance(input_values[2][:1200])

def test_zero_mean_unit_variance_normalization(self):
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
lengths = range(800, 1400, 200)
speech_inputs = [floats_list((1, x))[0] for x in lengths]

paddings = ["longest", "max_length", "do_not_pad"]
max_lengths = [None, 1600, None]

for max_length, padding in zip(max_lengths, paddings):
processed = feat_extract(speech_inputs, max_length=max_length, padding=padding)
input_values = processed.input_values

def _check_zero_mean_unit_variance(input_vector):
self.assertTrue(np.abs(np.mean(input_vector)) < 1e-3)
self.assertTrue(np.abs(np.var(input_vector) - 1) < 1e-3)

_check_zero_mean_unit_variance(input_values[0][:800])
_check_zero_mean_unit_variance(input_values[1][:1000])
_check_zero_mean_unit_variance(input_values[2][:1200])

def test_zero_mean_unit_variance_normalization_trunc(self):
def test_zero_mean_unit_variance_normalization_trunc_np(self):
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
processed = feat_extract(
Expand Down

0 comments on commit 60eb416

Please sign in to comment.