Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(whisper): support whisper arch #2141

Merged
merged 38 commits into from Nov 22, 2023
Merged

feat(whisper): support whisper arch #2141

merged 38 commits into from Nov 22, 2023

Conversation

xingchensong
Copy link
Member

@xingchensong xingchensong commented Nov 13, 2023

whisper基本结构约等于wenet.transformer,复用大部分代码+修改ckpt命名 而不是直接将 openai-whisper的模型定义copy过来,是出于以下几点考虑:

  1. 方便直接进行u2pp-like流式训练,所有训练接口均无缝适配
  2. c++推理的接口同样无缝适配
  3. whisper的ckpt发布后不会被修改,所以无需担心ckpt的tensor name会变,也即ckpt的转换工作是一次性的

TODO (This PR)

TODO (Next PR)

@Mddct Mddct self-requested a review November 14, 2023 01:17
@robin1001 robin1001 self-requested a review November 14, 2023 01:52
@robin1001
Copy link
Collaborator

Looking forward to the whisper powered by wenet.

@xingchensong
Copy link
Member Author

Confirm that torchaudio is equal to whisper.load_audio:

import torchaudio
import numpy as np

from subprocess import CalledProcessError, run


wav_file = "BAC009S0724W0121.wav"

# 1. torchaudio
waveform_torchaudio, sample_rate = torchaudio.load(wav_file)
waveform_torchaudio = waveform_torchaudio.numpy().flatten().astype(np.float32)

# 2. whisper
SAMPLE_RATE = 16000
N_FFT = 400
HOP_LENGTH = 160
CHUNK_LENGTH = 30
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE  # 480000 samples in a 30-second chunk
N_FRAMES = N_SAMPLES // HOP_LENGTH  # 3000 frames in a mel spectrogram input
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2  # the initial convolutions has stride 2
FRAMES_PER_SECOND = SAMPLE_RATE // HOP_LENGTH  # 10ms per audio frame
TOKENS_PER_SECOND = SAMPLE_RATE // N_SAMPLES_PER_TOKEN  # 20ms per audio token

def load_audio(file: str, sr: int = SAMPLE_RATE):
    """
    Open an audio file and read as mono waveform, resampling as necessary

    Parameters
    ----------
    file: str
        The audio file to open

    sr: int
        The sample rate to resample the audio if necessary

    Returns
    -------
    A NumPy array containing the audio waveform, in float32 dtype.
    """

    # This launches a subprocess to decode audio while down-mixing
    # and resampling as necessary.  Requires the ffmpeg CLI in PATH.
    # fmt: off
    cmd = [
        "ffmpeg",
        "-nostdin",
        "-threads", "0",
        "-i", file,
        "-f", "s16le",
        "-ac", "1",
        "-acodec", "pcm_s16le",
        "-ar", str(sr),
        "-"
    ]
    # fmt: on
    try:
        out = run(cmd, capture_output=True, check=True).stdout
    except CalledProcessError as e:
        raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e

    return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0

wavform_whisper = load_audio(wav_file)

# 3. compare
print(waveform_torchaudio.shape, waveform_torchaudio[:10])
print(wavform_whisper.shape, wavform_whisper[:10])
print(np.allclose(waveform_torchaudio, wavform_whisper))

image

@Mddct
Copy link
Collaborator

Mddct commented Nov 14, 2023

fbank这一块需要改动吗

@xingchensong
Copy link
Member Author

fbank这一块需要改动吗

正在验证的就是这个

@xingchensong
Copy link
Member Author

Compared to whisper.log_mel_spectrogram, the precision error of torchaudio and librosa is 1e-4 and 1e-6 respectively, thus we prefer to use librosa:

import torch
import torchaudio
import librosa
import numpy as np
import torch.nn.functional as F
import torchaudio.transforms as T

from functools import lru_cache
from typing import Optional, Union
from subprocess import CalledProcessError, run


wav_file = "BAC009S0724W0121.wav"
N_MEL = 128
SAMPLE_RATE = 16000
N_FFT = 400
HOP_LENGTH = 160
WINDOW_LENGTH = N_FFT


# 1. torchaudio
waveform_torchaudio, sample_rate = torchaudio.load(wav_file)

def torchaudio_log_mel_spectrogram(
    audio: torch.Tensor,
    n_mels: int = 80,
    padding: int = 0,
    device: Optional[Union[str, torch.device]] = None,
):
    if device is not None:
        audio = audio.to(device)
    if padding > 0:
        audio = F.pad(audio, (0, padding))
    window = torch.hann_window(N_FFT).to(audio.device)
    stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
    magnitudes = stft[..., :-1].abs() ** 2

    mel_transform = T.MelScale(
        n_mels=N_MEL,
        sample_rate=SAMPLE_RATE,
        n_stft=(N_FFT // 2) + 1,
        norm="slaney",
        mel_scale="slaney"
    )
    mel_spec = mel_transform(magnitudes).squeeze(0)

    log_spec = torch.clamp(mel_spec, min=1e-10).log10()
    log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
    log_spec = (log_spec + 4.0) / 4.0
    return log_spec

mat_torchaudio = torchaudio_log_mel_spectrogram(waveform_torchaudio).numpy()


# 2. librosa
def librosa_log_mel_spectrogram(
    audio: torch.Tensor,
    n_mels: int = 80,
    padding: int = 0,
    device: Optional[Union[str, torch.device]] = None,
):
    if device is not None:
        audio = audio.to(device)
    if padding > 0:
        audio = F.pad(audio, (0, padding))
    window = torch.hann_window(N_FFT).to(audio.device)
    stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
    magnitudes = stft[..., :-1].abs() ** 2

    filters = torch.from_numpy(
        librosa.filters.mel(sr=SAMPLE_RATE, n_fft=N_FFT, n_mels=N_MEL)
    ).to(magnitudes.device)
    mel_spec = filters @ magnitudes

    log_spec = torch.clamp(mel_spec, min=1e-10).log10().squeeze(0)
    log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
    log_spec = (log_spec + 4.0) / 4.0
    return log_spec

mat_librosa = librosa_log_mel_spectrogram(waveform_torchaudio).numpy()


# 3. whisper
def load_audio(file: str, sr: int = SAMPLE_RATE):
    """
    Open an audio file and read as mono waveform, resampling as necessary

    Parameters
    ----------
    file: str
        The audio file to open

    sr: int
        The sample rate to resample the audio if necessary

    Returns
    -------
    A NumPy array containing the audio waveform, in float32 dtype.
    """

    # This launches a subprocess to decode audio while down-mixing
    # and resampling as necessary.  Requires the ffmpeg CLI in PATH.
    # fmt: off
    cmd = [
        "ffmpeg",
        "-nostdin",
        "-threads", "0",
        "-i", file,
        "-f", "s16le",
        "-ac", "1",
        "-acodec", "pcm_s16le",
        "-ar", str(sr),
        "-"
    ]
    # fmt: on
    try:
        out = run(cmd, capture_output=True, check=True).stdout
    except CalledProcessError as e:
        raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e

    return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0


@lru_cache(maxsize=None)
def mel_filters(device, n_mels: int) -> torch.Tensor:
    """
    load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
    Allows decoupling librosa dependency; saved using:

        np.savez_compressed(
            "mel_filters.npz",
            mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
            mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
        )
    """
    assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"

    filters_path = "mel_filters.npz"
    with np.load(filters_path, allow_pickle=False) as f:
        return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)


def log_mel_spectrogram(
    audio: Union[str, np.ndarray, torch.Tensor],
    n_mels: int = 80,
    padding: int = 0,
    device: Optional[Union[str, torch.device]] = None,
):
    """
    Compute the log-Mel spectrogram of

    Parameters
    ----------
    audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
        The path to audio or either a NumPy array or Tensor
        containing the audio waveform in 16 kHz

    n_mels: int
        The number of Mel-frequency filters, only 80 is supported

    padding: int
        Number of zero samples to pad to the right

    device: Optional[Union[str, torch.device]]
        If given, the audio tensor is moved to this device before STFT

    Returns
    -------
    torch.Tensor, shape = (80, n_frames)
        A Tensor that contains the Mel spectrogram
    """
    if not torch.is_tensor(audio):
        if isinstance(audio, str):
            audio = load_audio(audio)
        audio = torch.from_numpy(audio)

    if device is not None:
        audio = audio.to(device)
    if padding > 0:
        audio = F.pad(audio, (0, padding))
    window = torch.hann_window(N_FFT).to(audio.device)
    stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
    magnitudes = stft[..., :-1].abs() ** 2

    filters = mel_filters(audio.device, n_mels)
    mel_spec = filters @ magnitudes

    log_spec = torch.clamp(mel_spec, min=1e-10).log10()
    log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
    log_spec = (log_spec + 4.0) / 4.0
    return log_spec

mat_whisper = log_mel_spectrogram(
    wav_file, n_mels=N_MEL, padding=0, device="cpu").numpy()


# 3. compare
print("torchaudio\n", mat_torchaudio.shape, mat_torchaudio[:10])
print("librosa\n", mat_librosa.shape, mat_librosa[:10])
print("whisper\n", mat_whisper.shape, mat_whisper[:10])
print("=================== librosa v.s. whisper =====================")
print("librosa v.s. whisper", np.allclose(mat_librosa, mat_whisper, atol=1e-06))
np.testing.assert_allclose(mat_librosa, mat_whisper, atol=1e-06)
print("=================== torchaudio v.s. whisper =====================")
print("torchaudio v.s. whisper", np.allclose(mat_torchaudio, mat_whisper, atol=1e-04))
np.testing.assert_allclose(mat_torchaudio, mat_whisper, atol=1e-04)

image

@xingchensong
Copy link
Member Author

xingchensong commented Nov 15, 2023

image

confirm that wenet.SinusoidalPositionalEncoding is different from whisper.SinusoidalPositionalEncoding, we need a new PE class for whisper

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright [2023-11-15] <sxc19@mails.tsinghua.edu.cn, Xingchen Song>

import torch
import math

import numpy as np


def wenet_sinusoids(length, channels):
    """Returns sinusoids for positional embedding"""
    d_model = channels
    xscale = math.sqrt(d_model)
    max_len = length

    pe = torch.zeros(max_len, d_model)
    position = torch.arange(0, max_len,
                            dtype=torch.float32).unsqueeze(1)
    div_term = torch.exp(
        torch.arange(0, d_model, 2, dtype=torch.float32) *
        -(math.log(10000.0) / d_model))
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    return pe


def whisper_sinusoids(length, channels, max_timescale=10000):
    """Returns sinusoids for positional embedding"""
    assert channels % 2 == 0
    log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
    inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
    scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
    return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)


wenet_pe = wenet_sinusoids(100, 512).numpy()
whisper_pe = whisper_sinusoids(100, 512).numpy()
print(wenet_pe.shape)
print(whisper_pe.shape)
np.testing.assert_allclose(wenet_pe, whisper_pe, atol=1e-8)

image

@xingchensong
Copy link
Member Author

load checkpoint succeed:

image

@xingchensong
Copy link
Member Author

xingchensong commented Nov 16, 2023

Compared to whisper.log_mel_spectrogram, the precision error of torchaudio and librosa is 1e-4 and 1e-6 respectively, thus we prefer to use librosa:

有个提议,直接把 whisper 放进 requirments.txt 如何,这样可以直接调用 whiper.log_mel_spectrogram,使用他的tokenizer也比较方便,同时省去了对齐的麻烦,@robin1001 @Mddct

image

whisper的requirments和我们的基本不冲突

image

@Mddct
Copy link
Collaborator

Mddct commented Nov 16, 2023

model 也要用whisper仓库的吗?
我看着huggface也有:
Screenshot 2023-11-16 at 11 50 30

@xingchensong
Copy link
Member Author

xingchensong commented Nov 16, 2023

引入whisper的库只用他的log_mel_spec和tokenizer,引入这个库,那啥mel_filter.npz和xx.tiktoken也不用手动下载了。cli也可以复用他的download相关函数来下载ckpt

@xingchensong
Copy link
Member Author

pass training
2c029365-3aa1-43f3-a9e0-34fedcab3181

@xingchensong
Copy link
Member Author

whisper-style decoder input:

image

def add_whisper_tokens(
    tokenizer, ys_pad: torch.Tensor,
    ignore_id: int, task_id: int, no_timestamp: bool,
    language: str, use_prev: bool
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Add whisper-style tokens.
    ([PREV] -> [previous text tokens or hotwords]).optional --
      ┌------------------------------------------------------↲

    [sot] -> [language id] -> [transcribe] -> [begin time] -> [text tokens] -> [end time] -> ... -> [eot]    # noqa
        |          |                |-------> [no timestamps] -> [text tokens] ----------------------↑       # noqa
        |          |                                                                                 |       # noqa
        |          |--------> [translate]  -> [begin time] -> [text tokens] -> [end time] -> ... --->|       # noqa
        |                           |-------> [no timestamps] -> [text tokens] --------------------->|       # noqa
        |                                                                                            |       # noqa
        |--> [no speech(VAD)] ---------------------------------------------------------------------->|       # noqa
    Args:
        tokenizer: get IDs of special tokens
        ignore_id (int): index of padding
        no_timestamp (bool): whether to add timestamps tokens
        language (str): language tag
    Returns:
        ys_in (torch.Tensor) : (B, Lmax + ?)
        ys_out (torch.Tensor) : (B, Lmax + ?)
    """
    if use_prev:
        # i.e., hotword list
        _prev = [tokenizer.sot_prev]
        # append hotword list to _prev
        # ...
        raise NotImplementedError
    else:
        _prev = []

    language_id = tokenizer.sot + 1 + WHISPER_LANGS.index(language)
    _sot = _prev + [tokenizer.sot, language_id, task_id]
    _eot = torch.tensor([tokenizer.eot],
                        dtype=torch.long,
                        requires_grad=False,
                        device=ys_pad.device)
    ys = [y[y != ignore_id] for y in ys_pad]  # parse padded ys

    if task_id == tokenizer.transcribe or task_id == tokenizer.translate:
        if no_timestamp:
            _sot.append(tokenizer.no_timestamps)
        else:
            _sot.append(tokenizer.timestamp_begin)
            # add subsequent tokens
            # ...
            raise NotImplementedError
    elif task_id == tokenizer.no_speech:
        _sot.append(tokenizer.no_speech)
    else:
        raise NotImplementedError

    _sot = torch.tensor(_sot, dtype=torch.long,
                        requires_grad=False, device=ys_pad.device)
    ys_in = [torch.cat([_sot, y], dim=0) for y in ys]
    ys_out = [torch.cat([_sot[1:], y, _eot], dim=0) for y in ys]
    return pad_list(ys_in, tokenizer.eot), pad_list(ys_out, ignore_id)

@xingchensong
Copy link
Member Author

xingchensong commented Nov 22, 2023

增加了unit test (单元测试包含 tiny/base/small/medium,large太大,github的cicd机器跑不起来),数值对齐结果如下:

  1. encoder_out, fp32 数值差异最大为 6e-3
  2. decoder_out, fp32 数值差异最大为 6e-2
  3. softmax(decoder_out), fp32 数值差异最大为 1e-10

从 softmax 后的数值可以看出,概率归一化之后,应该不影响解码结果

@xingchensong xingchensong marked this pull request as ready for review November 22, 2023 09:17
@xingchensong
Copy link
Member Author

xingchensong commented Nov 22, 2023

pass unit test, ready for final review @Mddct @robin1001 @whiteshirt0429

p.s. 现在单元测试大概需要7~8min

@xingchensong xingchensong changed the title feat(whisper)[WIP]: support whisper arch feat(whisper): support whisper arch Nov 22, 2023
@xingchensong
Copy link
Member Author

xingchensong commented Nov 22, 2023

From binbin: whisper.log_mel_spectrogramkaldi.fank 区别主要是前面的预加重和后面的归一化,后续有可能的话最好和 fbank 合并。实在不方便的话,可以参考fangjun csukuangfj/kaldifeat#82
在runtime中针对 whisper.log_mel_spectrogram 单独实现一下

@xingchensong
Copy link
Member Author

decoder.tie_or_clone_weight() 的作用是 emb 和 分类的 projection 训练时共享参数,导出时克隆参数(因为jit不支持参数共享)

wenet/whisper/whisper.py Outdated Show resolved Hide resolved
@robin1001 robin1001 merged commit 3b977ae into main Nov 22, 2023
6 checks passed
@robin1001 robin1001 deleted the xcsong-whisper branch November 22, 2023 15:16
@xingchensong
Copy link
Member Author

xingchensong commented Nov 23, 2023

一些开源工具在集成 whisper 时的相关实现(供参考):

  1. espnet, Add Full Whisper Model for Finetuning espnet/espnet#4793, 直接 model = whisper.load_model(),无法对模型结构魔改,不易流式改造
  2. funasr, add whisper model inference pipeline alibaba-damo-academy/FunASR#1003, 直接copy所有whisper的相关文件(模型定义,字典等)过来,只支持推理
  3. hugginface, Add WhisperModel to transformers huggingface/transformers#19166, 类似wenet的方案,先转openai的ckpt到hf的格式(tensor名字,shape等),然后单独为whisper定义了model以及forward流程

@xingchensong
Copy link
Member Author

finetuinng on aishell (u2 transformer, enc&dec init from whisper, ctc init randomly) seems work (torch-ddp in red, deepspeed-stage1 in blue)

image

@xingchensong
Copy link
Member Author

xingchensong commented Nov 23, 2023

指定 50362 为ctc的 blank_id, 原因如下:#2157

@xingchensong
Copy link
Member Author

whisper 潜在问题,whisper的encoder和decoder均采用固定长度emb(encoder 30s, decoder 448字符),因此当开启speed_perturb,变速1.1时 30s的audio可能会造成assertion error #2171

@xingchensong
Copy link
Member Author

xingchensong commented Nov 30, 2023

whisper的tokenizer 和 中文常用的char tokenizer,效率对比如下:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright [2023-11-30] <sxc19@mails.tsinghua.edu.cn, Xingchen Song>

import statistics as s

from whisper.tokenizer import get_tokenizer


tokenizer = get_tokenizer(multilingual=True, num_languages=100)
char_lens, token_lens = [], []
with open("data/train/text", "r") as f:
    lines = f.readlines()
    for l in lines:
        l = l.strip().split()[1]
        char_len = len(l)
        token_len = len(tokenizer.encoding.encode(l))
        char_lens.append(char_len)
        token_lens.append(token_len)
        print("{} {} {}".format(char_len, token_len, l))


print("Mean: CharTokenizer {}, WhisperTokenizer {}".format(s.mean(char_lens),
                                                           s.mean(token_lens)))
print("Var : CharTokenizer {}, WhisperTokenizer {}".format(s.variance(char_lens),
                                                           s.variance(token_lens)))
...
16 21 他还增设了一种叫消防费的收费项目
10 11 总共收上来一四万馀元
8 9 都揣进了自己腰包
20 32 改装小货车撞上出租车钢筋将出租车后座顶出
12 14 八月二零日晚将近一零点钟
15 23 一辆改装的小货车撞上一辆出租车
15 25 货车装载的钢筋将出租车后座对穿
9 12 幸亏车后座没有乘客
13 20 钢筋穿过后排座位并顶出车外
11 11 案件性质关系到国家安全

Mean: CharTokenizer 14.405843561091775, WhisperTokenizer 17.837832436843243
Var : CharTokenizer 18.59714879629654, WhisperTokenizer 35.86092595601854

@xingchensong
Copy link
Member Author

xingchensong commented Nov 30, 2023

whisper的tokenizer和英文常用的bpe tokenizer,效率对比如下:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright [2023-11-30] <sxc19@mails.tsinghua.edu.cn, Xingchen Song>

import statistics as s
import sentencepiece as spm

from whisper.tokenizer import get_tokenizer

from wenet.text.tokenize_utils import tokenize_by_bpe_model


tokenizer = get_tokenizer(multilingual=True, num_languages=100)
bpe_model = spm.SentencePieceProcessor()
bpe_model.load("../../../test/resources/librispeech.train_960_unigram5000.bpemodel")
bpe_lens, token_lens = [], []
with open("data/test_clean/text", "r") as f:
    lines = f.readlines()
    for l in lines:
        l = " ".join(l.strip().split()[1:])
        bpe_len = len(tokenize_by_bpe_model(bpe_model, l))
        token_len = len(tokenizer.encoding.encode(l))
        bpe_lens.append(bpe_len)
        token_lens.append(token_len)
        print("{} {} {}".format(bpe_len, token_len, l))


print("Mean: BpeTokenizer {}, WhisperTokenizer {}".format(s.mean(bpe_lens),
                                                          s.mean(token_lens)))
print("Var : BpeTokenizer {}, WhisperTokenizer {}".format(s.variance(bpe_lens),
                                                          s.variance(token_lens)))
...
18 25 I THANK ALL WHO HAVE LOVED ME IN THEIR HEARTS WITH THANKS AND LOVE FROM MINE
36 62 THEN I LONG TRIED BY NATURAL ILLS RECEIVED THE COMFORT FAST WHILE BUDDING AT THY SIGHT MY PILGRIM'S STAFF GAVE OUT GREEN LEAVES WITH MORNING DEWS IMPEARLED
24 32 I LOVE THEE FREELY AS MEN STRIVE FOR RIGHT I LOVE THEE PURELY AS THEY TURN FROM PRAISE
22 34 I LOVE THEE WITH THE PASSION PUT TO USE IN MY OLD GRIEFS AND WITH MY CHILDHOOD'S FAITH
43 58 I LOVE THEE WITH A LOVE I SEEMED TO LOSE WITH MY LOST SAINTS I LOVE THEE WITH THE BREATH SMILES TEARS OF ALL MY LIFE AND IF GOD CHOOSE I SHALL BUT LOVE THEE BETTER AFTER DEATH

Mean: BpeTokenizer 25.176335877862595, WhisperTokenizer 40.09809160305343
Var : BpeTokenizer 314.1025325790101, WhisperTokenizer 814.1969417556378
...

@xingchensong
Copy link
Member Author

The training of a 1.5 billion parameter model places high demands on CPU memory, and issues like the following are likely to occur on systems with memory equal to or less than 160GB.

pytorch/pytorch#8976

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants