You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The LSTM weight matrices in ppo_atari_lstm.py seem to be be initialized incorrectly, if the goal is to have a separate orthogonal matrix for each gate. Since lstm.weight_ih_l0 and lstm.weight_hh_l0 have the four gate matricies concatenated together, shouldn't each of the four parts of the fused weight matrix be separately initialized to an orthogonal matrix?
As a minimal example, checking just the $W_{hi}$ component of the hidden-hidden weights:
importtorchlstm=torch.nn.LSTM(512, 128)
_=torch.nn.init.orthogonal_(lstm.weight_hh_l0, 1.0)
W_hi=lstm.weight_hh_l0[:128]
torch.allclose(W_hi.T, torch.inverse(W_hi), atol=1e-05) # check that W_hi is orthogonal# -> False
Expected Behavior
importtorchlstm=torch.nn.LSTM(512, 128)
_=torch.nn.init.orthogonal_(lstm.weight_hh_l0[:128], 1.0) # init a view with only W_hiW_hi=lstm.weight_hh_l0[:128]
torch.allclose(W_hi.T, torch.inverse(W_hi), atol=1e-05) # check that W_hi is orthogonal# -> True
Possible Solution
self.lstm = nn.LSTM(512, 128)
for name, param in self.lstm.named_parameters():
if "bias" in name:
nn.init.constant_(param, 0)
elif "weight" in name:
- nn.init.orthogonal_(param, 1.0)+ nn.init.orthogonal_(param[:128], 1.0)+ nn.init.orthogonal_(param[128:128*2], 1.0)+ nn.init.orthogonal_(param[128*2:128*3], 1.0)+ nn.init.orthogonal_(param[128*3:], 1.0)
The text was updated successfully, but these errors were encountered:
Problem Description
The LSTM weight matrices in
ppo_atari_lstm.py
seem to be be initialized incorrectly, if the goal is to have a separate orthogonal matrix for each gate. Sincelstm.weight_ih_l0
andlstm.weight_hh_l0
have the four gate matricies concatenated together, shouldn't each of the four parts of the fused weight matrix be separately initialized to an orthogonal matrix?Checklist
I have installed dependencies viapoetry install
(see CleanRL's installation guideline.Current Behavior
As a minimal example, checking just the$W_{hi}$ component of the hidden-hidden weights:
Expected Behavior
Possible Solution
The text was updated successfully, but these errors were encountered: