Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

funsor.joint.eager_reduce_exp behaves differently with memoize #561

Open
ordabayevy opened this issue Oct 1, 2021 · 0 comments
Open

funsor.joint.eager_reduce_exp behaves differently with memoize #561

ordabayevy opened this issue Oct 1, 2021 · 0 comments

Comments

@ordabayevy
Copy link
Member

ordabayevy commented Oct 1, 2021

The if statement in eager_reduce_exp evaluates to False under memoize and the function returns None. Without memoize it returns log_result.exp() as expected.

funsor/funsor/joint.py

Lines 157 to 165 in ca1557b

@eager.register(Reduce, ops.AddOp, Unary[ops.ExpOp, Funsor], frozenset)
def eager_reduce_exp(op, arg, reduced_vars):
# x.exp().reduce(ops.add) == x.reduce(ops.logaddexp).exp()
log_result = arg.arg.reduce(ops.logaddexp, reduced_vars)
if log_result is not normalize.interpret(
Reduce, ops.logaddexp, arg.arg, reduced_vars
):
return log_result.exp()
return None

Example code:

from funsor.cnf import Contraction
from funsor.tensor import Tensor
import torch
import funsor.ops as ops
from funsor import Bint, Real
from funsor.terms import Unary, Binary, Variable, Number, eager, lazy, to_data, Reduce
from funsor.constant import Constant
from funsor.delta import Delta
from funsor.integrate import Integrate
import funsor

funsor.set_backend("torch")

cls = Reduce
args = (ops.add,
        Unary(ops.exp,
         Contraction(ops.null, ops.add,
          frozenset(),
          (Delta(
            (('x__BOUND_16',
              (Tensor(
                torch.tensor([1, 0, 1, 0, 0, 0, 1, 0, 1, 1], dtype=torch.int64),
                (('plate__BOUND_17',
                  Bint[10],),),
                3),
               Number(0.0),),),)),
           Tensor(
            torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=torch.float64),  # noqa
            (('plate__BOUND_17',
              Bint[10],),),
            'real'),))),
        frozenset({Variable('x__BOUND_16', Bint[3])})
    )

# evaluates to a Tensor
result = eager.interpret(cls, *args)

with funsor.interpretations.memoize():
    # evaluates to a lazy Contraction term
    result2 = eager.interpret(cls, *args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant