-
Notifications
You must be signed in to change notification settings - Fork 20
/
__init__.py
50 lines (36 loc) · 1.25 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
from jax.core import Tracer
from jax.interpreters.xla import DeviceArray
from funsor.adjoint import adjoint_ops
from funsor.interpreter import children, recursion_reinterpret
from funsor.ops import AssociativeOp
from funsor.tensor import Tensor, tensor_to_funsor
from funsor.terms import Funsor, to_funsor
from funsor.util import quote
from . import distributions as _
from . import ops as _
del _ # flake8
@adjoint_ops.register(
Tensor, AssociativeOp, AssociativeOp, Funsor, (DeviceArray, Tracer), tuple, object,
)
def adjoint_tensor(adj_redop, adj_binop, out_adj, data, inputs, dtype):
return {}
@recursion_reinterpret.register(DeviceArray)
@recursion_reinterpret.register(Tracer)
def _recursion_reinterpret_ground(x):
return x
@children.register(DeviceArray)
@children.register(Tracer)
def _children_ground(x):
return ()
to_funsor.register(DeviceArray)(tensor_to_funsor)
to_funsor.register(Tracer)(tensor_to_funsor)
@quote.register(DeviceArray)
def _quote(x, indent, out):
"""
Work around JAX's DeviceArray not supporting reproducible repr.
"""
out.append(
(indent, "np.array({}, dtype=np.{})".format(repr(x.copy().tolist()), x.dtype))
)