/
callbacks.py
72 lines (52 loc) · 2.5 KB
/
callbacks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import keras.callbacks as callbacks
import h5py
import numpy as np
import yaml
class MetaCheckpoint(callbacks.ModelCheckpoint):
"""
Checkpoints some training information with the model. This should enable
resuming training and having training information on every checkpoint.
Thanks to Roberto Estevao @robertomest - robertomest@poli.ufrj.br
"""
def __init__(self, filepath, monitor='val_loss', verbose=0,
save_best_only=False, save_weights_only=False,
mode='auto', period=1, training_args=None, meta=None):
super(MetaCheckpoint, self).__init__(filepath, monitor=monitor,
verbose=verbose, save_best_only=save_best_only,
save_weights_only=save_weights_only,
mode=mode, period=period)
self.filepath = filepath
self.meta = meta or {'epochs': []}
if training_args:
training_args = vars(training_args)
self.meta['training_args'] = training_args
def on_train_begin(self, logs={}):
super(MetaCheckpoint, self).on_train_begin(logs)
def on_epoch_end(self, epoch, logs={}):
super(MetaCheckpoint, self).on_epoch_end(epoch, logs)
# Get statistics
self.meta['epochs'].append(epoch)
for k, v in logs.items():
# Get default gets the value or sets (and gets) the default value
self.meta.setdefault(k, []).append(v)
# Save to file
filepath = self.filepath.format(epoch=epoch, **logs)
if self.epochs_since_last_save == 0:
with h5py.File(filepath, 'r+') as f:
if 'meta' in f.keys():
del f['meta']
meta_group = f.create_group('meta')
meta_group.attrs['training_args'] = yaml.dump(
self.meta.get('training_args', '{}'))
meta_group.create_dataset('epochs',
data=np.array(self.meta['epochs']))
for k in logs:
meta_group.create_dataset(k, data=np.array(self.meta[k]))
class ProgbarLogger(callbacks.ProgbarLogger):
def __init__(self, show_metrics=None):
super(ProgbarLogger, self).__init__()
self.show_metrics = show_metrics
def on_train_begin(self, logs=None):
super(ProgbarLogger, self).on_train_begin(logs)
if self.show_metrics:
self.params['metrics'] = self.show_metrics