-
Notifications
You must be signed in to change notification settings - Fork 4
/
neuralnet.py
39 lines (32 loc) · 1.26 KB
/
neuralnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch.nn as nn
from config import Args, configure
class Net(nn.Module):
"""
Actor-Critic Network for PPO
"""
def __init__(self, args:Args):
super(Net, self).__init__()
self.cnn_base = nn.Sequential(
nn.Linear(args.valueStackSize*args.numberOfLasers + 3*args.actionStack, 128), #stacking previous distances along with action
nn.ReLU(), # activation
) # output shape (256, 1, 1)
self.v = nn.Sequential(nn.Linear(128, 100), nn.ReLU(), nn.Linear(100, 1))
self.fc = nn.Sequential(nn.Linear(128, 100), nn.ReLU())
self.alpha_head = nn.Sequential(nn.Linear(100, 3), nn.Softplus())
self.beta_head = nn.Sequential(nn.Linear(100, 3), nn.Softplus())
self.apply(self._weights_init)
@staticmethod
def _weights_init(m):
if isinstance(m, nn.Conv2d):
nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu'))
nn.init.constant_(m.bias, 0.1)
def forward(self, x):
x = self.cnn_base(x)
x = x.view(-1, 128)
v = self.v(x)
x = self.fc(x)
alpha = self.alpha_head(x) + 1
beta = self.beta_head(x) + 1
return (alpha, beta), v
if __name__ == "__main__":
print(Net(configure()[0]))