-
Notifications
You must be signed in to change notification settings - Fork 1
/
train_matthieu.py
40 lines (31 loc) · 1.95 KB
/
train_matthieu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# YOLO
# from lns.common.preprocess import Preprocessor
# dataset_scale = Preprocessor.preprocess('ScaleLights')
# dataset_utias = Preprocessor.preprocess('ScaleLights_New_Utias')
# dataset_youtube = Preprocessor.preprocess('ScaleLights_New_Youtube')
# dataset_all = dataset_scale + dataset_utias + dataset_youtube
# dataset_all = dataset_all.merge_classes({
# "green": ["goLeft", "Green", "GreenLeft", "GreenStraightRight", "go", "GreenStraightLeft", "GreenRight", "GreenStraight", "3-green", "4-green", "5-green"],
# "yellow": ["warning", "Yellow", "warningLeft", "3-yellow", "4-yellow", "5-yellow"],
# "red": ["stop", "stopLeft", "RedStraightLeft", "Red", "RedLeft", "RedStraight", "RedRight", "3-red", "4-red", "5-red"],
# "off": ["OFF", "off", "3-off", "3-other", "4-off", "4-other", "5-off", "5-other"]
# })
# from lns.yolo.train import YoloTrainer
# trainer = YoloTrainer('darknet25_416_matthieu_trial2', dataset_all)
# trainer.train()
# Squeezedet
from lns.common.preprocess import Preprocessor
dataset_scale = Preprocessor.preprocess('ScaleLights')
dataset_utias = Preprocessor.preprocess('ScaleLights_New_Utias')
dataset_youtube = Preprocessor.preprocess('ScaleLights_New_Youtube')
dataset_all = dataset_scale + dataset_utias + dataset_youtube
dataset_all = dataset_all.merge_classes({
"green": ["goLeft", "Green", "GreenLeft", "GreenStraightRight", "go", "GreenStraightLeft", "GreenRight", "GreenStraight", "3-green", "4-green", "5-green"],
"yellow": ["warning", "Yellow", "warningLeft", "3-yellow", "4-yellow", "5-yellow"],
"red": ["stop", "stopLeft", "RedStraightLeft", "Red", "RedLeft", "RedStraight", "RedRight", "3-red", "4-red", "5-red"],
"off": ["OFF", "off", "3-off", "3-other", "4-off", "4-other", "5-off", "5-other"]
})
from lns.squeezedet.train import SqueezedetTrainer
# trainer = SqueezedetTrainer('squeezedet_fullres_tiffany_copy1', dataset_all)
trainer = SqueezedetTrainer('first-real_copy1', dataset_all)
trainer.train()