Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(whisper): support whisper arch #2141

Merged
merged 38 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
66a71f4
feat(whisper): support whisper arch
xingchensong Nov 13, 2023
a184bbb
feat(whisper): add arg, key_bias
xingchensong Nov 13, 2023
1a5b1d3
feat(whisper): success load
xingchensong Nov 13, 2023
65101d4
feat(whisper): format tensor name & dtype
xingchensong Nov 13, 2023
93091ad
feat(whisper): add arg, key_bias
xingchensong Nov 13, 2023
176adf8
feat(whisper): fix convert
xingchensong Nov 14, 2023
4c29e34
feat(whisper): add config
xingchensong Nov 14, 2023
6c57e81
feat(whisper): add config
xingchensong Nov 14, 2023
5dffed6
feat(whisper): fix lint
xingchensong Nov 14, 2023
3889e2e
feat(whisper): add compute_log_mel_spectrogram
xingchensong Nov 15, 2023
35b0db1
feat(whisper): add compute_log_mel_spectrogram
xingchensong Nov 15, 2023
5bfdca2
feat(whisper): add compute_log_mel_spectrogram
xingchensong Nov 15, 2023
6cd3a02
feat(whisper): add compute_log_mel_spectrogram
xingchensong Nov 15, 2023
63705ca
feat(whisper): add ref
xingchensong Nov 15, 2023
c99a0ee
feat(whisper): convert to units.txt
xingchensong Nov 15, 2023
2903532
feat(whisper): support learnable pe
xingchensong Nov 15, 2023
cb94bff
feat(whisper): support whisper pe
xingchensong Nov 15, 2023
15a11d9
feat(whisper): fix lint
xingchensong Nov 15, 2023
b7e7cbe
feat(whisper): fix config
xingchensong Nov 15, 2023
d94bf50
feat(whisper): load success
xingchensong Nov 15, 2023
3d171eb
feat(whisper): remove time align, add it in the future
xingchensong Nov 15, 2023
3a127de
feat(whisper): load succeed
xingchensong Nov 15, 2023
becb954
feat(whisper): fix lint
xingchensong Nov 15, 2023
ff8c953
feat(whisper): fix property
xingchensong Nov 15, 2023
8a784f0
feat(whisper): training success
xingchensong Nov 20, 2023
9c9b441
feat(whisper): fix lint
xingchensong Nov 20, 2023
6c27648
feat(whisper): fix cal_att_loss
xingchensong Nov 20, 2023
3bbdb8b
feat(whisper): add log for tie weights
xingchensong Nov 20, 2023
ce78819
feat(whisper): fix dict
xingchensong Nov 20, 2023
c5c1aa7
feat(whisper): add unit test for pe and mel
xingchensong Nov 21, 2023
e7beeb5
feat(whisper): fix apt install
xingchensong Nov 21, 2023
affca09
feat(whisper): pass unit test
xingchensong Nov 22, 2023
a5c51d5
feat(whisper): pass unit test
xingchensong Nov 22, 2023
c31bc7f
feat(whisper): try to pass unit test
xingchensong Nov 22, 2023
166a78a
feat(whisper): try to pass unit test
xingchensong Nov 22, 2023
9131a57
feat(whisper): add doc
xingchensong Nov 22, 2023
7d1b2e9
feat(whisper): remove r_att
xingchensong Nov 22, 2023
fa0a357
feat(whisper): fix lint
xingchensong Nov 22, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/unit_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ jobs:
run: |
set -eux
pip install -r requirements.txt
sudo apt update && sudo apt install -y ffmpeg
- name: Run Pytest
run: |
set -eux
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,5 @@ pyflakes==2.2.0
torch==1.13.0
torchaudio==0.13.0
deepspeed
librosa
openai-whisper
Binary file added test/resources/aishell-BAC009S0724W0121.wav
Binary file not shown.
Binary file added test/resources/librispeech-1995-1837-0001.wav
Binary file not shown.
358 changes: 358 additions & 0 deletions test/wenet/whisper/test_whisper.py

Large diffs are not rendered by default.

12 changes: 9 additions & 3 deletions wenet/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ def Dataset(data_type,
conf,
bpe_model=None,
non_lang_syms=None,
partition=True):
partition=True,
whisper_tokenizer=None):
""" Construct dataset from arguments

We have two shuffle stage in the Dataset. The first is global
Expand All @@ -145,7 +146,8 @@ def Dataset(data_type,
dataset = Processor(dataset, processor.parse_raw)

dataset = Processor(dataset, processor.tokenize, symbol_table, bpe_model,
non_lang_syms, conf.get('split_with_space', False))
non_lang_syms, conf.get('split_with_space', False),
whisper_tokenizer)
filter_conf = conf.get('filter_conf', {})
dataset = Processor(dataset, processor.filter, **filter_conf)

Expand All @@ -157,13 +159,17 @@ def Dataset(data_type,
dataset = Processor(dataset, processor.speed_perturb)

feats_type = conf.get('feats_type', 'fbank')
assert feats_type in ['fbank', 'mfcc']
assert feats_type in ['fbank', 'mfcc', 'log_mel_spectrogram']
if feats_type == 'fbank':
fbank_conf = conf.get('fbank_conf', {})
dataset = Processor(dataset, processor.compute_fbank, **fbank_conf)
elif feats_type == 'mfcc':
mfcc_conf = conf.get('mfcc_conf', {})
dataset = Processor(dataset, processor.compute_mfcc, **mfcc_conf)
elif feats_type == 'log_mel_spectrogram':
log_mel_spectrogram_conf = conf.get('log_mel_spectrogram_conf', {})
dataset = Processor(dataset, processor.compute_log_mel_spectrogram,
**log_mel_spectrogram_conf)

spec_aug = conf.get('spec_aug', True)
spec_sub = conf.get('spec_sub', False)
Expand Down
52 changes: 51 additions & 1 deletion wenet/dataset/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import librosa
import logging
import json
import random
Expand All @@ -23,6 +24,7 @@
import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from wenet.utils.tokenize_utils import tokenize_by_bpe_model

Expand Down Expand Up @@ -322,11 +324,54 @@ def compute_mfcc(data,
yield dict(key=sample['key'], label=sample['label'], feat=mat)


def compute_log_mel_spectrogram(data,
n_fft=400,
hop_length=160,
num_mel_bins=80,
padding=0):
""" Extract log mel spectrogram, modified from openai-whisper, see:
- https://github.com/openai/whisper/blob/main/whisper/audio.py
- https://github.com/wenet-e2e/wenet/pull/2141#issuecomment-1811765040

Args:
data: Iterable[{key, wav, label, sample_rate}]

Returns:
Iterable[{key, feat, label}]
"""
for sample in data:
assert 'sample_rate' in sample
assert 'wav' in sample
assert 'key' in sample
assert 'label' in sample
sample_rate = sample['sample_rate']
waveform = sample['wav'].squeeze(0) # (channel=1, sample) -> (sample,)
if padding > 0:
waveform = F.pad(waveform, (0, padding))
window = torch.hann_window(n_fft)
stft = torch.stft(waveform, n_fft, hop_length,
window=window, return_complex=True)
magnitudes = stft[..., :-1].abs() ** 2

filters = torch.from_numpy(
librosa.filters.mel(sr=sample_rate, n_fft=n_fft, n_mels=num_mel_bins)
)
mel_spec = filters @ magnitudes

# NOTE(xcsong): https://github.com/openai/whisper/discussions/269
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
yield dict(key=sample['key'], label=sample['label'],
feat=log_spec.transpose(0, 1))


def tokenize(data,
symbol_table,
bpe_model=None,
non_lang_syms=None,
split_with_space=False):
split_with_space=False,
whisper_tokenizer=None):
""" Decode text to chars or BPE
Inplace operation

Expand All @@ -352,6 +397,11 @@ def tokenize(data,
for sample in data:
assert 'txt' in sample
txt = sample['txt'].strip()
# TODO(xcsong): This is a dirty workaround for whisper tokernizer,
# refine it in the future
if whisper_tokenizer is not None:
sample['label'] = whisper_tokenizer.encode(txt)
yield sample
if non_lang_syms_pattern is not None:
parts = non_lang_syms_pattern.split(txt.upper())
parts = [w for w in parts if len(w.strip()) > 0]
Expand Down
5 changes: 3 additions & 2 deletions wenet/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,16 @@ class MultiHeadedAttention(nn.Module):
dropout_rate (float): Dropout rate.

"""
def __init__(self, n_head: int, n_feat: int, dropout_rate: float):
def __init__(self, n_head: int, n_feat: int, dropout_rate: float,
key_bias: bool = True):
"""Construct an MultiHeadedAttention object."""
super().__init__()
assert n_feat % n_head == 0
# We assume d_v always equals d_k
self.d_k = n_feat // n_head
self.h = n_head
self.linear_q = nn.Linear(n_feat, n_feat)
self.linear_k = nn.Linear(n_feat, n_feat)
self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
self.linear_v = nn.Linear(n_feat, n_feat)
self.linear_out = nn.Linear(n_feat, n_feat)
self.dropout = nn.Dropout(p=dropout_rate)
Expand Down
61 changes: 55 additions & 6 deletions wenet/transformer/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@
from typing import Tuple, List, Optional

import torch
import logging

from wenet.transformer.attention import MultiHeadedAttention
from wenet.transformer.decoder_layer import DecoderLayer
from wenet.transformer.embedding import PositionalEncoding
from wenet.transformer.embedding import NoPositionalEncoding
from wenet.transformer.embedding import LearnablePositionalEncoding
from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward
from wenet.utils.mask import (subsequent_mask, make_pad_mask)
from wenet.utils.common import get_activation


class TransformerDecoder(torch.nn.Module):
Expand All @@ -43,6 +46,7 @@ class TransformerDecoder(torch.nn.Module):
False: use layer_norm after each sub-block of a layer.
src_attention: if false, encoder-decoder cross attention is not
applied, such as CIF model
key_bias: whether use bias in attention.linear_k, False for whisper models.
"""

def __init__(
Expand All @@ -60,9 +64,12 @@ def __init__(
use_output_layer: bool = True,
normalize_before: bool = True,
src_attention: bool = True,
key_bias: bool = True,
activation_type: str = "relu",
):
super().__init__()
attention_dim = encoder_output_size
activation = get_activation(activation_type)

if input_layer == "embed":
self.embed = torch.nn.Sequential(
Expand All @@ -72,24 +79,33 @@ def __init__(
elif input_layer == 'none':
self.embed = NoPositionalEncoding(attention_dim,
positional_dropout_rate)
elif input_layer == 'embed_learnable_pe':
self.embed = torch.nn.Sequential(
torch.nn.Embedding(vocab_size, attention_dim),
LearnablePositionalEncoding(attention_dim,
positional_dropout_rate),
)
else:
raise ValueError(f"only 'embed' is supported: {input_layer}")

self.normalize_before = normalize_before
self.after_norm = torch.nn.LayerNorm(attention_dim, eps=1e-5)
self.use_output_layer = use_output_layer
self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
if use_output_layer:
self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
else:
self.output_layer = torch.nn.Identity()
self.num_blocks = num_blocks
self.decoders = torch.nn.ModuleList([
DecoderLayer(
attention_dim,
MultiHeadedAttention(attention_heads, attention_dim,
self_attention_dropout_rate),
self_attention_dropout_rate, key_bias),
MultiHeadedAttention(attention_heads, attention_dim,
src_attention_dropout_rate)
src_attention_dropout_rate, key_bias)
if src_attention else None,
PositionwiseFeedForward(attention_dim, linear_units,
dropout_rate),
dropout_rate, activation),
dropout_rate,
normalize_before,
) for _ in range(self.num_blocks)
Expand Down Expand Up @@ -185,6 +201,29 @@ def forward_one_step(
y = torch.log_softmax(self.output_layer(y), dim=-1)
return y, new_cache

def tie_or_clone_weights(self, jit_mode: bool = True):
"""Tie or clone module weights (between word_emb and output_layer)
depending of whether we are using TorchScript or not"""
if not self.use_output_layer:
return
if jit_mode:
logging.info("clone emb.weight to output.weight")
self.output_layer.weight = torch.nn.Parameter(self.embed[0].weight.clone())
else:
logging.info("tie emb.weight with output.weight")
self.output_layer.weight = self.embed[0].weight

if getattr(self.output_layer, "bias", None) is not None:
self.output_layer.bias.data = torch.nn.functional.pad(
self.output_layer.bias.data,
(
0,
self.output_layer.weight.shape[0] - self.output_layer.bias.shape[0],
),
"constant",
0,
)


class BiTransformerDecoder(torch.nn.Module):
"""Base class of Transfomer decoder module.
Expand All @@ -203,6 +242,7 @@ class BiTransformerDecoder(torch.nn.Module):
normalize_before:
True: use layer_norm before each sub-block of a layer.
False: use layer_norm after each sub-block of a layer.
key_bias: whether use bias in attention.linear_k, False for whisper models.
"""

def __init__(
Expand All @@ -220,20 +260,23 @@ def __init__(
input_layer: str = "embed",
use_output_layer: bool = True,
normalize_before: bool = True,
key_bias: bool = True,
):

super().__init__()
self.left_decoder = TransformerDecoder(
vocab_size, encoder_output_size, attention_heads, linear_units,
num_blocks, dropout_rate, positional_dropout_rate,
self_attention_dropout_rate, src_attention_dropout_rate,
input_layer, use_output_layer, normalize_before)
input_layer, use_output_layer, normalize_before,
key_bias=key_bias)

self.right_decoder = TransformerDecoder(
vocab_size, encoder_output_size, attention_heads, linear_units,
r_num_blocks, dropout_rate, positional_dropout_rate,
self_attention_dropout_rate, src_attention_dropout_rate,
input_layer, use_output_layer, normalize_before)
input_layer, use_output_layer, normalize_before,
key_bias=key_bias)

def forward(
self,
Expand Down Expand Up @@ -294,3 +337,9 @@ def forward_one_step(
"""
return self.left_decoder.forward_one_step(memory, memory_mask, tgt,
tgt_mask, cache)

def tie_or_clone_weights(self, jit_mode: bool = True):
"""Tie or clone module weights (between word_emb and output_layer)
depending of whether we are using TorchScript or not"""
self.left_decoder.tie_or_clone_weights(jit_mode)
self.right_decoder.tie_or_clone_weights(jit_mode)
33 changes: 30 additions & 3 deletions wenet/transformer/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import torch
import torch.nn.functional as F
import numpy as np

class PositionalEncoding(torch.nn.Module):
"""Positional encoding.
Expand Down Expand Up @@ -93,13 +94,13 @@ def position_encoding(self, offset: Union[int, torch.Tensor], size: int,
# How to subscript a Union type:
# https://github.com/pytorch/pytorch/issues/69434
if isinstance(offset, int):
assert offset + size < self.max_len
assert offset + size <= self.max_len
pos_emb = self.pe[:, offset:offset + size]
elif isinstance(offset, torch.Tensor) and offset.dim() == 0: # scalar
assert offset + size < self.max_len
assert offset + size <= self.max_len
pos_emb = self.pe[:, offset:offset + size]
else: # for batched streaming decoding on GPU
assert torch.max(offset) + size < self.max_len
assert torch.max(offset) + size <= self.max_len
index = offset.unsqueeze(1) + \
torch.arange(0, size).to(offset.device) # B X T
flag = index > 0
Expand Down Expand Up @@ -140,6 +141,32 @@ def forward(self,
return self.dropout(x), self.dropout(pos_emb)


class WhisperPositionalEncoding(PositionalEncoding):
""" Sinusoids position encoding used in openai-whisper.encoder
"""
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 1500):
super().__init__(d_model, dropout_rate, max_len)
self.xscale = 1.0
log_timescale_increment = np.log(10000) / (d_model // 2 - 1)
inv_timescales = torch.exp(-log_timescale_increment *
torch.arange(d_model // 2))
scaled_time = torch.arange(max_len)[:, np.newaxis] * \
inv_timescales[np.newaxis, :]
pe = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
delattr(self, "pe")
self.register_buffer("pe", pe.unsqueeze(0))


class LearnablePositionalEncoding(PositionalEncoding):
""" Learnable position encoding used in openai-whisper.decoder
"""
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 448):
super().__init__(d_model, dropout_rate, max_len)
# NOTE(xcsong): overwrite self.pe & self.xscale
self.pe = torch.nn.Parameter(torch.empty(1, max_len, d_model))
self.xscale = 1.0


class NoPositionalEncoding(torch.nn.Module):
""" No position encoding
"""
Expand Down