-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
42 lines (30 loc) · 1.29 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
with tf.Session() as sess:
total_test_rewards = []
# Load the model
saver.restore(sess, "./models/model.ckpt")
for episode in range(1):
total_rewards = 0
state = env.reset()
state, stacked_frames = stack_frames(stacked_frames, state, True, stack_size)
print("****************************************************")
print("EPISODE ", episode)
while True:
# Reshape the state
state = state.reshape((1, *state_size))
# Get action from Q-network
# Estimate the Qs values state
Qs = sess.run(DQNetwork.output, feed_dict={DQNetwork.inputs_: state})
# Take the biggest Q value (= the best action)
choice = np.argmax(Qs)
action = possible_actions[choice]
# Perform the action and get the next_state, reward, and done information
next_state, reward, done, _ = env.step(action)
env.render()
total_rewards += reward
if done:
print ("Score", total_rewards)
total_test_rewards.append(total_rewards)
break
next_state, stacked_frames = stack_frames(stacked_frames, next_state, False, stack_size)
state = next_state
env.close()