Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LSTM weights should have separate orthogonal initializations for each gate #358

Open
2 of 3 tasks
Jammf opened this issue Feb 11, 2023 · 0 comments
Open
2 of 3 tasks

Comments

@Jammf
Copy link

Jammf commented Feb 11, 2023

Problem Description

The LSTM weight matrices in ppo_atari_lstm.py seem to be be initialized incorrectly, if the goal is to have a separate orthogonal matrix for each gate. Since lstm.weight_ih_l0 and lstm.weight_hh_l0 have the four gate matricies concatenated together, shouldn't each of the four parts of the fused weight matrix be separately initialized to an orthogonal matrix?

Checklist

Current Behavior

As a minimal example, checking just the $W_{hi}$ component of the hidden-hidden weights:

import torch
lstm = torch.nn.LSTM(512, 128)
_ = torch.nn.init.orthogonal_(lstm.weight_hh_l0, 1.0)
W_hi = lstm.weight_hh_l0[:128]
torch.allclose(W_hi.T, torch.inverse(W_hi), atol=1e-05)  # check that W_hi is orthogonal
# -> False

Expected Behavior

import torch
lstm = torch.nn.LSTM(512, 128)
_ = torch.nn.init.orthogonal_(lstm.weight_hh_l0[:128], 1.0)  # init a view with only W_hi
W_hi = lstm.weight_hh_l0[:128]
torch.allclose(W_hi.T, torch.inverse(W_hi), atol=1e-05)  # check that W_hi is orthogonal
# -> True

Possible Solution

  self.lstm = nn.LSTM(512, 128)
  for name, param in self.lstm.named_parameters():
      if "bias" in name:
          nn.init.constant_(param, 0)
      elif "weight" in name:
-         nn.init.orthogonal_(param, 1.0)
+         nn.init.orthogonal_(param[:128], 1.0)
+         nn.init.orthogonal_(param[128:128*2], 1.0)
+         nn.init.orthogonal_(param[128*2:128*3], 1.0)
+         nn.init.orthogonal_(param[128*3:], 1.0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant