forked from xdcesc/my_ch_speech_recognition
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
67 lines (56 loc) · 2.25 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#coding=utf-8
import os
import tensorflow as tf
from utils import get_data, data_hparams
# 0.准备训练所需数据------------------------------
data_args = data_hparams()
data_args.shuffle = True
train_data = get_data(data_args)
# 1.声学模型训练-----------------------------------
from model_speech.cnn_ctc import Am, am_hparams
am_args = am_hparams()
am_args.vocab_size = len(train_data.am_vocab)
am = Am(am_args)
if os.path.exists('logs_am/model.h5'):
print('load acoustic model...')
am.ctc_model.load_weights('logs_am/model.h5')
epochs = 10
batch_num = len(train_data.wav_lst) // train_data.batch_size
for k in range(epochs):
print('this is the', k+1, 'th epochs trainning !!!')
batch = train_data.get_am_batch()
am.ctc_model.fit_generator(batch, steps_per_epoch=batch_num, epochs=1)
am.ctc_model.save_weights('logs_am/model.h5')
# 2.语言模型训练-------------------------------------------
from model_language.transformer import Lm, lm_hparams
lm_args = lm_hparams()
lm_args.input_vocab_size = len(train_data.pny_vocab)
lm_args.label_vocab_size = len(train_data.han_vocab)
lm = Lm(lm_args)
epochs = 10
with lm.graph.as_default():
saver =tf.train.Saver()
with tf.Session(graph=lm.graph) as sess:
merged = tf.summary.merge_all()
sess.run(tf.global_variables_initializer())
add_num = 0
if os.path.exists('logs_lm/checkpoint'):
print('loading language model...')
latest = tf.train.latest_checkpoint('logs_lm')
add_num = int(latest.split('_')[-1])
saver.restore(sess, latest)
writer = tf.summary.FileWriter('logs_lm/tensorboard', tf.get_default_graph())
for k in range(epochs):
total_loss = 0
batch = train_data.get_lm_batch()
for i in range(batch_num):
input_batch, label_batch = next(batch)
feed = {lm.x: input_batch, lm.y: label_batch}
cost,_ = sess.run([lm.mean_loss,lm.train_op], feed_dict=feed)
total_loss += cost
if (k * batch_num + i) % 10 == 0:
rs=sess.run(merged, feed_dict=feed)
writer.add_summary(rs, k * batch_num + i)
print('epochs', k+1, ': average loss = ', total_loss/batch_num)
saver.save(sess, 'logs_lm/model_%d' % (epochs + add_num))
writer.close()