-
Notifications
You must be signed in to change notification settings - Fork 0
/
audio.py
159 lines (135 loc) · 4.59 KB
/
audio.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import soundfile as sf
import torch
from torch import nn
import torchaudio
from torchaudio.sox_effects import apply_effects_tensor
from torchaudio.transforms import MelSpectrogram
class Wav2Mel(nn.Module):
"""Transform audio file into mel spectrogram tensors."""
def __init__(
self,
sample_rate: int = 16000,
norm_db: float = -3.0,
sil_threshold: float = 1.0,
sil_duration: float = 0.1,
fft_window_ms: float = 50.0,
fft_hop_ms: float = 12.5,
n_fft: int = 2048,
f_min: float = 50.0,
n_mels: int = 80,
preemph: float = 0.97,
ref_db: float = 20.0,
dc_db: float = 100.0,
):
super().__init__()
self.sample_rate = sample_rate
self.norm_db = norm_db
self.sil_threshold = sil_threshold
self.sil_duration = sil_duration
self.fft_window_ms = fft_window_ms
self.fft_hop_ms = fft_hop_ms
self.n_fft = n_fft
self.f_min = f_min
self.n_mels = n_mels
self.preemph = preemph
self.ref_db = ref_db
self.dc_db = dc_db
self.sox_effects = SoxEffects(sample_rate, norm_db, sil_threshold, sil_duration)
self.log_melspectrogram = LogMelspectrogram(
sample_rate,
fft_window_ms,
fft_hop_ms,
n_fft,
f_min,
n_mels,
preemph,
ref_db,
dc_db,
)
def forward(self, wav_tensor: torch.Tensor, sample_rate: int) -> torch.Tensor:
wav_tensor = self.sox_effects(wav_tensor, sample_rate)
if wav_tensor.numel() == 0:
return None
mel_tensor = self.log_melspectrogram(wav_tensor)
return mel_tensor
def parse_file(self, file_path):
return self(*torchaudio.load(file_path))
class SoxEffects(nn.Module):
"""Transform waveform tensors."""
def __init__(
self,
sample_rate: int,
norm_db: float,
sil_threshold: float,
sil_duration: float,
):
super().__init__()
self.effects = [
["channels", "1"], # convert to mono
["rate", f"{sample_rate}"], # resample
["norm", f"{norm_db}"], # normalize to -3 dB
[
"silence",
"1",
f"{sil_duration}",
f"{sil_threshold}%",
"-1",
f"{sil_duration}",
f"{sil_threshold}%",
], # remove silence throughout the file
]
def forward(self, wav_tensor: torch.Tensor, sample_rate: int) -> torch.Tensor:
wav_tensor, _ = apply_effects_tensor(wav_tensor, sample_rate, self.effects)
return wav_tensor
class LogMelspectrogram(nn.Module):
"""Transform waveform tensors into log mel spectrogram tensors."""
def __init__(
self,
sample_rate: int,
fft_window_ms: float,
fft_hop_ms: float,
n_fft: int,
f_min: float,
n_mels: int,
preemph: float,
ref_db: float,
dc_db: float,
):
super().__init__()
self.melspectrogram = MelSpectrogram(
sample_rate=sample_rate,
win_length=int(sample_rate * fft_window_ms / 1000),
hop_length=int(sample_rate * fft_hop_ms / 1000),
n_fft=n_fft,
f_min=f_min,
n_mels=n_mels,
)
self.preemph = preemph
self.ref_db = ref_db
self.dc_db = dc_db
def forward(self, wav_tensor: torch.Tensor) -> torch.Tensor:
# preemph
wav_tensor = torch.cat(
(
wav_tensor[:, 0].unsqueeze(-1),
wav_tensor[:, 1:] - self.preemph * wav_tensor[:, :-1],
),
dim=-1,
)
mel_tensor = self.melspectrogram(wav_tensor).squeeze(0) # (n_mels, time)
mel_tensor = 20 * mel_tensor.clamp(min=1e-9).log10()
mel_tensor = (mel_tensor - self.ref_db + self.dc_db) / self.dc_db
return mel_tensor
class Mel2Wav(nn.Module):
def __init__(self, sample_rate: int = 16000, vocoder_path: str = "pretrained/vocoder.pt"):
super().__init__()
self.vocoder = torch.jit.load(vocoder_path).eval()
self.sample_rate = sample_rate
def convert(self, mels):
return self.vocoder.generate([mel.T for mel in mels])
def to_files(self, mels, save_paths):
wavs = self.convert(mels)
for wav, save_path in zip(wavs, save_paths):
sf.write(save_path, wav.data.cpu().numpy(), self.sample_rate)
def to_file(self, mel, save_path):
self.to_files([mel], [save_path])