-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
75 lines (64 loc) · 2.03 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, Flatten
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import os
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
config.log_device_placement = True
sess = tf.compat.v1.Session(config=config)
# Paths
model_path = 'models/'
data_path = 'data/mnist/images/'
train_path = data_path + 'train'
image_size = (28, 28)
classes = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
batch_size = 50
epochs_count = 10
train_data_generator = ImageDataGenerator(
validation_split=0.1,
rotation_range=10,
width_shift_range=0.1,
height_shift_range=0.1,
shear_range=0.15
)
train_directory_iterator = train_data_generator.flow_from_directory(
train_path,
target_size=image_size,
classes=classes,
class_mode='sparse',
batch_size=batch_size,
color_mode='grayscale',
subset='training'
)
validation_directory_iterator = train_data_generator.flow_from_directory(
train_path,
target_size=image_size,
classes=classes,
class_mode='sparse',
batch_size=batch_size,
color_mode='grayscale',
subset='validation'
)
model = Sequential([
Conv2D(32, (5, 5), activation='relu', input_shape=(28, 28, 1)),
MaxPooling2D(pool_size=(3, 3)),
Conv2D(32, (3, 3), activation='relu'),
MaxPooling2D(pool_size=(2, 2)),
Flatten(),
Dense(48, activation='relu'),
Dense(10, activation='softmax'),
])
model.summary()
model.compile(optimizer=Adam(learning_rate=0.0001), loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(
train_directory_iterator,
validation_data=validation_directory_iterator,
steps_per_epoch=train_directory_iterator.samples / batch_size,
validation_steps=validation_directory_iterator.samples / batch_size,
epochs=epochs_count
)
if not os.path.exists(model_path):
os.makedirs(model_path)
model.save(model_path + 'model.h5')