Skip to content

Commit

Permalink
feat(whisper): support whisper arch (#2141)
Browse files Browse the repository at this point in the history
* feat(whisper): support whisper arch

* feat(whisper): add arg, key_bias

* feat(whisper): success load

* feat(whisper): format tensor name & dtype

* feat(whisper): add arg, key_bias

* feat(whisper): fix convert

* feat(whisper): add config

* feat(whisper): add config

* feat(whisper): fix lint

* feat(whisper): add compute_log_mel_spectrogram

* feat(whisper): add compute_log_mel_spectrogram

* feat(whisper): add compute_log_mel_spectrogram

* feat(whisper): add compute_log_mel_spectrogram

* feat(whisper): add ref

* feat(whisper): convert to units.txt

* feat(whisper): support learnable pe

* feat(whisper): support whisper pe

* feat(whisper): fix lint

* feat(whisper): fix config

* feat(whisper): load success

* feat(whisper): remove time align, add it in the future

* feat(whisper): load succeed

* feat(whisper): fix lint

* feat(whisper): fix property

* feat(whisper): training success

* feat(whisper): fix lint

* feat(whisper): fix cal_att_loss

* feat(whisper): add log for tie weights

* feat(whisper): fix dict

* feat(whisper): add unit test for pe and mel

* feat(whisper): fix apt install

* feat(whisper): pass unit test

* feat(whisper): pass unit test

* feat(whisper): try to pass unit test

* feat(whisper): try to pass unit test

* feat(whisper): add doc

* feat(whisper): remove r_att

* feat(whisper): fix lint
  • Loading branch information
xingchensong committed Nov 22, 2023
1 parent db25890 commit 3b977ae
Show file tree
Hide file tree
Showing 18 changed files with 1,030 additions and 21 deletions.
1 change: 1 addition & 0 deletions .github/workflows/unit_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ jobs:
run: |
set -eux
pip install -r requirements.txt
sudo apt update && sudo apt install -y ffmpeg
- name: Run Pytest
run: |
set -eux
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,5 @@ pyflakes==2.2.0
torch==1.13.0
torchaudio==0.13.0
deepspeed
librosa
openai-whisper
Binary file added test/resources/aishell-BAC009S0724W0121.wav
Binary file not shown.
Binary file added test/resources/librispeech-1995-1837-0001.wav
Binary file not shown.
358 changes: 358 additions & 0 deletions test/wenet/whisper/test_whisper.py

Large diffs are not rendered by default.

12 changes: 9 additions & 3 deletions wenet/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ def Dataset(data_type,
conf,
bpe_model=None,
non_lang_syms=None,
partition=True):
partition=True,
whisper_tokenizer=None):
""" Construct dataset from arguments
We have two shuffle stage in the Dataset. The first is global
Expand All @@ -145,7 +146,8 @@ def Dataset(data_type,
dataset = Processor(dataset, processor.parse_raw)

dataset = Processor(dataset, processor.tokenize, symbol_table, bpe_model,
non_lang_syms, conf.get('split_with_space', False))
non_lang_syms, conf.get('split_with_space', False),
whisper_tokenizer)
filter_conf = conf.get('filter_conf', {})
dataset = Processor(dataset, processor.filter, **filter_conf)

Expand All @@ -157,13 +159,17 @@ def Dataset(data_type,
dataset = Processor(dataset, processor.speed_perturb)

feats_type = conf.get('feats_type', 'fbank')
assert feats_type in ['fbank', 'mfcc']
assert feats_type in ['fbank', 'mfcc', 'log_mel_spectrogram']
if feats_type == 'fbank':
fbank_conf = conf.get('fbank_conf', {})
dataset = Processor(dataset, processor.compute_fbank, **fbank_conf)
elif feats_type == 'mfcc':
mfcc_conf = conf.get('mfcc_conf', {})
dataset = Processor(dataset, processor.compute_mfcc, **mfcc_conf)
elif feats_type == 'log_mel_spectrogram':
log_mel_spectrogram_conf = conf.get('log_mel_spectrogram_conf', {})
dataset = Processor(dataset, processor.compute_log_mel_spectrogram,
**log_mel_spectrogram_conf)

spec_aug = conf.get('spec_aug', True)
spec_sub = conf.get('spec_sub', False)
Expand Down
52 changes: 51 additions & 1 deletion wenet/dataset/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import librosa
import logging
import json
import random
Expand All @@ -23,6 +24,7 @@
import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from wenet.utils.tokenize_utils import tokenize_by_bpe_model

Expand Down Expand Up @@ -322,11 +324,54 @@ def compute_mfcc(data,
yield dict(key=sample['key'], label=sample['label'], feat=mat)


def compute_log_mel_spectrogram(data,
n_fft=400,
hop_length=160,
num_mel_bins=80,
padding=0):
""" Extract log mel spectrogram, modified from openai-whisper, see:
- https://github.com/openai/whisper/blob/main/whisper/audio.py
- https://github.com/wenet-e2e/wenet/pull/2141#issuecomment-1811765040
Args:
data: Iterable[{key, wav, label, sample_rate}]
Returns:
Iterable[{key, feat, label}]
"""
for sample in data:
assert 'sample_rate' in sample
assert 'wav' in sample
assert 'key' in sample
assert 'label' in sample
sample_rate = sample['sample_rate']
waveform = sample['wav'].squeeze(0) # (channel=1, sample) -> (sample,)
if padding > 0:
waveform = F.pad(waveform, (0, padding))
window = torch.hann_window(n_fft)
stft = torch.stft(waveform, n_fft, hop_length,
window=window, return_complex=True)
magnitudes = stft[..., :-1].abs() ** 2

filters = torch.from_numpy(
librosa.filters.mel(sr=sample_rate, n_fft=n_fft, n_mels=num_mel_bins)
)
mel_spec = filters @ magnitudes

# NOTE(xcsong): https://github.com/openai/whisper/discussions/269
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
yield dict(key=sample['key'], label=sample['label'],
feat=log_spec.transpose(0, 1))


def tokenize(data,
symbol_table,
bpe_model=None,
non_lang_syms=None,
split_with_space=False):
split_with_space=False,
whisper_tokenizer=None):
""" Decode text to chars or BPE
Inplace operation
Expand All @@ -352,6 +397,11 @@ def tokenize(data,
for sample in data:
assert 'txt' in sample
txt = sample['txt'].strip()
# TODO(xcsong): This is a dirty workaround for whisper tokernizer,
# refine it in the future
if whisper_tokenizer is not None:
sample['label'] = whisper_tokenizer.encode(txt)
yield sample
if non_lang_syms_pattern is not None:
parts = non_lang_syms_pattern.split(txt.upper())
parts = [w for w in parts if len(w.strip()) > 0]
Expand Down
5 changes: 3 additions & 2 deletions wenet/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,16 @@ class MultiHeadedAttention(nn.Module):
dropout_rate (float): Dropout rate.
"""
def __init__(self, n_head: int, n_feat: int, dropout_rate: float):
def __init__(self, n_head: int, n_feat: int, dropout_rate: float,
key_bias: bool = True):
"""Construct an MultiHeadedAttention object."""
super().__init__()
assert n_feat % n_head == 0
# We assume d_v always equals d_k
self.d_k = n_feat // n_head
self.h = n_head
self.linear_q = nn.Linear(n_feat, n_feat)
self.linear_k = nn.Linear(n_feat, n_feat)
self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
self.linear_v = nn.Linear(n_feat, n_feat)
self.linear_out = nn.Linear(n_feat, n_feat)
self.dropout = nn.Dropout(p=dropout_rate)
Expand Down
61 changes: 55 additions & 6 deletions wenet/transformer/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@
from typing import Tuple, List, Optional

import torch
import logging

from wenet.transformer.attention import MultiHeadedAttention
from wenet.transformer.decoder_layer import DecoderLayer
from wenet.transformer.embedding import PositionalEncoding
from wenet.transformer.embedding import NoPositionalEncoding
from wenet.transformer.embedding import LearnablePositionalEncoding
from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward
from wenet.utils.mask import (subsequent_mask, make_pad_mask)
from wenet.utils.common import get_activation


class TransformerDecoder(torch.nn.Module):
Expand All @@ -43,6 +46,7 @@ class TransformerDecoder(torch.nn.Module):
False: use layer_norm after each sub-block of a layer.
src_attention: if false, encoder-decoder cross attention is not
applied, such as CIF model
key_bias: whether use bias in attention.linear_k, False for whisper models.
"""

def __init__(
Expand All @@ -60,9 +64,12 @@ def __init__(
use_output_layer: bool = True,
normalize_before: bool = True,
src_attention: bool = True,
key_bias: bool = True,
activation_type: str = "relu",
):
super().__init__()
attention_dim = encoder_output_size
activation = get_activation(activation_type)

if input_layer == "embed":
self.embed = torch.nn.Sequential(
Expand All @@ -72,24 +79,33 @@ def __init__(
elif input_layer == 'none':
self.embed = NoPositionalEncoding(attention_dim,
positional_dropout_rate)
elif input_layer == 'embed_learnable_pe':
self.embed = torch.nn.Sequential(
torch.nn.Embedding(vocab_size, attention_dim),
LearnablePositionalEncoding(attention_dim,
positional_dropout_rate),
)
else:
raise ValueError(f"only 'embed' is supported: {input_layer}")

self.normalize_before = normalize_before
self.after_norm = torch.nn.LayerNorm(attention_dim, eps=1e-5)
self.use_output_layer = use_output_layer
self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
if use_output_layer:
self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
else:
self.output_layer = torch.nn.Identity()
self.num_blocks = num_blocks
self.decoders = torch.nn.ModuleList([
DecoderLayer(
attention_dim,
MultiHeadedAttention(attention_heads, attention_dim,
self_attention_dropout_rate),
self_attention_dropout_rate, key_bias),
MultiHeadedAttention(attention_heads, attention_dim,
src_attention_dropout_rate)
src_attention_dropout_rate, key_bias)
if src_attention else None,
PositionwiseFeedForward(attention_dim, linear_units,
dropout_rate),
dropout_rate, activation),
dropout_rate,
normalize_before,
) for _ in range(self.num_blocks)
Expand Down Expand Up @@ -185,6 +201,29 @@ def forward_one_step(
y = torch.log_softmax(self.output_layer(y), dim=-1)
return y, new_cache

def tie_or_clone_weights(self, jit_mode: bool = True):
"""Tie or clone module weights (between word_emb and output_layer)
depending of whether we are using TorchScript or not"""
if not self.use_output_layer:
return
if jit_mode:
logging.info("clone emb.weight to output.weight")
self.output_layer.weight = torch.nn.Parameter(self.embed[0].weight.clone())
else:
logging.info("tie emb.weight with output.weight")
self.output_layer.weight = self.embed[0].weight

if getattr(self.output_layer, "bias", None) is not None:
self.output_layer.bias.data = torch.nn.functional.pad(
self.output_layer.bias.data,
(
0,
self.output_layer.weight.shape[0] - self.output_layer.bias.shape[0],
),
"constant",
0,
)


class BiTransformerDecoder(torch.nn.Module):
"""Base class of Transfomer decoder module.
Expand All @@ -203,6 +242,7 @@ class BiTransformerDecoder(torch.nn.Module):
normalize_before:
True: use layer_norm before each sub-block of a layer.
False: use layer_norm after each sub-block of a layer.
key_bias: whether use bias in attention.linear_k, False for whisper models.
"""

def __init__(
Expand All @@ -220,20 +260,23 @@ def __init__(
input_layer: str = "embed",
use_output_layer: bool = True,
normalize_before: bool = True,
key_bias: bool = True,
):

super().__init__()
self.left_decoder = TransformerDecoder(
vocab_size, encoder_output_size, attention_heads, linear_units,
num_blocks, dropout_rate, positional_dropout_rate,
self_attention_dropout_rate, src_attention_dropout_rate,
input_layer, use_output_layer, normalize_before)
input_layer, use_output_layer, normalize_before,
key_bias=key_bias)

self.right_decoder = TransformerDecoder(
vocab_size, encoder_output_size, attention_heads, linear_units,
r_num_blocks, dropout_rate, positional_dropout_rate,
self_attention_dropout_rate, src_attention_dropout_rate,
input_layer, use_output_layer, normalize_before)
input_layer, use_output_layer, normalize_before,
key_bias=key_bias)

def forward(
self,
Expand Down Expand Up @@ -294,3 +337,9 @@ def forward_one_step(
"""
return self.left_decoder.forward_one_step(memory, memory_mask, tgt,
tgt_mask, cache)

def tie_or_clone_weights(self, jit_mode: bool = True):
"""Tie or clone module weights (between word_emb and output_layer)
depending of whether we are using TorchScript or not"""
self.left_decoder.tie_or_clone_weights(jit_mode)
self.right_decoder.tie_or_clone_weights(jit_mode)
33 changes: 30 additions & 3 deletions wenet/transformer/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import torch
import torch.nn.functional as F
import numpy as np

class PositionalEncoding(torch.nn.Module):
"""Positional encoding.
Expand Down Expand Up @@ -93,13 +94,13 @@ def position_encoding(self, offset: Union[int, torch.Tensor], size: int,
# How to subscript a Union type:
# https://github.com/pytorch/pytorch/issues/69434
if isinstance(offset, int):
assert offset + size < self.max_len
assert offset + size <= self.max_len
pos_emb = self.pe[:, offset:offset + size]
elif isinstance(offset, torch.Tensor) and offset.dim() == 0: # scalar
assert offset + size < self.max_len
assert offset + size <= self.max_len
pos_emb = self.pe[:, offset:offset + size]
else: # for batched streaming decoding on GPU
assert torch.max(offset) + size < self.max_len
assert torch.max(offset) + size <= self.max_len
index = offset.unsqueeze(1) + \
torch.arange(0, size).to(offset.device) # B X T
flag = index > 0
Expand Down Expand Up @@ -140,6 +141,32 @@ def forward(self,
return self.dropout(x), self.dropout(pos_emb)


class WhisperPositionalEncoding(PositionalEncoding):
""" Sinusoids position encoding used in openai-whisper.encoder
"""
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 1500):
super().__init__(d_model, dropout_rate, max_len)
self.xscale = 1.0
log_timescale_increment = np.log(10000) / (d_model // 2 - 1)
inv_timescales = torch.exp(-log_timescale_increment *
torch.arange(d_model // 2))
scaled_time = torch.arange(max_len)[:, np.newaxis] * \
inv_timescales[np.newaxis, :]
pe = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
delattr(self, "pe")
self.register_buffer("pe", pe.unsqueeze(0))


class LearnablePositionalEncoding(PositionalEncoding):
""" Learnable position encoding used in openai-whisper.decoder
"""
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 448):
super().__init__(d_model, dropout_rate, max_len)
# NOTE(xcsong): overwrite self.pe & self.xscale
self.pe = torch.nn.Parameter(torch.empty(1, max_len, d_model))
self.xscale = 1.0


class NoPositionalEncoding(torch.nn.Module):
""" No position encoding
"""
Expand Down

0 comments on commit 3b977ae

Please sign in to comment.