-
Notifications
You must be signed in to change notification settings - Fork 13
/
model_loader.py
643 lines (497 loc) · 24.4 KB
/
model_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
from abc import ABC, abstractmethod
import logging
import math
from typing import Literal
import numpy as np
import soundfile
import torch
import librosa
from torch import nn
from pathlib import Path
from hypy_utils.downloader import download_file
import torch.nn.functional as F
import importlib.util
log = logging.getLogger(__name__)
class ModelLoader(ABC):
"""
Abstract class for loading a model and getting embeddings from it. The model should be loaded in the `load_model` method.
"""
def __init__(self, name: str, num_features: int, sr: int):
self.model = None
self.sr = sr
self.num_features = num_features
self.name = name
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
def get_embedding(self, audio: np.ndarray):
embd = self._get_embedding(audio)
if self.device == torch.device('cuda'):
embd = embd.cpu()
embd = embd.detach().numpy()
# If embedding is float32, convert to float16 to be space-efficient
if embd.dtype == np.float32:
embd = embd.astype(np.float16)
return embd
@abstractmethod
def load_model(self):
pass
@abstractmethod
def _get_embedding(self, audio: np.ndarray):
"""
Returns the embedding of the audio file. The resulting vector should be of shape (n_frames, n_features).
"""
pass
def load_wav(self, wav_file: Path):
wav_data, _ = soundfile.read(wav_file, dtype='int16')
wav_data = wav_data / 32768.0 # Convert to [-1.0, +1.0]
return wav_data
class VGGishModel(ModelLoader):
"""
S. Hershey et al., "CNN Architectures for Large-Scale Audio Classification", ICASSP 2017
"""
def __init__(self, use_pca=False, use_activation=False):
super().__init__("vggish", 128, 16000)
self.use_pca = use_pca
self.use_activation = use_activation
def load_model(self):
self.model = torch.hub.load('harritaylor/torchvggish', 'vggish')
if not self.use_pca:
self.model.postprocess = False
if not self.use_activation:
self.model.embeddings = nn.Sequential(*list(self.model.embeddings.children())[:-1])
self.model.eval()
self.model.to(self.device)
def _get_embedding(self, audio: np.ndarray) -> np.ndarray:
return self.model.forward(audio, self.sr)
class EncodecEmbModel(ModelLoader):
"""
Encodec model from https://github.com/facebookresearch/encodec
Thiss version uses the embedding outputs (continuous values of 128 features).
"""
def __init__(self, variant: Literal['48k', '24k'] = '24k'):
super().__init__('encodec-emb' if variant == '24k' else f"encodec-emb-{variant}", 128,
sr=24000 if variant == '24k' else 48000)
self.variant = variant
def load_model(self):
from encodec import EncodecModel
if self.variant == '48k':
self.model = EncodecModel.encodec_model_48khz()
self.model.set_target_bandwidth(24)
else:
self.model = EncodecModel.encodec_model_24khz()
self.model.set_target_bandwidth(12)
self.model.to(self.device)
def _get_embedding(self, audio: np.ndarray) -> np.ndarray:
segment_length = self.model.segment_length
# The 24k model doesn't use segmenting
if segment_length is None:
return self._get_frame(audio)
# The 48k model uses segmenting
assert audio.dim() == 3
_, channels, length = audio.shape
assert channels > 0 and channels <= 2
stride = segment_length
encoded_frames: list[torch.Tensor] = []
for offset in range(0, length, stride):
frame = audio[:, :, offset:offset + segment_length]
encoded_frames.append(self._get_frame(frame))
# Concatenate
encoded_frames = torch.cat(encoded_frames, dim=0) # [timeframes, 128]
return encoded_frames
def _get_frame(self, audio: np.ndarray) -> np.ndarray:
with torch.no_grad():
length = audio.shape[-1]
duration = length / self.sr
assert self.model.segment is None or duration <= 1e-5 + self.model.segment, f"Audio is too long ({duration} > {self.model.segment})"
emb = self.model.encoder(audio.to(self.device)) # [1, 128, timeframes]
emb = emb[0] # [128, timeframes]
emb = emb.transpose(0, 1) # [timeframes, 128]
return emb
def load_wav(self, wav_file: Path):
import torchaudio
from encodec.utils import convert_audio
wav, sr = torchaudio.load(str(wav_file))
wav = convert_audio(wav, sr, self.sr, self.model.channels)
# If it's longer than 3 minutes, cut it
if wav.shape[1] > 3 * 60 * self.sr:
wav = wav[:, :3 * 60 * self.sr]
return wav.unsqueeze(0)
def _decode_frame(self, emb: np.ndarray) -> np.ndarray:
with torch.no_grad():
emb = torch.from_numpy(emb).float().to(self.device) # [timeframes, 128]
emb = emb.transpose(0, 1) # [128, timeframes]
emb = emb.unsqueeze(0) # [1, 128, timeframes]
audio = self.model.decoder(emb) # [1, 1, timeframes]
audio = audio[0, 0] # [timeframes]
return audio.cpu().numpy()
class DACModel(ModelLoader):
"""
DAC model from https://github.com/descriptinc/descript-audio-codec
pip install descript-audio-codec
"""
def __init__(self):
super().__init__("dac-44kHz", 1024, 44100)
def load_model(self):
from dac.utils import load_model
self.model = load_model(tag='latest', model_type='44khz')
self.model.eval()
self.model.to(self.device)
def _get_embedding(self, audio) -> np.ndarray:
from audiotools import AudioSignal
import time
audio: AudioSignal
# Set variables
win_len = 5.0
overlap_hop_ratio = 0.5
# Fix overlap window so that it's divisible by 4 in # of samples
win_len = ((win_len * self.sr) // 4) * 4
win_len = win_len / self.sr
hop_len = win_len * overlap_hop_ratio
stime = time.time()
# Sanitize input
audio.normalize(-16)
audio.ensure_max_of_audio()
nb, nac, nt = audio.audio_data.shape
audio.audio_data = audio.audio_data.reshape(nb * nac, 1, nt)
pad_length = math.ceil(audio.signal_duration / win_len) * win_len
audio.zero_pad_to(int(pad_length * self.sr))
audio = audio.collect_windows(win_len, hop_len)
print(win_len, hop_len, audio.batch_size, f"(processed in {(time.time() - stime) * 1000:.0f}ms)")
stime = time.time()
emb = []
for i in range(audio.batch_size):
signal_from_batch = AudioSignal(audio.audio_data[i, ...], self.sr)
signal_from_batch.to(self.device)
e1 = self.model.encoder(signal_from_batch.audio_data).cpu() # [1, 1024, timeframes]
e1 = e1[0] # [1024, timeframes]
e1 = e1.transpose(0, 1) # [timeframes, 1024]
emb.append(e1)
emb = torch.cat(emb, dim=0)
print(emb.shape, f'(computing finished in {(time.time() - stime) * 1000:.0f}ms)')
return emb
def load_wav(self, wav_file: Path):
from audiotools import AudioSignal
return AudioSignal(wav_file)
class MERTModel(ModelLoader):
"""
MERT model from https://huggingface.co/m-a-p/MERT-v1-330M
Please specify the layer to use (1-12).
"""
def __init__(self, size='v1-95M', layer=12, limit_minutes=6):
super().__init__(f"MERT-{size}" + ("" if layer == 12 else f"-{layer}"), 768, 24000)
self.huggingface_id = f"m-a-p/MERT-{size}"
self.layer = layer
self.limit = limit_minutes * 60 * self.sr
def load_model(self):
from transformers import Wav2Vec2FeatureExtractor
from transformers import AutoModel
self.model = AutoModel.from_pretrained(self.huggingface_id, trust_remote_code=True)
self.processor = Wav2Vec2FeatureExtractor.from_pretrained(self.huggingface_id, trust_remote_code=True)
# self.sr = self.processor.sampling_rate
self.model.to(self.device)
def _get_embedding(self, audio: np.ndarray) -> np.ndarray:
# Limit to 9 minutes
if audio.shape[0] > self.limit:
log.warning(f"Audio is too long ({audio.shape[0] / self.sr / 60:.2f} minutes > {self.limit / self.sr / 60:.2f} minutes). Truncating.")
audio = audio[:self.limit]
inputs = self.processor(audio, sampling_rate=self.sr, return_tensors="pt").to(self.device)
with torch.no_grad():
out = self.model(**inputs, output_hidden_states=True)
out = torch.stack(out.hidden_states).squeeze() # [13 layers, timeframes, 768]
out = out[self.layer] # [timeframes, 768]
return out
class CLAPLaionModel(ModelLoader):
"""
CLAP model from https://github.com/LAION-AI/CLAP
"""
def __init__(self, type: Literal['audio', 'music']):
super().__init__(f"clap-laion-{type}", 512, 48000)
self.type = type
if type == 'audio':
url = 'https://huggingface.co/lukewys/laion_clap/resolve/main/630k-audioset-best.pt'
elif type == 'music':
url = 'https://huggingface.co/lukewys/laion_clap/resolve/main/music_audioset_epoch_15_esc_90.14.pt'
self.model_file = Path(__file__).parent / ".model-checkpoints" / url.split('/')[-1]
# Download file if it doesn't exist
if not self.model_file.exists():
self.model_file.parent.mkdir(parents=True, exist_ok=True)
download_file(url, self.model_file)
# Patch the model file to remove position_ids (will raise an error otherwise)
self.patch_model_430(self.model_file)
def patch_model_430(self, file: Path):
"""
Patch the model file to remove position_ids (will raise an error otherwise)
This is a new issue after the transformers 4.30.0 update
Please refer to https://github.com/LAION-AI/CLAP/issues/127
"""
# Create a "patched" file when patching is done
patched = file.parent / f"{file.name}.patched.430"
if patched.exists():
return
OFFENDING_KEY = "module.text_branch.embeddings.position_ids"
log.warning("Patching LAION-CLAP's model checkpoints")
# Load the checkpoint from the given path
checkpoint = torch.load(file, map_location="cpu")
# Extract the state_dict from the checkpoint
if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
else:
state_dict = checkpoint
# Delete the specific key from the state_dict
if OFFENDING_KEY in state_dict:
del state_dict[OFFENDING_KEY]
# Save the modified state_dict back to the checkpoint
if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
checkpoint["state_dict"] = state_dict
# Save the modified checkpoint
torch.save(checkpoint, file)
log.warning(f"Saved patched checkpoint to {file}")
# Create a "patched" file when patching is done
patched.touch()
def load_model(self):
import laion_clap
self.model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-tiny' if self.type == 'audio' else 'HTSAT-base')
self.model.load_ckpt(self.model_file)
self.model.to(self.device)
def _get_embedding(self, audio: np.ndarray) -> np.ndarray:
audio = audio.reshape(1, -1)
# The int16-float32 conversion is used for quantization
audio = self.int16_to_float32(self.float32_to_int16(audio))
# Split the audio into 10s chunks with 1s hop
chunk_size = 10 * self.sr # 10 seconds
hop_size = self.sr # 1 second
chunks = [audio[:, i:i+chunk_size] for i in range(0, audio.shape[1], hop_size)]
# Calculate embeddings for each chunk
embeddings = []
for chunk in chunks:
with torch.no_grad():
chunk = chunk if chunk.shape[1] == chunk_size else np.pad(chunk, ((0,0), (0, chunk_size-chunk.shape[1])))
chunk = torch.from_numpy(chunk).float().to(self.device)
emb = self.model.get_audio_embedding_from_data(x = chunk, use_tensor=True)
embeddings.append(emb)
# Concatenate the embeddings
emb = torch.cat(embeddings, dim=0) # [timeframes, 512]
return emb
def int16_to_float32(self, x):
return (x / 32767.0).astype(np.float32)
def float32_to_int16(self, x):
x = np.clip(x, a_min=-1., a_max=1.)
return (x * 32767.).astype(np.int16)
class CdpamModel(ModelLoader):
"""
CDPAM model from https://github.com/pranaymanocha/PerceptualAudio/tree/master/cdpam
"""
def __init__(self, mode: Literal['acoustic', 'content']) -> None:
super().__init__(f"cdpam-{mode}", 512, 22050)
self.mode = mode
assert mode in ['acoustic', 'content'], "Mode must be 'acoustic' or 'content'"
def load_model(self):
from cdpam import CDPAM
self.model = CDPAM(dev=self.device)
def _get_embedding(self, audio: np.ndarray) -> np.ndarray:
audio = torch.from_numpy(audio).float().to(self.device)
# Take 1s chunks
chunk_size = self.sr
frames = []
for i in range(0, audio.shape[1], chunk_size):
chunk = audio[:, i:i+chunk_size]
_, acoustic, content = self.model.model.base_encoder.forward(chunk.unsqueeze(1))
v = acoustic if self.mode == 'acoustic' else content
v = F.normalize(v, dim=1)
frames.append(v)
# Concatenate the embeddings
emb = torch.cat(frames, dim=0) # [timeframes, 512]
return emb
def load_wav(self, wav_file: Path):
x, _ = librosa.load(wav_file, sr=self.sr)
# Convert to 16 bit floating point
x = np.round(x.astype(np.float) * 32768)
x = np.reshape(x, [-1, 1])
x = np.reshape(x, [1, x.shape[0]])
x = np.float32(x)
return x
class CLAPModel(ModelLoader):
"""
CLAP model from https://github.com/microsoft/CLAP
"""
def __init__(self, type: Literal['2023']):
super().__init__(f"clap-{type}", 1024, 44100)
self.type = type
if type == '2023':
url = 'https://huggingface.co/microsoft/msclap/resolve/main/CLAP_weights_2023.pth'
self.model_file = Path(__file__).parent / ".model-checkpoints" / url.split('/')[-1]
# Download file if it doesn't exist
if not self.model_file.exists():
self.model_file.parent.mkdir(parents=True, exist_ok=True)
download_file(url, self.model_file)
def load_model(self):
from msclap import CLAP
self.model = CLAP(self.model_file, version = self.type, use_cuda=self.device == torch.device('cuda'))
#self.model.to(self.device)
def _get_embedding(self, audio: np.ndarray) -> np.ndarray:
audio = audio.reshape(1, -1)
# The int16-float32 conversion is used for quantization
#audio = self.int16_to_float32(self.float32_to_int16(audio))
# Split the audio into 7s chunks with 1s hop
chunk_size = 7 * self.sr # 10 seconds
hop_size = self.sr # 1 second
chunks = [audio[:, i:i+chunk_size] for i in range(0, audio.shape[1], hop_size)]
# zero-pad chunks to make equal length
clen = [x.shape[1] for x in chunks]
chunks = [np.pad(ch, ((0,0), (0,np.max(clen) - ch.shape[1]))) for ch in chunks]
self.model.default_collate(chunks)
# Calculate embeddings for each chunk
embeddings = []
for chunk in chunks:
with torch.no_grad():
chunk = chunk if chunk.shape[1] == chunk_size else np.pad(chunk, ((0,0), (0, chunk_size-chunk.shape[1])))
chunk = torch.from_numpy(chunk).float().to(self.device)
emb = self.model.clap.audio_encoder(chunk)[0]
embeddings.append(emb)
# Concatenate the embeddings
emb = torch.cat(embeddings, dim=0) # [timeframes, 1024]
return emb
def int16_to_float32(self, x):
return (x / 32767.0).astype(np.float32)
def float32_to_int16(self, x):
x = np.clip(x, a_min=-1., a_max=1.)
return (x * 32767.).astype(np.int16)
class W2V2Model(ModelLoader):
"""
W2V2 model from https://huggingface.co/facebook/wav2vec2-base-960h, https://huggingface.co/facebook/wav2vec2-large-960h
Please specify the size ('base' or 'large') and the layer to use (1-12 for 'base' or 1-24 for 'large').
"""
def __init__(self, size: Literal['base', 'large'], layer: Literal['12', '24'], limit_minutes=6):
model_dim = 768 if size == 'base' else 1024
model_identifier = f"w2v2-{size}" + ("" if (layer == 12 and size == 'base') or (layer == 24 and size == 'large') else f"-{layer}")
super().__init__(model_identifier, model_dim, 16000)
self.huggingface_id = f"facebook/wav2vec2-{size}-960h"
self.layer = layer
self.limit = limit_minutes * 60 * self.sr
def load_model(self):
from transformers import AutoProcessor, Wav2Vec2Model
self.model = Wav2Vec2Model.from_pretrained(self.huggingface_id)
self.processor = AutoProcessor.from_pretrained(self.huggingface_id)
self.model.to(self.device)
def _get_embedding(self, audio: np.ndarray) -> np.ndarray:
# Limit to specified minutes
if audio.shape[0] > self.limit:
log.warning(f"Audio is too long ({audio.shape[0] / self.sr / 60:.2f} minutes > {self.limit / self.sr / 60:.2f} minutes). Truncating.")
audio = audio[:self.limit]
inputs = self.processor(audio, sampling_rate=self.sr, return_tensors="pt").to(self.device)
with torch.no_grad():
out = self.model(**inputs, output_hidden_states=True)
out = torch.stack(out.hidden_states).squeeze() # [13 or 25 layers, timeframes, 768 or 1024]
out = out[self.layer] # [timeframes, 768 or 1024]
return out
class HuBERTModel(ModelLoader):
"""
HuBERT model from https://huggingface.co/facebook/hubert-base-ls960, https://huggingface.co/facebook/hubert-large-ls960
Please specify the size ('base' or 'large') and the layer to use (1-12 for 'base' or 1-24 for 'large').
"""
def __init__(self, size: Literal['base', 'large'], layer: Literal['12', '24'], limit_minutes=6):
model_dim = 768 if size == 'base' else 1024
model_identifier = f"hubert-{size}" + ("" if (layer == 12 and size == 'base') or (layer == 24 and size == 'large') else f"-{layer}")
super().__init__(model_identifier, model_dim, 16000)
self.huggingface_id = f"facebook/hubert-{size}-ls960"
self.layer = layer
self.limit = limit_minutes * 60 * self.sr
def load_model(self):
from transformers import AutoProcessor, HubertModel
self.model = HubertModel.from_pretrained(self.huggingface_id)
self.processor = AutoProcessor.from_pretrained("facebook/hubert-large-ls960-ft")
self.model.to(self.device)
def _get_embedding(self, audio: np.ndarray) -> np.ndarray:
# Limit to specified minutes
if audio.shape[0] > self.limit:
log.warning(f"Audio is too long ({audio.shape[0] / self.sr / 60:.2f} minutes > {self.limit / self.sr / 60:.2f} minutes). Truncating.")
audio = audio[:self.limit]
inputs = self.processor(audio, sampling_rate=self.sr, return_tensors="pt").to(self.device)
with torch.no_grad():
out = self.model(**inputs, output_hidden_states=True)
out = torch.stack(out.hidden_states).squeeze() # [13 or 25 layers, timeframes, 768 or 1024]
out = out[self.layer] # [timeframes, 768 or 1024]
return out
class WavLMModel(ModelLoader):
"""
WavLM model from https://huggingface.co/microsoft/wavlm-base, https://huggingface.co/microsoft/wavlm-base-plus, https://huggingface.co/microsoft/wavlm-large
Please specify the model size ('base', 'base-plus', or 'large') and the layer to use (1-12 for 'base' or 'base-plus' and 1-24 for 'large').
"""
def __init__(self, size: Literal['base', 'base-plus', 'large'], layer: Literal['12', '24'], limit_minutes=6):
model_dim = 768 if size in ['base', 'base-plus'] else 1024
model_identifier = f"wavlm-{size}" + ("" if (layer == 12 and size in ['base', 'base-plus']) or (layer == 24 and size == 'large') else f"-{layer}")
super().__init__(model_identifier, model_dim, 16000)
self.huggingface_id = f"patrickvonplaten/wavlm-libri-clean-100h-{size}"
self.layer = layer
self.limit = limit_minutes * 60 * self.sr
def load_model(self):
from transformers import AutoProcessor, WavLMModel
self.model = WavLMModel.from_pretrained(self.huggingface_id)
self.processor = AutoProcessor.from_pretrained(self.huggingface_id)
self.model.to(self.device)
def _get_embedding(self, audio: np.ndarray) -> np.ndarray:
# Limit to specified minutes
if audio.shape[0] > self.limit:
log.warning(f"Audio is too long ({audio.shape[0] / self.sr / 60:.2f} minutes > {self.limit / self.sr / 60:.2f} minutes). Truncating.")
audio = audio[:self.limit]
inputs = self.processor(audio, sampling_rate=self.sr, return_tensors="pt").to(self.device)
with torch.no_grad():
out = self.model(**inputs, output_hidden_states=True)
out = torch.stack(out.hidden_states).squeeze() # [13 or 25 layers, timeframes, 768 or 1024]
out = out[self.layer] # [timeframes, 768 or 1024]
return out
class WhisperModel(ModelLoader):
"""
Whisper model from https://huggingface.co/openai/whisper-base
Please specify the model size ('tiny', 'base', 'small', 'medium', or 'large').
"""
def __init__(self, size: Literal['tiny', 'base', 'small', 'medium', 'large']):
dimensions = {
'tiny': 384,
'base': 512,
'small': 768,
'medium': 1024,
'large': 1280
}
model_dim = dimensions.get(size)
model_identifier = f"whisper-{size}"
super().__init__(model_identifier, model_dim, 16000)
self.huggingface_id = f"openai/whisper-{size}"
def load_model(self):
from transformers import AutoFeatureExtractor
from transformers import WhisperModel
self.model = WhisperModel.from_pretrained(self.huggingface_id)
self.feature_extractor = AutoFeatureExtractor.from_pretrained(self.huggingface_id)
self.model.to(self.device)
def _get_embedding(self, audio: np.ndarray) -> np.ndarray:
inputs = self.feature_extractor(audio, sampling_rate=self.sr, return_tensors="pt").to(self.device)
input_features = inputs.input_features
decoder_input_ids = torch.tensor([[1, 1]]) * self.model.config.decoder_start_token_id
with torch.no_grad():
out = self.model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state # [1, timeframes, 512]
out = out.squeeze() # [timeframes, 384 or 512 or 768 or 1024 or 1280]
return out
def get_all_models() -> list[ModelLoader]:
ms = [
CLAPModel('2023'),
CLAPLaionModel('audio'), CLAPLaionModel('music'),
VGGishModel(),
*(MERTModel(layer=v) for v in range(1, 13)),
EncodecEmbModel('24k'), EncodecEmbModel('48k'),
# DACModel(),
# CdpamModel('acoustic'), CdpamModel('content'),
*(W2V2Model('base', layer=v) for v in range(1, 13)),
*(W2V2Model('large', layer=v) for v in range(1, 25)),
*(HuBERTModel('base', layer=v) for v in range(1, 13)),
*(HuBERTModel('large', layer=v) for v in range(1, 25)),
*(WavLMModel('base', layer=v) for v in range(1, 13)),
*(WavLMModel('base-plus', layer=v) for v in range(1, 13)),
*(WavLMModel('large', layer=v) for v in range(1, 25)),
WhisperModel('tiny'), WhisperModel('small'),
WhisperModel('base'), WhisperModel('medium'),
WhisperModel('large'),
]
if importlib.util.find_spec("dac") is not None:
ms.append(DACModel())
if importlib.util.find_spec("cdpam") is not None:
ms += [CdpamModel('acoustic'), CdpamModel('content')]
return ms