Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

data-controlled ODE with general control signal #121

Open
Wang-Tianyu opened this issue Nov 7, 2021 · 3 comments
Open

data-controlled ODE with general control signal #121

Wang-Tianyu opened this issue Nov 7, 2021 · 3 comments
Labels
question Further information is requested

Comments

@Wang-Tianyu
Copy link

Hi, thanks for the great library.

I am interested in the ODE of the form dz(s)/ds = f_theta (s, r, z(s)) where r can be a vector independent of input x.
Does the current implementation support this feature?

@Wang-Tianyu Wang-Tianyu added the question Further information is requested label Nov 7, 2021
@Zymrael
Copy link
Member

Zymrael commented Nov 8, 2021

Hey, thanks for using torchdyn.

What you ask is not currently supported with the DataControl layer, but can be done pretty easily in two ways:

  • You can have your vector field nn.Module save r in self.r and use it in a forward that calls (s, z) only.
  • This is where the data control is set by default. The only module with u we consider is DataControl, which triggers the above to assign the current input to module.u. Depending on how general you'd prefer your implementation to be, you can simply modify DataControl
class CustomDataControl(nn.Module):
    """Data-control module. Allows for datasets-control inputs at arbitrary points of the DEFunc
    """
    def __init__(self, r):
        super().__init__()
        self.u = None
        self.r = r

    def forward(self, x):
        return torch.cat([x, self.r], 1).to(x)

which is slightly hacky but works, or you can alter the logic in _prep_integration to allow for custom assignments to module.u.

@Wang-Tianyu
Copy link
Author

Hey, thanks for your quick reply.

I managed to make it work following your suggestion!

One remaining question is, in this case, does the general condition r get any gradient?
I am interested in building conditional CNF like this work https://github.com/stevenygd/PointFlow
Do you think that is doable with this library?

@Zymrael
Copy link
Member

Zymrael commented Nov 8, 2021

Glad to hear it worked. Depending on your implementation, it should work just fine and get gradients. It is certainly possible to build PointFlows with torchdyn :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants