-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_binary.py
86 lines (71 loc) · 2.36 KB
/
run_binary.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import numpy as np
from skimage.io import imshow, imread
import pandas as pd
import keras.callbacks
import keras.utils
from model import get_model
from common import (
bound_gpu_usage,
get_rotated_subregion,
augment_images
)
NAME = '01_demo'
LR = 0.000005
BATCH_SIZE = 16
EPOCHS = 128
NUM_CLASSES = 2
IMAGE_SHAPE = (256, 256)
AUGMENT = True
STEPS = 192
STEPS_VAL = 64
def image_generator(df):
while True:
image_index = np.random.randint(0, len(df))
big_image = imread(df['path'][image_index])
mask = imread(df['mask'][image_index], as_grey=True)
angles = np.random.uniform(0, np.pi * 2, (BATCH_SIZE,))
angles = np.random.uniform(0, np.pi * 2, (BATCH_SIZE,))
shifts_x = np.random.uniform(0, mask.shape[0], (BATCH_SIZE,))
shifts_y = np.random.uniform(0, mask.shape[1], (BATCH_SIZE,))
images = np.array([get_rotated_subregion(
big_image, IMAGE_SHAPE, angle, shift
) for angle, shift in zip(angles, zip(shifts_x, shifts_y))
])
if AUGMENT:
images = augment_images(images)
Ys = np.array([get_rotated_subregion(mask, IMAGE_SHAPE,
angle, shift) for angle, shift in zip(angles,
zip(shifts_x, shifts_y))])
if NUM_CLASSES > 2:
labels = keras.utils.to_categorical(
Ys.flatten(),
NUM_CLASSES
).reshape(BATCH_SIZE, IMAGE_SHAPE[0], IMAGE_SHAPE[1], NUM_CLASSES)
else:
labels = Ys.reshape(BATCH_SIZE, IMAGE_SHAPE[0], IMAGE_SHAPE[1])
yield images, labels
def fit():
model = get_model(NUM_CLASSES, LR)
train_df = pd.read_csv('demo/train.tsv', sep='\t')
val_df = pd.read_csv('demo/val.tsv', sep='\t')
hist = model.fit_generator(
image_generator(train_df),
steps_per_epoch=STEPS,
validation_data=image_generator(val_df),
validation_steps=STEPS_VAL,
epochs=EPOCHS,
callbacks=[
keras.callbacks.TensorBoard(
'logs/{}'.format(NAME),
write_images=False,
batch_size=BATCH_SIZE
),
keras.callbacks.ModelCheckpoint(
'models/{}.h5'.format(NAME), verbose=False,
save_best_only=True, monitor='val_loss'
)
]
)
if __name__ == '__main__':
bound_gpu_usage()
fit()