diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index fc5aa11148359..4b58c2cb06e20 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -1,3 +1,27 @@ paddle.fluid.optimizer.PipelineOptimizer (paddle.fluid.optimizer.PipelineOptimizer, ('document', '2e55a29dbeb874934f7a1a1af3a22b8c')) paddle.fluid.optimizer.PipelineOptimizer.__init__ (ArgSpec(args=['self', 'optimizer', 'num_microbatches', 'start_cpu_core_id'], varargs=None, keywords=None, defaults=(1, 0)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.PipelineOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) +paddle.audio.features (ArgSpec(), ('document', 'd41d8cd98f00b204e9800998ecf8427e')) +paddle.audio.features.layers.LogMelSpectrogram (ArgSpec(), ('document', 'c38b53606aa89215c4f00d3833e158b8')) +paddle.audio.features.layers.LogMelSpectrogram.forward (ArgSpec(args=['self', 'x'], varargs=None, varkw=None, defaults=None, kwonlyargs=[], kwonlydefaults=None, annotations={'return': , 'x': }), ('document', '6c14f6f78dc697a6981cf90412e2f1ea')) +paddle.audio.features.layers.LogMelSpectrogram.load_dict (ArgSpec(args=[], varargs='args', varkw='kwargs', defaults=None, kwonlyargs=[], kwonlydefaults=None, annotations={}), ('document', '01221a60445ee437f439a8cbe293f759')) +paddle.audio.features.layers.LogMelSpectrogram.state_dict (ArgSpec(args=['self', 'destination', 'include_sublayers', 'structured_name_prefix', 'use_hook'], varargs=None, varkw=None, defaults=(None, True, '', True), kwonlyargs=[], kwonlydefaults=None, annotations={}), ('document', '0c01cb0c12220c9426ae49549b145b0b')) +paddle.audio.features.layers.MFCC (ArgSpec(), ('document', 'bcbe6499830d9228a4f746ddd63b6c0f')) +paddle.audio.features.layers.MFCC.forward (ArgSpec(args=['self', 'x'], varargs=None, varkw=None, defaults=None, kwonlyargs=[], kwonlydefaults=None, annotations={'return': , 'x': }), ('document', 'd86bcaa345f26851089bfdb3efecd9e7')) +paddle.audio.features.layers.MelSpectrogram (ArgSpec(), ('document', 'adf4012310984568ae9da6170aa89f91')) +paddle.audio.features.layers.MelSpectrogram.forward (ArgSpec(args=['self', 'x'], varargs=None, varkw=None, defaults=None, kwonlyargs=[], kwonlydefaults=None, annotations={'return': , 'x': }), ('document', '458e9d454c8773091567c6b400f48cf5')) +paddle.audio.features.layers.Spectrogram (ArgSpec(), ('document', '83811af6da032099bf147e3e01a458e1')) +paddle.audio.features.layers.Spectrogram.forward (ArgSpec(args=['self', 'x'], varargs=None, varkw=None, defaults=None, kwonlyargs=[], kwonlydefaults=None, annotations={'return': , 'x': }), ('document', 'ab11e318fca1410f743b5432394dea35')) +paddle.audio.functional (ArgSpec(), ('document', 'd41d8cd98f00b204e9800998ecf8427e')) +paddle.audio.functional.functional.compute_fbank_matrix (ArgSpec(args=['sr', 'n_fft', 'n_mels', 'f_min', 'f_max', 'htk', 'norm', 'dtype'], varargs=None, varkw=None, defaults=(64, 0.0, None, False, 'slaney', 'float32'), kwonlyargs=[], kwonlydefaults=None, annotations={'return': , 'sr': , 'n_fft': , 'n_mels': , 'f_min': , 'f_max': typing.Union[float, NoneType], 'htk': , 'norm': typing.Union[str, float], 'dtype': }), ('document', '3c5411caa6baedb68860b09c81e0147c')) +paddle.audio.functional.functional.create_dct (ArgSpec(args=['n_mfcc', 'n_mels', 'norm', 'dtype'], varargs=None, varkw=None, defaults=('ortho', 'float32'), kwonlyargs=[], kwonlydefaults=None, annotations={'return': , 'n_mfcc': , 'n_mels': , 'norm': typing.Union[str, NoneType], 'dtype': }), ('document', 'c9c57550671f9725b053769411d2f65a')) +paddle.audio.functional.functional.fft_frequencies (ArgSpec(args=['sr', 'n_fft', 'dtype'], varargs=None, varkw=None, defaults=('float32',), kwonlyargs=[], kwonlydefaults=None, annotations={'return': , 'sr': , 'n_fft': , 'dtype': }), ('document', '057b990e79c9c780622407267c0a43c6')) +paddle.audio.functional.functional.hz_to_mel (ArgSpec(args=['freq', 'htk'], varargs=None, varkw=None, defaults=(False,), kwonlyargs=[], kwonlydefaults=None, annotations={'return': typing.Union[paddle.Tensor, float], 'freq': typing.Union[paddle.Tensor, float], 'htk': }), ('document', '7ca01521dd0bf26cd3f72c67f7168dc4')) +paddle.audio.functional.functional.mel_frequencies (ArgSpec(args=['n_mels', 'f_min', 'f_max', 'htk', 'dtype'], varargs=None, varkw=None, defaults=(64, 0.0, 11025.0, False, 'float32'), kwonlyargs=[], kwonlydefaults=None, annotations={'return': , 'n_mels': , 'f_min': , 'f_max': , 'htk': , 'dtype': }), ('document', '2af3cf997ed1274214ec240b2b59a98d')) +paddle.audio.functional.functional.mel_to_hz (ArgSpec(args=['mel', 'htk'], varargs=None, varkw=None, defaults=(False,), kwonlyargs=[], kwonlydefaults=None, annotations={'return': typing.Union[float, paddle.Tensor], 'mel': typing.Union[float, paddle.Tensor], 'htk': }), ('document', 'e93b432d382f98c60d7c7599489e7072')) +paddle.audio.functional.functional.power_to_db (ArgSpec(args=['spect', 'ref_value', 'amin', 'top_db'], varargs=None, varkw=None, defaults=(1.0, 1e-10, 80.0), kwonlyargs=[], kwonlydefaults=None, annotations={'return': , 'spect': , 'ref_value': , 'amin': , 'top_db': typing.Union[float, NoneType]}), ('document', '28bbb1973e8399e856bfaea0415cecb9')) +paddle.audio.functional.window.get_window (ArgSpec(args=['window', 'win_length', 'fftbins', 'dtype'], varargs=None, varkw=None, defaults=(True, 'float64'), kwonlyargs=[], kwonlydefaults=None, annotations={'return': , 'window': typing.Union[str, typing.Tuple[str, float]], 'win_length': , 'fftbins': , 'dtype': }), ('document', '2418d63da10c0cd5da9ecf0a88ddf783')) +paddle.audio.utils (ArgSpec(), ('document', 'd41d8cd98f00b204e9800998ecf8427e')) +paddle.audio.utils.error.ParameterError (ArgSpec(), ('document', 'e12783df4d137af121ebadceb389bf7a')) +paddle.audio.utils.error.ParameterError.args (ArgSpec(), ('document', 'd41d8cd98f00b204e9800998ecf8427e')) +paddle.audio.utils.error.ParameterError.with_traceback (ArgSpec(), ('document', '3f2d1353ad5034ed0f4628f2c9f066cc')) diff --git a/python/paddle/audio/__init__.py b/python/paddle/audio/__init__.py new file mode 100644 index 0000000000000..e76a80300f5e6 --- /dev/null +++ b/python/paddle/audio/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import features +from . import functional +from . import utils + +__all__ = ["functional", "features", "utils"] diff --git a/python/paddle/audio/features/__init__.py b/python/paddle/audio/features/__init__.py new file mode 100644 index 0000000000000..e6b005e501988 --- /dev/null +++ b/python/paddle/audio/features/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .layers import LogMelSpectrogram +from .layers import MelSpectrogram +from .layers import MFCC +from .layers import Spectrogram diff --git a/python/paddle/audio/features/layers.py b/python/paddle/audio/features/layers.py new file mode 100644 index 0000000000000..cddb42635d6b0 --- /dev/null +++ b/python/paddle/audio/features/layers.py @@ -0,0 +1,323 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial +from typing import Optional +from typing import Union + +import paddle +import paddle.nn as nn +from paddle import Tensor + +from ..functional import compute_fbank_matrix +from ..functional import create_dct +from ..functional import power_to_db +from ..functional.window import get_window + +__all__ = [ + 'Spectrogram', + 'MelSpectrogram', + 'LogMelSpectrogram', + 'MFCC', +] + + +class Spectrogram(nn.Layer): + """Compute spectrogram of given signals, typically audio waveforms. + The spectorgram is defined as the complex norm of the short-time Fourier transformation. + + Args: + n_fft (int, optional): The number of frequency components of the discrete Fourier transform. Defaults to 512. + hop_length (Optional[int], optional): The hop length of the short time FFT. If `None`, it is set to `win_length//4`. Defaults to None. + win_length (Optional[int], optional): The window length of the short time FFT. If `None`, it is set to same as `n_fft`. Defaults to None. + window (str, optional): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'. Defaults to 'hann'. + power (float, optional): Exponent for the magnitude spectrogram. Defaults to 2.0. + center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\\_length` at the center of `t`-th frame. Defaults to True. + pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to 'reflect'. + dtype (str, optional): Data type of input and window. Defaults to 'float32'. + """ + + def __init__(self, + n_fft: int = 512, + hop_length: Optional[int] = 512, + win_length: Optional[int] = None, + window: str = 'hann', + power: float = 1.0, + center: bool = True, + pad_mode: str = 'reflect', + dtype: str = 'float32') -> None: + super(Spectrogram, self).__init__() + + assert power > 0, 'Power of spectrogram must be > 0.' + self.power = power + + if win_length is None: + win_length = n_fft + + self.fft_window = get_window(window, + win_length, + fftbins=True, + dtype=dtype) + self._stft = partial(paddle.signal.stft, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=self.fft_window, + center=center, + pad_mode=pad_mode) + self.register_buffer('fft_window', self.fft_window) + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x (Tensor): Tensor of waveforms with shape `(N, T)` + + Returns: + Tensor: Spectrograms with shape `(N, n_fft//2 + 1, num_frames)`. + """ + stft = self._stft(x) + spectrogram = paddle.pow(paddle.abs(stft), self.power) + return spectrogram + + +class MelSpectrogram(nn.Layer): + """Compute the melspectrogram of given signals, typically audio waveforms. It is computed by multiplying spectrogram with Mel filter bank matrix. + + Args: + sr (int, optional): Sample rate. Defaults to 22050. + n_fft (int, optional): The number of frequency components of the discrete Fourier transform. Defaults to 512. + hop_length (Optional[int], optional): The hop length of the short time FFT. If `None`, it is set to `win_length//4`. Defaults to None. + win_length (Optional[int], optional): The window length of the short time FFT. If `None`, it is set to same as `n_fft`. Defaults to None. + window (str, optional): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'. Defaults to 'hann'. + power (float, optional): Exponent for the magnitude spectrogram. Defaults to 2.0. + center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\\_length` at the center of `t`-th frame. Defaults to True. + pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to 'reflect'. + n_mels (int, optional): Number of mel bins. Defaults to 64. + f_min (float, optional): Minimum frequency in Hz. Defaults to 50.0. + f_max (Optional[float], optional): Maximum frequency in Hz. Defaults to None. + htk (bool, optional): Use HTK formula in computing fbank matrix. Defaults to False. + norm (Union[str, float], optional): Type of normalization in computing fbank matrix. Slaney-style is used by default. You can specify norm=1.0/2.0 to use customized p-norm normalization. Defaults to 'slaney'. + dtype (str, optional): Data type of input and window. Defaults to 'float32'. + """ + + def __init__(self, + sr: int = 22050, + n_fft: int = 2048, + hop_length: Optional[int] = 512, + win_length: Optional[int] = None, + window: str = 'hann', + power: float = 2.0, + center: bool = True, + pad_mode: str = 'reflect', + n_mels: int = 64, + f_min: float = 50.0, + f_max: Optional[float] = None, + htk: bool = False, + norm: Union[str, float] = 'slaney', + dtype: str = 'float32') -> None: + super(MelSpectrogram, self).__init__() + + self._spectrogram = Spectrogram(n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + power=power, + center=center, + pad_mode=pad_mode, + dtype=dtype) + self.n_mels = n_mels + self.f_min = f_min + self.f_max = f_max + self.htk = htk + self.norm = norm + if f_max is None: + f_max = sr // 2 + self.fbank_matrix = compute_fbank_matrix(sr=sr, + n_fft=n_fft, + n_mels=n_mels, + f_min=f_min, + f_max=f_max, + htk=htk, + norm=norm, + dtype=dtype) + self.register_buffer('fbank_matrix', self.fbank_matrix) + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x (Tensor): Tensor of waveforms with shape `(N, T)` + + Returns: + Tensor: Mel spectrograms with shape `(N, n_mels, num_frames)`. + """ + spect_feature = self._spectrogram(x) + mel_feature = paddle.matmul(self.fbank_matrix, spect_feature) + return mel_feature + + +class LogMelSpectrogram(nn.Layer): + """Compute log-mel-spectrogram feature of given signals, typically audio waveforms. + + Args: + sr (int, optional): Sample rate. Defaults to 22050. + n_fft (int, optional): The number of frequency components of the discrete Fourier transform. Defaults to 512. + hop_length (Optional[int], optional): The hop length of the short time FFT. If `None`, it is set to `win_length//4`. Defaults to None. + win_length (Optional[int], optional): The window length of the short time FFT. If `None`, it is set to same as `n_fft`. Defaults to None. + window (str, optional): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'. Defaults to 'hann'. + power (float, optional): Exponent for the magnitude spectrogram. Defaults to 2.0. + center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\\_length` at the center of `t`-th frame. Defaults to True. + pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to 'reflect'. + n_mels (int, optional): Number of mel bins. Defaults to 64. + f_min (float, optional): Minimum frequency in Hz. Defaults to 50.0. + f_max (Optional[float], optional): Maximum frequency in Hz. Defaults to None. + htk (bool, optional): Use HTK formula in computing fbank matrix. Defaults to False. + norm (Union[str, float], optional): Type of normalization in computing fbank matrix. Slaney-style is used by default. You can specify norm=1.0/2.0 to use customized p-norm normalization. Defaults to 'slaney'. + ref_value (float, optional): The reference value. If smaller than 1.0, the db level of the signal will be pulled up accordingly. Otherwise, the db level is pushed down. Defaults to 1.0. + amin (float, optional): The minimum value of input magnitude. Defaults to 1e-10. + top_db (Optional[float], optional): The maximum db value of spectrogram. Defaults to None. + dtype (str, optional): Data type of input and window. Defaults to 'float32'. + """ + + def __init__(self, + sr: int = 22050, + n_fft: int = 512, + hop_length: Optional[int] = None, + win_length: Optional[int] = None, + window: str = 'hann', + power: float = 2.0, + center: bool = True, + pad_mode: str = 'reflect', + n_mels: int = 64, + f_min: float = 50.0, + f_max: Optional[float] = None, + htk: bool = False, + norm: Union[str, float] = 'slaney', + ref_value: float = 1.0, + amin: float = 1e-10, + top_db: Optional[float] = None, + dtype: str = 'float32') -> None: + super(LogMelSpectrogram, self).__init__() + + self._melspectrogram = MelSpectrogram(sr=sr, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + power=power, + center=center, + pad_mode=pad_mode, + n_mels=n_mels, + f_min=f_min, + f_max=f_max, + htk=htk, + norm=norm, + dtype=dtype) + + self.ref_value = ref_value + self.amin = amin + self.top_db = top_db + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x (Tensor): Tensor of waveforms with shape `(N, T)` + + Returns: + Tensor: Log mel spectrograms with shape `(N, n_mels, num_frames)`. + """ + mel_feature = self._melspectrogram(x) + log_mel_feature = power_to_db(mel_feature, + ref_value=self.ref_value, + amin=self.amin, + top_db=self.top_db) + return log_mel_feature + + +class MFCC(nn.Layer): + """Compute mel frequency cepstral coefficients(MFCCs) feature of given waveforms. + + Args: + sr (int, optional): Sample rate. Defaults to 22050. + n_mfcc (int, optional): [description]. Defaults to 40. + n_fft (int, optional): The number of frequency components of the discrete Fourier transform. Defaults to 512. + hop_length (Optional[int], optional): The hop length of the short time FFT. If `None`, it is set to `win_length//4`. Defaults to None. + win_length (Optional[int], optional): The window length of the short time FFT. If `None`, it is set to same as `n_fft`. Defaults to None. + window (str, optional): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'. Defaults to 'hann'. + power (float, optional): Exponent for the magnitude spectrogram. Defaults to 2.0. + center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\\_length` at the center of `t`-th frame. Defaults to True. + pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to 'reflect'. + n_mels (int, optional): Number of mel bins. Defaults to 64. + f_min (float, optional): Minimum frequency in Hz. Defaults to 50.0. + f_max (Optional[float], optional): Maximum frequency in Hz. Defaults to None. + htk (bool, optional): Use HTK formula in computing fbank matrix. Defaults to False. + norm (Union[str, float], optional): Type of normalization in computing fbank matrix. Slaney-style is used by default. You can specify norm=1.0/2.0 to use customized p-norm normalization. Defaults to 'slaney'. + ref_value (float, optional): The reference value. If smaller than 1.0, the db level of the signal will be pulled up accordingly. Otherwise, the db level is pushed down. Defaults to 1.0. + amin (float, optional): The minimum value of input magnitude. Defaults to 1e-10. + top_db (Optional[float], optional): The maximum db value of spectrogram. Defaults to None. + dtype (str, optional): Data type of input and window. Defaults to 'float32'. + """ + + def __init__(self, + sr: int = 22050, + n_mfcc: int = 40, + n_fft: int = 512, + hop_length: Optional[int] = None, + win_length: Optional[int] = None, + window: str = 'hann', + power: float = 2.0, + center: bool = True, + pad_mode: str = 'reflect', + n_mels: int = 64, + f_min: float = 50.0, + f_max: Optional[float] = None, + htk: bool = False, + norm: Union[str, float] = 'slaney', + ref_value: float = 1.0, + amin: float = 1e-10, + top_db: Optional[float] = None, + dtype: str = 'float32') -> None: + super(MFCC, self).__init__() + assert n_mfcc <= n_mels, 'n_mfcc cannot be larger than n_mels: %d vs %d' % ( + n_mfcc, n_mels) + self._log_melspectrogram = LogMelSpectrogram(sr=sr, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + power=power, + center=center, + pad_mode=pad_mode, + n_mels=n_mels, + f_min=f_min, + f_max=f_max, + htk=htk, + norm=norm, + ref_value=ref_value, + amin=amin, + top_db=top_db, + dtype=dtype) + self.dct_matrix = create_dct(n_mfcc=n_mfcc, n_mels=n_mels, dtype=dtype) + self.register_buffer('dct_matrix', self.dct_matrix) + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x (Tensor): Tensor of waveforms with shape `(N, T)` + + Returns: + Tensor: Mel frequency cepstral coefficients with shape `(N, n_mfcc, num_frames)`. + """ + log_mel_feature = self._log_melspectrogram(x) + mfcc = paddle.matmul(log_mel_feature.transpose( + (0, 2, 1)), self.dct_matrix).transpose((0, 2, 1)) # (B, n_mels, L) + return mfcc diff --git a/python/paddle/audio/functional/__init__.py b/python/paddle/audio/functional/__init__.py new file mode 100644 index 0000000000000..0216172db1400 --- /dev/null +++ b/python/paddle/audio/functional/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .functional import compute_fbank_matrix +from .functional import create_dct +from .functional import fft_frequencies +from .functional import hz_to_mel +from .functional import mel_frequencies +from .functional import mel_to_hz +from .functional import power_to_db +from .window import get_window diff --git a/python/paddle/audio/functional/functional.py b/python/paddle/audio/functional/functional.py new file mode 100644 index 0000000000000..26c095a6e9ae2 --- /dev/null +++ b/python/paddle/audio/functional/functional.py @@ -0,0 +1,268 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from librosa(https://github.com/librosa/librosa) +import math +from typing import Optional +from typing import Union + +import paddle +from paddle import Tensor + +__all__ = [ + 'hz_to_mel', + 'mel_to_hz', + 'mel_frequencies', + 'fft_frequencies', + 'compute_fbank_matrix', + 'power_to_db', + 'create_dct', +] + + +def hz_to_mel(freq: Union[Tensor, float], + htk: bool = False) -> Union[Tensor, float]: + """Convert Hz to Mels. + + Args: + freq (Union[Tensor, float]): The input tensor with arbitrary shape. + htk (bool, optional): Use htk scaling. Defaults to False. + + Returns: + Union[Tensor, float]: Frequency in mels. + """ + + if htk: + if isinstance(freq, Tensor): + return 2595.0 * paddle.log10(1.0 + freq / 700.0) + else: + return 2595.0 * math.log10(1.0 + freq / 700.0) + + # Fill in the linear part + f_min = 0.0 + f_sp = 200.0 / 3 + + mels = (freq - f_min) / f_sp + + # Fill in the log-scale part + + min_log_hz = 1000.0 # beginning of log region (Hz) + min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels) + logstep = math.log(6.4) / 27.0 # step size for log region + + if isinstance(freq, Tensor): + target = min_log_mel + paddle.log( + freq / min_log_hz + 1e-10) / logstep # prevent nan with 1e-10 + mask = (freq > min_log_hz).astype(freq.dtype) + mels = target * mask + mels * ( + 1 - mask) # will replace by masked_fill OP in future + else: + if freq >= min_log_hz: + mels = min_log_mel + math.log(freq / min_log_hz + 1e-10) / logstep + + return mels + + +def mel_to_hz(mel: Union[float, Tensor], + htk: bool = False) -> Union[float, Tensor]: + """Convert mel bin numbers to frequencies. + + Args: + mel (Union[float, Tensor]): The mel frequency represented as a tensor with arbitrary shape. + htk (bool, optional): Use htk scaling. Defaults to False. + + Returns: + Union[float, Tensor]: Frequencies in Hz. + """ + if htk: + return 700.0 * (10.0**(mel / 2595.0) - 1.0) + + f_min = 0.0 + f_sp = 200.0 / 3 + freqs = f_min + f_sp * mel + # And now the nonlinear scale + min_log_hz = 1000.0 # beginning of log region (Hz) + min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels) + logstep = math.log(6.4) / 27.0 # step size for log region + if isinstance(mel, Tensor): + target = min_log_hz * paddle.exp(logstep * (mel - min_log_mel)) + mask = (mel > min_log_mel).astype(mel.dtype) + freqs = target * mask + freqs * ( + 1 - mask) # will replace by masked_fill OP in future + else: + if mel >= min_log_mel: + freqs = min_log_hz * math.exp(logstep * (mel - min_log_mel)) + return freqs + + +def mel_frequencies(n_mels: int = 64, + f_min: float = 0.0, + f_max: float = 11025.0, + htk: bool = False, + dtype: str = 'float32') -> Tensor: + """Compute mel frequencies. + + Args: + n_mels (int, optional): Number of mel bins. Defaults to 64. + f_min (float, optional): Minimum frequency in Hz. Defaults to 0.0. + fmax (float, optional): Maximum frequency in Hz. Defaults to 11025.0. + htk (bool, optional): Use htk scaling. Defaults to False. + dtype (str, optional): The data type of the return frequencies. Defaults to 'float32'. + + Returns: + Tensor: Tensor of n_mels frequencies in Hz with shape `(n_mels,)`. + """ + # 'Center freqs' of mel bands - uniformly spaced between limits + min_mel = hz_to_mel(f_min, htk=htk) + max_mel = hz_to_mel(f_max, htk=htk) + mels = paddle.linspace(min_mel, max_mel, n_mels, dtype=dtype) + freqs = mel_to_hz(mels, htk=htk) + return freqs + + +def fft_frequencies(sr: int, n_fft: int, dtype: str = 'float32') -> Tensor: + """Compute fourier frequencies. + + Args: + sr (int): Sample rate. + n_fft (int): Number of fft bins. + dtype (str, optional): The data type of the return frequencies. Defaults to 'float32'. + + Returns: + Tensor: FFT frequencies in Hz with shape `(n_fft//2 + 1,)`. + """ + return paddle.linspace(0, float(sr) / 2, int(1 + n_fft // 2), dtype=dtype) + + +def compute_fbank_matrix(sr: int, + n_fft: int, + n_mels: int = 64, + f_min: float = 0.0, + f_max: Optional[float] = None, + htk: bool = False, + norm: Union[str, float] = 'slaney', + dtype: str = 'float32') -> Tensor: + """Compute fbank matrix. + + Args: + sr (int): Sample rate. + n_fft (int): Number of fft bins. + n_mels (int, optional): Number of mel bins. Defaults to 64. + f_min (float, optional): Minimum frequency in Hz. Defaults to 0.0. + f_max (Optional[float], optional): Maximum frequency in Hz. Defaults to None. + htk (bool, optional): Use htk scaling. Defaults to False. + norm (Union[str, float], optional): Type of normalization. Defaults to 'slaney'. + dtype (str, optional): The data type of the return matrix. Defaults to 'float32'. + + Returns: + Tensor: Mel transform matrix with shape `(n_mels, n_fft//2 + 1)`. + """ + + if f_max is None: + f_max = float(sr) / 2 + + # Initialize the weights + weights = paddle.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype) + + # Center freqs of each FFT bin + fftfreqs = fft_frequencies(sr=sr, n_fft=n_fft, dtype=dtype) + + # 'Center freqs' of mel bands - uniformly spaced between limits + mel_f = mel_frequencies(n_mels + 2, + f_min=f_min, + f_max=f_max, + htk=htk, + dtype=dtype) + + fdiff = mel_f[1:] - mel_f[:-1] #np.diff(mel_f) + ramps = mel_f.unsqueeze(1) - fftfreqs.unsqueeze(0) + #ramps = np.subtract.outer(mel_f, fftfreqs) + + for i in range(n_mels): + # lower and upper slopes for all bins + lower = -ramps[i] / fdiff[i] + upper = ramps[i + 2] / fdiff[i + 1] + + # .. then intersect them with each other and zero + weights[i] = paddle.maximum(paddle.zeros_like(lower), + paddle.minimum(lower, upper)) + + # Slaney-style mel is scaled to be approx constant energy per channel + if norm == 'slaney': + enorm = 2.0 / (mel_f[2:n_mels + 2] - mel_f[:n_mels]) + weights *= enorm.unsqueeze(1) + elif isinstance(norm, int) or isinstance(norm, float): + weights = paddle.nn.functional.normalize(weights, p=norm, axis=-1) + + return weights + + +def power_to_db(spect: Tensor, + ref_value: float = 1.0, + amin: float = 1e-10, + top_db: Optional[float] = 80.0) -> Tensor: + """Convert a power spectrogram (amplitude squared) to decibel (dB) units. The function computes the scaling `10 * log10(x / ref)` in a numerically stable way. + + Args: + spect (Tensor): STFT power spectrogram. + ref_value (float, optional): The reference value. If smaller than 1.0, the db level of the signal will be pulled up accordingly. Otherwise, the db level is pushed down. Defaults to 1.0. + amin (float, optional): Minimum threshold. Defaults to 1e-10. + top_db (Optional[float], optional): Threshold the output at `top_db` below the peak. Defaults to None. + + Returns: + Tensor: Power spectrogram in db scale. + """ + if amin <= 0: + raise Exception("amin must be strictly positive") + + if ref_value <= 0: + raise Exception("ref_value must be strictly positive") + + ones = paddle.ones_like(spect) + log_spec = 10.0 * paddle.log10(paddle.maximum(ones * amin, spect)) + log_spec -= 10.0 * math.log10(max(ref_value, amin)) + + if top_db is not None: + if top_db < 0: + raise Exception("top_db must be non-negative") + log_spec = paddle.maximum(log_spec, ones * (log_spec.max() - top_db)) + + return log_spec + + +def create_dct(n_mfcc: int, + n_mels: int, + norm: Optional[str] = 'ortho', + dtype: str = 'float32') -> Tensor: + """Create a discrete cosine transform(DCT) matrix. + + Args: + n_mfcc (int): Number of mel frequency cepstral coefficients. + n_mels (int): Number of mel filterbanks. + norm (Optional[str], optional): Normalizaiton type. Defaults to 'ortho'. + dtype (str, optional): The data type of the return matrix. Defaults to 'float32'. + + Returns: + Tensor: The DCT matrix with shape `(n_mels, n_mfcc)`. + """ + n = paddle.arange(n_mels, dtype=dtype) + k = paddle.arange(n_mfcc, dtype=dtype).unsqueeze(1) + dct = paddle.cos(math.pi / float(n_mels) * (n + 0.5) * + k) # size (n_mfcc, n_mels) + if norm is None: + dct *= 2.0 + else: + assert norm == "ortho" + dct[0] *= 1.0 / math.sqrt(2.0) + dct *= math.sqrt(2.0 / float(n_mels)) + return dct.T diff --git a/python/paddle/audio/functional/window.py b/python/paddle/audio/functional/window.py new file mode 100644 index 0000000000000..a4692dbc962df --- /dev/null +++ b/python/paddle/audio/functional/window.py @@ -0,0 +1,351 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +import math +from typing import List +from typing import Tuple +from typing import Union + +import paddle +from paddle import Tensor + +__all__ = [ + 'get_window', +] + + +def _cat(x: List[Tensor], data_type: str) -> Tensor: + l = [paddle.to_tensor(_, data_type) for _ in x] + return paddle.concat(l) + + +def _acosh(x: Union[Tensor, float]) -> Tensor: + if isinstance(x, float): + return math.log(x + math.sqrt(x**2 - 1)) + return paddle.log(x + paddle.sqrt(paddle.square(x) - 1)) + + +def _extend(M: int, sym: bool) -> bool: + """Extend window by 1 sample if needed for DFT-even symmetry. """ + if not sym: + return M + 1, True + else: + return M, False + + +def _len_guards(M: int) -> bool: + """Handle small or incorrect window lengths. """ + if int(M) != M or M < 0: + raise ValueError('Window length M must be a non-negative integer') + + return M <= 1 + + +def _truncate(w: Tensor, needed: bool) -> Tensor: + """Truncate window by 1 sample if needed for DFT-even symmetry. """ + if needed: + return w[:-1] + else: + return w + + +def _general_gaussian(M: int, + p, + sig, + sym: bool = True, + dtype: str = 'float64') -> Tensor: + """Compute a window with a generalized Gaussian shape. + This function is consistent with scipy.signal.windows.general_gaussian(). + """ + if _len_guards(M): + return paddle.ones((M, ), dtype=dtype) + M, needs_trunc = _extend(M, sym) + + n = paddle.arange(0, M, dtype=dtype) - (M - 1.0) / 2.0 + w = paddle.exp(-0.5 * paddle.abs(n / sig)**(2 * p)) + + return _truncate(w, needs_trunc) + + +def _general_cosine(M: int, + a: float, + sym: bool = True, + dtype: str = 'float64') -> Tensor: + """Compute a generic weighted sum of cosine terms window. + This function is consistent with scipy.signal.windows.general_cosine(). + """ + if _len_guards(M): + return paddle.ones((M, ), dtype=dtype) + M, needs_trunc = _extend(M, sym) + fac = paddle.linspace(-math.pi, math.pi, M, dtype=dtype) + w = paddle.zeros((M, ), dtype=dtype) + for k in range(len(a)): + w += a[k] * paddle.cos(k * fac) + return _truncate(w, needs_trunc) + + +def _general_hamming(M: int, + alpha: float, + sym: bool = True, + dtype: str = 'float64') -> Tensor: + """Compute a generalized Hamming window. + This function is consistent with scipy.signal.windows.general_hamming() + """ + return _general_cosine(M, [alpha, 1. - alpha], sym, dtype=dtype) + + +def _taylor(M: int, + nbar=4, + sll=30, + norm=True, + sym: bool = True, + dtype: str = 'float64') -> Tensor: + """Compute a Taylor window. + The Taylor window taper function approximates the Dolph-Chebyshev window's + constant sidelobe level for a parameterized number of near-in sidelobes. + """ + if _len_guards(M): + return paddle.ones((M, ), dtype=dtype) + M, needs_trunc = _extend(M, sym) + # Original text uses a negative sidelobe level parameter and then negates + # it in the calculation of B. To keep consistent with other methods we + # assume the sidelobe level parameter to be positive. + B = 10**(sll / 20) + A = _acosh(B) / math.pi + s2 = nbar**2 / (A**2 + (nbar - 0.5)**2) + ma = paddle.arange(1, nbar, dtype=dtype) + + Fm = paddle.empty((nbar - 1, ), dtype=dtype) + signs = paddle.empty_like(ma) + signs[::2] = 1 + signs[1::2] = -1 + m2 = ma * ma + for mi in range(len(ma)): + numer = signs[mi] * paddle.prod(1 - m2[mi] / s2 / (A**2 + + (ma - 0.5)**2)) + if mi == 0: + denom = 2 * paddle.prod(1 - m2[mi] / m2[mi + 1:]) + elif mi == len(ma) - 1: + denom = 2 * paddle.prod(1 - m2[mi] / m2[:mi]) + else: + denom = 2 * paddle.prod(1 - m2[mi] / m2[:mi]) * paddle.prod( + 1 - m2[mi] / m2[mi + 1:]) + + Fm[mi] = numer / denom + + def W(n): + return 1 + 2 * paddle.matmul( + Fm.unsqueeze(0), + paddle.cos(2 * math.pi * ma.unsqueeze(1) * (n - M / 2. + 0.5) / M)) + + w = W(paddle.arange(0, M, dtype=dtype)) + + # normalize (Note that this is not described in the original text [1]) + if norm: + scale = 1.0 / W((M - 1) / 2) + w *= scale + w = w.squeeze() + return _truncate(w, needs_trunc) + + +def _hamming(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor: + """Compute a Hamming window. + The Hamming window is a taper formed by using a raised cosine with + non-zero endpoints, optimized to minimize the nearest side lobe. + """ + return _general_hamming(M, 0.54, sym, dtype=dtype) + + +def _hann(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor: + """Compute a Hann window. + The Hann window is a taper formed by using a raised cosine or sine-squared + with ends that touch zero. + """ + return _general_hamming(M, 0.5, sym, dtype=dtype) + + +def _tukey(M: int, + alpha=0.5, + sym: bool = True, + dtype: str = 'float64') -> Tensor: + """Compute a Tukey window. + The Tukey window is also known as a tapered cosine window. + """ + if _len_guards(M): + return paddle.ones((M, ), dtype=dtype) + + if alpha <= 0: + return paddle.ones((M, ), dtype=dtype) + elif alpha >= 1.0: + return hann(M, sym=sym) + + M, needs_trunc = _extend(M, sym) + + n = paddle.arange(0, M, dtype=dtype) + width = int(alpha * (M - 1) / 2.0) + n1 = n[0:width + 1] + n2 = n[width + 1:M - width - 1] + n3 = n[M - width - 1:] + + w1 = 0.5 * (1 + paddle.cos(math.pi * (-1 + 2.0 * n1 / alpha / (M - 1)))) + w2 = paddle.ones(n2.shape, dtype=dtype) + w3 = 0.5 * (1 + paddle.cos(math.pi * (-2.0 / alpha + 1 + 2.0 * n3 / alpha / + (M - 1)))) + w = paddle.concat([w1, w2, w3]) + + return _truncate(w, needs_trunc) + + +def _kaiser(M: int, + beta: float, + sym: bool = True, + dtype: str = 'float64') -> Tensor: + """Compute a Kaiser window. + The Kaiser window is a taper formed by using a Bessel function. + """ + raise NotImplementedError() + + +def _gaussian(M: int, + std: float, + sym: bool = True, + dtype: str = 'float64') -> Tensor: + """Compute a Gaussian window. + The Gaussian widows has a Gaussian shape defined by the standard deviation(std). + """ + if _len_guards(M): + return paddle.ones((M, ), dtype=dtype) + M, needs_trunc = _extend(M, sym) + + n = paddle.arange(0, M, dtype=dtype) - (M - 1.0) / 2.0 + sig2 = 2 * std * std + w = paddle.exp(-n**2 / sig2) + + return _truncate(w, needs_trunc) + + +def _exponential(M: int, + center=None, + tau=1., + sym: bool = True, + dtype: str = 'float64') -> Tensor: + """Compute an exponential (or Poisson) window. """ + if sym and center is not None: + raise ValueError("If sym==True, center must be None.") + if _len_guards(M): + return paddle.ones((M, ), dtype=dtype) + M, needs_trunc = _extend(M, sym) + + if center is None: + center = (M - 1) / 2 + + n = paddle.arange(0, M, dtype=dtype) + w = paddle.exp(-paddle.abs(n - center) / tau) + + return _truncate(w, needs_trunc) + + +def _triang(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor: + """Compute a triangular window. + """ + if _len_guards(M): + return paddle.ones((M, ), dtype=dtype) + M, needs_trunc = _extend(M, sym) + + n = paddle.arange(1, (M + 1) // 2 + 1, dtype=dtype) + if M % 2 == 0: + w = (2 * n - 1.0) / M + w = paddle.concat([w, w[::-1]]) + else: + w = 2 * n / (M + 1.0) + w = paddle.concat([w, w[-2::-1]]) + + return _truncate(w, needs_trunc) + + +def _bohman(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor: + """Compute a Bohman window. + The Bohman window is the autocorrelation of a cosine window. + """ + if _len_guards(M): + return paddle.ones((M, ), dtype=dtype) + M, needs_trunc = _extend(M, sym) + + fac = paddle.abs(paddle.linspace(-1, 1, M, dtype=dtype)[1:-1]) + w = (1 - fac) * paddle.cos(math.pi * fac) + 1.0 / math.pi * paddle.sin( + math.pi * fac) + w = _cat([0, w, 0], dtype) + + return _truncate(w, needs_trunc) + + +def _blackman(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor: + """Compute a Blackman window. + The Blackman window is a taper formed by using the first three terms of + a summation of cosines. It was designed to have close to the minimal + leakage possible. It is close to optimal, only slightly worse than a + Kaiser window. + """ + return _general_cosine(M, [0.42, 0.50, 0.08], sym, dtype=dtype) + + +def _cosine(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor: + """Compute a window with a simple cosine shape. + """ + if _len_guards(M): + return paddle.ones((M, ), dtype=dtype) + M, needs_trunc = _extend(M, sym) + w = paddle.sin(math.pi / M * (paddle.arange(0, M, dtype=dtype) + .5)) + + return _truncate(w, needs_trunc) + + +def get_window(window: Union[str, Tuple[str, float]], + win_length: int, + fftbins: bool = True, + dtype: str = 'float64') -> Tensor: + """Return a window of a given length and type. + + Args: + window (Union[str, Tuple[str, float]]): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'. + win_length (int): Number of samples. + fftbins (bool, optional): If True, create a "periodic" window. Otherwise, create a "symmetric" window, for use in filter design. Defaults to True. + dtype (str, optional): The data type of the return window. Defaults to 'float64'. + + Returns: + Tensor: The window represented as a tensor. + """ + sym = not fftbins + + args = () + if isinstance(window, tuple): + winstr = window[0] + if len(window) > 1: + args = window[1:] + elif isinstance(window, str): + if window in ['gaussian', 'exponential']: + raise ValueError("The '" + window + "' window needs one or " + "more parameters -- pass a tuple.") + else: + winstr = window + else: + raise ValueError("%s as window type is not supported." % + str(type(window))) + + try: + winfunc = eval('_' + winstr) + except NameError as e: + raise ValueError("Unknown window type.") from e + + params = (win_length, ) + args + kwargs = {'sym': sym} + return winfunc(*params, dtype=dtype, **kwargs) diff --git a/python/paddle/audio/utils/__init__.py b/python/paddle/audio/utils/__init__.py new file mode 100644 index 0000000000000..b3502b76cce25 --- /dev/null +++ b/python/paddle/audio/utils/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .error import ParameterError diff --git a/python/paddle/audio/utils/error.py b/python/paddle/audio/utils/error.py new file mode 100644 index 0000000000000..ab239a24970ad --- /dev/null +++ b/python/paddle/audio/utils/error.py @@ -0,0 +1,20 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = ['ParameterError'] + + +class ParameterError(Exception): + """Exception class for Parameter checking""" + pass diff --git a/python/paddle/tests/test_audio_functions.py b/python/paddle/tests/test_audio_functions.py new file mode 100644 index 0000000000000..cc5c83e76dc39 --- /dev/null +++ b/python/paddle/tests/test_audio_functions.py @@ -0,0 +1,301 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import librosa +import numpy as np +import os +import paddle + +import paddle.audio +from scipy import signal +import itertools +from parameterized import parameterized + + +def parameterize(*params): + return parameterized.expand(list(itertools.product(*params))) + + +class TestAudioFuncitons(unittest.TestCase): + + def setUp(self): + self.initParmas() + + def initParmas(self): + + def get_wav_data(dtype: str, num_channels: int, num_frames: int): + dtype_ = getattr(paddle, dtype) + base = paddle.linspace(-1.0, 1.0, num_frames, dtype=dtype_) * 0.1 + data = base.tile([num_channels, 1]) + return data + + self.n_fft = 512 + self.hop_length = 128 + self.n_mels = 40 + self.n_mfcc = 20 + self.fmin = 0.0 + self.window_str = 'hann' + self.pad_mode = 'reflect' + self.top_db = 80.0 + self.duration = 0.5 + self.num_channels = 1 + self.sr = 16000 + self.dtype = "float32" + self.window_size = 1024 + waveform_tensor = get_wav_data(self.dtype, + self.num_channels, + num_frames=self.duration * self.sr) + self.waveform = waveform_tensor.numpy() + + @parameterize([1.0, 3.0, 9.0, 25.0], [True, False]) + def test_audio_function(self, val: float, htk_flag: bool): + mel_paddle = paddle.audio.functional.hz_to_mel(val, htk_flag) + mel_paddle_tensor = paddle.audio.functional.hz_to_mel( + paddle.to_tensor(val), htk_flag) + mel_librosa = librosa.hz_to_mel(val, htk_flag) + np.testing.assert_almost_equal(mel_paddle, mel_librosa, decimal=5) + np.testing.assert_almost_equal(mel_paddle_tensor.numpy(), + mel_librosa, + decimal=4) + + hz_paddle = paddle.audio.functional.mel_to_hz(val, htk_flag) + hz_paddle_tensor = paddle.audio.functional.mel_to_hz( + paddle.to_tensor(val), htk_flag) + hz_librosa = librosa.mel_to_hz(val, htk_flag) + np.testing.assert_almost_equal(hz_paddle, hz_librosa, decimal=4) + np.testing.assert_almost_equal(hz_paddle_tensor.numpy(), + hz_librosa, + decimal=4) + + decibel_paddle = paddle.audio.functional.power_to_db( + paddle.to_tensor(val)) + decibel_librosa = librosa.power_to_db(val) + np.testing.assert_almost_equal(decibel_paddle.numpy(), + decibel_paddle, + decimal=5) + + @parameterize([64, 128, 256], [0.0, 0.5, 1.0], [10000, 11025], + [False, True]) + def test_audio_function_mel(self, n_mels: int, f_min: float, f_max: float, + htk_flag: bool): + librosa_mel_freq = librosa.mel_frequencies(n_mels, f_min, f_max, + htk_flag) + paddle_mel_freq = paddle.audio.functional.mel_frequencies( + n_mels, f_min, f_max, htk_flag, 'float64') + np.testing.assert_almost_equal(paddle_mel_freq, + librosa_mel_freq, + decimal=3) + + @parameterize([8000, 16000], [64, 128, 256]) + def test_audio_function_fft(self, sr: int, n_fft: int): + librosa_fft = librosa.fft_frequencies(sr, n_fft) + paddle_fft = paddle.audio.functional.fft_frequencies(sr, n_fft) + np.testing.assert_almost_equal(paddle_fft, librosa_fft, decimal=5) + + @parameterize([1.0, 3.0, 9.0]) + def test_audio_function_exception(self, spect: float): + try: + paddle.audio.functional.power_to_db(paddle.to_tensor([spect]), + amin=0) + except Exception: + pass + + try: + paddle.audio.functional.power_to_db(paddle.to_tensor([spect]), + ref_value=0) + + except Exception: + pass + + try: + paddle.audio.functional.power_to_db(paddle.to_tensor([spect]), + top_db=-1) + except Exception: + pass + + @parameterize([ + "hamming", "hann", "triang", "bohman", "blackman", "cosine", "tukey", + "taylor" + ], [1, 512]) + def test_window(self, window_type: str, n_fft: int): + window_scipy = signal.get_window(window_type, n_fft) + window_paddle = paddle.audio.functional.get_window(window_type, n_fft) + np.testing.assert_array_almost_equal(window_scipy, + window_paddle.numpy(), + decimal=5) + + @parameterize([1, 512]) + def test_gussian_window_and_exception(self, n_fft: int): + window_scipy_gaussain = signal.windows.gaussian(n_fft, std=7) + window_paddle_gaussian = paddle.audio.functional.get_window( + ('gaussian', 7), n_fft, False) + np.testing.assert_array_almost_equal(window_scipy_gaussain, + window_paddle_gaussian.numpy(), + decimal=5) + window_scipy_general_gaussain = signal.windows.general_gaussian( + n_fft, 1, 7) + window_paddle_general_gaussian = paddle.audio.functional.get_window( + ('general_gaussian', 1, 7), n_fft, False) + np.testing.assert_array_almost_equal(window_scipy_gaussain, + window_paddle_gaussian.numpy(), + decimal=5) + + window_scipy_exp = signal.windows.exponential(n_fft) + window_paddle_exp = paddle.audio.functional.get_window( + ('exponential', None, 1), n_fft, False) + np.testing.assert_array_almost_equal(window_scipy_exp, + window_paddle_exp.numpy(), + decimal=5) + try: + window_paddle = paddle.audio.functional.get_window(("kaiser", 1.0), + self.n_fft) + except NotImplementedError: + pass + + try: + window_paddle = paddle.audio.functional.get_window("hann", -1) + except ValueError: + pass + + try: + window_paddle = paddle.audio.functional.get_window( + "fake_window", self.n_fft) + except ValueError: + pass + + try: + window_paddle = paddle.audio.functional.get_window(1043, self.n_fft) + except ValueError: + pass + + @parameterize([5, 13, 23], [257, 513, 1025]) + def test_create_dct(self, n_mfcc: int, n_mels: int): + + def dct(n_filters, n_input): + basis = np.empty((n_filters, n_input)) + basis[0, :] = 1.0 / np.sqrt(n_input) + samples = np.arange(1, 2 * n_input, 2) * np.pi / (2.0 * n_input) + + for i in range(1, n_filters): + basis[i, :] = np.cos(i * samples) * np.sqrt(2.0 / n_input) + return basis.T + + librosa_dct = dct(n_mfcc, n_mels) + paddle_dct = paddle.audio.functional.create_dct(n_mfcc, n_mels) + np.testing.assert_array_almost_equal(librosa_dct, paddle_dct, decimal=5) + + @parameterize([128, 256, 512], ["hamming", "hann", "triang", "bohman"], + [True, False]) + def test_stft_and_spect(self, n_fft: int, window_str: str, + center_flag: bool): + hop_length = int(n_fft / 4) + if len(self.waveform.shape) == 2: # (C, T) + self.waveform = self.waveform.squeeze( + 0) # 1D input for librosa.feature.melspectrogram + feature_librosa = librosa.core.stft( + y=self.waveform, + n_fft=n_fft, + hop_length=hop_length, + win_length=None, + window=window_str, + center=center_flag, + dtype=None, + pad_mode=self.pad_mode, + ) + x = paddle.to_tensor(self.waveform).unsqueeze(0) + window = paddle.audio.functional.get_window(window_str, + n_fft, + dtype=x.dtype) + feature_paddle = paddle.signal.stft( + x=x, + n_fft=n_fft, + hop_length=hop_length, + win_length=None, + window=window, + center=center_flag, + pad_mode=self.pad_mode, + normalized=False, + onesided=True, + ).squeeze(0) + np.testing.assert_array_almost_equal(feature_librosa, + feature_paddle, + decimal=5) + + feature_bg = np.power(np.abs(feature_librosa), 2.0) + feature_extractor = paddle.audio.features.Spectrogram( + n_fft=n_fft, + hop_length=hop_length, + win_length=None, + window=window_str, + power=2.0, + center=center_flag, + pad_mode=self.pad_mode, + ) + feature_layer = feature_extractor(x).squeeze(0) + np.testing.assert_array_almost_equal(feature_layer, + feature_bg, + decimal=3) + + @parameterize([128, 256, 512], [64, 82], + ["hamming", "hann", "triang", "bohman"]) + def test_istft(self, n_fft: int, hop_length: int, window_str: str): + if len(self.waveform.shape) == 2: # (C, T) + self.waveform = self.waveform.squeeze( + 0) # 1D input for librosa.feature.melspectrogram + # librosa + # Get stft result from librosa. + stft_matrix = librosa.core.stft( + y=self.waveform, + n_fft=n_fft, + hop_length=hop_length, + win_length=None, + window=window_str, + center=True, + pad_mode=self.pad_mode, + ) + feature_librosa = librosa.core.istft( + stft_matrix=stft_matrix, + hop_length=hop_length, + win_length=None, + window=window_str, + center=True, + dtype=None, + length=None, + ) + x = paddle.to_tensor(stft_matrix).unsqueeze(0) + window = paddle.audio.functional.get_window(window_str, + n_fft, + dtype=paddle.to_tensor( + self.waveform).dtype) + feature_paddle = paddle.signal.istft( + x=x, + n_fft=n_fft, + hop_length=hop_length, + win_length=None, + window=window, + center=True, + normalized=False, + onesided=True, + length=None, + return_complex=False, + ).squeeze(0) + + np.testing.assert_array_almost_equal(feature_librosa, + feature_paddle, + decimal=5) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/tests/test_audio_logmel_feature.py b/python/paddle/tests/test_audio_logmel_feature.py new file mode 100644 index 0000000000000..a89dc583c3d58 --- /dev/null +++ b/python/paddle/tests/test_audio_logmel_feature.py @@ -0,0 +1,163 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import librosa +import numpy as np +import os +import paddle + +import paddle.audio +import scipy +from scipy import signal +import itertools +from parameterized import parameterized + + +def parameterize(*params): + return parameterized.expand(list(itertools.product(*params))) + + +class TestFeatures(unittest.TestCase): + + def setUp(self): + self.initParmas() + + def initParmas(self): + + def get_wav_data(dtype: str, num_channels: int, num_frames: int): + dtype_ = getattr(paddle, dtype) + base = paddle.linspace(-1.0, 1.0, num_frames, dtype=dtype_) * 0.1 + data = base.tile([num_channels, 1]) + return data + + self.fmin = 0.0 + self.top_db = 80.0 + self.duration = 0.5 + self.num_channels = 1 + self.sr = 16000 + self.dtype = "float32" + waveform_tensor = get_wav_data(self.dtype, + self.num_channels, + num_frames=self.duration * self.sr) + self.waveform = waveform_tensor.numpy() + + @parameterize([16000], ["hamming", "bohman"], [128], [128, 64], [64, 32], + [0.0, 50.0]) + def test_log_melspect(self, sr: int, window_str: str, n_fft: int, + hop_length: int, n_mels: int, fmin: float): + if len(self.waveform.shape) == 2: # (C, T) + self.waveform = self.waveform.squeeze( + 0) # 1D input for librosa.feature.melspectrogram + + # librosa: + feature_librosa = librosa.feature.melspectrogram(y=self.waveform, + sr=sr, + n_fft=n_fft, + hop_length=hop_length, + window=window_str, + n_mels=n_mels, + center=True, + fmin=fmin, + pad_mode='reflect') + feature_librosa = librosa.power_to_db(feature_librosa, top_db=None) + x = paddle.to_tensor(self.waveform, dtype=paddle.float64).unsqueeze( + 0) # Add batch dim. + feature_extractor = paddle.audio.features.LogMelSpectrogram( + sr=sr, + n_fft=n_fft, + hop_length=hop_length, + window=window_str, + center=True, + n_mels=n_mels, + f_min=fmin, + top_db=None, + dtype=x.dtype) + feature_layer = feature_extractor(x).squeeze(0).numpy() + np.testing.assert_array_almost_equal(feature_librosa, + feature_layer, + decimal=2) + # relative difference + np.testing.assert_allclose(feature_librosa, feature_layer, rtol=1e-4) + + @parameterize([16000], [256, 128], [40, 64], [64, 128], + ['float32', 'float64']) + def test_mfcc(self, sr: int, n_fft: int, n_mfcc: int, n_mels: int, + dtype: str): + if paddle.version.cuda() != 'False': + if float(paddle.version.cuda()) >= 11.0: + return + + if len(self.waveform.shape) == 2: # (C, T) + self.waveform = self.waveform.squeeze( + 0) # 1D input for librosa.feature.melspectrogram + + # librosa: + np_dtype = getattr(np, dtype) + feature_librosa = librosa.feature.mfcc(y=self.waveform, + sr=sr, + S=None, + n_mfcc=n_mfcc, + dct_type=2, + lifter=0, + n_fft=n_fft, + hop_length=64, + n_mels=n_mels, + fmin=50.0, + dtype=np_dtype) + # paddlespeech.audio.features.layer + x = paddle.to_tensor(self.waveform, + dtype=dtype).unsqueeze(0) # Add batch dim. + feature_extractor = paddle.audio.features.MFCC(sr=sr, + n_mfcc=n_mfcc, + n_fft=n_fft, + hop_length=64, + n_mels=n_mels, + top_db=self.top_db, + dtype=x.dtype) + feature_layer = feature_extractor(x).squeeze(0).numpy() + + np.testing.assert_array_almost_equal(feature_librosa, + feature_layer, + decimal=3) + + np.testing.assert_allclose(feature_librosa, feature_layer, rtol=1e-1) + + # split mffcc: logmel-->dct --> mfcc, which prove the difference. + # the dct module is correct. + feature_extractor = paddle.audio.features.LogMelSpectrogram( + sr=sr, + n_fft=n_fft, + hop_length=64, + n_mels=n_mels, + center=True, + pad_mode='reflect', + top_db=self.top_db, + dtype=x.dtype) + feature_layer_logmel = feature_extractor(x).squeeze(0).numpy() + + feature_layer_mfcc = scipy.fftpack.dct(feature_layer_logmel, + axis=0, + type=2, + norm="ortho")[:n_mfcc] + np.testing.assert_array_almost_equal(feature_layer_mfcc, + feature_librosa, + decimal=3) + np.testing.assert_allclose(feature_layer_mfcc, + feature_librosa, + rtol=1e-1) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/tests/test_audio_mel_feature.py b/python/paddle/tests/test_audio_mel_feature.py new file mode 100644 index 0000000000000..427e9864117cd --- /dev/null +++ b/python/paddle/tests/test_audio_mel_feature.py @@ -0,0 +1,117 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import librosa +import numpy as np +import os +import paddle + +import paddle.audio +from scipy import signal +import itertools +from parameterized import parameterized + + +def parameterize(*params): + return parameterized.expand(list(itertools.product(*params))) + + +class TestFeatures(unittest.TestCase): + + def setUp(self): + self.initParmas() + + def initParmas(self): + + def get_wav_data(dtype: str, num_channels: int, num_frames: int): + dtype_ = getattr(paddle, dtype) + base = paddle.linspace(-1.0, 1.0, num_frames, dtype=dtype_) * 0.1 + data = base.tile([num_channels, 1]) + return data + + self.hop_length = 128 + self.duration = 0.5 + self.num_channels = 1 + self.sr = 16000 + self.dtype = "float32" + waveform_tensor = get_wav_data(self.dtype, + self.num_channels, + num_frames=self.duration * self.sr) + self.waveform = waveform_tensor.numpy() + + @parameterize([8000], [128, 256], [64, 32], [0.0, 1.0], + ['float32', 'float64']) + def test_mel(self, sr: int, n_fft: int, n_mels: int, fmin: float, + dtype: str): + feature_librosa = librosa.filters.mel( + sr=sr, + n_fft=n_fft, + n_mels=n_mels, + fmin=fmin, + fmax=None, + htk=False, + norm='slaney', + dtype=np.dtype(dtype), + ) + paddle_dtype = getattr(paddle, dtype) + feature_functional = paddle.audio.functional.compute_fbank_matrix( + sr=sr, + n_fft=n_fft, + n_mels=n_mels, + f_min=fmin, + f_max=None, + htk=False, + norm='slaney', + dtype=paddle_dtype, + ) + + np.testing.assert_array_almost_equal(feature_librosa, + feature_functional) + + @parameterize([8000, 16000], [128, 256], [64, 82], [40, 80], [False, True]) + def test_melspect(self, sr: int, n_fft: int, hop_length: int, n_mels: int, + htk: bool): + if len(self.waveform.shape) == 2: # (C, T) + self.waveform = self.waveform.squeeze( + 0) # 1D input for librosa.feature.melspectrogram + + # librosa: + feature_librosa = librosa.feature.melspectrogram(y=self.waveform, + sr=sr, + n_fft=n_fft, + hop_length=hop_length, + n_mels=n_mels, + htk=htk, + fmin=50.0) + + # paddle.audio.features.layer + x = paddle.to_tensor(self.waveform, dtype=paddle.float64).unsqueeze( + 0) # Add batch dim. + feature_extractor = paddle.audio.features.MelSpectrogram( + sr=sr, + n_fft=n_fft, + hop_length=hop_length, + n_mels=n_mels, + htk=htk, + dtype=x.dtype) + feature_layer = feature_extractor(x).squeeze(0).numpy() + + np.testing.assert_array_almost_equal(feature_librosa, + feature_layer, + decimal=5) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/setup.py.in b/python/setup.py.in index d919227450fe4..3d400881de382 100755 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -365,6 +365,10 @@ packages=['paddle', 'paddle.vision.models', 'paddle.vision.transforms', 'paddle.vision.datasets', + 'paddle.audio', + 'paddle.audio.functional', + 'paddle.audio.features', + 'paddle.audio.utils', 'paddle.text', 'paddle.text.datasets', 'paddle.incubate', diff --git a/python/unittest_py/requirements.txt b/python/unittest_py/requirements.txt index f70037e71611f..78c6518953bd2 100644 --- a/python/unittest_py/requirements.txt +++ b/python/unittest_py/requirements.txt @@ -15,3 +15,5 @@ prettytable distro numpy>=1.20,<1.22; python_version >= "3.7" autograd==1.4 +librosa==0.8.1 +parameterized diff --git a/tools/dockerfile/ci_dockerfile.sh b/tools/dockerfile/ci_dockerfile.sh index e5a6c240fe5f9..fbc21ec955d58 100644 --- a/tools/dockerfile/ci_dockerfile.sh +++ b/tools/dockerfile/ci_dockerfile.sh @@ -38,7 +38,7 @@ function make_ubuntu_dockerfile(){ ENV PATH=/usr/local/gcc-8.2/bin:\$PATH #g" ${dockerfile_name} sed -i "s#bash /build_scripts/install_nccl2.sh#wget -q --no-proxy https://nccl2-deb.cdn.bcebos.com/nccl-repo-ubuntu1604-2.7.8-ga-cuda10.1_1-1_amd64.deb \\ RUN dpkg -i nccl-repo-ubuntu1604-2.7.8-ga-cuda10.1_1-1_amd64.deb \\ - RUN apt remove -y libnccl* --allow-change-held-packages \&\& apt-get install -y libnccl2=2.7.8-1+cuda10.1 libnccl-dev=2.7.8-1+cuda10.1 zstd pigz --allow-change-held-packages #g" ${dockerfile_name} + RUN apt remove -y libnccl* --allow-change-held-packages \&\& apt-get install -y libsndfile1 libnccl2=2.7.8-1+cuda10.1 libnccl-dev=2.7.8-1+cuda10.1 zstd pigz --allow-change-held-packages #g" ${dockerfile_name} } function make_ubuntu_trt7_dockerfile(){ @@ -47,7 +47,7 @@ function make_ubuntu_trt7_dockerfile(){ sed -i "s#liblzma-dev#liblzma-dev openmpi-bin openmpi-doc libopenmpi-dev#g" ${dockerfile_name} dockerfile_line=$(wc -l ${dockerfile_name}|awk '{print $1}') sed -i "${dockerfile_line}i RUN apt remove -y libcudnn* --allow-change-held-packages \&\& \ - apt-get install -y --allow-unauthenticated libcudnn8=8.1.0.77-1+cuda10.2 libcudnn8-dev=8.1.0.77-1+cuda10.2 --allow-change-held-packages" ${dockerfile_name} + apt-get install -y --allow-unauthenticated libsndfile1 libcudnn8=8.1.0.77-1+cuda10.2 libcudnn8-dev=8.1.0.77-1+cuda10.2 --allow-change-held-packages" ${dockerfile_name} sed -i "${dockerfile_line}i RUN wget --no-check-certificate -q \ https://developer.download.nvidia.com/compute/cuda/10.2/Prod/patches/2/cuda_10.2.2_linux.run \&\& \ bash cuda_10.2.2_linux.run --silent --toolkit \&\& ldconfig" ${dockerfile_name} @@ -73,7 +73,7 @@ function make_ubuntu_trt7_dockerfile(){ RUN ln -s /usr/local/gcc-8.2/bin/g++ /usr/bin/g++ \\ ENV PATH=/usr/local/gcc-8.2/bin:\$PATH #g" ${dockerfile_name} sed -i "s#bash /build_scripts/install_nccl2.sh#wget -q --no-proxy https://nccl2-deb.cdn.bcebos.com/nccl-repo-ubuntu1604-2.7.8-ga-cuda10.1_1-1_amd64.deb \\ - RUN apt remove -y libnccl* --allow-change-held-packages \&\& apt-get install -y libnccl2=2.7.8-1+cuda10.1 libnccl-dev=2.7.8-1+cuda10.1 zstd pigz --allow-change-held-packages #g" ${dockerfile_name} + RUN apt remove -y libnccl* --allow-change-held-packages \&\& apt-get install -y libsndfile1 libnccl2=2.7.8-1+cuda10.1 libnccl-dev=2.7.8-1+cuda10.1 zstd pigz --allow-change-held-packages #g" ${dockerfile_name} } @@ -82,7 +82,7 @@ function make_centos_dockerfile(){ sed "s//11.0-cudnn8-devel-centos7/g" Dockerfile.centos >${dockerfile_name} sed -i "s#COPY build_scripts /build_scripts#COPY tools/dockerfile/build_scripts ./build_scripts#g" ${dockerfile_name} dockerfile_line=$(wc -l ${dockerfile_name}|awk '{print $1}') - sed -i "${dockerfile_line}i RUN yum install -y pigz graphviz zstd" ${dockerfile_name} + sed -i "${dockerfile_line}i RUN yum install -y pigz graphviz zstd libsndfile" ${dockerfile_name} sed -i "${dockerfile_line}i RUN pip3.7 install distro" ${dockerfile_name} sed -i "${dockerfile_line}i ENV LD_LIBRARY_PATH /opt/_internal/cpython-3.7.0/lib:/usr/local/ssl/lib:/opt/rh/devtoolset-2/root/usr/lib64:/opt/rh/devtoolset-2/root/usr/lib:/usr/local/lib64:/usr/local/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 " ${dockerfile_name} sed -i "${dockerfile_line}i ENV PATH /opt/_internal/cpython-3.7.0/bin:/usr/local/ssl:/usr/local/gcc-8.2/bin:/usr/local/go/bin:/root/gopath/bin:/opt/rh/devtoolset-2/root/usr/bin:/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/java/jdk1.8.0_192/bin " ${dockerfile_name} @@ -104,7 +104,7 @@ function make_cinn_dockerfile(){ sed -i 's###g' ${dockerfile_name} sed -i "7i ENV TZ=Asia/Beijing" ${dockerfile_name} sed -i "8i RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone" ${dockerfile_name} - sed -i "9i RUN apt-get update && apt-get install -y liblzma-dev openmpi-bin openmpi-doc libopenmpi-dev" ${dockerfile_name} + sed -i "27i RUN apt-get update && apt-get install -y liblzma-dev openmpi-bin openmpi-doc libopenmpi-dev libsndfile1" ${dockerfile_name} dockerfile_line=$(wc -l ${dockerfile_name}|awk '{print $1}') sed -i "${dockerfile_line}i RUN wget --no-check-certificate -q https://paddle-edl.bj.bcebos.com/hadoop-2.7.7.tar.gz \&\& \ tar -xzf hadoop-2.7.7.tar.gz && mv hadoop-2.7.7 /usr/local/" ${dockerfile_name}