Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gaussian funsor variable elimination #559

Open
1 of 3 tasks
fritzo opened this issue Sep 30, 2021 · 7 comments
Open
1 of 3 tasks

Gaussian funsor variable elimination #559

fritzo opened this issue Sep 30, 2021 · 7 comments
Labels
discussion enhancement New feature or request

Comments

@fritzo
Copy link
Member

fritzo commented Sep 30, 2021

Addresses pyro-ppl/pyro#2929
See design doc

This issue tracks changes needed to efficiently perform variable elimination in Gaussian graphical models with plates. While funsor.sum_product.sum_product() is a partial solution, we'd like to generalize to a complete solution.

Tasks

  • Introduce a new Funsor ConditionalGaussian(info_vec, precision, conditional, inputs) representing the batched conditional distribution of the rightmost real input variable, conditioned on other real input variables. This could be (i) a new Funsor in addition to Gaussian, (ii) a replacement or generalization of Gaussian, or (iii) a special case of Gaussian where the input info_vec and precision are structured (requires Refactor Gaussian info_vec,precision from backend arrays to Funsors #556). This may allow cheaper linear algebra.

    Alternatively Switch to sqrt(prescision) representation in Gaussian? #567
    Temporary Workaround: naively scatter the three parameters (info_vec, precision, conditional) into a dense Gaussian. This can be much more computationally expensive.

  • Handle collider variables where a latent variable outside a plate depends on an upstream latent variable inside a plate, thereby coupling the upstream variables via moralization. Currently such problems cannot even be specified in the plated-einsum DSL.
    Temporary workaround: Globally break all plates out of which any arrow leads; equivalent to .to_event().

  • Handle complete bipartite graphs resulting from the RBM motif (x_i --> y_ij <-- z_j). Currently sum_product() and the TVE algorithm give up in this case with "intractable!".
    Temporary workaround: no known workaround

@fritzo
Copy link
Member Author

fritzo commented Oct 5, 2021

@eb8680 it looks like AutoGaussian(pyrocov_model) runs out of GPU memory in constructing a low-rank matrices precision = sqrt @ sqrt.T. One possible solution is to use a sqrt(precision) representation in funsor's Gaussian. I guess the crux is whether we can implement cheap Gaussian tensordot without materializing intermediate low-rank precision matrices. @fehiepsi already worked out most of the sqrt representation in Pyro PR #2019, where ops.add becomes mere concatenation.

@fehiepsi how much effort do you think it would it take for us to port your Pyro PR #2019 to funsor (where it would also be available in NumPyro 😉)?

@fritzo
Copy link
Member Author

fritzo commented Oct 6, 2021

Here is the optimized GFVE schedule for my pyro-cov model. It fits in main memory but runs out of GPU memory.

Contraction(ops.null, ops.add,
 frozenset(),
 (Contraction(ops.logaddexp, ops.add,
   frozenset({Variable('rate_loc_scale__BOUND_13', Real)}),
   (Gaussian(
   │ torch.tensor(...1..., dtype=torch.float32),
   │ torch.tensor(...1 x 1..., dtype=torch.float32),
   │ (('rate_loc_scale__BOUND_13', Real),)),
   │Contraction(ops.logaddexp, ops.add,
   │ frozenset({Variable('rate_scale__BOUND_14', Real)}),
   │ (Gaussian(
   │   torch.tensor(...1..., dtype=torch.float32),
   │   torch.tensor(...1 x 1..., dtype=torch.float32),
   │   (('rate_scale__BOUND_14', Real),)),
   │  Contraction(ops.logaddexp, ops.add,
   │   frozenset({Variable('coef__BOUND_12', Reals[2367])}),
   │   (Gaussian(
   │   │ torch.tensor(...2367..., dtype=torch.float32),
   │   │ torch.tensor(...2367 x 2367..., dtype=torch.float32),
   │   │ (('coef__BOUND_12', Reals[2367]),)),
   │   │Contraction(ops.add, ops.null,
   │   │ frozenset({Variable('strain__BOUND_11', Bint[1343])}),
   │   │ (Contraction(ops.logaddexp, ops.add,
   │   │   frozenset({Variable('rate_loc__BOUND_10', Real)}),
   │   │   (Gaussian(
   │   │   │ torch.tensor(...1343 x 2369..., dtype=torch.float32),
   │   │   │ torch.tensor(...1343 x 2369 x 2369..., dtype=torch.float32),
   │   │   │ (('strain__BOUND_11', Bint[1343]),
   │   │   │  ('rate_loc__BOUND_10', Real),
   │   │   │  ('rate_loc_scale__BOUND_13', Real),
   │   │   │  ('coef__BOUND_12', Reals[2367]),)),
   │   │   │Contraction(ops.add, ops.null,
   │   │   │ frozenset({Variable('place__BOUND_4', Bint[1372])}),
   │   │   │ (Contraction(ops.logaddexp, ops.null,
   │   │   │   frozenset({Variable('rate__BOUND_3', Real)}),
   │   │   │   (Gaussian(
   │   │   │   │ torch.tensor(...1372 x 1343 x 3..., dtype=torch.float32),
   │   │   │   │ torch.tensor(...1372 x 1343 x 3 x 3..., dtype=torch.float32),
   │   │   │   │ (('place__BOUND_4', Bint[1372]),
   │   │   │   │  ('strain__BOUND_11', Bint[1343]),
   │   │   │   │  ('rate__BOUND_3', Real),
   │   │   │   │  ('rate_scale__BOUND_14', Real),
   │   │   │   │  ('rate_loc__BOUND_10', Real),)),)),)),)),)),)),)),)),
  Contraction(ops.null, ops.add,
   frozenset(),
   (Contraction(ops.logaddexp, ops.add,
   │ frozenset({Variable('pois_loc__BOUND_16', Real)}),
   │ (Gaussian(
   │   torch.tensor(...1..., dtype=torch.float32),
   │   torch.tensor(...1 x 1..., dtype=torch.float32),
   │   (('pois_loc__BOUND_16', Real),)),
   │  Contraction(ops.logaddexp, ops.add,
   │   frozenset({Variable('pois_scale__BOUND_15', Real)}),
   │   (Gaussian(
   │   │ torch.tensor(...1..., dtype=torch.float32),
   │   │ torch.tensor(...1 x 1..., dtype=torch.float32),
   │   │ (('pois_scale__BOUND_15', Real),)),
   │   │Contraction(ops.add, ops.null,
   │   │ frozenset({Variable('place__BOUND_6', Bint[1372]), Variable('time__BOUND_7', Bint[49])}),
   │   │ (Contraction(ops.logaddexp, ops.null,
   │   │   frozenset({Variable('pois__BOUND_5', Real)}),
   │   │   (Gaussian(
   │   │   │ torch.tensor(...49 x 1372 x 3..., dtype=torch.float32),
   │   │   │ torch.tensor(...49 x 1372 x 3 x 3..., dtype=torch.float32),
   │   │   │ (('time__BOUND_7', Bint[49]),
   │   │   │  ('place__BOUND_6', Bint[1372]),
   │   │   │  ('pois__BOUND_5', Real),
   │   │   │  ('pois_loc__BOUND_16', Real),
   │   │   │  ('pois_scale__BOUND_15', Real),)),)),)),)),)),
   │Contraction(ops.logaddexp, ops.add,
   │ frozenset({Variable('init_loc_scale__BOUND_17', Real)}),
   │ (Gaussian(
   │   torch.tensor(...1..., dtype=torch.float32),
   │   torch.tensor(...1 x 1..., dtype=torch.float32),
   │   (('init_loc_scale__BOUND_17', Real),)),
   │  Contraction(ops.logaddexp, ops.add,
   │   frozenset({Variable('init_scale__BOUND_18', Real)}),
   │   (Gaussian(
   │   │ torch.tensor(...1..., dtype=torch.float32),
   │   │ torch.tensor(...1 x 1..., dtype=torch.float32),
   │   │ (('init_scale__BOUND_18', Real),)),
   │   │Contraction(ops.add, ops.null,
   │   │ frozenset({Variable('strain__BOUND_9', Bint[1343])}),
   │   │ (Contraction(ops.logaddexp, ops.add,
   │   │   frozenset({Variable('init_loc__BOUND_8', Real)}),
   │   │   (Gaussian(
   │   │   │ torch.tensor(...1343 x 2..., dtype=torch.float32),
   │   │   │ torch.tensor(...1343 x 2 x 2..., dtype=torch.float32),
   │   │   │ (('strain__BOUND_9', Bint[1343]),
   │   │   │  ('init_loc__BOUND_8', Real),
   │   │   │  ('init_loc_scale__BOUND_17', Real),)),
   │   │   │Contraction(ops.add, ops.null,
   │   │   │ frozenset({Variable('place__BOUND_2', Bint[1372])}),
   │   │   │ (Contraction(ops.logaddexp, ops.null,
   │   │   │   frozenset({Variable('init__BOUND_1', Real)}),
   │   │   │   (Gaussian(
   │   │   │   │ torch.tensor(...1372 x 1343 x 3..., dtype=torch.float32),
   │   │   │   │ torch.tensor(...1372 x 1343 x 3 x 3..., dtype=torch.float32),
   │   │   │   │ (('place__BOUND_2', Bint[1372]),
   │   │   │   │  ('strain__BOUND_9', Bint[1343]),
   │   │   │   │  ('init__BOUND_1', Real),
   │   │   │   │  ('init_scale__BOUND_18', Real),
   │   │   │   │  ('init_loc__BOUND_8', Real),)),)),)),)),)),)),)),)),))

The crux is this pair of Gaussian contractions with over 1e9 elements

   │  Contraction(ops.logaddexp, ops.add,
   │   frozenset({Variable('coef__BOUND_12', Reals[2367])}),
   │   (Gaussian(
   │   │ torch.tensor(...2367..., dtype=torch.float32),
   │   │ torch.tensor(...2367 x 2367..., dtype=torch.float32),
   │   │ (('coef__BOUND_12', Reals[2367]),)),
   │   │Contraction(ops.add, ops.null,
   │   │ frozenset({Variable('strain__BOUND_11', Bint[1343])}),
   │   │ (Contraction(ops.logaddexp, ops.add,
   │   │   frozenset({Variable('rate_loc__BOUND_10', Real)}),
   │   │   (Gaussian(
   │   │   │ torch.tensor(...1343 x 2369..., dtype=torch.float32),
   │   │   │ torch.tensor(...1343 x 2369 x 2369..., dtype=torch.float32),  # <-------- OOM here
   │   │   │ (('strain__BOUND_11', Bint[1343]),
   │   │   │  ('rate_loc__BOUND_10', Real),
   │   │   │  ('rate_loc_scale__BOUND_13', Real),
   │   │   │  ('coef__BOUND_12', Reals[2367]),)),
   │   │   │Contraction(ops.add, ops.null,
   │   │   │ frozenset({Variable('place__BOUND_4', Bint[1372])}),
   │   │   │ (Contraction(ops.logaddexp, ops.null,
   │   │   │   frozenset({Variable('rate__BOUND_3', Real)}),
   │   │   │   (Gaussian(
   │   │   │   │ torch.tensor(...1372 x 1343 x 3..., dtype=torch.float32),
   │   │   │   │ torch.tensor(...1372 x 1343 x 3 x 3..., dtype=torch.float32),
   │   │   │   │ (('place__BOUND_4', Bint[1372]),
   │   │   │   │  ('strain__BOUND_11', Bint[1343]),
   │   │   │   │  ('rate__BOUND_3', Real),
   │   │   │   │  ('rate_scale__BOUND_14', Real),
   │   │   │   │  ('rate_loc__BOUND_10', Real),)),)),)),)),)),)),)),)),

I believe we can work around this using a combination of @fehiepsi's prec_sqrt representation pyro-ppl/pyro#2019 and a ConditionalGaussian that generalizes AffineNormal. Happy to discuss.

@fehiepsi
Copy link
Member

fehiepsi commented Oct 7, 2021

My impression is most of the details can be preserved (e.g. block vector, block matrix, align gaussian). Back then, one issue was batch qr is very slow on GPU, but torch linalg seems to have been improved a lot since then.

@fritzo
Copy link
Member Author

fritzo commented Oct 7, 2021

@fehiepsi do you recall whether Cholesky was sufficient instead of QR? IIRC there was a PyTorch discussion about cheaply testing for positive definiteness or condition number using torch.linalg.cholesky_ex().

@fehiepsi
Copy link
Member

fehiepsi commented Oct 7, 2021

Looking at the code, I guess we need to triangulate a non-positive-definite precision matrix (e.g. zeros matrix) but I can't recall when we need such triangularization. :( Probably, it is unnecessary. (anyway, we can switch to qr if we face the positive definiteness issue)

@fritzo
Copy link
Member Author

fritzo commented Feb 19, 2022

@eb8680 want to pair code next week on the high-level algorithm for variable elimination, continuing our work from https://github.com/pyro-ppl/funsor/compare/tractable-for-gaussians ?

@eb8680
Copy link
Member

eb8680 commented Feb 19, 2022

Sure!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
discussion enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants