Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Delayed param in minipyro #533

Open
ordabayevy opened this issue Apr 21, 2021 · 1 comment
Open

Delayed param in minipyro #533

ordabayevy opened this issue Apr 21, 2021 · 1 comment
Labels
discussion examples Examples and tutorials

Comments

@ordabayevy
Copy link
Member

ordabayevy commented Apr 21, 2021

This issue proposes to create delayed pyro.param by log_joint and then use funsor.adam.Adam to optimize parameters. This would allow to write optimization part in minipyro in a way that is more consistent with funsor style (if there is such a thing) and also make it backend-agnostic (with more supported backends in funsor.adam.Adam). Support for constrained variables (#502) might be useful here.

class log_joint(Messenger):
    ...
    def process_message(self, msg):
        if msg["type"] == "param":
            msg["value"] = funsor.Variable(msg["name"], msg["output"])

class SVI:
    ...
    def run(self, *args, **kwargs):
        ...
        loss = ...
        with funsor.montecarlo.MonteCarlo():
            with funsor.adam.Adam(**options):
                loss.reduce(ops.min)

(I have tried this out for a simple guide where pyro.param doesn't have any constraints and not nested in pyro.plate and it seems to work.)

@fritzo
Copy link
Member

fritzo commented Apr 21, 2021

This is great idea! It's a near literal translation of "(stochastic) (gradient descent)".

@fritzo fritzo added the examples Examples and tutorials label Apr 21, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
discussion examples Examples and tutorials
Projects
None yet
Development

No branches or pull requests

2 participants