Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make redundant features argument optional for recurrent cells #3717

Open
carlosgmartin opened this issue Feb 25, 2024 · 4 comments
Open

Make redundant features argument optional for recurrent cells #3717

carlosgmartin opened this issue Feb 25, 2024 · 4 comments
Labels
Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment.

Comments

@carlosgmartin
Copy link
Contributor

carlosgmartin commented Feb 25, 2024

For recurrent cells such as the following:

the features argument of the constructor is redundant: It can be inferred from the carry input to its __call__ method. (The only cell that currently uses self.features in its __call__ method is ConvLSTMCell, which ought to be modified to infer it from its carry input.)

For each cell, the only place where self.features is needed is in the initialize_carry method. But in many models, the initial carry comes from "upstream" in the model, so this method is never used.

Proposal:

  1. Edit ConvLSTMCell to infer features in its __call__ method from its carry input.

  2. Set features=None by default in each cell's constructor.

  3. Add the following line to each initialize_carry method:

assert self.features is not None, "features cannot be None when calling initialize_carry"

I can submit a PR for this, if desired.

An alternative would be to pass features directly to the initialize_carry method.

@cgarciae
Copy link
Collaborator

cgarciae commented Mar 6, 2024

This feature is needed for RNN I think. I think we added them in just for this 😅
Also, it feels more natural to specify hparams explicitly in the constructor.

@carlosgmartin
Copy link
Contributor Author

carlosgmartin commented Mar 8, 2024

@cgarciae Isn't shape inference from inputs, as done for the inputs argument, more in line with flax's init philosophy?

The RNN situation could be resolved as follows:

  1. Add a features argument to the cell's initialize_carry method.
  2. Add a features argument to RNN's constructor, and on this line, pass it to the self.cell.initialize_carry call.

That seems more natural and elegant to me, since the number of features may ultimately be determined by stuff upstream in the model (as opposed to being intrinsic to the cell itself).

@chiamp chiamp added the Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment. label Mar 19, 2024
@cgarciae
Copy link
Collaborator

What you are describing is how the Flax recurrent API was before, however it was a bit inconsistent e.g. some classes like ConvLSTM required passing the output features while others did not, and it also lacked some of the structure needed to implement the RNN class in simple terms. The solution was to add features to all RNN layers and slightly simplify initialize_carry.

@carlosgmartin
Copy link
Contributor Author

@cgarciae Hmm, is there any reason ConvLSTM can't infer features from its inputs, like the other recurrent modules? I submitted a PR to address that here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment.
Projects
None yet
Development

No branches or pull requests

3 participants