From 0c22aae961e6f458bf6d7fef30569843348e9312 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Sat, 20 Feb 2021 11:47:18 -0600 Subject: [PATCH 1/8] add badge --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 3e675a1a0..ba8cf1294 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -[![Build Status](https://travis-ci.com/pyro-ppl/funsor.svg?branch=master)](https://travis-ci.com/pyro-ppl/funsor) +[![Build Status](https://github.com/pyro-ppl/funsor/workflows/CI/badge.svg)](https://github.com/pyro-ppl/funsor/actions) [![Latest Version](https://badge.fury.io/py/funsor.svg)](https://pypi.python.org/pypi/funsor) [![Documentation Status](https://readthedocs.org/projects/funsor/badge)](http://funsor.readthedocs.io) From 80283a2164e35f7f4f4352af9f163b98fa943ec5 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Sat, 20 Feb 2021 16:59:43 -0600 Subject: [PATCH 2/8] update document page --- docs/source/_static/img/pyro_logo_wide.png | Bin 0 -> 5564 bytes docs/source/_templates/breadcrumbs.html | 28 ++ docs/source/conf.py | 73 ++- docs/source/index.rst | 16 + examples/README.rst | 2 + examples/discrete_hmm.py | 6 + examples/eeg_slds.py | 4 + examples/kalman_filter.py | 6 + examples/minipyro.py | 6 + examples/mixed_hmm/model.py | 10 +- examples/pcfg.py | 6 + examples/sensor.py | 6 + examples/slds.py | 11 +- examples/vae.py | 6 + funsor/distribution.py | 5 +- funsor/interpreter.py | 3 +- funsor/jax/__init__.py | 8 +- funsor/joint.py | 4 +- funsor/syntax.py | 15 +- funsor/testing.py | 6 +- setup.py | 17 +- test/examples/test_bart.py | 92 +--- test/examples/test_sensor_fusion.py | 11 +- test/pyro/test_hmm.py | 139 +----- test/test_adjoint.py | 7 +- test/test_distribution.py | 5 +- test/test_distribution_generic.py | 4 +- test/test_domains.py | 8 +- test/test_factory.py | 22 +- test/test_gaussian.py | 13 +- test/test_minipyro.py | 17 +- test/test_optimizer.py | 24 +- test/test_samplers.py | 103 +--- test/test_sum_product.py | 547 +++------------------ test/test_tensor.py | 181 +------ test/test_terms.py | 13 +- tutorials/sum_product_network.ipynb | 34 ++ 37 files changed, 377 insertions(+), 1081 deletions(-) create mode 100644 docs/source/_static/img/pyro_logo_wide.png create mode 100644 docs/source/_templates/breadcrumbs.html create mode 100644 examples/README.rst create mode 100644 tutorials/sum_product_network.ipynb diff --git a/docs/source/_static/img/pyro_logo_wide.png b/docs/source/_static/img/pyro_logo_wide.png new file mode 100644 index 0000000000000000000000000000000000000000..cdd87ba014b22478dcba566f30262d6480249293 GIT binary patch literal 5564 zcmd^DXH=6*w|)VShz$}zk*Xj9QY{n}3_=7k90Wu}h%^ftq$Ef)7~lvep@|Iw0tw1N zJkms^gq8rINQ-oo(1H+@CN=bY!}-2-@8A3Hu66GZ7V}PK&+OUd+0X1KYb*0TJEeC* z5VXhQ0@4l`K+=wE-#~csjO;sj;WeMxp zB00^X(a#AR1LetP`8~TJ=z`l39?-6nSh5W|cmOK`86W4sp|G1wWeECALP5|o7`Vz| zB`pEyf*D#6dJ_A8eqr;dJfupV#(669c8}=I_lf3bmUgo_ZC6>VgA9BSn>WMJs;7NrrCvX`r~iv+~udjSP3%Ae>S-!wAFJU-hGqQ zd##r(tBRAaoL|KtbExyqR~EB(q|$?Dl0(l{OpRf7Z*sBgyf39weo0mw_|Y7POmqxs zLdp53{chL`j~_*j`mYs~Sq~CV+(J*AF;}Z^zg$rOdp%g-qbC-msy0XtY4qUy6Ahw^ zMJePAK6SL1Si(&s_1@8t+kO91`ZuPsoi!{$=GeELeFSV=Tx4#Q{c>w|A13!a!jiJ@ zt{Ml2Shb-Yn~e1cUT&R;X>$qBu5ruaDY|dW+aIN7I@E@QHYm_#VA=M^-A)6yJ#k;d zE19|c9im?+m7wD9l%*lf4K%(2@(tikdgE9@(n z(23V{lghDV%eVSi5y>ytz2iRU2Df`7*$%zm*fRca&fRy)ZE8ap;Bb<@FG|>Pq=a+S zzh%2?0-arANOUT*0*>03SpT4}B~EH;re$IrXw9nRmjMoUw7}?$>xfNik#ozb8gBB##c{fG5YL={I^3yH+p zF;4YL6Q9A!kX>8CK)^?hnW*2a#K=$1%}Oh}p3J5RLuXNF6L~yMI7~XIqe@;?0V{Gq z_C`ACXr3ND_3I)Bp_n%RN2LhbZ+hhdo9(tFHRoxHFI{96GX!otOU9Fr>gBK*A zTcZ?qZI@HZt<+tV^JtSO5~?9978ZDh2kArrg&_J7twU5RaLd=~GrE^!J5dKO1e^eOH+W8B>ALyZuCcs~=1YxNcJiDzAdIb~(i}1%LmRGaVCt>;1Vx=t?Ygq_ zl?pJQT`0_*V>}4#(5_b00E$qnloo(Xb>*)+2Z3!0N0L#h;WThps((mBRvFtdYV)dJ zUTo0kHCTc3v$uvtWITl7+D;2V>SSvSP%bSpx22Y}Y_>ojJ24)xk~Z!WuAQi%;idsH zHi3PD@#!N@MAAAWu|OPVlXE){89EU%+v(GsEv>&k8VQmUZ$=J<1_Nm+v zDWa!MJfC2E>+3$@bgLYa>1_@?>@!9x1fUKkmwd8VQA!#R=9=JXCnw2Q39i~ zr+5geGe-+PKcXdou|x}scU<@G-;ot8fcYIv-t2l-FyWKIGKTz$YKsQ8VO$> z#sROsj6%8Eb_4q}mceK;`U(pB`7aPx&>NuQ5jRcQ?`|g2C}B|?3+ezH-?-HUcP%BM zT-kcn@01yxVq`||eX^7KaD)NF4n}4;Nc0U)vHf)X1Ayx~tk8lZ?T<<7G@|QIgXVv( zLjxdYq&UzFF7aGNEBvq0rra*lGQi6uM}P>7PvT_byX^P#3Fv zf)9S&@@iMlb??F&b)qV<;@_(oU%(EYZ3sz%?!2lP3bA>bWCx{>=4xH(FU^XuX#MCy`vvk3I;>6zl2)_Wz=y<<=-kybnUZFWEfYhTZz_ zxoAn2op;2P&%1oEbNVH8wZzdPMjs4oG;{a~v%juA+$Y-9V!6n`G8s74Y~fPZPSnrp z#D9`lt^K;U7m^uRqi3+FsPzhYFQ!I9b??+CB~D9ohA}3v{Er=tAh}y4-M)%wn`BXP z98Es=?juTzCw8(4AY0_HBD%^N$%4gcaN-=OkS=Lri8p#mzBCzCpLAjwLBFnX;Fy1k zUPU!d8(>QE3dd#|_NME)aBzAmQ@PV}wkxglp+X4w?L{ld{ID!>vb9nzIZzMBI`+G? z_Sb!^A@3VM(x2mZ4E+G>QD@2zK74)K`GY2>`fYw^5)|2?NRE_`ujdgs!eT1!KnX(9+%Q~X;S92FT?0`qAcWN?B^y zo|(c3j2PMap*~Jj$83r%lv*EIE9w_-9asiJWRYzOah0@QRYyNE^ckbvh<%5>aCg(D zxfwr=i?s)#n+l*00}_lO)el$>yUBHq_~CqOLQ6YXtxx9hl~;+R4$TW0Ul;ZF*S1d0>}zHQ)dS7_$M~w_v{jk8}WU--# zNY@dq@DD*gZk`|0@T8|Lh-hx*rp2_{JB)rvr0~&8K*-k<{HSMx9p)iOUwbH)>y_q+ zl!so`#g@V8_w*I*w~`v}EZ$%Yl=8=kK4Ihn^L4RJ|LR^J5f&emN334EkyqT0P94&` zT$}kmLK=jUR|4VeF^qH~40epK>b%^2^HMxdUZ2xOe>j`3g3Bi`oh^LT+(dM@EV~a- zb#s1%zhsxb#;&-lXQN|!5R`EvJUwY;UZfwRTcsPJ+zHI}(~~Rjx+C-o@4cZOua2c6 zokT>u+-T28aaqX_VPpK6RJu$x9AkktS@c5s+iUZbI5GgXWoW|5f6?>_@-04ulLa0+ z)O8*^us(PLP(vfLJ~+M0;j~zy^<{v|dCQjWlI~MhNDW~10G#Nvnn}*1pkb%AV}TJqpB)vgWei)Id77;t7e$xb8peDh5JYCRTT5Mlo7S13_fll zVC+=g>*9Xd$xx&MtF}JS<#CNd%j3JRF9qW_$0`CEcL7Rq28Q(~DfGu}s>o#%cI;V} z1%Pta8_DN~lRxK<=i3*ZVA;s;-#fygZk~KOxPAwooiu|8;VyRB?FUHRcMR=NxqpMJ z9;s|80P2(yfT7WG`Wh|?y70prnX}JnsZj4%()MHH2B7Q+3ygGu#oY=sG%b40m!AxK#hwFqN}lZ$$7%UNQTiK^@J?t zsHPi0jbG+;7u1MASsrz>@($Bi`{+r~*owtdhT0wQF`6;T)U6Pm>G06L6aza{uIfqze5``4_zZ zo8X16jE!W(52L&2(qCujv32SSY24FgZK91!CEI^{XNszb~MW{d#~%c4EzhK>9Hx3HVuQbp&G=R7eGHUkV_Q6Hqn5wk;YZ7Y-XxOW3jHs(Ryvyic?|mj zNXF06bN?Crce_W*7KF_+{PwU9pcSOs?r?hiJ4GPLn*(>2RAMSy3xX{SZpAch zfC^f&?Vpv`07$EWdI2H9j&A)8;Hwy%CU!_mV0(9;_W+56qHzr}K^YF)ffdn-{gCT8 za92Zaa#UXMpb1)V_g?_xVoAj~7WN|G$xuC0nc)g508kQ~!PN|J7b8PlyBtS?3QsmG z2ef>K4r4`p2>^c#M)R(q((eI*Qwv2Yp0L^5fI#=g&&g>Bl0eg}$wT%WKNOBmvfY?b**BAE@wKVm8aq#XZ040?&$534dj!xOSNZ-vY=asA_87j8=6 literal 0 HcmV?d00001 diff --git a/docs/source/_templates/breadcrumbs.html b/docs/source/_templates/breadcrumbs.html new file mode 100644 index 000000000..5db8caf4d --- /dev/null +++ b/docs/source/_templates/breadcrumbs.html @@ -0,0 +1,28 @@ +{%- extends "sphinx_rtd_theme/breadcrumbs.html" %} + +{% set display_vcs_links = display_vcs_links if display_vcs_links is defined else True %} + +{% block breadcrumbs_aside %} +
  • + {% if hasdoc(pagename) and display_vcs_links %} + {% if display_github %} + {% if check_meta and 'github_url' in meta %} + + {{ _('Edit on GitHub') }} + {% else %} + {% if 'examples/index' in pagename %} + {{ _('Edit on GitHub') }} + {% elif 'examples/' in pagename %} + {{ _('Edit on GitHub') }} + {% else %} + {{ _('Edit on GitHub') }} + {% endif %} + {% endif %} + {% elif show_source and source_url_prefix %} + {{ _('View page source') }} + {% elif show_source and has_source and sourcename %} + {{ _('View page source') }} + {% endif %} + {% endif %} +
  • +{% endblock %} \ No newline at end of file diff --git a/docs/source/conf.py b/docs/source/conf.py index 73949c96e..a215c3422 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,7 +1,9 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +import glob import os +import shutil import sys import sphinx_rtd_theme @@ -46,11 +48,13 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ + "nbsphinx", "sphinx.ext.autodoc", "sphinx.ext.doctest", "sphinx.ext.intersphinx", "sphinx.ext.mathjax", "sphinx.ext.viewcode", + "sphinx_gallery.gen_gallery", ] # Disable documentation inheritance so as to avoid inheriting @@ -76,7 +80,13 @@ # You can specify multiple suffix as a list of string: # # source_suffix = ['.rst', '.md'] -source_suffix = ".rst" +source_suffix = [".rst", ".ipynb"] + +# do not execute cells +nbsphinx_execute = "never" + +# Don't add .txt suffix to source files: +html_sourcelink_suffix = "" # The master toctree document. master_doc = "index" @@ -91,7 +101,11 @@ # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path . -exclude_patterns = [] +exclude_patterns = [ + ".ipynb_checkpoints", + "examples/*ipynb", + "examples/*py", +] # The name of the Pygments (syntax highlighting) style to use. pygments_style = "sphinx" @@ -100,6 +114,61 @@ # do not prepend module name to functions add_module_names = False + +# This is processed by Jinja2 and inserted before each notebook +nbsphinx_prolog = r""" +{% set docname = 'tutorials/' + env.doc2path(env.docname, base=None).split('/')[-1] %} +:github_url: https://github.com/pyro-ppl/funsor/blob/master/{{ docname }} +.. raw:: html +
    + Interactive online version: + + + Open In Colab + + +
    +""" # noqa: E501 + + +# -- Copy notebook files + +if not os.path.exists("tutorials"): + os.makedirs("tutorials") + +for src_file in glob.glob("../../tutorials/*.ipynb"): + dst_file = os.path.join("tutorials", src_file.split("/")[-1]) + shutil.copy(src_file, "tutorials/") + + +# -- Convert scripts to notebooks + +sphinx_gallery_conf = { + "examples_dirs": ["../../examples"], + "gallery_dirs": ["examples"], + # only execute files beginning with plot_ + "filename_pattern": "/plot_", + # 'ignore_pattern': '(minipyro|__init__)', + # not display Total running time of the script because we do not execute it + "min_reported_time": 1, +} + + +# -- Add thumbnails images + +nbsphinx_thumbnails = {} + +for src_file in glob.glob("../../tutorials/*.ipynb") + glob.glob("../../examples/*.py"): + toctree_path = "tutorials/" if src_file.endswith("ipynb") else "examples/" + filename = os.path.splitext(src_file.split("/")[-1])[0] + png_path = "_static/img/" + toctree_path + filename + ".png" + # use Pyro logo if not exist png file + if not os.path.exists(png_path): + png_path = "_static/img/pyro_logo_wide.png" + nbsphinx_thumbnails[toctree_path + filename] = png_path + + # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for diff --git a/docs/source/index.rst b/docs/source/index.rst index 16b7f6252..9401cc1c5 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -32,6 +32,22 @@ Funsor is a tensor-like library for functions and distributions minipyro einsum +.. nbgallery:: + :maxdepth: 1 + :caption: Tutorials and Examples + :name: tutorials-and-examples + + tutorials/sum_product_network + examples/mixed_hmm/* + examples/discrete_hmm + examples/eeg_slds + examples/kalman_filter + examples/minipyro + examples/pcfg + examples/sensor + examples/slds + examples/vae + Indices and tables ================== diff --git a/examples/README.rst b/examples/README.rst new file mode 100644 index 000000000..5354631a5 --- /dev/null +++ b/examples/README.rst @@ -0,0 +1,2 @@ +Code Examples +============= \ No newline at end of file diff --git a/examples/discrete_hmm.py b/examples/discrete_hmm.py index f2bfc8fea..8618f9bd0 100644 --- a/examples/discrete_hmm.py +++ b/examples/discrete_hmm.py @@ -1,6 +1,12 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +""" +Example: Discrete HMM +===================== + +""" + import argparse from collections import OrderedDict diff --git a/examples/eeg_slds.py b/examples/eeg_slds.py index e826912c7..ec4bac340 100644 --- a/examples/eeg_slds.py +++ b/examples/eeg_slds.py @@ -2,6 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 """ +Example: Switching Linear Dynamical System EEG +================================================= + We use a switching linear dynamical system [1] to model a EEG time series dataset. For inference we use a moment-matching approximation enabled by `funsor.interpretation(funsor.terms.moment_matching)`. @@ -10,6 +13,7 @@ [1] Anderson, B., and J. Moore. "Optimal filtering. Prentice-Hall, Englewood Cliffs." New Jersey (1979). """ + import argparse import time from collections import OrderedDict diff --git a/examples/kalman_filter.py b/examples/kalman_filter.py index f839566c0..1782edc3a 100644 --- a/examples/kalman_filter.py +++ b/examples/kalman_filter.py @@ -1,6 +1,12 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +""" +Example: Kalman Filter +====================== + +""" + import argparse import torch diff --git a/examples/minipyro.py b/examples/minipyro.py index ff022f0f5..15cae9847 100644 --- a/examples/minipyro.py +++ b/examples/minipyro.py @@ -1,6 +1,12 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +""" +Example: Mini Pyro +================== + +""" + import argparse import torch diff --git a/examples/mixed_hmm/model.py b/examples/mixed_hmm/model.py index 4baabea8f..8da984a1c 100644 --- a/examples/mixed_hmm/model.py +++ b/examples/mixed_hmm/model.py @@ -164,12 +164,7 @@ def initialize_params(self): params["e_i"]["probs"] = Tensor( pyro.param( "probs_e_i", - lambda: torch.randn( - ( - N_c, - N_v, - ) - ).abs(), + lambda: torch.randn((N_c, N_v,)).abs(), constraint=constraints.simplex, ), OrderedDict([("g", Bint[N_c])]), # different value per group @@ -329,8 +324,7 @@ def __call__(self): # initialize gamma to uniform gamma = Tensor( - torch.zeros((N_state, N_state)), - OrderedDict([("y_prev", Bint[N_state])]), + torch.zeros((N_state, N_state)), OrderedDict([("y_prev", Bint[N_state])]), ) N_v = self.config["sizes"]["random"] diff --git a/examples/pcfg.py b/examples/pcfg.py index 82bfd1047..fe5fd5144 100644 --- a/examples/pcfg.py +++ b/examples/pcfg.py @@ -1,6 +1,12 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +""" +Example: PCFG +============= + +""" + import argparse import math from collections import OrderedDict diff --git a/examples/sensor.py b/examples/sensor.py index 16af482ae..3e23e4dca 100644 --- a/examples/sensor.py +++ b/examples/sensor.py @@ -1,6 +1,12 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +""" +Example: Biased Kalman Filter +============================= + +""" + import argparse import itertools import math diff --git a/examples/slds.py b/examples/slds.py index 10aaed8ff..c88b3da95 100644 --- a/examples/slds.py +++ b/examples/slds.py @@ -1,6 +1,12 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +""" +Example: Switching Linear Dynamical System +========================================== + +""" + import argparse import torch @@ -19,10 +25,7 @@ def main(args): ) trans_noise = funsor.Tensor( torch.tensor( - [ - 0.1, # low noise component - 1.0, # high noisy component - ], + [0.1, 1.0,], # low noise component # high noisy component requires_grad=True, ) ) diff --git a/examples/vae.py b/examples/vae.py index 3f938e34e..343d413e4 100644 --- a/examples/vae.py +++ b/examples/vae.py @@ -1,6 +1,12 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +""" +Example: VAE MNIST +================== + +""" + import argparse import os from collections import OrderedDict diff --git a/funsor/distribution.py b/funsor/distribution.py index b179905a5..e8b666509 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -379,10 +379,7 @@ def dist_init(self, **kwargs): dist_class = DistributionMeta( backend_dist_class.__name__.split("Wrapper_")[-1], (Distribution,), - { - "dist_class": backend_dist_class, - "__init__": dist_init, - }, + {"dist_class": backend_dist_class, "__init__": dist_init,}, ) if generate_eager: diff --git a/funsor/interpreter.py b/funsor/interpreter.py index 759d75587..54dda7183 100644 --- a/funsor/interpreter.py +++ b/funsor/interpreter.py @@ -80,8 +80,7 @@ def interpret(cls, *args): def interpretation(new): warnings.warn( - "'with interpretation(x)' should be replaced by 'with x'", - DeprecationWarning, + "'with interpretation(x)' should be replaced by 'with x'", DeprecationWarning, ) return new diff --git a/funsor/jax/__init__.py b/funsor/jax/__init__.py index ffefedd58..10fa35780 100644 --- a/funsor/jax/__init__.py +++ b/funsor/jax/__init__.py @@ -18,13 +18,7 @@ @adjoint_ops.register( - Tensor, - AssociativeOp, - AssociativeOp, - Funsor, - (DeviceArray, Tracer), - tuple, - object, + Tensor, AssociativeOp, AssociativeOp, Funsor, (DeviceArray, Tracer), tuple, object, ) def adjoint_tensor(adj_redop, adj_binop, out_adj, data, inputs, dtype): return {} diff --git a/funsor/joint.py b/funsor/joint.py index cdc2a092c..cb5693f98 100644 --- a/funsor/joint.py +++ b/funsor/joint.py @@ -104,9 +104,7 @@ def moment_matching_contract_joint(red_op, bin_op, reduced_vars, discrete, gauss discrete += gaussian.log_normalizer new_discrete = discrete.reduce(ops.logaddexp, approx_vars & discrete.input_vars) num_elements = reduce( - ops.mul, - [v.output.num_elements for v in approx_vars - discrete.input_vars], - 1, + ops.mul, [v.output.num_elements for v in approx_vars - discrete.input_vars], 1, ) if num_elements != 1: new_discrete -= math.log(num_elements) diff --git a/funsor/syntax.py b/funsor/syntax.py index ed2b7d6eb..05eb123be 100644 --- a/funsor/syntax.py +++ b/funsor/syntax.py @@ -59,10 +59,7 @@ def visit_UnaryOp(self, node): var = self.prefix.get(type(node.op)) if var is not None: node = ast.Call( - func=ast.Name( - id=var, - ctx=ast.Load(), - ), + func=ast.Name(id=var, ctx=ast.Load(),), args=[node.operand], keywords=[], ) @@ -73,10 +70,7 @@ def visit_BinOp(self, node): var = self.infix.get(type(node.op)) if var is not None: node = ast.Call( - func=ast.Name( - id=var, - ctx=ast.Load(), - ), + func=ast.Name(id=var, ctx=ast.Load(),), args=[node.left, node.right], keywords=[], ) @@ -98,10 +92,7 @@ def visit_Compare(self, node): var = self.infix.get(type(node_op)) if var is not None: node = ast.Call( - func=ast.Name( - id=var, - ctx=ast.Load(), - ), + func=ast.Name(id=var, ctx=ast.Load(),), args=[node.left, node_right], keywords=[], ) diff --git a/funsor/testing.py b/funsor/testing.py index ef5dc7ef9..1979734b9 100644 --- a/funsor/testing.py +++ b/funsor/testing.py @@ -114,9 +114,9 @@ def assert_close(actual, expected, atol=1e-6, rtol=1e-6): n for n, p in expected.terms ) actual = actual.align(tuple(n for n, p in expected.terms)) - for (actual_name, (actual_point, actual_log_density)), ( - expected_name, - (expected_point, expected_log_density), + for ( + (actual_name, (actual_point, actual_log_density)), + (expected_name, (expected_point, expected_log_density),), ) in zip(actual.terms, expected.terms): assert actual_name == expected_name assert_close(actual_point, expected_point, atol=atol, rtol=rtol) diff --git a/setup.py b/setup.py index e94d4e205..042d63859 100644 --- a/setup.py +++ b/setup.py @@ -27,9 +27,7 @@ description="A tensor-like library for functions and distributions", packages=find_packages(include=["funsor", "funsor.*"]), url="https://github.com/pyro-ppl/funsor", - project_urls={ - "Documentation": "https://funsor.pyro.ai", - }, + project_urls={"Documentation": "https://funsor.pyro.ai",}, author="Uber AI Labs", python_requires=">=3.6", install_requires=[ @@ -39,15 +37,8 @@ "opt_einsum>=2.3.2", ], extras_require={ - "torch": [ - "pyro-ppl>=1.5.2", - "torch>=1.7.0", - ], - "jax": [ - "numpyro>=0.2.4", - "jax>=0.1.57", - "jaxlib>=0.1.37", - ], + "torch": ["pyro-ppl>=1.5.2", "torch>=1.7.0",], + "jax": ["numpyro>=0.2.4", "jax>=0.1.57", "jaxlib>=0.1.37",], "test": [ "black", "flake8", @@ -63,11 +54,13 @@ "black", "flake8", "isort>=5.0", + "nbsphinx", "pandas", "pytest==4.3.1", "pytest-xdist==1.27.0", "scipy", "sphinx>=2.0", + "sphinx-gallery", "sphinx_rtd_theme", "torchvision", ], diff --git a/test/examples/test_bart.py b/test/examples/test_bart.py index afcdb7a27..c30f1a9f8 100644 --- a/test/examples/test_bart.py +++ b/test/examples/test_bart.py @@ -52,10 +52,7 @@ def unpack_gate_rate(gate_rate): @pytest.mark.parametrize( "analytic_kl", - [ - False, - xfail_param(True, reason="missing pattern"), - ], + [False, xfail_param(True, reason="missing pattern"),], ids=["monte-carlo-kl", "analytic-kl"], ) def test_bart(analytic_kl): @@ -96,16 +93,7 @@ def test_bart(analytic_kl): ], dtype=torch.float32, ), # noqa - ( - ( - "time_b4", - Bint[2], - ), - ( - "_event_1_b2", - Bint[8], - ), - ), + (("time_b4", Bint[2],), ("_event_1_b2", Bint[8],),), "real", ), Gaussian( @@ -160,18 +148,9 @@ def test_bart(analytic_kl): dtype=torch.float32, ), # noqa ( - ( - "time_b4", - Bint[2], - ), - ( - "_event_1_b2", - Bint[8], - ), - ( - "value_b1", - Real, - ), + ("time_b4", Bint[2],), + ("_event_1_b2", Bint[8],), + ("value_b1", Real,), ), ), ), @@ -241,14 +220,8 @@ def test_bart(analytic_kl): dtype=torch.float32, ), # noqa ( - ( - "state_b7", - Reals[2], - ), - ( - "state(time=1)_b8", - Reals[2], - ), + ("state_b7", Reals[2],), + ("state(time=1)_b8", Reals[2],), ), ), Subs( @@ -308,12 +281,7 @@ def test_bart(analytic_kl): ], dtype=torch.float32, ), # noqa - ( - ( - "time_b9", - Bint[2], - ), - ), + (("time_b9", Bint[2],),), "real", ), Tensor( @@ -342,12 +310,7 @@ def test_bart(analytic_kl): ], dtype=torch.float32, ), # noqa - ( - ( - "time_b9", - Bint[2], - ), - ), + (("time_b9", Bint[2],),), "real", ), Variable("state(time=1)_b8", Reals[2]), @@ -389,12 +352,7 @@ def test_bart(analytic_kl): ), Variable("value_b5", Reals[2]), ), - ( - ( - "value_b5", - Variable("state_b10", Reals[2]), - ), - ), + (("value_b5", Variable("state_b10", Reals[2]),),), ), ), ) @@ -491,18 +449,9 @@ def test_bart(analytic_kl): dtype=torch.float32, ), # noqa ( - ( - "time_b17", - Bint[2], - ), - ( - "origin_b15", - Bint[2], - ), - ( - "destin_b16", - Bint[2], - ), + ("time_b17", Bint[2],), + ("origin_b15", Bint[2],), + ("destin_b16", Bint[2],), ), "real", ), @@ -527,18 +476,9 @@ def test_bart(analytic_kl): dtype=torch.float32, ), # noqa ( - ( - "time_b17", - Bint[2], - ), - ( - "origin_b15", - Bint[2], - ), - ( - "destin_b16", - Bint[2], - ), + ("time_b17", Bint[2],), + ("origin_b15", Bint[2],), + ("destin_b16", Bint[2],), ), "real", ), diff --git a/test/examples/test_sensor_fusion.py b/test/examples/test_sensor_fusion.py index f8fc8b77f..48dc52e31 100644 --- a/test/examples/test_sensor_fusion.py +++ b/test/examples/test_sensor_fusion.py @@ -142,16 +142,7 @@ def test_affine_subs(): ], dtype=torch.float32, ), # noqa - ( - ( - "state_1_b6", - Reals[3], - ), - ( - "obs_b2", - Reals[2], - ), - ), + (("state_1_b6", Reals[3],), ("obs_b2", Reals[2],),), ), ( ( diff --git a/test/pyro/test_hmm.py b/test/pyro/test_hmm.py index 69d645ca5..28614a305 100644 --- a/test/pyro/test_hmm.py +++ b/test/pyro/test_hmm.py @@ -245,134 +245,27 @@ def test_gaussian_mrf_log_prob(init_shape, trans_shape, obs_shape, hidden_dim, o ] ) SLHMM_SHAPES = [ + ((2,), (), (1, 2,), (1, 3, 3), (1,), (1, 3, 4), (1,),), + ((2,), (), (5, 1, 2,), (1, 3, 3), (1,), (1, 3, 4), (1,),), + ((2,), (), (1, 2,), (5, 1, 3, 3), (1,), (1, 3, 4), (1,),), + ((2,), (), (1, 2,), (1, 3, 3), (5, 1), (1, 3, 4), (1,),), + ((2,), (), (1, 2,), (1, 3, 3), (1,), (5, 1, 3, 4), (1,),), + ((2,), (), (1, 2,), (1, 3, 3), (1,), (1, 3, 4), (5, 1),), + ((2,), (), (5, 1, 2,), (5, 1, 3, 3), (5, 1), (5, 1, 3, 4), (5, 1),), + ((2,), (2,), (5, 2, 2,), (5, 2, 3, 3), (5, 2), (5, 2, 3, 4), (5, 2),), ( - (2,), + (7, 2,), (), - ( - 1, - 2, - ), - (1, 3, 3), - (1,), - (1, 3, 4), - (1,), - ), - ( - (2,), - (), - ( - 5, - 1, - 2, - ), - (1, 3, 3), - (1,), - (1, 3, 4), - (1,), - ), - ( - (2,), - (), - ( - 1, - 2, - ), - (5, 1, 3, 3), - (1,), - (1, 3, 4), - (1,), - ), - ( - (2,), - (), - ( - 1, - 2, - ), - (1, 3, 3), - (5, 1), - (1, 3, 4), - (1,), - ), - ( - (2,), - (), - ( - 1, - 2, - ), - (1, 3, 3), - (1,), - (5, 1, 3, 4), - (1,), - ), - ( - (2,), - (), - ( - 1, - 2, - ), - (1, 3, 3), - (1,), - (1, 3, 4), - (5, 1), - ), - ( - (2,), - (), - ( - 5, - 1, - 2, - ), - (5, 1, 3, 3), - (5, 1), - (5, 1, 3, 4), - (5, 1), - ), - ( - (2,), - (2,), - ( - 5, - 2, - 2, - ), - (5, 2, 3, 3), - (5, 2), - (5, 2, 3, 4), - (5, 2), - ), - ( - ( - 7, - 2, - ), - (), - ( - 7, - 5, - 1, - 2, - ), + (7, 5, 1, 2,), (7, 5, 1, 3, 3), (7, 5, 1), (7, 5, 1, 3, 4), (7, 5, 1), ), ( - ( - 7, - 2, - ), + (7, 2,), (7, 2), - ( - 7, - 5, - 2, - 2, - ), + (7, 5, 2, 2,), (7, 5, 2, 3, 3), (7, 5, 2), (7, 5, 2, 3, 4), @@ -518,13 +411,7 @@ def test_switching_linear_hmm_log_prob_alternating(exact, num_steps, num_compone -1, num_components, -1, -1 ) - trans_mvn = random_mvn( - ( - num_steps, - num_components, - ), - hidden_dim, - ) + trans_mvn = random_mvn((num_steps, num_components,), hidden_dim,) hmm_obs_matrix = torch.randn(num_steps, hidden_dim, obs_dim) switching_obs_matrix = hmm_obs_matrix.unsqueeze(-3).expand( -1, num_components, -1, -1 diff --git a/test/test_adjoint.py b/test/test_adjoint.py index b098c1c95..0e2d6407b 100644 --- a/test/test_adjoint.py +++ b/test/test_adjoint.py @@ -201,12 +201,7 @@ def test_optimized_plated_einsum_adjoint(equation, plates, backend): ids=lambda d: ",".join(d.keys()), ) @pytest.mark.parametrize( - "impl", - [ - sequential_sum_product, - naive_sequential_sum_product, - MarkovProduct, - ], + "impl", [sequential_sum_product, naive_sequential_sum_product, MarkovProduct,], ) def test_sequential_sum_product_adjoint( impl, sum_op, prod_op, batch_inputs, state_domain, num_steps diff --git a/test/test_distribution.py b/test/test_distribution.py index d0794a1eb..33d4c1655 100644 --- a/test/test_distribution.py +++ b/test/test_distribution.py @@ -1459,10 +1459,7 @@ def test_power_transform(shape): @pytest.mark.parametrize("shape", [(10,), (4, 3)], ids=str) @pytest.mark.parametrize( "to_event", - [ - True, - xfail_param(False, reason="bug in to_funsor(TransformedDistribution)"), - ], + [True, xfail_param(False, reason="bug in to_funsor(TransformedDistribution)"),], ) def test_haar_transform(shape, to_event): try: diff --git a/test/test_distribution_generic.py b/test/test_distribution_generic.py index 5ffa99bf6..66c5ba466 100644 --- a/test/test_distribution_generic.py +++ b/test/test_distribution_generic.py @@ -186,9 +186,7 @@ def __hash__(self): # Chi2 DistTestCase( - "dist.Chi2(df=case.df)", - (("df", f"rand({batch_shape})"),), - funsor.Real, + "dist.Chi2(df=case.df)", (("df", f"rand({batch_shape})"),), funsor.Real, ) # ContinuousBernoulli diff --git a/test/test_domains.py b/test/test_domains.py index d721ee03e..fb1461412 100644 --- a/test/test_domains.py +++ b/test/test_domains.py @@ -10,13 +10,7 @@ @pytest.mark.parametrize( - "expr", - [ - "Bint[2]", - "Real", - "Reals[4]", - "Reals[3, 2]", - ], + "expr", ["Bint[2]", "Real", "Reals[4]", "Reals[3, 2]",], ) def test_pickle(expr): x = eval(expr) diff --git a/test/test_factory.py b/test/test_factory.py index a9fe8b78c..b28df08d9 100644 --- a/test/test_factory.py +++ b/test/test_factory.py @@ -19,9 +19,7 @@ def test_lambda_lambda(): @make_funsor def LambdaLambda( - i: Bound, - j: Bound, - x: Funsor, + i: Bound, j: Bound, x: Funsor, ) -> Fresh[lambda i, j, x: Array[x.dtype, (i.size, j.size) + x.shape]]: assert i in x.inputs assert j in x.inputs @@ -51,10 +49,7 @@ def GetitemGetitem( def test_flatten(): @make_funsor def Flatten21( - x: Funsor, - i: Bound, - j: Bound, - ij: Fresh[lambda i, j: Bint[i.size * j.size]], + x: Funsor, i: Bound, j: Bound, ij: Fresh[lambda i, j: Bint[i.size * j.size]], ) -> Fresh[lambda x: x.dtype]: m = to_funsor(i, x.inputs.get(i, None)).output.size n = to_funsor(j, x.inputs.get(j, None)).output.size @@ -120,9 +115,7 @@ def Cat2( def test_normal(): @make_funsor def Normal( - loc: Funsor, - scale: Funsor, - value: Fresh[lambda loc: loc], + loc: Funsor, scale: Funsor, value: Fresh[lambda loc: loc], ) -> Fresh[Real]: return None @@ -147,11 +140,7 @@ def _(loc, scale, value): def test_matmul(): @make_funsor - def MatMul( - x: Funsor, - y: Funsor, - i: Bound, - ) -> Fresh[lambda x: x]: + def MatMul(x: Funsor, y: Funsor, i: Bound,) -> Fresh[lambda x: x]: return (x * y).reduce(ops.add, i) x = random_tensor(OrderedDict(a=Bint[3], b=Bint[4])) @@ -182,8 +171,7 @@ def Scatter1( def test_value_dependence(): @make_funsor def Sum( - x: Funsor, - dim: Value[int], + x: Funsor, dim: Value[int], ) -> Fresh[lambda x, dim: Array[x.dtype, x.shape[:dim] + x.shape[dim + 1 :]]]: return None diff --git a/test/test_gaussian.py b/test/test_gaussian.py index d9e66af02..369e908c3 100644 --- a/test/test_gaussian.py +++ b/test/test_gaussian.py @@ -577,20 +577,11 @@ def test_reduce_logsumexp(int_inputs, real_inputs): @pytest.mark.parametrize( - "int_inputs", - [ - {}, - {"i": Bint[2]}, - ], - ids=id_from_inputs, + "int_inputs", [{}, {"i": Bint[2]},], ids=id_from_inputs, ) @pytest.mark.parametrize( "real_inputs", - [ - {"x": Real}, - {"x": Reals[4]}, - {"x": Reals[2, 3]}, - ], + [{"x": Real}, {"x": Reals[4]}, {"x": Reals[2, 3]},], ids=id_from_inputs, ) def test_integrate_variable(int_inputs, real_inputs): diff --git a/test/test_minipyro.py b/test/test_minipyro.py index 5224ab25c..db654100d 100644 --- a/test/test_minipyro.py +++ b/test/test_minipyro.py @@ -36,8 +36,9 @@ def Vindex(x): def _check_loss_and_grads(expected_loss, actual_loss, atol=1e-4, rtol=1e-4): # copied from pyro - expected_loss, actual_loss = funsor.to_data(expected_loss), funsor.to_data( - actual_loss + expected_loss, actual_loss = ( + funsor.to_data(expected_loss), + funsor.to_data(actual_loss), ) assert ops.allclose(actual_loss, expected_loss, atol=atol, rtol=rtol) names = pyro.get_param_store().keys() @@ -301,11 +302,7 @@ def guide(): @pytest.mark.parametrize( - "backend", - [ - "pyro", - xfail_param("funsor", reason="missing patterns"), - ], + "backend", ["pyro", xfail_param("funsor", reason="missing patterns"),], ) def test_mean_field_ok(backend): def model(): @@ -323,11 +320,7 @@ def guide(): @pytest.mark.parametrize( - "backend", - [ - "pyro", - xfail_param("funsor", reason="missing patterns"), - ], + "backend", ["pyro", xfail_param("funsor", reason="missing patterns"),], ) def test_mean_field_warn(backend): def model(): diff --git a/test/test_optimizer.py b/test/test_optimizer.py index ee16c9d75..c8f0d4550 100644 --- a/test/test_optimizer.py +++ b/test/test_optimizer.py @@ -45,18 +45,10 @@ @pytest.mark.parametrize("equation", OPTIMIZED_EINSUM_EXAMPLES) @pytest.mark.parametrize( - "backend", - [ - "pyro.ops.einsum.torch_log", - "pyro.ops.einsum.torch_map", - ], + "backend", ["pyro.ops.einsum.torch_log", "pyro.ops.einsum.torch_map",], ) @pytest.mark.parametrize( - "einsum_impl", - [ - naive_einsum, - naive_contract_einsum, - ], + "einsum_impl", [naive_einsum, naive_contract_einsum,], ) def test_optimized_einsum(equation, backend, einsum_impl): inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(equation) @@ -79,11 +71,7 @@ def test_optimized_einsum(equation, backend, einsum_impl): @pytest.mark.parametrize( - "eqn1,eqn2", - [ - ("a,ab->b", "bc->"), - ("ab,bc,cd->d", "de,ef,fg->"), - ], + "eqn1,eqn2", [("a,ab->b", "bc->"), ("ab,bc,cd->d", "de,ef,fg->"),], ) @pytest.mark.parametrize("optimize1", [False, True]) @pytest.mark.parametrize("optimize2", [False, True]) @@ -151,11 +139,7 @@ def test_nested_einsum( @pytest.mark.parametrize("equation,plates", PLATED_EINSUM_EXAMPLES) @pytest.mark.parametrize( - "backend", - [ - "pyro.ops.einsum.torch_log", - "pyro.ops.einsum.torch_map", - ], + "backend", ["pyro.ops.einsum.torch_log", "pyro.ops.einsum.torch_map",], ) def test_optimized_plated_einsum(equation, plates, backend): inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(equation) diff --git a/test/test_samplers.py b/test/test_samplers.py index aafc29a21..68e7fd075 100644 --- a/test/test_samplers.py +++ b/test/test_samplers.py @@ -37,28 +37,17 @@ @pytest.mark.parametrize( "sample_inputs", - [ - (), - (("s", Bint[6]),), - (("s", Bint[6]), ("t", Bint[7])), - ], + [(), (("s", Bint[6]),), (("s", Bint[6]), ("t", Bint[7])),], ids=id_from_inputs, ) @pytest.mark.parametrize( "batch_inputs", - [ - (), - (("b", Bint[4]),), - (("b", Bint[4]), ("c", Bint[5])), - ], + [(), (("b", Bint[4]),), (("b", Bint[4]), ("c", Bint[5])),], ids=id_from_inputs, ) @pytest.mark.parametrize( "event_inputs", - [ - (("e", Bint[2]),), - (("e", Bint[2]), ("f", Bint[3])), - ], + [(("e", Bint[2]),), (("e", Bint[2]), ("f", Bint[3])),], ids=id_from_inputs, ) def test_tensor_shape(sample_inputs, batch_inputs, event_inputs): @@ -92,29 +81,17 @@ def test_tensor_shape(sample_inputs, batch_inputs, event_inputs): @pytest.mark.parametrize( "sample_inputs", - [ - (), - (("s", Bint[3]),), - (("s", Bint[3]), ("t", Bint[4])), - ], + [(), (("s", Bint[3]),), (("s", Bint[3]), ("t", Bint[4])),], ids=id_from_inputs, ) @pytest.mark.parametrize( "batch_inputs", - [ - (), - (("b", Bint[2]),), - (("c", Real),), - (("b", Bint[2]), ("c", Real)), - ], + [(), (("b", Bint[2]),), (("c", Real),), (("b", Bint[2]), ("c", Real)),], ids=id_from_inputs, ) @pytest.mark.parametrize( "event_inputs", - [ - (("e", Real),), - (("e", Real), ("f", Reals[2])), - ], + [(("e", Real),), (("e", Real), ("f", Reals[2])),], ids=id_from_inputs, ) def test_gaussian_shape(sample_inputs, batch_inputs, event_inputs): @@ -155,29 +132,17 @@ def test_gaussian_shape(sample_inputs, batch_inputs, event_inputs): @pytest.mark.parametrize( "sample_inputs", - [ - (), - (("s", Bint[3]),), - (("s", Bint[3]), ("t", Bint[4])), - ], + [(), (("s", Bint[3]),), (("s", Bint[3]), ("t", Bint[4])),], ids=id_from_inputs, ) @pytest.mark.parametrize( "batch_inputs", - [ - (), - (("b", Bint[2]),), - (("c", Real),), - (("b", Bint[2]), ("c", Real)), - ], + [(), (("b", Bint[2]),), (("c", Real),), (("b", Bint[2]), ("c", Real)),], ids=id_from_inputs, ) @pytest.mark.parametrize( "event_inputs", - [ - (("e", Real),), - (("e", Real), ("f", Reals[2])), - ], + [(("e", Real),), (("e", Real), ("f", Reals[2])),], ids=id_from_inputs, ) def test_transformed_gaussian_shape(sample_inputs, batch_inputs, event_inputs): @@ -226,28 +191,17 @@ def test_transformed_gaussian_shape(sample_inputs, batch_inputs, event_inputs): @pytest.mark.parametrize( "sample_inputs", - [ - (), - (("s", Bint[6]),), - (("s", Bint[6]), ("t", Bint[7])), - ], + [(), (("s", Bint[6]),), (("s", Bint[6]), ("t", Bint[7])),], ids=id_from_inputs, ) @pytest.mark.parametrize( "int_event_inputs", - [ - (), - (("d", Bint[2]),), - (("d", Bint[2]), ("e", Bint[3])), - ], + [(), (("d", Bint[2]),), (("d", Bint[2]), ("e", Bint[3])),], ids=id_from_inputs, ) @pytest.mark.parametrize( "real_event_inputs", - [ - (("g", Real),), - (("g", Real), ("h", Reals[4])), - ], + [(("g", Real),), (("g", Real), ("h", Reals[4])),], ids=id_from_inputs, ) def test_joint_shape(sample_inputs, int_event_inputs, real_event_inputs): @@ -289,19 +243,12 @@ def test_joint_shape(sample_inputs, int_event_inputs, real_event_inputs): @pytest.mark.parametrize( "batch_inputs", - [ - (), - (("b", Bint[4]),), - (("b", Bint[2]), ("c", Bint[2])), - ], + [(), (("b", Bint[4]),), (("b", Bint[2]), ("c", Bint[2])),], ids=id_from_inputs, ) @pytest.mark.parametrize( "event_inputs", - [ - (("e", Bint[3]),), - (("e", Bint[2]), ("f", Bint[2])), - ], + [(("e", Bint[3]),), (("e", Bint[2]), ("f", Bint[2])),], ids=id_from_inputs, ) @pytest.mark.parametrize("test_grad", [False, True], ids=["value", "grad"]) @@ -347,19 +294,12 @@ def diff_fn(p_data): @pytest.mark.parametrize( "batch_inputs", - [ - (), - (("b", Bint[3]),), - (("b", Bint[3]), ("c", Bint[4])), - ], + [(), (("b", Bint[3]),), (("b", Bint[3]), ("c", Bint[4])),], ids=id_from_inputs, ) @pytest.mark.parametrize( "event_inputs", - [ - (("e", Real),), - (("e", Real), ("f", Reals[2])), - ], + [(("e", Real),), (("e", Real), ("f", Reals[2])),], ids=id_from_inputs, ) def test_gaussian_distribution(event_inputs, batch_inputs): @@ -396,19 +336,12 @@ def test_gaussian_distribution(event_inputs, batch_inputs): @pytest.mark.parametrize( "batch_inputs", - [ - (), - (("b", Bint[3]),), - (("b", Bint[3]), ("c", Bint[2])), - ], + [(), (("b", Bint[3]),), (("b", Bint[3]), ("c", Bint[2])),], ids=id_from_inputs, ) @pytest.mark.parametrize( "event_inputs", - [ - (("e", Real), ("f", Bint[3])), - (("e", Reals[2]), ("f", Bint[2])), - ], + [(("e", Real), ("f", Bint[3])), (("e", Reals[2]), ("f", Bint[2])),], ids=id_from_inputs, ) def test_gaussian_mixture_distribution(batch_inputs, event_inputs): diff --git a/test/test_sum_product.py b/test/test_sum_product.py index 8977cc22c..9fc82aae4 100644 --- a/test/test_sum_product.py +++ b/test/test_sum_product.py @@ -101,11 +101,7 @@ def test_partition(inputs, dims, expected_num_components): ], ) @pytest.mark.parametrize( - "impl", - [ - partial_sum_product, - modified_partial_sum_product, - ], + "impl", [partial_sum_product, modified_partial_sum_product,], ) def test_partial_sum_product(impl, sum_op, prod_op, inputs, plates, vars1, vars2): inputs = inputs.split(",") @@ -145,12 +141,7 @@ def test_partial_sum_product(impl, sum_op, prod_op, inputs, plates, vars1, vars2 ], ) @pytest.mark.parametrize( - "x_dim,time", - [ - (3, 1), - (1, 5), - (3, 5), - ], + "x_dim,time", [(3, 1), (1, 5), (3, 5),], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -159,22 +150,10 @@ def test_modified_partial_sum_product_0(sum_op, prod_op, vars1, vars2, x_dim, ti f1 = random_tensor(OrderedDict({})) - f2 = random_tensor( - OrderedDict( - { - "x_0": Bint[x_dim], - } - ) - ) + f2 = random_tensor(OrderedDict({"x_0": Bint[x_dim],})) f3 = random_tensor( - OrderedDict( - { - "time": Bint[time], - "x_prev": Bint[x_dim], - "x_curr": Bint[x_dim], - } - ) + OrderedDict({"time": Bint[time], "x_prev": Bint[x_dim], "x_curr": Bint[x_dim],}) ) factors = [f1, f2, f3] @@ -207,13 +186,7 @@ def test_modified_partial_sum_product_0(sum_op, prod_op, vars1, vars2, x_dim, ti ], ) @pytest.mark.parametrize( - "x_dim,y_dim,time", - [ - (2, 3, 5), - (1, 3, 5), - (2, 1, 5), - (2, 3, 1), - ], + "x_dim,y_dim,time", [(2, 3, 5), (1, 3, 5), (2, 1, 5), (2, 3, 1),], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -224,41 +197,16 @@ def test_modified_partial_sum_product_1( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor( - OrderedDict( - { - "x_0": Bint[x_dim], - } - ) - ) + f2 = random_tensor(OrderedDict({"x_0": Bint[x_dim],})) f3 = random_tensor( - OrderedDict( - { - "time": Bint[time], - "x_prev": Bint[x_dim], - "x_curr": Bint[x_dim], - } - ) + OrderedDict({"time": Bint[time], "x_prev": Bint[x_dim], "x_curr": Bint[x_dim],}) ) - f4 = random_tensor( - OrderedDict( - { - "x_0": Bint[x_dim], - "y_0": Bint[y_dim], - } - ) - ) + f4 = random_tensor(OrderedDict({"x_0": Bint[x_dim], "y_0": Bint[y_dim],})) f5 = random_tensor( - OrderedDict( - { - "time": Bint[time], - "x_curr": Bint[x_dim], - "y_curr": Bint[y_dim], - } - ) + OrderedDict({"time": Bint[time], "x_curr": Bint[x_dim], "y_curr": Bint[y_dim],}) ) factors = [f1, f2, f3, f4, f5] @@ -296,13 +244,7 @@ def test_modified_partial_sum_product_1( ], ) @pytest.mark.parametrize( - "x_dim,y_dim,time", - [ - (2, 3, 5), - (1, 3, 5), - (2, 1, 5), - (2, 3, 1), - ], + "x_dim,y_dim,time", [(2, 3, 5), (1, 3, 5), (2, 1, 5), (2, 3, 1),], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -313,40 +255,16 @@ def test_modified_partial_sum_product_2( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor( - OrderedDict( - { - "x_0": Bint[x_dim], - } - ) - ) + f2 = random_tensor(OrderedDict({"x_0": Bint[x_dim],})) f3 = random_tensor( - OrderedDict( - { - "time": Bint[time], - "x_prev": Bint[x_dim], - "x_curr": Bint[x_dim], - } - ) + OrderedDict({"time": Bint[time], "x_prev": Bint[x_dim], "x_curr": Bint[x_dim],}) ) - f4 = random_tensor( - OrderedDict( - { - "y_0": Bint[y_dim], - } - ) - ) + f4 = random_tensor(OrderedDict({"y_0": Bint[y_dim],})) f5 = random_tensor( - OrderedDict( - { - "time": Bint[time], - "y_prev": Bint[y_dim], - "y_curr": Bint[y_dim], - } - ) + OrderedDict({"time": Bint[time], "y_prev": Bint[y_dim], "y_curr": Bint[y_dim],}) ) factors = [f1, f2, f3, f4, f5] @@ -386,13 +304,7 @@ def test_modified_partial_sum_product_2( ], ) @pytest.mark.parametrize( - "x_dim,y_dim,time", - [ - (2, 3, 5), - (1, 3, 5), - (2, 1, 5), - (2, 3, 1), - ], + "x_dim,y_dim,time", [(2, 3, 5), (1, 3, 5), (2, 1, 5), (2, 3, 1),], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -403,32 +315,13 @@ def test_modified_partial_sum_product_3( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor( - OrderedDict( - { - "x_0": Bint[x_dim], - } - ) - ) + f2 = random_tensor(OrderedDict({"x_0": Bint[x_dim],})) f3 = random_tensor( - OrderedDict( - { - "time": Bint[time], - "x_prev": Bint[x_dim], - "x_curr": Bint[x_dim], - } - ) + OrderedDict({"time": Bint[time], "x_prev": Bint[x_dim], "x_curr": Bint[x_dim],}) ) - f4 = random_tensor( - OrderedDict( - { - "x_0": Bint[x_dim], - "y_0": Bint[y_dim], - } - ) - ) + f4 = random_tensor(OrderedDict({"x_0": Bint[x_dim], "y_0": Bint[y_dim],})) f5 = random_tensor( OrderedDict( @@ -509,12 +402,7 @@ def test_modified_partial_sum_product_3( ) @pytest.mark.parametrize( "x_dim,y_dim,sequences,time,tones", - [ - (2, 3, 2, 5, 4), - (1, 3, 2, 5, 4), - (2, 1, 2, 5, 4), - (2, 3, 2, 1, 4), - ], + [(2, 3, 2, 5, 4), (1, 3, 2, 5, 4), (2, 1, 2, 5, 4), (2, 3, 2, 1, 4),], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -525,14 +413,7 @@ def test_modified_partial_sum_product_4( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor( - OrderedDict( - { - "sequences": Bint[sequences], - "x_0": Bint[x_dim], - } - ) - ) + f2 = random_tensor(OrderedDict({"sequences": Bint[sequences], "x_0": Bint[x_dim],})) f3 = random_tensor( OrderedDict( @@ -547,11 +428,7 @@ def test_modified_partial_sum_product_4( f4 = random_tensor( OrderedDict( - { - "sequences": Bint[sequences], - "tones": Bint[tones], - "y_0": Bint[y_dim], - } + {"sequences": Bint[sequences], "tones": Bint[tones], "y_0": Bint[y_dim],} ) ) @@ -657,12 +534,7 @@ def test_modified_partial_sum_product_4( ) @pytest.mark.parametrize( "x_dim,y_dim,sequences,days,weeks,tones", - [ - (2, 3, 2, 5, 4, 3), - (1, 3, 2, 5, 4, 3), - (2, 1, 2, 5, 4, 3), - (2, 3, 2, 1, 4, 3), - ], + [(2, 3, 2, 5, 4, 3), (1, 3, 2, 5, 4, 3), (2, 1, 2, 5, 4, 3), (2, 3, 2, 1, 4, 3),], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -675,11 +547,7 @@ def test_modified_partial_sum_product_5( f2 = random_tensor( OrderedDict( - { - "sequences": Bint[sequences], - "tones": Bint[tones], - "x_0": Bint[x_dim], - } + {"sequences": Bint[sequences], "tones": Bint[tones], "x_0": Bint[x_dim],} ) ) @@ -695,14 +563,7 @@ def test_modified_partial_sum_product_5( ) ) - f4 = random_tensor( - OrderedDict( - { - "sequences": Bint[sequences], - "y_0": Bint[y_dim], - } - ) - ) + f4 = random_tensor(OrderedDict({"sequences": Bint[sequences], "y_0": Bint[y_dim],})) f5 = random_tensor( OrderedDict( @@ -786,12 +647,7 @@ def test_modified_partial_sum_product_5( ) @pytest.mark.parametrize( "x_dim,y_dim,sequences,time,tones", - [ - (2, 3, 2, 5, 4), - (1, 3, 2, 5, 4), - (2, 1, 2, 5, 4), - (2, 3, 2, 1, 4), - ], + [(2, 3, 2, 5, 4), (1, 3, 2, 5, 4), (2, 1, 2, 5, 4), (2, 3, 2, 1, 4),], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -802,14 +658,7 @@ def test_modified_partial_sum_product_6( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor( - OrderedDict( - { - "sequences": Bint[sequences], - "x_0": Bint[x_dim], - } - ) - ) + f2 = random_tensor(OrderedDict({"sequences": Bint[sequences], "x_0": Bint[x_dim],})) f3 = random_tensor( OrderedDict( @@ -915,12 +764,7 @@ def test_modified_partial_sum_product_6( ) @pytest.mark.parametrize( "x_dim,y_dim,sequences,time,tones", - [ - (2, 3, 2, 5, 4), - (1, 3, 2, 5, 4), - (2, 1, 2, 5, 4), - (2, 3, 2, 1, 4), - ], + [(2, 3, 2, 5, 4), (1, 3, 2, 5, 4), (2, 1, 2, 5, 4), (2, 3, 2, 1, 4),], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -931,14 +775,7 @@ def test_modified_partial_sum_product_7( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor( - OrderedDict( - { - "sequences": Bint[sequences], - "x_0": Bint[x_dim], - } - ) - ) + f2 = random_tensor(OrderedDict({"sequences": Bint[sequences], "x_0": Bint[x_dim],})) f3 = random_tensor( OrderedDict( @@ -978,12 +815,7 @@ def test_modified_partial_sum_product_7( factors = [f1, f2, f3, f4, f5] plate_to_step = { "sequences": {}, - "time": frozenset( - { - ("x_0", "x_prev", "x_curr"), - ("y_0", "y_prev", "y_curr"), - } - ), + "time": frozenset({("x_0", "x_prev", "x_curr"), ("y_0", "y_prev", "y_curr"),}), "tones": {}, } @@ -1072,12 +904,7 @@ def test_modified_partial_sum_product_7( ) @pytest.mark.parametrize( "w_dim,x_dim,y_dim,sequences,time,tones", - [ - (3, 2, 3, 2, 5, 4), - (3, 1, 3, 2, 5, 4), - (3, 2, 1, 2, 5, 4), - (3, 2, 3, 2, 1, 4), - ], + [(3, 2, 3, 2, 5, 4), (3, 1, 3, 2, 5, 4), (3, 2, 1, 2, 5, 4), (3, 2, 3, 2, 1, 4),], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -1088,14 +915,7 @@ def test_modified_partial_sum_product_8( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor( - OrderedDict( - { - "sequences": Bint[sequences], - "w_0": Bint[w_dim], - } - ) - ) + f2 = random_tensor(OrderedDict({"sequences": Bint[sequences], "w_0": Bint[w_dim],})) f3 = random_tensor( OrderedDict( @@ -1108,14 +928,7 @@ def test_modified_partial_sum_product_8( ) ) - f4 = random_tensor( - OrderedDict( - { - "sequences": Bint[sequences], - "x_0": Bint[x_dim], - } - ) - ) + f4 = random_tensor(OrderedDict({"sequences": Bint[sequences], "x_0": Bint[x_dim],})) f5 = random_tensor( OrderedDict( @@ -1156,12 +969,7 @@ def test_modified_partial_sum_product_8( factors = [f1, f2, f3, f4, f5, f6, f7] plate_to_step = { "sequences": {}, - "time": frozenset( - { - ("x_0", "x_prev", "x_curr"), - ("w_0", "w_prev", "w_curr"), - } - ), + "time": frozenset({("x_0", "x_prev", "x_curr"), ("w_0", "w_prev", "w_curr"),}), "tones": {}, } @@ -1259,12 +1067,7 @@ def test_modified_partial_sum_product_8( ) @pytest.mark.parametrize( "w_dim,x_dim,y_dim,sequences,time,tones", - [ - (3, 2, 3, 2, 5, 4), - (3, 1, 3, 2, 5, 4), - (3, 2, 1, 2, 5, 4), - (3, 2, 3, 2, 1, 4), - ], + [(3, 2, 3, 2, 5, 4), (3, 1, 3, 2, 5, 4), (3, 2, 1, 2, 5, 4), (3, 2, 3, 2, 1, 4),], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -1275,14 +1078,7 @@ def test_modified_partial_sum_product_9( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor( - OrderedDict( - { - "sequences": Bint[sequences], - "w_0": Bint[w_dim], - } - ) - ) + f2 = random_tensor(OrderedDict({"sequences": Bint[sequences], "w_0": Bint[w_dim],})) f3 = random_tensor( OrderedDict( @@ -1297,11 +1093,7 @@ def test_modified_partial_sum_product_9( f4 = random_tensor( OrderedDict( - { - "sequences": Bint[sequences], - "w_0": Bint[w_dim], - "x_0": Bint[x_dim], - } + {"sequences": Bint[sequences], "w_0": Bint[w_dim], "x_0": Bint[x_dim],} ) ) @@ -1345,12 +1137,7 @@ def test_modified_partial_sum_product_9( factors = [f1, f2, f3, f4, f5, f6, f7] plate_to_step = { "sequences": {}, - "time": frozenset( - { - ("x_0", "x_prev", "x_curr"), - ("w_0", "w_prev", "w_curr"), - } - ), + "time": frozenset({("x_0", "x_prev", "x_curr"), ("w_0", "w_prev", "w_curr"),}), "tones": {}, } @@ -1437,12 +1224,7 @@ def test_modified_partial_sum_product_9( ) @pytest.mark.parametrize( "w_dim,x_dim,y_dim,sequences,time,tones", - [ - (3, 2, 3, 2, 5, 4), - (3, 1, 3, 2, 5, 4), - (3, 2, 1, 2, 5, 4), - (3, 2, 3, 2, 1, 4), - ], + [(3, 2, 3, 2, 5, 4), (3, 1, 3, 2, 5, 4), (3, 2, 1, 2, 5, 4), (3, 2, 3, 2, 1, 4),], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -1453,32 +1235,17 @@ def test_modified_partial_sum_product_10( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor( - OrderedDict( - { - "sequences": Bint[sequences], - "w_0": Bint[w_dim], - } - ) - ) + f2 = random_tensor(OrderedDict({"sequences": Bint[sequences], "w_0": Bint[w_dim],})) f3 = random_tensor( OrderedDict( - { - "sequences": Bint[sequences], - "time": Bint[time], - "w_curr": Bint[w_dim], - } + {"sequences": Bint[sequences], "time": Bint[time], "w_curr": Bint[w_dim],} ) ) f4 = random_tensor( OrderedDict( - { - "sequences": Bint[sequences], - "w_0": Bint[w_dim], - "x_0": Bint[x_dim], - } + {"sequences": Bint[sequences], "w_0": Bint[w_dim], "x_0": Bint[x_dim],} ) ) @@ -1654,30 +1421,13 @@ def test_modified_partial_sum_product_11( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor( - OrderedDict( - { - "a": Bint[a_dim], - } - ) - ) + f2 = random_tensor(OrderedDict({"a": Bint[a_dim],})) - f3 = random_tensor( - OrderedDict( - { - "sequences": Bint[sequences], - "b": Bint[b_dim], - } - ) - ) + f3 = random_tensor(OrderedDict({"sequences": Bint[sequences], "b": Bint[b_dim],})) f4 = random_tensor( OrderedDict( - { - "a": Bint[a_dim], - "sequences": Bint[sequences], - "w_0": Bint[w_dim], - } + {"a": Bint[a_dim], "sequences": Bint[sequences], "w_0": Bint[w_dim],} ) ) @@ -1829,12 +1579,7 @@ def test_modified_partial_sum_product_11( ) @pytest.mark.parametrize( "w_dim,x_dim,y_dim,sequences,time,tones", - [ - (3, 2, 3, 2, 5, 4), - (3, 1, 3, 2, 5, 4), - (3, 2, 1, 2, 5, 4), - (3, 2, 3, 2, 1, 4), - ], + [(3, 2, 3, 2, 5, 4), (3, 1, 3, 2, 5, 4), (3, 2, 1, 2, 5, 4), (3, 2, 3, 2, 1, 4),], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -1845,22 +1590,11 @@ def test_modified_partial_sum_product_12( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor( - OrderedDict( - { - "sequences": Bint[sequences], - "w_0": Bint[w_dim], - } - ) - ) + f2 = random_tensor(OrderedDict({"sequences": Bint[sequences], "w_0": Bint[w_dim],})) f3 = random_tensor( OrderedDict( - { - "sequences": Bint[sequences], - "time": Bint[time], - "w_curr": Bint[w_dim], - } + {"sequences": Bint[sequences], "time": Bint[time], "w_curr": Bint[w_dim],} ) ) @@ -2069,11 +1803,7 @@ def test_modified_partial_sum_product_13( f4 = random_tensor( OrderedDict( - { - "w": Bint[w_dim], - "sequences": Bint[sequences], - "y_0": Bint[y_dim], - } + {"w": Bint[w_dim], "sequences": Bint[sequences], "y_0": Bint[y_dim],} ) ) @@ -2194,12 +1924,7 @@ def test_modified_partial_sum_product_13( ) @pytest.mark.parametrize( "x_dim,y_dim,sequences,time,tones", - [ - (2, 3, 2, 3, 2), - (1, 3, 2, 3, 2), - (2, 1, 2, 3, 2), - (2, 3, 2, 1, 2), - ], + [(2, 3, 2, 3, 2), (1, 3, 2, 3, 2), (2, 1, 2, 3, 2), (2, 3, 2, 1, 2),], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -2210,14 +1935,7 @@ def test_modified_partial_sum_product_14( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor( - OrderedDict( - { - "sequences": Bint[sequences], - "x_0": Bint[x_dim], - } - ) - ) + f2 = random_tensor(OrderedDict({"sequences": Bint[sequences], "x_0": Bint[x_dim],})) f3 = random_tensor( OrderedDict( @@ -2232,11 +1950,7 @@ def test_modified_partial_sum_product_14( f4 = random_tensor( OrderedDict( - { - "sequences": Bint[sequences], - "x_0": Bint[x_dim], - "y0_0": Bint[y_dim], - } + {"sequences": Bint[sequences], "x_0": Bint[x_dim], "y0_0": Bint[y_dim],} ) ) @@ -2281,10 +1995,7 @@ def test_modified_partial_sum_product_14( "sequences": {}, "time": frozenset({("x_0", "x_prev", "x_curr")}), "tones": frozenset( - { - ("y0_0", "y0_prev", "y0_curr"), - ("ycurr_0", "ycurr_prev", "ycurr_curr"), - } + {("y0_0", "y0_prev", "y0_curr"), ("ycurr_0", "ycurr_prev", "ycurr_curr"),} ), } @@ -2320,13 +2031,7 @@ def test_modified_partial_sum_product_14( ], ) @pytest.mark.parametrize( - "x_dim,y_dim,time", - [ - (2, 3, 5), - (1, 3, 5), - (2, 1, 5), - (2, 3, 1), - ], + "x_dim,y_dim,time", [(2, 3, 5), (1, 3, 5), (2, 1, 5), (2, 3, 1),], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -2337,50 +2042,21 @@ def test_modified_partial_sum_product_16( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor( - OrderedDict( - { - "x_0": Bint[x_dim], - } - ) - ) + f2 = random_tensor(OrderedDict({"x_0": Bint[x_dim],})) f3 = random_tensor( - OrderedDict( - { - "time": Bint[time], - "y_prev": Bint[y_dim], - "x_curr": Bint[x_dim], - } - ) + OrderedDict({"time": Bint[time], "y_prev": Bint[y_dim], "x_curr": Bint[x_dim],}) ) - f4 = random_tensor( - OrderedDict( - { - "y_0": Bint[y_dim], - } - ) - ) + f4 = random_tensor(OrderedDict({"y_0": Bint[y_dim],})) f5 = random_tensor( - OrderedDict( - { - "time": Bint[time], - "x_prev": Bint[x_dim], - "y_curr": Bint[y_dim], - } - ) + OrderedDict({"time": Bint[time], "x_prev": Bint[x_dim], "y_curr": Bint[y_dim],}) ) factors = [f1, f2, f3, f4, f5] plate_to_step = { - "time": frozenset( - { - ("x_0", "x_prev", "x_curr"), - ("y_0", "y_prev", "y_curr"), - } - ), + "time": frozenset({("x_0", "x_prev", "x_curr"), ("y_0", "y_prev", "y_curr"),}), } factors1 = modified_partial_sum_product( @@ -2450,13 +2126,7 @@ def test_modified_partial_sum_product_16( ], ) @pytest.mark.parametrize( - "x_dim,y_dim,z_dim,time", - [ - (2, 3, 2, 5), - (1, 3, 2, 5), - (2, 1, 2, 5), - (2, 3, 2, 1), - ], + "x_dim,y_dim,z_dim,time", [(2, 3, 2, 5), (1, 3, 2, 5), (2, 1, 2, 5), (2, 3, 2, 1),], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -2467,22 +2137,10 @@ def test_modified_partial_sum_product_17( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor( - OrderedDict( - { - "x_0": Bint[x_dim], - } - ) - ) + f2 = random_tensor(OrderedDict({"x_0": Bint[x_dim],})) f3 = random_tensor( - OrderedDict( - { - "time": Bint[time], - "x_prev": Bint[x_dim], - "x_curr": Bint[x_dim], - } - ) + OrderedDict({"time": Bint[time], "x_prev": Bint[x_dim], "x_curr": Bint[x_dim],}) ) f4 = random_tensor( @@ -2532,13 +2190,7 @@ def test_modified_partial_sum_product_17( ) f8 = random_tensor( - OrderedDict( - { - "x_0": Bint[x_dim], - "y_0": Bint[y_dim], - "z2_0": Bint[z_dim], - } - ) + OrderedDict({"x_0": Bint[x_dim], "y_0": Bint[y_dim], "z2_0": Bint[z_dim],}) ) f9 = random_tensor( @@ -2655,11 +2307,7 @@ def test_sequential_sum_product( ) @pytest.mark.parametrize( "x_domain,y_domain", - [ - (Bint[2], Bint[3]), - (Real, Reals[2, 2]), - (Bint[2], Reals[2]), - ], + [(Bint[2], Bint[3]), (Real, Reals[2, 2]), (Bint[2], Reals[2]),], ids=str, ) @pytest.mark.parametrize( @@ -2728,29 +2376,15 @@ def test_sequential_sum_product_multi( @pytest.mark.parametrize("dim", [1, 2, 3]) def test_sequential_sum_product_bias_1(num_steps, dim): time = Variable("time", Bint[num_steps]) - bias_dist = random_gaussian( - OrderedDict( - [ - ("bias", Reals[dim]), - ] - ) - ) + bias_dist = random_gaussian(OrderedDict([("bias", Reals[dim]),])) trans = random_gaussian( OrderedDict( - [ - ("time", Bint[num_steps]), - ("x_prev", Reals[dim]), - ("x_curr", Reals[dim]), - ] + [("time", Bint[num_steps]), ("x_prev", Reals[dim]), ("x_curr", Reals[dim]),] ) ) obs = random_gaussian( OrderedDict( - [ - ("time", Bint[num_steps]), - ("x_curr", Reals[dim]), - ("bias", Reals[dim]), - ] + [("time", Bint[num_steps]), ("x_curr", Reals[dim]), ("bias", Reals[dim]),] ) ) factor = trans + obs + bias_dist @@ -2769,29 +2403,15 @@ def test_sequential_sum_product_bias_1(num_steps, dim): def test_sequential_sum_product_bias_2(num_steps, num_sensors, dim): time = Variable("time", Bint[num_steps]) bias = Variable("bias", Reals[num_sensors, dim]) - bias_dist = random_gaussian( - OrderedDict( - [ - ("bias", Reals[num_sensors, dim]), - ] - ) - ) + bias_dist = random_gaussian(OrderedDict([("bias", Reals[num_sensors, dim]),])) trans = random_gaussian( OrderedDict( - [ - ("time", Bint[num_steps]), - ("x_prev", Reals[dim]), - ("x_curr", Reals[dim]), - ] + [("time", Bint[num_steps]), ("x_prev", Reals[dim]), ("x_curr", Reals[dim]),] ) ) obs = random_gaussian( OrderedDict( - [ - ("time", Bint[num_steps]), - ("x_curr", Reals[dim]), - ("bias", Reals[dim]), - ] + [("time", Bint[num_steps]), ("x_curr", Reals[dim]), ("bias", Reals[dim]),] ) ) @@ -2837,14 +2457,7 @@ def _check_sarkka_bilmes(trans, expected_inputs, global_vars, num_periods=1): @pytest.mark.parametrize("duration", [2, 3, 4, 5, 6]) def test_sarkka_bilmes_example_0(duration): - trans = random_tensor( - OrderedDict( - { - "time": Bint[duration], - "a": Bint[3], - } - ) - ) + trans = random_tensor(OrderedDict({"time": Bint[duration], "a": Bint[3],})) expected_inputs = { "a": Bint[3], @@ -2858,12 +2471,7 @@ def test_sarkka_bilmes_example_1(duration): trans = random_tensor( OrderedDict( - { - "time": Bint[duration], - "a": Bint[3], - "b": Bint[2], - "_PREV_b": Bint[2], - } + {"time": Bint[duration], "a": Bint[3], "b": Bint[2], "_PREV_b": Bint[2],} ) ) @@ -2957,12 +2565,7 @@ def test_sarkka_bilmes_example_5(duration): trans = random_tensor( OrderedDict( - { - "time": Bint[duration], - "a": Bint[3], - "_PREV_a": Bint[3], - "x": Bint[2], - } + {"time": Bint[duration], "a": Bint[3], "_PREV_a": Bint[3], "x": Bint[2],} ) ) @@ -3007,11 +2610,7 @@ def test_sarkka_bilmes_example_6(duration): @pytest.mark.parametrize("time_input", [("time", Bint[t]) for t in range(6, 11)]) @pytest.mark.parametrize( - "global_inputs", - [ - (), - (("x", Bint[2]),), - ], + "global_inputs", [(), (("x", Bint[2]),),], ) @pytest.mark.parametrize( "local_inputs", diff --git a/test/test_tensor.py b/test/test_tensor.py index 2050fb105..f46611358 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -126,15 +126,7 @@ def test_indexing(): def test_advanced_indexing_shape(): I, J, M, N = 4, 4, 2, 3 - x = Tensor( - randn((I, J)), - OrderedDict( - [ - ("i", Bint[I]), - ("j", Bint[J]), - ] - ), - ) + x = Tensor(randn((I, J)), OrderedDict([("i", Bint[I]), ("j", Bint[J]),]),) m = Tensor(numeric_array([2, 3]), OrderedDict([("m", Bint[M])]), I) n = Tensor(numeric_array([0, 1, 1]), OrderedDict([("n", Bint[N])]), J) assert x.data.shape == (I, J) @@ -231,54 +223,17 @@ def test_advanced_indexing_tensor(output_shape): # x output = Reals[output_shape] x = random_tensor( - OrderedDict( - [ - ("i", Bint[2]), - ("j", Bint[3]), - ("k", Bint[4]), - ] - ), - output, - ) - i = random_tensor( - OrderedDict( - [ - ("u", Bint[5]), - ] - ), - Bint[2], - ) - j = random_tensor( - OrderedDict( - [ - ("v", Bint[6]), - ("u", Bint[5]), - ] - ), - Bint[3], - ) - k = random_tensor( - OrderedDict( - [ - ("v", Bint[6]), - ] - ), - Bint[4], + OrderedDict([("i", Bint[2]), ("j", Bint[3]), ("k", Bint[4]),]), output, ) + i = random_tensor(OrderedDict([("u", Bint[5]),]), Bint[2],) + j = random_tensor(OrderedDict([("v", Bint[6]), ("u", Bint[5]),]), Bint[3],) + k = random_tensor(OrderedDict([("v", Bint[6]),]), Bint[4],) expected_data = empty((5, 6) + output_shape) for u in range(5): for v in range(6): expected_data[u, v] = x.data[i.data[u], j.data[v, u], k.data[v]] - expected = Tensor( - expected_data, - OrderedDict( - [ - ("u", Bint[5]), - ("v", Bint[6]), - ] - ), - ) + expected = Tensor(expected_data, OrderedDict([("u", Bint[5]), ("v", Bint[6]),]),) assert_equiv(expected, x(i, j, k)) assert_equiv(expected, x(i=i, j=j, k=k)) @@ -303,13 +258,7 @@ def test_advanced_indexing_tensor(output_shape): def test_advanced_indexing_lazy(output_shape): x = Tensor( randn((2, 3, 4) + output_shape), - OrderedDict( - [ - ("i", Bint[2]), - ("j", Bint[3]), - ("k", Bint[4]), - ] - ), + OrderedDict([("i", Bint[2]), ("j", Bint[3]), ("k", Bint[4]),]), ) u = Variable("u", Bint[2]) v = Variable("v", Bint[3]) @@ -325,15 +274,7 @@ def test_advanced_indexing_lazy(output_shape): for u in range(2): for v in range(3): expected_data[u, v] = x.data[i_data[u], j_data[v], k_data[u, v]] - expected = Tensor( - expected_data, - OrderedDict( - [ - ("u", Bint[2]), - ("v", Bint[3]), - ] - ), - ) + expected = Tensor(expected_data, OrderedDict([("u", Bint[2]), ("v", Bint[3]),]),) assert_equiv(expected, x(i, j, k)) assert_equiv(expected, x(i=i, j=j, k=k)) @@ -363,18 +304,7 @@ def unary_eval(symbol, x): @pytest.mark.parametrize("dims", [(), ("a",), ("a", "b")]) @pytest.mark.parametrize( "symbol", - [ - "~", - "-", - "abs", - "atanh", - "sqrt", - "exp", - "log", - "log1p", - "sigmoid", - "tanh", - ], + ["~", "-", "abs", "atanh", "sqrt", "exp", "log", "log1p", "sigmoid", "tanh",], ) def test_unary(symbol, dims): sizes = {"a": 3, "b": 4} @@ -908,13 +838,7 @@ def test_function_of_numeric_array(): def test_align(): x = Tensor( randn((2, 3, 4)), - OrderedDict( - [ - ("i", Bint[2]), - ("j", Bint[3]), - ("k", Bint[4]), - ] - ), + OrderedDict([("i", Bint[2]), ("j", Bint[3]), ("k", Bint[4]),]), ) y = x.align(("j", "k", "i")) assert isinstance(y, Tensor) @@ -1027,41 +951,13 @@ def test_tensor_stack(n, shape, dim): @pytest.mark.parametrize("output", [Bint[2], Real, Reals[4], Reals[2, 3]], ids=str) def test_funsor_stack(output): - x = random_tensor( - OrderedDict( - [ - ("i", Bint[2]), - ] - ), - output, - ) - y = random_tensor( - OrderedDict( - [ - ("j", Bint[3]), - ] - ), - output, - ) - z = random_tensor( - OrderedDict( - [ - ("i", Bint[2]), - ("k", Bint[4]), - ] - ), - output, - ) + x = random_tensor(OrderedDict([("i", Bint[2]),]), output,) + y = random_tensor(OrderedDict([("j", Bint[3]),]), output,) + z = random_tensor(OrderedDict([("i", Bint[2]), ("k", Bint[4]),]), output,) xy = Stack("t", (x, y)) assert isinstance(xy, Tensor) - assert xy.inputs == OrderedDict( - [ - ("t", Bint[2]), - ("i", Bint[2]), - ("j", Bint[3]), - ] - ) + assert xy.inputs == OrderedDict([("t", Bint[2]), ("i", Bint[2]), ("j", Bint[3]),]) assert xy.output == output for j in range(3): assert_close(xy(t=0, j=j), x) @@ -1071,12 +967,7 @@ def test_funsor_stack(output): xyz = Stack("t", (x, y, z)) assert isinstance(xyz, Tensor) assert xyz.inputs == OrderedDict( - [ - ("t", Bint[3]), - ("i", Bint[2]), - ("j", Bint[3]), - ("k", Bint[4]), - ] + [("t", Bint[3]), ("i", Bint[2]), ("j", Bint[3]), ("k", Bint[4]),] ) assert xy.output == output for j in range(3): @@ -1091,32 +982,9 @@ def test_funsor_stack(output): @pytest.mark.parametrize("output", [Bint[2], Real, Reals[4], Reals[2, 3]], ids=str) def test_cat_simple(output): - x = random_tensor( - OrderedDict( - [ - ("i", Bint[2]), - ] - ), - output, - ) - y = random_tensor( - OrderedDict( - [ - ("i", Bint[3]), - ("j", Bint[4]), - ] - ), - output, - ) - z = random_tensor( - OrderedDict( - [ - ("i", Bint[5]), - ("k", Bint[6]), - ] - ), - output, - ) + x = random_tensor(OrderedDict([("i", Bint[2]),]), output,) + y = random_tensor(OrderedDict([("i", Bint[3]), ("j", Bint[4]),]), output,) + z = random_tensor(OrderedDict([("i", Bint[5]), ("k", Bint[6]),]), output,) assert Cat("i", (x,)) is x assert Cat("i", (y,)) is y @@ -1124,22 +992,13 @@ def test_cat_simple(output): xy = Cat("i", (x, y)) assert isinstance(xy, Tensor) - assert xy.inputs == OrderedDict( - [ - ("i", Bint[2 + 3]), - ("j", Bint[4]), - ] - ) + assert xy.inputs == OrderedDict([("i", Bint[2 + 3]), ("j", Bint[4]),]) assert xy.output == output xyz = Cat("i", (x, y, z)) assert isinstance(xyz, Tensor) assert xyz.inputs == OrderedDict( - [ - ("i", Bint[2 + 3 + 5]), - ("j", Bint[4]), - ("k", Bint[6]), - ] + [("i", Bint[2 + 3 + 5]), ("j", Bint[4]), ("k", Bint[6]),] ) assert xy.output == output diff --git a/test/test_terms.py b/test/test_terms.py index daa5f49a5..f7f3b550f 100644 --- a/test/test_terms.py +++ b/test/test_terms.py @@ -260,18 +260,7 @@ def unary_eval(symbol, x): @pytest.mark.parametrize("data", [0, 0.5, 1]) @pytest.mark.parametrize( "symbol", - [ - "~", - "-", - "atanh", - "abs", - "sqrt", - "exp", - "log", - "log1p", - "sigmoid", - "tanh", - ], + ["~", "-", "atanh", "abs", "sqrt", "exp", "log", "log1p", "sigmoid", "tanh",], ) def test_unary(symbol, data): dtype = "real" diff --git a/tutorials/sum_product_network.ipynb b/tutorials/sum_product_network.ipynb new file mode 100644 index 000000000..fa2f417fe --- /dev/null +++ b/tutorials/sum_product_network.ipynb @@ -0,0 +1,34 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Sum Product Network\n", + "\n", + "(in preparation)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} From 274841da4020a906ce44014c1ef4261af38010f9 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Sun, 21 Feb 2021 09:52:31 -0600 Subject: [PATCH 3/8] sketch for sum product network --- docs/source/conf.py | 10 +- examples/eeg_slds.py | 2 +- examples/mixed_hmm/model.py | 12 +- examples/slds.py | 2 +- funsor/affine.py | 6 +- funsor/cnf.py | 6 +- funsor/delta.py | 5 +- funsor/distribution.py | 4 +- funsor/gaussian.py | 7 +- funsor/instrument.py | 9 +- funsor/integrate.py | 4 +- funsor/interpreter.py | 2 +- funsor/jax/__init__.py | 2 +- funsor/jax/ops.py | 5 +- funsor/joint.py | 2 +- funsor/memoize.py | 4 +- funsor/montecarlo.py | 4 +- funsor/registry.py | 4 +- funsor/syntax.py | 14 +- funsor/terms.py | 2 +- funsor/testing.py | 2 +- scripts/update_headers.py | 5 +- setup.py | 13 +- test/examples/test_bart.py | 32 +-- test/examples/test_sensor_fusion.py | 2 +- test/pyro/test_hmm.py | 32 ++- test/test_adjoint.py | 2 +- test/test_distribution.py | 2 +- test/test_distribution_generic.py | 6 +- test/test_domains.py | 4 +- test/test_factory.py | 10 +- test/test_gaussian.py | 6 +- test/test_memoize.py | 4 +- test/test_minipyro.py | 4 +- test/test_optimizer.py | 12 +- test/test_samplers.py | 44 ++-- test/test_sum_product.py | 172 +++++++--------- test/test_tensor.py | 41 ++-- test/test_terms.py | 18 +- tutorials/sum_product_network.ipynb | 301 +++++++++++++++++++++++++++- 40 files changed, 502 insertions(+), 316 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index a215c3422..d14e5865d 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -101,11 +101,7 @@ # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path . -exclude_patterns = [ - ".ipynb_checkpoints", - "examples/*ipynb", - "examples/*py", -] +exclude_patterns = [".ipynb_checkpoints", "examples/*ipynb", "examples/*py"] # The name of the Pygments (syntax highlighting) style to use. pygments_style = "sphinx" @@ -226,7 +222,7 @@ # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, "Funsor.tex", u"Funsor Documentation", u"Uber AI Labs", "manual"), + (master_doc, "Funsor.tex", u"Funsor Documentation", u"Uber AI Labs", "manual") ] # -- Options for manual page output ------------------------------------------ @@ -249,7 +245,7 @@ "Funsor", "Functional analysis + tensors + symbolic algebra.", "Miscellaneous", - ), + ) ] diff --git a/examples/eeg_slds.py b/examples/eeg_slds.py index ec4bac340..fa31a5059 100644 --- a/examples/eeg_slds.py +++ b/examples/eeg_slds.py @@ -155,7 +155,7 @@ def get_tensors_and_dists(self): self.observation_matrix, obs_mvn, event_dims, "x", "y" ) - return trans_logits, trans_probs, trans_mvn, obs_mvn, x_trans_dist, y_dist + return (trans_logits, trans_probs, trans_mvn, obs_mvn, x_trans_dist, y_dist) # compute the marginal log probability of the observed data using a moment-matching approximation @funsor.interpretation(funsor.terms.moment_matching) diff --git a/examples/mixed_hmm/model.py b/examples/mixed_hmm/model.py index 8da984a1c..b79fcc00e 100644 --- a/examples/mixed_hmm/model.py +++ b/examples/mixed_hmm/model.py @@ -24,10 +24,7 @@ def __init__(self, config): def initialize_params(self): # dictionary of guide random effect parameters - params = { - "eps_g": {}, - "eps_i": {}, - } + params = {"eps_g": {}, "eps_i": {}} N_state = self.config["sizes"]["state"] @@ -153,8 +150,7 @@ def initialize_params(self): ) params["eps_g"]["scale"] = Tensor( - torch.ones((N_state, N_state)), - OrderedDict([("y_prev", Bint[N_state])]), + torch.ones((N_state, N_state)), OrderedDict([("y_prev", Bint[N_state])]) ) # initialize individual-level random effect parameters @@ -164,7 +160,7 @@ def initialize_params(self): params["e_i"]["probs"] = Tensor( pyro.param( "probs_e_i", - lambda: torch.randn((N_c, N_v,)).abs(), + lambda: torch.randn((N_c, N_v)).abs(), constraint=constraints.simplex, ), OrderedDict([("g", Bint[N_c])]), # different value per group @@ -324,7 +320,7 @@ def __call__(self): # initialize gamma to uniform gamma = Tensor( - torch.zeros((N_state, N_state)), OrderedDict([("y_prev", Bint[N_state])]), + torch.zeros((N_state, N_state)), OrderedDict([("y_prev", Bint[N_state])]) ) N_v = self.config["sizes"]["random"] diff --git a/examples/slds.py b/examples/slds.py index c88b3da95..b9f36eb73 100644 --- a/examples/slds.py +++ b/examples/slds.py @@ -25,7 +25,7 @@ def main(args): ) trans_noise = funsor.Tensor( torch.tensor( - [0.1, 1.0,], # low noise component # high noisy component + [0.1, 1.0], # low noise component # high noisy component requires_grad=True, ) ) diff --git a/funsor/affine.py b/funsor/affine.py index c9e181911..1d6c9a2d1 100644 --- a/funsor/affine.py +++ b/funsor/affine.py @@ -159,8 +159,4 @@ def extract_affine(fn): return const, coeffs -__all__ = [ - "affine_inputs", - "extract_affine", - "is_affine", -] +__all__ = ["affine_inputs", "extract_affine", "is_affine"] diff --git a/funsor/cnf.py b/funsor/cnf.py index 2200a6e74..0e570df50 100644 --- a/funsor/cnf.py +++ b/funsor/cnf.py @@ -594,11 +594,7 @@ def unary_contract(op, arg): ) -BACKEND_TO_EINSUM_BACKEND = { - "numpy": "numpy", - "torch": "torch", - "jax": "jax.numpy", -} +BACKEND_TO_EINSUM_BACKEND = {"numpy": "numpy", "torch": "torch", "jax": "jax.numpy"} # NB: numpy_log, numpy_map is backend-agnostic so they also work for torch backend; # however, we might need to profile to make a switch BACKEND_TO_LOGSUMEXP_BACKEND = { diff --git a/funsor/delta.py b/funsor/delta.py index f87359b8f..341280050 100644 --- a/funsor/delta.py +++ b/funsor/delta.py @@ -248,7 +248,4 @@ def eager_independent_delta(delta, reals_var, bint_var, diag_var): return None -__all__ = [ - "Delta", - "solve", -] +__all__ = ["Delta", "solve"] diff --git a/funsor/distribution.py b/funsor/distribution.py index e8b666509..f67ceb5ef 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -184,7 +184,7 @@ def eager_log_prob(cls, *params): params, value = params[:-1], params[-1] params = params + (Variable("value", value.output),) instance = reflect.interpret(cls, *params) - raw_dist, value_name, value_output, dim_to_name = instance._get_raw_dist() + (raw_dist, value_name, value_output, dim_to_name) = instance._get_raw_dist() assert value.output == value_output name_to_dim = {v: k for k, v in dim_to_name.items()} dim_to_name.update( @@ -379,7 +379,7 @@ def dist_init(self, **kwargs): dist_class = DistributionMeta( backend_dist_class.__name__.split("Wrapper_")[-1], (Distribution,), - {"dist_class": backend_dist_class, "__init__": dist_init,}, + {"dist_class": backend_dist_class, "__init__": dist_init}, ) if generate_eager: diff --git a/funsor/gaussian.py b/funsor/gaussian.py index caa278c9c..13edad9d1 100644 --- a/funsor/gaussian.py +++ b/funsor/gaussian.py @@ -779,9 +779,4 @@ def eager_neg(op, arg): return Gaussian(info_vec, precision, arg.inputs) -__all__ = [ - "BlockMatrix", - "BlockVector", - "Gaussian", - "align_gaussian", -] +__all__ = ["BlockMatrix", "BlockVector", "Gaussian", "align_gaussian"] diff --git a/funsor/instrument.py b/funsor/instrument.py index 71f797cf9..569c771a4 100644 --- a/funsor/instrument.py +++ b/funsor/instrument.py @@ -108,11 +108,4 @@ def print_counters(): print("-" * 80) -__all__ = [ - "DEBUG", - "PROFILE", - "STACK_SIZE", - "debug_logged", - "get_indent", - "profile", -] +__all__ = ["DEBUG", "PROFILE", "STACK_SIZE", "debug_logged", "get_indent", "profile"] diff --git a/funsor/integrate.py b/funsor/integrate.py index b75b6f50e..7212628af 100644 --- a/funsor/integrate.py +++ b/funsor/integrate.py @@ -230,6 +230,4 @@ def eager_integrate(log_measure, integrand, reduced_vars): return None # defer to default implementation -__all__ = [ - "Integrate", -] +__all__ = ["Integrate"] diff --git a/funsor/interpreter.py b/funsor/interpreter.py index 54dda7183..5059672b4 100644 --- a/funsor/interpreter.py +++ b/funsor/interpreter.py @@ -80,7 +80,7 @@ def interpret(cls, *args): def interpretation(new): warnings.warn( - "'with interpretation(x)' should be replaced by 'with x'", DeprecationWarning, + "'with interpretation(x)' should be replaced by 'with x'", DeprecationWarning ) return new diff --git a/funsor/jax/__init__.py b/funsor/jax/__init__.py index 10fa35780..dae7b7435 100644 --- a/funsor/jax/__init__.py +++ b/funsor/jax/__init__.py @@ -18,7 +18,7 @@ @adjoint_ops.register( - Tensor, AssociativeOp, AssociativeOp, Funsor, (DeviceArray, Tracer), tuple, object, + Tensor, AssociativeOp, AssociativeOp, Funsor, (DeviceArray, Tracer), tuple, object ) def adjoint_tensor(adj_redop, adj_binop, out_adj, data, inputs, dtype): return {} diff --git a/funsor/jax/ops.py b/funsor/jax/ops.py index 07e38c9d6..0570b92b6 100644 --- a/funsor/jax/ops.py +++ b/funsor/jax/ops.py @@ -257,10 +257,7 @@ def _triangular_solve(x, y, upper=False, transpose=False): x_new_shape = batch_shape[:prepend_ndim] for (sy, sx) in zip(y.shape[:-2], batch_shape[prepend_ndim:]): x_new_shape += (sx // sy, sy) - x_new_shape += ( - n, - m, - ) + x_new_shape += (n, m) x = np.reshape(x, x_new_shape) # Permute y to make it have shape (..., 1, j, m, i, 1, n) batch_ndim = x.ndim - 2 diff --git a/funsor/joint.py b/funsor/joint.py index cb5693f98..8658248db 100644 --- a/funsor/joint.py +++ b/funsor/joint.py @@ -104,7 +104,7 @@ def moment_matching_contract_joint(red_op, bin_op, reduced_vars, discrete, gauss discrete += gaussian.log_normalizer new_discrete = discrete.reduce(ops.logaddexp, approx_vars & discrete.input_vars) num_elements = reduce( - ops.mul, [v.output.num_elements for v in approx_vars - discrete.input_vars], 1, + ops.mul, [v.output.num_elements for v in approx_vars - discrete.input_vars], 1 ) if num_elements != 1: new_discrete -= math.log(num_elements) diff --git a/funsor/memoize.py b/funsor/memoize.py index 90fea683f..baf45a471 100644 --- a/funsor/memoize.py +++ b/funsor/memoize.py @@ -40,6 +40,4 @@ def interpret(self, cls, *args): return value -__all__ = [ - "memoize", -] +__all__ = ["memoize"] diff --git a/funsor/montecarlo.py b/funsor/montecarlo.py index b533c0f1a..3ef205b4f 100644 --- a/funsor/montecarlo.py +++ b/funsor/montecarlo.py @@ -40,6 +40,4 @@ def monte_carlo_integrate(state, log_measure, integrand, reduced_vars): return Integrate(sample, integrand, reduced_vars) -__all__ = [ - "MonteCarlo", -] +__all__ = ["MonteCarlo"] diff --git a/funsor/registry.py b/funsor/registry.py index 07f9f5542..353693861 100644 --- a/funsor/registry.py +++ b/funsor/registry.py @@ -84,6 +84,4 @@ def dispatch(self, key, *args): return self[key].partial_call(*args) -__all__ = [ - "KeyedRegistry", -] +__all__ = ["KeyedRegistry"] diff --git a/funsor/syntax.py b/funsor/syntax.py index 05eb123be..7b4291548 100644 --- a/funsor/syntax.py +++ b/funsor/syntax.py @@ -59,9 +59,7 @@ def visit_UnaryOp(self, node): var = self.prefix.get(type(node.op)) if var is not None: node = ast.Call( - func=ast.Name(id=var, ctx=ast.Load(),), - args=[node.operand], - keywords=[], + func=ast.Name(id=var, ctx=ast.Load()), args=[node.operand], keywords=[] ) return node @@ -70,7 +68,7 @@ def visit_BinOp(self, node): var = self.infix.get(type(node.op)) if var is not None: node = ast.Call( - func=ast.Name(id=var, ctx=ast.Load(),), + func=ast.Name(id=var, ctx=ast.Load()), args=[node.left, node.right], keywords=[], ) @@ -92,7 +90,7 @@ def visit_Compare(self, node): var = self.infix.get(type(node_op)) if var is not None: node = ast.Call( - func=ast.Name(id=var, ctx=ast.Load(),), + func=ast.Name(id=var, ctx=ast.Load()), args=[node.left, node_right], keywords=[], ) @@ -163,8 +161,4 @@ def decorator(fn): return decorator -__all__ = [ - "INFIX_OPERATORS", - "PREFIX_OPERATORS", - "rewrite_ops", -] +__all__ = ["INFIX_OPERATORS", "PREFIX_OPERATORS", "rewrite_ops"] diff --git a/funsor/terms.py b/funsor/terms.py index 3b13ced73..efc1cc33c 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -1520,7 +1520,7 @@ def eager_subs(self, subs): n -= size assert False elif isinstance(value, Slice): - start, stop, step = value.slice.start, value.slice.stop, value.slice.step + start, stop, step = (value.slice.start, value.slice.stop, value.slice.step) new_parts = [] pos = 0 for part in self.parts: diff --git a/funsor/testing.py b/funsor/testing.py index 1979734b9..aefcb52d8 100644 --- a/funsor/testing.py +++ b/funsor/testing.py @@ -116,7 +116,7 @@ def assert_close(actual, expected, atol=1e-6, rtol=1e-6): actual = actual.align(tuple(n for n, p in expected.terms)) for ( (actual_name, (actual_point, actual_log_density)), - (expected_name, (expected_point, expected_log_density),), + (expected_name, (expected_point, expected_log_density)), ) in zip(actual.terms, expected.terms): assert actual_name == expected_name assert_close(actual_point, expected_point, atol=atol, rtol=rtol) diff --git a/scripts/update_headers.py b/scripts/update_headers.py index 8faa7dd8b..37a78522d 100644 --- a/scripts/update_headers.py +++ b/scripts/update_headers.py @@ -8,10 +8,7 @@ root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) blacklist = ["/build/", "/dist/"] -file_types = [ - ("*.py", "# {}"), - ("*.cpp", "// {}"), -] +file_types = [("*.py", "# {}"), ("*.cpp", "// {}")] parser = argparse.ArgumentParser() parser.add_argument("--check", action="store_true") diff --git a/setup.py b/setup.py index 042d63859..d0b6de0e0 100644 --- a/setup.py +++ b/setup.py @@ -27,18 +27,13 @@ description="A tensor-like library for functions and distributions", packages=find_packages(include=["funsor", "funsor.*"]), url="https://github.com/pyro-ppl/funsor", - project_urls={"Documentation": "https://funsor.pyro.ai",}, + project_urls={"Documentation": "https://funsor.pyro.ai"}, author="Uber AI Labs", python_requires=">=3.6", - install_requires=[ - "makefun", - "multipledispatch", - "numpy>=1.7", - "opt_einsum>=2.3.2", - ], + install_requires=["makefun", "multipledispatch", "numpy>=1.7", "opt_einsum>=2.3.2"], extras_require={ - "torch": ["pyro-ppl>=1.5.2", "torch>=1.7.0",], - "jax": ["numpyro>=0.2.4", "jax>=0.1.57", "jaxlib>=0.1.37",], + "torch": ["pyro-ppl>=1.5.2", "torch>=1.7.0"], + "jax": ["numpyro>=0.2.4", "jax>=0.1.57", "jaxlib>=0.1.37"], "test": [ "black", "flake8", diff --git a/test/examples/test_bart.py b/test/examples/test_bart.py index c30f1a9f8..c6f1762f1 100644 --- a/test/examples/test_bart.py +++ b/test/examples/test_bart.py @@ -52,7 +52,7 @@ def unpack_gate_rate(gate_rate): @pytest.mark.parametrize( "analytic_kl", - [False, xfail_param(True, reason="missing pattern"),], + [False, xfail_param(True, reason="missing pattern")], ids=["monte-carlo-kl", "analytic-kl"], ) def test_bart(analytic_kl): @@ -93,7 +93,7 @@ def test_bart(analytic_kl): ], dtype=torch.float32, ), # noqa - (("time_b4", Bint[2],), ("_event_1_b2", Bint[8],),), + (("time_b4", Bint[2]), ("_event_1_b2", Bint[8])), "real", ), Gaussian( @@ -148,9 +148,9 @@ def test_bart(analytic_kl): dtype=torch.float32, ), # noqa ( - ("time_b4", Bint[2],), - ("_event_1_b2", Bint[8],), - ("value_b1", Real,), + ("time_b4", Bint[2]), + ("_event_1_b2", Bint[8]), + ("value_b1", Real), ), ), ), @@ -220,8 +220,8 @@ def test_bart(analytic_kl): dtype=torch.float32, ), # noqa ( - ("state_b7", Reals[2],), - ("state(time=1)_b8", Reals[2],), + ("state_b7", Reals[2]), + ("state(time=1)_b8", Reals[2]), ), ), Subs( @@ -281,7 +281,7 @@ def test_bart(analytic_kl): ], dtype=torch.float32, ), # noqa - (("time_b9", Bint[2],),), + (("time_b9", Bint[2]),), "real", ), Tensor( @@ -310,7 +310,7 @@ def test_bart(analytic_kl): ], dtype=torch.float32, ), # noqa - (("time_b9", Bint[2],),), + (("time_b9", Bint[2]),), "real", ), Variable("state(time=1)_b8", Reals[2]), @@ -352,7 +352,7 @@ def test_bart(analytic_kl): ), Variable("value_b5", Reals[2]), ), - (("value_b5", Variable("state_b10", Reals[2]),),), + (("value_b5", Variable("state_b10", Reals[2])),), ), ), ) @@ -449,9 +449,9 @@ def test_bart(analytic_kl): dtype=torch.float32, ), # noqa ( - ("time_b17", Bint[2],), - ("origin_b15", Bint[2],), - ("destin_b16", Bint[2],), + ("time_b17", Bint[2]), + ("origin_b15", Bint[2]), + ("destin_b16", Bint[2]), ), "real", ), @@ -476,9 +476,9 @@ def test_bart(analytic_kl): dtype=torch.float32, ), # noqa ( - ("time_b17", Bint[2],), - ("origin_b15", Bint[2],), - ("destin_b16", Bint[2],), + ("time_b17", Bint[2]), + ("origin_b15", Bint[2]), + ("destin_b16", Bint[2]), ), "real", ), diff --git a/test/examples/test_sensor_fusion.py b/test/examples/test_sensor_fusion.py index 48dc52e31..f52e483a0 100644 --- a/test/examples/test_sensor_fusion.py +++ b/test/examples/test_sensor_fusion.py @@ -142,7 +142,7 @@ def test_affine_subs(): ], dtype=torch.float32, ), # noqa - (("state_1_b6", Reals[3],), ("obs_b2", Reals[2],),), + (("state_1_b6", Reals[3]), ("obs_b2", Reals[2])), ), ( ( diff --git a/test/pyro/test_hmm.py b/test/pyro/test_hmm.py index 28614a305..9db9dcad7 100644 --- a/test/pyro/test_hmm.py +++ b/test/pyro/test_hmm.py @@ -245,27 +245,19 @@ def test_gaussian_mrf_log_prob(init_shape, trans_shape, obs_shape, hidden_dim, o ] ) SLHMM_SHAPES = [ - ((2,), (), (1, 2,), (1, 3, 3), (1,), (1, 3, 4), (1,),), - ((2,), (), (5, 1, 2,), (1, 3, 3), (1,), (1, 3, 4), (1,),), - ((2,), (), (1, 2,), (5, 1, 3, 3), (1,), (1, 3, 4), (1,),), - ((2,), (), (1, 2,), (1, 3, 3), (5, 1), (1, 3, 4), (1,),), - ((2,), (), (1, 2,), (1, 3, 3), (1,), (5, 1, 3, 4), (1,),), - ((2,), (), (1, 2,), (1, 3, 3), (1,), (1, 3, 4), (5, 1),), - ((2,), (), (5, 1, 2,), (5, 1, 3, 3), (5, 1), (5, 1, 3, 4), (5, 1),), - ((2,), (2,), (5, 2, 2,), (5, 2, 3, 3), (5, 2), (5, 2, 3, 4), (5, 2),), + ((2,), (), (1, 2), (1, 3, 3), (1,), (1, 3, 4), (1,)), + ((2,), (), (5, 1, 2), (1, 3, 3), (1,), (1, 3, 4), (1,)), + ((2,), (), (1, 2), (5, 1, 3, 3), (1,), (1, 3, 4), (1,)), + ((2,), (), (1, 2), (1, 3, 3), (5, 1), (1, 3, 4), (1,)), + ((2,), (), (1, 2), (1, 3, 3), (1,), (5, 1, 3, 4), (1,)), + ((2,), (), (1, 2), (1, 3, 3), (1,), (1, 3, 4), (5, 1)), + ((2,), (), (5, 1, 2), (5, 1, 3, 3), (5, 1), (5, 1, 3, 4), (5, 1)), + ((2,), (2,), (5, 2, 2), (5, 2, 3, 3), (5, 2), (5, 2, 3, 4), (5, 2)), + ((7, 2), (), (7, 5, 1, 2), (7, 5, 1, 3, 3), (7, 5, 1), (7, 5, 1, 3, 4), (7, 5, 1)), ( - (7, 2,), - (), - (7, 5, 1, 2,), - (7, 5, 1, 3, 3), - (7, 5, 1), - (7, 5, 1, 3, 4), - (7, 5, 1), - ), - ( - (7, 2,), (7, 2), - (7, 5, 2, 2,), + (7, 2), + (7, 5, 2, 2), (7, 5, 2, 3, 3), (7, 5, 2), (7, 5, 2, 3, 4), @@ -411,7 +403,7 @@ def test_switching_linear_hmm_log_prob_alternating(exact, num_steps, num_compone -1, num_components, -1, -1 ) - trans_mvn = random_mvn((num_steps, num_components,), hidden_dim,) + trans_mvn = random_mvn((num_steps, num_components), hidden_dim) hmm_obs_matrix = torch.randn(num_steps, hidden_dim, obs_dim) switching_obs_matrix = hmm_obs_matrix.unsqueeze(-3).expand( -1, num_components, -1, -1 diff --git a/test/test_adjoint.py b/test/test_adjoint.py index 0e2d6407b..4756479eb 100644 --- a/test/test_adjoint.py +++ b/test/test_adjoint.py @@ -201,7 +201,7 @@ def test_optimized_plated_einsum_adjoint(equation, plates, backend): ids=lambda d: ",".join(d.keys()), ) @pytest.mark.parametrize( - "impl", [sequential_sum_product, naive_sequential_sum_product, MarkovProduct,], + "impl", [sequential_sum_product, naive_sequential_sum_product, MarkovProduct] ) def test_sequential_sum_product_adjoint( impl, sum_op, prod_op, batch_inputs, state_domain, num_steps diff --git a/test/test_distribution.py b/test/test_distribution.py index 33d4c1655..c73da8567 100644 --- a/test/test_distribution.py +++ b/test/test_distribution.py @@ -1459,7 +1459,7 @@ def test_power_transform(shape): @pytest.mark.parametrize("shape", [(10,), (4, 3)], ids=str) @pytest.mark.parametrize( "to_event", - [True, xfail_param(False, reason="bug in to_funsor(TransformedDistribution)"),], + [True, xfail_param(False, reason="bug in to_funsor(TransformedDistribution)")], ) def test_haar_transform(shape, to_event): try: diff --git a/test/test_distribution_generic.py b/test/test_distribution_generic.py index 66c5ba466..db1c4c263 100644 --- a/test/test_distribution_generic.py +++ b/test/test_distribution_generic.py @@ -186,7 +186,7 @@ def __hash__(self): # Chi2 DistTestCase( - "dist.Chi2(df=case.df)", (("df", f"rand({batch_shape})"),), funsor.Real, + "dist.Chi2(df=case.df)", (("df", f"rand({batch_shape})"),), funsor.Real ) # ContinuousBernoulli @@ -368,9 +368,7 @@ def __hash__(self): # Poisson DistTestCase( - "dist.Poisson(rate=case.rate)", - (("rate", f"rand({batch_shape})"),), - funsor.Real, + "dist.Poisson(rate=case.rate)", (("rate", f"rand({batch_shape})"),), funsor.Real ) # RelaxedBernoulli diff --git a/test/test_domains.py b/test/test_domains.py index fb1461412..29bfc5cdd 100644 --- a/test/test_domains.py +++ b/test/test_domains.py @@ -9,9 +9,7 @@ from funsor.domains import Bint, Real, Reals # noqa F401 -@pytest.mark.parametrize( - "expr", ["Bint[2]", "Real", "Reals[4]", "Reals[3, 2]",], -) +@pytest.mark.parametrize("expr", ["Bint[2]", "Real", "Reals[4]", "Reals[3, 2]"]) def test_pickle(expr): x = eval(expr) f = io.BytesIO() diff --git a/test/test_factory.py b/test/test_factory.py index b28df08d9..e23f50e20 100644 --- a/test/test_factory.py +++ b/test/test_factory.py @@ -19,7 +19,7 @@ def test_lambda_lambda(): @make_funsor def LambdaLambda( - i: Bound, j: Bound, x: Funsor, + i: Bound, j: Bound, x: Funsor ) -> Fresh[lambda i, j, x: Array[x.dtype, (i.size, j.size) + x.shape]]: assert i in x.inputs assert j in x.inputs @@ -49,7 +49,7 @@ def GetitemGetitem( def test_flatten(): @make_funsor def Flatten21( - x: Funsor, i: Bound, j: Bound, ij: Fresh[lambda i, j: Bint[i.size * j.size]], + x: Funsor, i: Bound, j: Bound, ij: Fresh[lambda i, j: Bint[i.size * j.size]] ) -> Fresh[lambda x: x.dtype]: m = to_funsor(i, x.inputs.get(i, None)).output.size n = to_funsor(j, x.inputs.get(j, None)).output.size @@ -115,7 +115,7 @@ def Cat2( def test_normal(): @make_funsor def Normal( - loc: Funsor, scale: Funsor, value: Fresh[lambda loc: loc], + loc: Funsor, scale: Funsor, value: Fresh[lambda loc: loc] ) -> Fresh[Real]: return None @@ -140,7 +140,7 @@ def _(loc, scale, value): def test_matmul(): @make_funsor - def MatMul(x: Funsor, y: Funsor, i: Bound,) -> Fresh[lambda x: x]: + def MatMul(x: Funsor, y: Funsor, i: Bound) -> Fresh[lambda x: x]: return (x * y).reduce(ops.add, i) x = random_tensor(OrderedDict(a=Bint[3], b=Bint[4])) @@ -171,7 +171,7 @@ def Scatter1( def test_value_dependence(): @make_funsor def Sum( - x: Funsor, dim: Value[int], + x: Funsor, dim: Value[int] ) -> Fresh[lambda x, dim: Array[x.dtype, x.shape[:dim] + x.shape[dim + 1 :]]]: return None diff --git a/test/test_gaussian.py b/test/test_gaussian.py index 369e908c3..f3c5ac636 100644 --- a/test/test_gaussian.py +++ b/test/test_gaussian.py @@ -576,12 +576,10 @@ def test_reduce_logsumexp(int_inputs, real_inputs): ) -@pytest.mark.parametrize( - "int_inputs", [{}, {"i": Bint[2]},], ids=id_from_inputs, -) +@pytest.mark.parametrize("int_inputs", [{}, {"i": Bint[2]}], ids=id_from_inputs) @pytest.mark.parametrize( "real_inputs", - [{"x": Real}, {"x": Reals[4]}, {"x": Reals[2, 3]},], + [{"x": Real}, {"x": Reals[4]}, {"x": Reals[2, 3]}], ids=id_from_inputs, ) def test_integrate_variable(int_inputs, real_inputs): diff --git a/test/test_memoize.py b/test/test_memoize.py index 14b11b3aa..e54b18cb2 100644 --- a/test/test_memoize.py +++ b/test/test_memoize.py @@ -169,10 +169,10 @@ def test_nested_einsum_complete_sharing( eqn1, eqn2, einsum_impl1, einsum_impl2, backend1, backend2 ): - inputs1, outputs1, sizes1, operands1, funsor_operands1 = make_einsum_example( + (inputs1, outputs1, sizes1, operands1, funsor_operands1) = make_einsum_example( eqn1, sizes=(3,) ) - inputs2, outputs2, sizes2, operands2, funsor_operands2 = make_einsum_example( + (inputs2, outputs2, sizes2, operands2, funsor_operands2) = make_einsum_example( eqn2, sizes=(3,) ) diff --git a/test/test_minipyro.py b/test/test_minipyro.py index db654100d..b9c1eb937 100644 --- a/test/test_minipyro.py +++ b/test/test_minipyro.py @@ -302,7 +302,7 @@ def guide(): @pytest.mark.parametrize( - "backend", ["pyro", xfail_param("funsor", reason="missing patterns"),], + "backend", ["pyro", xfail_param("funsor", reason="missing patterns")] ) def test_mean_field_ok(backend): def model(): @@ -320,7 +320,7 @@ def guide(): @pytest.mark.parametrize( - "backend", ["pyro", xfail_param("funsor", reason="missing patterns"),], + "backend", ["pyro", xfail_param("funsor", reason="missing patterns")] ) def test_mean_field_warn(backend): def model(): diff --git a/test/test_optimizer.py b/test/test_optimizer.py index c8f0d4550..7c5399622 100644 --- a/test/test_optimizer.py +++ b/test/test_optimizer.py @@ -45,11 +45,9 @@ @pytest.mark.parametrize("equation", OPTIMIZED_EINSUM_EXAMPLES) @pytest.mark.parametrize( - "backend", ["pyro.ops.einsum.torch_log", "pyro.ops.einsum.torch_map",], -) -@pytest.mark.parametrize( - "einsum_impl", [naive_einsum, naive_contract_einsum,], + "backend", ["pyro.ops.einsum.torch_log", "pyro.ops.einsum.torch_map"] ) +@pytest.mark.parametrize("einsum_impl", [naive_einsum, naive_contract_einsum]) def test_optimized_einsum(equation, backend, einsum_impl): inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(equation) expected = pyro_einsum(equation, *operands, backend=backend)[0] @@ -71,7 +69,7 @@ def test_optimized_einsum(equation, backend, einsum_impl): @pytest.mark.parametrize( - "eqn1,eqn2", [("a,ab->b", "bc->"), ("ab,bc,cd->d", "de,ef,fg->"),], + "eqn1,eqn2", [("a,ab->b", "bc->"), ("ab,bc,cd->d", "de,ef,fg->")] ) @pytest.mark.parametrize("optimize1", [False, True]) @pytest.mark.parametrize("optimize2", [False, True]) @@ -86,7 +84,7 @@ def test_nested_einsum( eqn1, eqn2, optimize1, optimize2, backend1, backend2, einsum_impl ): inputs1, outputs1, sizes1, operands1, _ = make_einsum_example(eqn1, sizes=(3,)) - inputs2, outputs2, sizes2, operands2, funsor_operands2 = make_einsum_example( + (inputs2, outputs2, sizes2, operands2, funsor_operands2) = make_einsum_example( eqn2, sizes=(3,) ) @@ -139,7 +137,7 @@ def test_nested_einsum( @pytest.mark.parametrize("equation,plates", PLATED_EINSUM_EXAMPLES) @pytest.mark.parametrize( - "backend", ["pyro.ops.einsum.torch_log", "pyro.ops.einsum.torch_map",], + "backend", ["pyro.ops.einsum.torch_log", "pyro.ops.einsum.torch_map"] ) def test_optimized_plated_einsum(equation, plates, backend): inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(equation) diff --git a/test/test_samplers.py b/test/test_samplers.py index 68e7fd075..d30aca38e 100644 --- a/test/test_samplers.py +++ b/test/test_samplers.py @@ -37,17 +37,17 @@ @pytest.mark.parametrize( "sample_inputs", - [(), (("s", Bint[6]),), (("s", Bint[6]), ("t", Bint[7])),], + [(), (("s", Bint[6]),), (("s", Bint[6]), ("t", Bint[7]))], ids=id_from_inputs, ) @pytest.mark.parametrize( "batch_inputs", - [(), (("b", Bint[4]),), (("b", Bint[4]), ("c", Bint[5])),], + [(), (("b", Bint[4]),), (("b", Bint[4]), ("c", Bint[5]))], ids=id_from_inputs, ) @pytest.mark.parametrize( "event_inputs", - [(("e", Bint[2]),), (("e", Bint[2]), ("f", Bint[3])),], + [(("e", Bint[2]),), (("e", Bint[2]), ("f", Bint[3]))], ids=id_from_inputs, ) def test_tensor_shape(sample_inputs, batch_inputs, event_inputs): @@ -81,18 +81,16 @@ def test_tensor_shape(sample_inputs, batch_inputs, event_inputs): @pytest.mark.parametrize( "sample_inputs", - [(), (("s", Bint[3]),), (("s", Bint[3]), ("t", Bint[4])),], + [(), (("s", Bint[3]),), (("s", Bint[3]), ("t", Bint[4]))], ids=id_from_inputs, ) @pytest.mark.parametrize( "batch_inputs", - [(), (("b", Bint[2]),), (("c", Real),), (("b", Bint[2]), ("c", Real)),], + [(), (("b", Bint[2]),), (("c", Real),), (("b", Bint[2]), ("c", Real))], ids=id_from_inputs, ) @pytest.mark.parametrize( - "event_inputs", - [(("e", Real),), (("e", Real), ("f", Reals[2])),], - ids=id_from_inputs, + "event_inputs", [(("e", Real),), (("e", Real), ("f", Reals[2]))], ids=id_from_inputs ) def test_gaussian_shape(sample_inputs, batch_inputs, event_inputs): be_inputs = OrderedDict(batch_inputs + event_inputs) @@ -132,18 +130,16 @@ def test_gaussian_shape(sample_inputs, batch_inputs, event_inputs): @pytest.mark.parametrize( "sample_inputs", - [(), (("s", Bint[3]),), (("s", Bint[3]), ("t", Bint[4])),], + [(), (("s", Bint[3]),), (("s", Bint[3]), ("t", Bint[4]))], ids=id_from_inputs, ) @pytest.mark.parametrize( "batch_inputs", - [(), (("b", Bint[2]),), (("c", Real),), (("b", Bint[2]), ("c", Real)),], + [(), (("b", Bint[2]),), (("c", Real),), (("b", Bint[2]), ("c", Real))], ids=id_from_inputs, ) @pytest.mark.parametrize( - "event_inputs", - [(("e", Real),), (("e", Real), ("f", Reals[2])),], - ids=id_from_inputs, + "event_inputs", [(("e", Real),), (("e", Real), ("f", Reals[2]))], ids=id_from_inputs ) def test_transformed_gaussian_shape(sample_inputs, batch_inputs, event_inputs): be_inputs = OrderedDict(batch_inputs + event_inputs) @@ -191,17 +187,17 @@ def test_transformed_gaussian_shape(sample_inputs, batch_inputs, event_inputs): @pytest.mark.parametrize( "sample_inputs", - [(), (("s", Bint[6]),), (("s", Bint[6]), ("t", Bint[7])),], + [(), (("s", Bint[6]),), (("s", Bint[6]), ("t", Bint[7]))], ids=id_from_inputs, ) @pytest.mark.parametrize( "int_event_inputs", - [(), (("d", Bint[2]),), (("d", Bint[2]), ("e", Bint[3])),], + [(), (("d", Bint[2]),), (("d", Bint[2]), ("e", Bint[3]))], ids=id_from_inputs, ) @pytest.mark.parametrize( "real_event_inputs", - [(("g", Real),), (("g", Real), ("h", Reals[4])),], + [(("g", Real),), (("g", Real), ("h", Reals[4]))], ids=id_from_inputs, ) def test_joint_shape(sample_inputs, int_event_inputs, real_event_inputs): @@ -243,12 +239,12 @@ def test_joint_shape(sample_inputs, int_event_inputs, real_event_inputs): @pytest.mark.parametrize( "batch_inputs", - [(), (("b", Bint[4]),), (("b", Bint[2]), ("c", Bint[2])),], + [(), (("b", Bint[4]),), (("b", Bint[2]), ("c", Bint[2]))], ids=id_from_inputs, ) @pytest.mark.parametrize( "event_inputs", - [(("e", Bint[3]),), (("e", Bint[2]), ("f", Bint[2])),], + [(("e", Bint[3]),), (("e", Bint[2]), ("f", Bint[2]))], ids=id_from_inputs, ) @pytest.mark.parametrize("test_grad", [False, True], ids=["value", "grad"]) @@ -271,7 +267,7 @@ def diff_fn(p_data): _, (p_data, mq_data) = align_tensors(p, mq) assert p_data.shape == mq_data.shape - return (ops.exp(mq_data) * probe).sum() - (ops.exp(p_data) * probe).sum(), mq + return ((ops.exp(mq_data) * probe).sum() - (ops.exp(p_data) * probe).sum(), mq) if test_grad: if get_backend() == "jax": @@ -294,13 +290,11 @@ def diff_fn(p_data): @pytest.mark.parametrize( "batch_inputs", - [(), (("b", Bint[3]),), (("b", Bint[3]), ("c", Bint[4])),], + [(), (("b", Bint[3]),), (("b", Bint[3]), ("c", Bint[4]))], ids=id_from_inputs, ) @pytest.mark.parametrize( - "event_inputs", - [(("e", Real),), (("e", Real), ("f", Reals[2])),], - ids=id_from_inputs, + "event_inputs", [(("e", Real),), (("e", Real), ("f", Reals[2]))], ids=id_from_inputs ) def test_gaussian_distribution(event_inputs, batch_inputs): num_samples = 100000 @@ -336,12 +330,12 @@ def test_gaussian_distribution(event_inputs, batch_inputs): @pytest.mark.parametrize( "batch_inputs", - [(), (("b", Bint[3]),), (("b", Bint[3]), ("c", Bint[2])),], + [(), (("b", Bint[3]),), (("b", Bint[3]), ("c", Bint[2]))], ids=id_from_inputs, ) @pytest.mark.parametrize( "event_inputs", - [(("e", Real), ("f", Bint[3])), (("e", Reals[2]), ("f", Bint[2])),], + [(("e", Real), ("f", Bint[3])), (("e", Reals[2]), ("f", Bint[2]))], ids=id_from_inputs, ) def test_gaussian_mixture_distribution(batch_inputs, event_inputs): diff --git a/test/test_sum_product.py b/test/test_sum_product.py index 9fc82aae4..ff2fb5ee8 100644 --- a/test/test_sum_product.py +++ b/test/test_sum_product.py @@ -100,9 +100,7 @@ def test_partition(inputs, dims, expected_num_components): ("abcij", ""), ], ) -@pytest.mark.parametrize( - "impl", [partial_sum_product, modified_partial_sum_product,], -) +@pytest.mark.parametrize("impl", [partial_sum_product, modified_partial_sum_product]) def test_partial_sum_product(impl, sum_op, prod_op, inputs, plates, vars1, vars2): inputs = inputs.split(",") factors = [random_tensor(OrderedDict((d, Bint[2]) for d in ds)) for ds in inputs] @@ -140,9 +138,7 @@ def test_partial_sum_product(impl, sum_op, prod_op, inputs, plates, vars1, vars2 (frozenset({"time", "x_0", "x_prev", "x_curr"}), frozenset()), ], ) -@pytest.mark.parametrize( - "x_dim,time", [(3, 1), (1, 5), (3, 5),], -) +@pytest.mark.parametrize("x_dim,time", [(3, 1), (1, 5), (3, 5)]) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] ) @@ -150,10 +146,10 @@ def test_modified_partial_sum_product_0(sum_op, prod_op, vars1, vars2, x_dim, ti f1 = random_tensor(OrderedDict({})) - f2 = random_tensor(OrderedDict({"x_0": Bint[x_dim],})) + f2 = random_tensor(OrderedDict({"x_0": Bint[x_dim]})) f3 = random_tensor( - OrderedDict({"time": Bint[time], "x_prev": Bint[x_dim], "x_curr": Bint[x_dim],}) + OrderedDict({"time": Bint[time], "x_prev": Bint[x_dim], "x_curr": Bint[x_dim]}) ) factors = [f1, f2, f3] @@ -186,7 +182,7 @@ def test_modified_partial_sum_product_0(sum_op, prod_op, vars1, vars2, x_dim, ti ], ) @pytest.mark.parametrize( - "x_dim,y_dim,time", [(2, 3, 5), (1, 3, 5), (2, 1, 5), (2, 3, 1),], + "x_dim,y_dim,time", [(2, 3, 5), (1, 3, 5), (2, 1, 5), (2, 3, 1)] ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -197,16 +193,16 @@ def test_modified_partial_sum_product_1( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor(OrderedDict({"x_0": Bint[x_dim],})) + f2 = random_tensor(OrderedDict({"x_0": Bint[x_dim]})) f3 = random_tensor( - OrderedDict({"time": Bint[time], "x_prev": Bint[x_dim], "x_curr": Bint[x_dim],}) + OrderedDict({"time": Bint[time], "x_prev": Bint[x_dim], "x_curr": Bint[x_dim]}) ) - f4 = random_tensor(OrderedDict({"x_0": Bint[x_dim], "y_0": Bint[y_dim],})) + f4 = random_tensor(OrderedDict({"x_0": Bint[x_dim], "y_0": Bint[y_dim]})) f5 = random_tensor( - OrderedDict({"time": Bint[time], "x_curr": Bint[x_dim], "y_curr": Bint[y_dim],}) + OrderedDict({"time": Bint[time], "x_curr": Bint[x_dim], "y_curr": Bint[y_dim]}) ) factors = [f1, f2, f3, f4, f5] @@ -244,7 +240,7 @@ def test_modified_partial_sum_product_1( ], ) @pytest.mark.parametrize( - "x_dim,y_dim,time", [(2, 3, 5), (1, 3, 5), (2, 1, 5), (2, 3, 1),], + "x_dim,y_dim,time", [(2, 3, 5), (1, 3, 5), (2, 1, 5), (2, 3, 1)] ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -255,16 +251,16 @@ def test_modified_partial_sum_product_2( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor(OrderedDict({"x_0": Bint[x_dim],})) + f2 = random_tensor(OrderedDict({"x_0": Bint[x_dim]})) f3 = random_tensor( - OrderedDict({"time": Bint[time], "x_prev": Bint[x_dim], "x_curr": Bint[x_dim],}) + OrderedDict({"time": Bint[time], "x_prev": Bint[x_dim], "x_curr": Bint[x_dim]}) ) - f4 = random_tensor(OrderedDict({"y_0": Bint[y_dim],})) + f4 = random_tensor(OrderedDict({"y_0": Bint[y_dim]})) f5 = random_tensor( - OrderedDict({"time": Bint[time], "y_prev": Bint[y_dim], "y_curr": Bint[y_dim],}) + OrderedDict({"time": Bint[time], "y_prev": Bint[y_dim], "y_curr": Bint[y_dim]}) ) factors = [f1, f2, f3, f4, f5] @@ -304,7 +300,7 @@ def test_modified_partial_sum_product_2( ], ) @pytest.mark.parametrize( - "x_dim,y_dim,time", [(2, 3, 5), (1, 3, 5), (2, 1, 5), (2, 3, 1),], + "x_dim,y_dim,time", [(2, 3, 5), (1, 3, 5), (2, 1, 5), (2, 3, 1)] ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -315,13 +311,13 @@ def test_modified_partial_sum_product_3( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor(OrderedDict({"x_0": Bint[x_dim],})) + f2 = random_tensor(OrderedDict({"x_0": Bint[x_dim]})) f3 = random_tensor( - OrderedDict({"time": Bint[time], "x_prev": Bint[x_dim], "x_curr": Bint[x_dim],}) + OrderedDict({"time": Bint[time], "x_prev": Bint[x_dim], "x_curr": Bint[x_dim]}) ) - f4 = random_tensor(OrderedDict({"x_0": Bint[x_dim], "y_0": Bint[y_dim],})) + f4 = random_tensor(OrderedDict({"x_0": Bint[x_dim], "y_0": Bint[y_dim]})) f5 = random_tensor( OrderedDict( @@ -402,7 +398,7 @@ def test_modified_partial_sum_product_3( ) @pytest.mark.parametrize( "x_dim,y_dim,sequences,time,tones", - [(2, 3, 2, 5, 4), (1, 3, 2, 5, 4), (2, 1, 2, 5, 4), (2, 3, 2, 1, 4),], + [(2, 3, 2, 5, 4), (1, 3, 2, 5, 4), (2, 1, 2, 5, 4), (2, 3, 2, 1, 4)], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -413,7 +409,7 @@ def test_modified_partial_sum_product_4( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor(OrderedDict({"sequences": Bint[sequences], "x_0": Bint[x_dim],})) + f2 = random_tensor(OrderedDict({"sequences": Bint[sequences], "x_0": Bint[x_dim]})) f3 = random_tensor( OrderedDict( @@ -428,7 +424,7 @@ def test_modified_partial_sum_product_4( f4 = random_tensor( OrderedDict( - {"sequences": Bint[sequences], "tones": Bint[tones], "y_0": Bint[y_dim],} + {"sequences": Bint[sequences], "tones": Bint[tones], "y_0": Bint[y_dim]} ) ) @@ -534,7 +530,7 @@ def test_modified_partial_sum_product_4( ) @pytest.mark.parametrize( "x_dim,y_dim,sequences,days,weeks,tones", - [(2, 3, 2, 5, 4, 3), (1, 3, 2, 5, 4, 3), (2, 1, 2, 5, 4, 3), (2, 3, 2, 1, 4, 3),], + [(2, 3, 2, 5, 4, 3), (1, 3, 2, 5, 4, 3), (2, 1, 2, 5, 4, 3), (2, 3, 2, 1, 4, 3)], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -547,7 +543,7 @@ def test_modified_partial_sum_product_5( f2 = random_tensor( OrderedDict( - {"sequences": Bint[sequences], "tones": Bint[tones], "x_0": Bint[x_dim],} + {"sequences": Bint[sequences], "tones": Bint[tones], "x_0": Bint[x_dim]} ) ) @@ -563,7 +559,7 @@ def test_modified_partial_sum_product_5( ) ) - f4 = random_tensor(OrderedDict({"sequences": Bint[sequences], "y_0": Bint[y_dim],})) + f4 = random_tensor(OrderedDict({"sequences": Bint[sequences], "y_0": Bint[y_dim]})) f5 = random_tensor( OrderedDict( @@ -647,7 +643,7 @@ def test_modified_partial_sum_product_5( ) @pytest.mark.parametrize( "x_dim,y_dim,sequences,time,tones", - [(2, 3, 2, 5, 4), (1, 3, 2, 5, 4), (2, 1, 2, 5, 4), (2, 3, 2, 1, 4),], + [(2, 3, 2, 5, 4), (1, 3, 2, 5, 4), (2, 1, 2, 5, 4), (2, 3, 2, 1, 4)], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -658,7 +654,7 @@ def test_modified_partial_sum_product_6( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor(OrderedDict({"sequences": Bint[sequences], "x_0": Bint[x_dim],})) + f2 = random_tensor(OrderedDict({"sequences": Bint[sequences], "x_0": Bint[x_dim]})) f3 = random_tensor( OrderedDict( @@ -764,7 +760,7 @@ def test_modified_partial_sum_product_6( ) @pytest.mark.parametrize( "x_dim,y_dim,sequences,time,tones", - [(2, 3, 2, 5, 4), (1, 3, 2, 5, 4), (2, 1, 2, 5, 4), (2, 3, 2, 1, 4),], + [(2, 3, 2, 5, 4), (1, 3, 2, 5, 4), (2, 1, 2, 5, 4), (2, 3, 2, 1, 4)], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -775,7 +771,7 @@ def test_modified_partial_sum_product_7( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor(OrderedDict({"sequences": Bint[sequences], "x_0": Bint[x_dim],})) + f2 = random_tensor(OrderedDict({"sequences": Bint[sequences], "x_0": Bint[x_dim]})) f3 = random_tensor( OrderedDict( @@ -815,7 +811,7 @@ def test_modified_partial_sum_product_7( factors = [f1, f2, f3, f4, f5] plate_to_step = { "sequences": {}, - "time": frozenset({("x_0", "x_prev", "x_curr"), ("y_0", "y_prev", "y_curr"),}), + "time": frozenset({("x_0", "x_prev", "x_curr"), ("y_0", "y_prev", "y_curr")}), "tones": {}, } @@ -904,7 +900,7 @@ def test_modified_partial_sum_product_7( ) @pytest.mark.parametrize( "w_dim,x_dim,y_dim,sequences,time,tones", - [(3, 2, 3, 2, 5, 4), (3, 1, 3, 2, 5, 4), (3, 2, 1, 2, 5, 4), (3, 2, 3, 2, 1, 4),], + [(3, 2, 3, 2, 5, 4), (3, 1, 3, 2, 5, 4), (3, 2, 1, 2, 5, 4), (3, 2, 3, 2, 1, 4)], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -915,7 +911,7 @@ def test_modified_partial_sum_product_8( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor(OrderedDict({"sequences": Bint[sequences], "w_0": Bint[w_dim],})) + f2 = random_tensor(OrderedDict({"sequences": Bint[sequences], "w_0": Bint[w_dim]})) f3 = random_tensor( OrderedDict( @@ -928,7 +924,7 @@ def test_modified_partial_sum_product_8( ) ) - f4 = random_tensor(OrderedDict({"sequences": Bint[sequences], "x_0": Bint[x_dim],})) + f4 = random_tensor(OrderedDict({"sequences": Bint[sequences], "x_0": Bint[x_dim]})) f5 = random_tensor( OrderedDict( @@ -969,7 +965,7 @@ def test_modified_partial_sum_product_8( factors = [f1, f2, f3, f4, f5, f6, f7] plate_to_step = { "sequences": {}, - "time": frozenset({("x_0", "x_prev", "x_curr"), ("w_0", "w_prev", "w_curr"),}), + "time": frozenset({("x_0", "x_prev", "x_curr"), ("w_0", "w_prev", "w_curr")}), "tones": {}, } @@ -1067,7 +1063,7 @@ def test_modified_partial_sum_product_8( ) @pytest.mark.parametrize( "w_dim,x_dim,y_dim,sequences,time,tones", - [(3, 2, 3, 2, 5, 4), (3, 1, 3, 2, 5, 4), (3, 2, 1, 2, 5, 4), (3, 2, 3, 2, 1, 4),], + [(3, 2, 3, 2, 5, 4), (3, 1, 3, 2, 5, 4), (3, 2, 1, 2, 5, 4), (3, 2, 3, 2, 1, 4)], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -1078,7 +1074,7 @@ def test_modified_partial_sum_product_9( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor(OrderedDict({"sequences": Bint[sequences], "w_0": Bint[w_dim],})) + f2 = random_tensor(OrderedDict({"sequences": Bint[sequences], "w_0": Bint[w_dim]})) f3 = random_tensor( OrderedDict( @@ -1093,7 +1089,7 @@ def test_modified_partial_sum_product_9( f4 = random_tensor( OrderedDict( - {"sequences": Bint[sequences], "w_0": Bint[w_dim], "x_0": Bint[x_dim],} + {"sequences": Bint[sequences], "w_0": Bint[w_dim], "x_0": Bint[x_dim]} ) ) @@ -1137,7 +1133,7 @@ def test_modified_partial_sum_product_9( factors = [f1, f2, f3, f4, f5, f6, f7] plate_to_step = { "sequences": {}, - "time": frozenset({("x_0", "x_prev", "x_curr"), ("w_0", "w_prev", "w_curr"),}), + "time": frozenset({("x_0", "x_prev", "x_curr"), ("w_0", "w_prev", "w_curr")}), "tones": {}, } @@ -1224,7 +1220,7 @@ def test_modified_partial_sum_product_9( ) @pytest.mark.parametrize( "w_dim,x_dim,y_dim,sequences,time,tones", - [(3, 2, 3, 2, 5, 4), (3, 1, 3, 2, 5, 4), (3, 2, 1, 2, 5, 4), (3, 2, 3, 2, 1, 4),], + [(3, 2, 3, 2, 5, 4), (3, 1, 3, 2, 5, 4), (3, 2, 1, 2, 5, 4), (3, 2, 3, 2, 1, 4)], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -1235,17 +1231,17 @@ def test_modified_partial_sum_product_10( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor(OrderedDict({"sequences": Bint[sequences], "w_0": Bint[w_dim],})) + f2 = random_tensor(OrderedDict({"sequences": Bint[sequences], "w_0": Bint[w_dim]})) f3 = random_tensor( OrderedDict( - {"sequences": Bint[sequences], "time": Bint[time], "w_curr": Bint[w_dim],} + {"sequences": Bint[sequences], "time": Bint[time], "w_curr": Bint[w_dim]} ) ) f4 = random_tensor( OrderedDict( - {"sequences": Bint[sequences], "w_0": Bint[w_dim], "x_0": Bint[x_dim],} + {"sequences": Bint[sequences], "w_0": Bint[w_dim], "x_0": Bint[x_dim]} ) ) @@ -1421,13 +1417,13 @@ def test_modified_partial_sum_product_11( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor(OrderedDict({"a": Bint[a_dim],})) + f2 = random_tensor(OrderedDict({"a": Bint[a_dim]})) - f3 = random_tensor(OrderedDict({"sequences": Bint[sequences], "b": Bint[b_dim],})) + f3 = random_tensor(OrderedDict({"sequences": Bint[sequences], "b": Bint[b_dim]})) f4 = random_tensor( OrderedDict( - {"a": Bint[a_dim], "sequences": Bint[sequences], "w_0": Bint[w_dim],} + {"a": Bint[a_dim], "sequences": Bint[sequences], "w_0": Bint[w_dim]} ) ) @@ -1579,7 +1575,7 @@ def test_modified_partial_sum_product_11( ) @pytest.mark.parametrize( "w_dim,x_dim,y_dim,sequences,time,tones", - [(3, 2, 3, 2, 5, 4), (3, 1, 3, 2, 5, 4), (3, 2, 1, 2, 5, 4), (3, 2, 3, 2, 1, 4),], + [(3, 2, 3, 2, 5, 4), (3, 1, 3, 2, 5, 4), (3, 2, 1, 2, 5, 4), (3, 2, 3, 2, 1, 4)], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -1590,11 +1586,11 @@ def test_modified_partial_sum_product_12( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor(OrderedDict({"sequences": Bint[sequences], "w_0": Bint[w_dim],})) + f2 = random_tensor(OrderedDict({"sequences": Bint[sequences], "w_0": Bint[w_dim]})) f3 = random_tensor( OrderedDict( - {"sequences": Bint[sequences], "time": Bint[time], "w_curr": Bint[w_dim],} + {"sequences": Bint[sequences], "time": Bint[time], "w_curr": Bint[w_dim]} ) ) @@ -1803,7 +1799,7 @@ def test_modified_partial_sum_product_13( f4 = random_tensor( OrderedDict( - {"w": Bint[w_dim], "sequences": Bint[sequences], "y_0": Bint[y_dim],} + {"w": Bint[w_dim], "sequences": Bint[sequences], "y_0": Bint[y_dim]} ) ) @@ -1924,7 +1920,7 @@ def test_modified_partial_sum_product_13( ) @pytest.mark.parametrize( "x_dim,y_dim,sequences,time,tones", - [(2, 3, 2, 3, 2), (1, 3, 2, 3, 2), (2, 1, 2, 3, 2), (2, 3, 2, 1, 2),], + [(2, 3, 2, 3, 2), (1, 3, 2, 3, 2), (2, 1, 2, 3, 2), (2, 3, 2, 1, 2)], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -1935,7 +1931,7 @@ def test_modified_partial_sum_product_14( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor(OrderedDict({"sequences": Bint[sequences], "x_0": Bint[x_dim],})) + f2 = random_tensor(OrderedDict({"sequences": Bint[sequences], "x_0": Bint[x_dim]})) f3 = random_tensor( OrderedDict( @@ -1950,7 +1946,7 @@ def test_modified_partial_sum_product_14( f4 = random_tensor( OrderedDict( - {"sequences": Bint[sequences], "x_0": Bint[x_dim], "y0_0": Bint[y_dim],} + {"sequences": Bint[sequences], "x_0": Bint[x_dim], "y0_0": Bint[y_dim]} ) ) @@ -1995,7 +1991,7 @@ def test_modified_partial_sum_product_14( "sequences": {}, "time": frozenset({("x_0", "x_prev", "x_curr")}), "tones": frozenset( - {("y0_0", "y0_prev", "y0_curr"), ("ycurr_0", "ycurr_prev", "ycurr_curr"),} + {("y0_0", "y0_prev", "y0_curr"), ("ycurr_0", "ycurr_prev", "ycurr_curr")} ), } @@ -2031,7 +2027,7 @@ def test_modified_partial_sum_product_14( ], ) @pytest.mark.parametrize( - "x_dim,y_dim,time", [(2, 3, 5), (1, 3, 5), (2, 1, 5), (2, 3, 1),], + "x_dim,y_dim,time", [(2, 3, 5), (1, 3, 5), (2, 1, 5), (2, 3, 1)] ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -2042,21 +2038,21 @@ def test_modified_partial_sum_product_16( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor(OrderedDict({"x_0": Bint[x_dim],})) + f2 = random_tensor(OrderedDict({"x_0": Bint[x_dim]})) f3 = random_tensor( - OrderedDict({"time": Bint[time], "y_prev": Bint[y_dim], "x_curr": Bint[x_dim],}) + OrderedDict({"time": Bint[time], "y_prev": Bint[y_dim], "x_curr": Bint[x_dim]}) ) - f4 = random_tensor(OrderedDict({"y_0": Bint[y_dim],})) + f4 = random_tensor(OrderedDict({"y_0": Bint[y_dim]})) f5 = random_tensor( - OrderedDict({"time": Bint[time], "x_prev": Bint[x_dim], "y_curr": Bint[y_dim],}) + OrderedDict({"time": Bint[time], "x_prev": Bint[x_dim], "y_curr": Bint[y_dim]}) ) factors = [f1, f2, f3, f4, f5] plate_to_step = { - "time": frozenset({("x_0", "x_prev", "x_curr"), ("y_0", "y_prev", "y_curr"),}), + "time": frozenset({("x_0", "x_prev", "x_curr"), ("y_0", "y_prev", "y_curr")}) } factors1 = modified_partial_sum_product( @@ -2126,7 +2122,7 @@ def test_modified_partial_sum_product_16( ], ) @pytest.mark.parametrize( - "x_dim,y_dim,z_dim,time", [(2, 3, 2, 5), (1, 3, 2, 5), (2, 1, 2, 5), (2, 3, 2, 1),], + "x_dim,y_dim,z_dim,time", [(2, 3, 2, 5), (1, 3, 2, 5), (2, 1, 2, 5), (2, 3, 2, 1)] ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -2137,10 +2133,10 @@ def test_modified_partial_sum_product_17( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor(OrderedDict({"x_0": Bint[x_dim],})) + f2 = random_tensor(OrderedDict({"x_0": Bint[x_dim]})) f3 = random_tensor( - OrderedDict({"time": Bint[time], "x_prev": Bint[x_dim], "x_curr": Bint[x_dim],}) + OrderedDict({"time": Bint[time], "x_prev": Bint[x_dim], "x_curr": Bint[x_dim]}) ) f4 = random_tensor( @@ -2190,7 +2186,7 @@ def test_modified_partial_sum_product_17( ) f8 = random_tensor( - OrderedDict({"x_0": Bint[x_dim], "y_0": Bint[y_dim], "z2_0": Bint[z_dim],}) + OrderedDict({"x_0": Bint[x_dim], "y_0": Bint[y_dim], "z2_0": Bint[z_dim]}) ) f9 = random_tensor( @@ -2205,9 +2201,7 @@ def test_modified_partial_sum_product_17( ) factors = [f1, f2, f3, f4, f5, f6, f7, f8, f9] - plate_to_step = { - "time": frozenset({("x_0", "x_prev", "x_curr")}), - } + plate_to_step = {"time": frozenset({("x_0", "x_prev", "x_curr")})} with (lazy if use_lazy else eager): factors1 = modified_partial_sum_product( @@ -2307,7 +2301,7 @@ def test_sequential_sum_product( ) @pytest.mark.parametrize( "x_domain,y_domain", - [(Bint[2], Bint[3]), (Real, Reals[2, 2]), (Bint[2], Reals[2]),], + [(Bint[2], Bint[3]), (Real, Reals[2, 2]), (Bint[2], Reals[2])], ids=str, ) @pytest.mark.parametrize( @@ -2376,15 +2370,15 @@ def test_sequential_sum_product_multi( @pytest.mark.parametrize("dim", [1, 2, 3]) def test_sequential_sum_product_bias_1(num_steps, dim): time = Variable("time", Bint[num_steps]) - bias_dist = random_gaussian(OrderedDict([("bias", Reals[dim]),])) + bias_dist = random_gaussian(OrderedDict([("bias", Reals[dim])])) trans = random_gaussian( OrderedDict( - [("time", Bint[num_steps]), ("x_prev", Reals[dim]), ("x_curr", Reals[dim]),] + [("time", Bint[num_steps]), ("x_prev", Reals[dim]), ("x_curr", Reals[dim])] ) ) obs = random_gaussian( OrderedDict( - [("time", Bint[num_steps]), ("x_curr", Reals[dim]), ("bias", Reals[dim]),] + [("time", Bint[num_steps]), ("x_curr", Reals[dim]), ("bias", Reals[dim])] ) ) factor = trans + obs + bias_dist @@ -2403,15 +2397,15 @@ def test_sequential_sum_product_bias_1(num_steps, dim): def test_sequential_sum_product_bias_2(num_steps, num_sensors, dim): time = Variable("time", Bint[num_steps]) bias = Variable("bias", Reals[num_sensors, dim]) - bias_dist = random_gaussian(OrderedDict([("bias", Reals[num_sensors, dim]),])) + bias_dist = random_gaussian(OrderedDict([("bias", Reals[num_sensors, dim])])) trans = random_gaussian( OrderedDict( - [("time", Bint[num_steps]), ("x_prev", Reals[dim]), ("x_curr", Reals[dim]),] + [("time", Bint[num_steps]), ("x_prev", Reals[dim]), ("x_curr", Reals[dim])] ) ) obs = random_gaussian( OrderedDict( - [("time", Bint[num_steps]), ("x_curr", Reals[dim]), ("bias", Reals[dim]),] + [("time", Bint[num_steps]), ("x_curr", Reals[dim]), ("bias", Reals[dim])] ) ) @@ -2457,11 +2451,9 @@ def _check_sarkka_bilmes(trans, expected_inputs, global_vars, num_periods=1): @pytest.mark.parametrize("duration", [2, 3, 4, 5, 6]) def test_sarkka_bilmes_example_0(duration): - trans = random_tensor(OrderedDict({"time": Bint[duration], "a": Bint[3],})) + trans = random_tensor(OrderedDict({"time": Bint[duration], "a": Bint[3]})) - expected_inputs = { - "a": Bint[3], - } + expected_inputs = {"a": Bint[3]} _check_sarkka_bilmes(trans, expected_inputs, frozenset()) @@ -2471,15 +2463,11 @@ def test_sarkka_bilmes_example_1(duration): trans = random_tensor( OrderedDict( - {"time": Bint[duration], "a": Bint[3], "b": Bint[2], "_PREV_b": Bint[2],} + {"time": Bint[duration], "a": Bint[3], "b": Bint[2], "_PREV_b": Bint[2]} ) ) - expected_inputs = { - "a": Bint[3], - "b": Bint[2], - "_PREV_b": Bint[2], - } + expected_inputs = {"a": Bint[3], "b": Bint[2], "_PREV_b": Bint[2]} _check_sarkka_bilmes(trans, expected_inputs, frozenset()) @@ -2565,15 +2553,11 @@ def test_sarkka_bilmes_example_5(duration): trans = random_tensor( OrderedDict( - {"time": Bint[duration], "a": Bint[3], "_PREV_a": Bint[3], "x": Bint[2],} + {"time": Bint[duration], "a": Bint[3], "_PREV_a": Bint[3], "x": Bint[2]} ) ) - expected_inputs = { - "a": Bint[3], - "_PREV_a": Bint[3], - "x": Bint[2], - } + expected_inputs = {"a": Bint[3], "_PREV_a": Bint[3], "x": Bint[2]} global_vars = frozenset(["x"]) @@ -2609,9 +2593,7 @@ def test_sarkka_bilmes_example_6(duration): @pytest.mark.parametrize("time_input", [("time", Bint[t]) for t in range(6, 11)]) -@pytest.mark.parametrize( - "global_inputs", [(), (("x", Bint[2]),),], -) +@pytest.mark.parametrize("global_inputs", [(), (("x", Bint[2]),)]) @pytest.mark.parametrize( "local_inputs", [ diff --git a/test/test_tensor.py b/test/test_tensor.py index f46611358..ff06ce771 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -126,7 +126,7 @@ def test_indexing(): def test_advanced_indexing_shape(): I, J, M, N = 4, 4, 2, 3 - x = Tensor(randn((I, J)), OrderedDict([("i", Bint[I]), ("j", Bint[J]),]),) + x = Tensor(randn((I, J)), OrderedDict([("i", Bint[I]), ("j", Bint[J])])) m = Tensor(numeric_array([2, 3]), OrderedDict([("m", Bint[M])]), I) n = Tensor(numeric_array([0, 1, 1]), OrderedDict([("n", Bint[N])]), J) assert x.data.shape == (I, J) @@ -223,17 +223,17 @@ def test_advanced_indexing_tensor(output_shape): # x output = Reals[output_shape] x = random_tensor( - OrderedDict([("i", Bint[2]), ("j", Bint[3]), ("k", Bint[4]),]), output, + OrderedDict([("i", Bint[2]), ("j", Bint[3]), ("k", Bint[4])]), output ) - i = random_tensor(OrderedDict([("u", Bint[5]),]), Bint[2],) - j = random_tensor(OrderedDict([("v", Bint[6]), ("u", Bint[5]),]), Bint[3],) - k = random_tensor(OrderedDict([("v", Bint[6]),]), Bint[4],) + i = random_tensor(OrderedDict([("u", Bint[5])]), Bint[2]) + j = random_tensor(OrderedDict([("v", Bint[6]), ("u", Bint[5])]), Bint[3]) + k = random_tensor(OrderedDict([("v", Bint[6])]), Bint[4]) expected_data = empty((5, 6) + output_shape) for u in range(5): for v in range(6): expected_data[u, v] = x.data[i.data[u], j.data[v, u], k.data[v]] - expected = Tensor(expected_data, OrderedDict([("u", Bint[5]), ("v", Bint[6]),]),) + expected = Tensor(expected_data, OrderedDict([("u", Bint[5]), ("v", Bint[6])])) assert_equiv(expected, x(i, j, k)) assert_equiv(expected, x(i=i, j=j, k=k)) @@ -258,7 +258,7 @@ def test_advanced_indexing_tensor(output_shape): def test_advanced_indexing_lazy(output_shape): x = Tensor( randn((2, 3, 4) + output_shape), - OrderedDict([("i", Bint[2]), ("j", Bint[3]), ("k", Bint[4]),]), + OrderedDict([("i", Bint[2]), ("j", Bint[3]), ("k", Bint[4])]), ) u = Variable("u", Bint[2]) v = Variable("v", Bint[3]) @@ -274,7 +274,7 @@ def test_advanced_indexing_lazy(output_shape): for u in range(2): for v in range(3): expected_data[u, v] = x.data[i_data[u], j_data[v], k_data[u, v]] - expected = Tensor(expected_data, OrderedDict([("u", Bint[2]), ("v", Bint[3]),]),) + expected = Tensor(expected_data, OrderedDict([("u", Bint[2]), ("v", Bint[3])])) assert_equiv(expected, x(i, j, k)) assert_equiv(expected, x(i=i, j=j, k=k)) @@ -304,7 +304,7 @@ def unary_eval(symbol, x): @pytest.mark.parametrize("dims", [(), ("a",), ("a", "b")]) @pytest.mark.parametrize( "symbol", - ["~", "-", "abs", "atanh", "sqrt", "exp", "log", "log1p", "sigmoid", "tanh",], + ["~", "-", "abs", "atanh", "sqrt", "exp", "log", "log1p", "sigmoid", "tanh"], ) def test_unary(symbol, dims): sizes = {"a": 3, "b": 4} @@ -837,8 +837,7 @@ def test_function_of_numeric_array(): def test_align(): x = Tensor( - randn((2, 3, 4)), - OrderedDict([("i", Bint[2]), ("j", Bint[3]), ("k", Bint[4]),]), + randn((2, 3, 4)), OrderedDict([("i", Bint[2]), ("j", Bint[3]), ("k", Bint[4])]) ) y = x.align(("j", "k", "i")) assert isinstance(y, Tensor) @@ -951,13 +950,13 @@ def test_tensor_stack(n, shape, dim): @pytest.mark.parametrize("output", [Bint[2], Real, Reals[4], Reals[2, 3]], ids=str) def test_funsor_stack(output): - x = random_tensor(OrderedDict([("i", Bint[2]),]), output,) - y = random_tensor(OrderedDict([("j", Bint[3]),]), output,) - z = random_tensor(OrderedDict([("i", Bint[2]), ("k", Bint[4]),]), output,) + x = random_tensor(OrderedDict([("i", Bint[2])]), output) + y = random_tensor(OrderedDict([("j", Bint[3])]), output) + z = random_tensor(OrderedDict([("i", Bint[2]), ("k", Bint[4])]), output) xy = Stack("t", (x, y)) assert isinstance(xy, Tensor) - assert xy.inputs == OrderedDict([("t", Bint[2]), ("i", Bint[2]), ("j", Bint[3]),]) + assert xy.inputs == OrderedDict([("t", Bint[2]), ("i", Bint[2]), ("j", Bint[3])]) assert xy.output == output for j in range(3): assert_close(xy(t=0, j=j), x) @@ -967,7 +966,7 @@ def test_funsor_stack(output): xyz = Stack("t", (x, y, z)) assert isinstance(xyz, Tensor) assert xyz.inputs == OrderedDict( - [("t", Bint[3]), ("i", Bint[2]), ("j", Bint[3]), ("k", Bint[4]),] + [("t", Bint[3]), ("i", Bint[2]), ("j", Bint[3]), ("k", Bint[4])] ) assert xy.output == output for j in range(3): @@ -982,9 +981,9 @@ def test_funsor_stack(output): @pytest.mark.parametrize("output", [Bint[2], Real, Reals[4], Reals[2, 3]], ids=str) def test_cat_simple(output): - x = random_tensor(OrderedDict([("i", Bint[2]),]), output,) - y = random_tensor(OrderedDict([("i", Bint[3]), ("j", Bint[4]),]), output,) - z = random_tensor(OrderedDict([("i", Bint[5]), ("k", Bint[6]),]), output,) + x = random_tensor(OrderedDict([("i", Bint[2])]), output) + y = random_tensor(OrderedDict([("i", Bint[3]), ("j", Bint[4])]), output) + z = random_tensor(OrderedDict([("i", Bint[5]), ("k", Bint[6])]), output) assert Cat("i", (x,)) is x assert Cat("i", (y,)) is y @@ -992,13 +991,13 @@ def test_cat_simple(output): xy = Cat("i", (x, y)) assert isinstance(xy, Tensor) - assert xy.inputs == OrderedDict([("i", Bint[2 + 3]), ("j", Bint[4]),]) + assert xy.inputs == OrderedDict([("i", Bint[2 + 3]), ("j", Bint[4])]) assert xy.output == output xyz = Cat("i", (x, y, z)) assert isinstance(xyz, Tensor) assert xyz.inputs == OrderedDict( - [("i", Bint[2 + 3 + 5]), ("j", Bint[4]), ("k", Bint[6]),] + [("i", Bint[2 + 3 + 5]), ("j", Bint[4]), ("k", Bint[6])] ) assert xy.output == output diff --git a/test/test_terms.py b/test/test_terms.py index f7f3b550f..ab33c2566 100644 --- a/test/test_terms.py +++ b/test/test_terms.py @@ -260,7 +260,7 @@ def unary_eval(symbol, x): @pytest.mark.parametrize("data", [0, 0.5, 1]) @pytest.mark.parametrize( "symbol", - ["~", "-", "atanh", "abs", "sqrt", "exp", "log", "log1p", "sigmoid", "tanh",], + ["~", "-", "atanh", "abs", "sqrt", "exp", "log", "log1p", "sigmoid", "tanh"], ) def test_unary(symbol, data): dtype = "real" @@ -276,21 +276,7 @@ def test_unary(symbol, data): check_funsor(actual, {}, Array[dtype, ()], expected_data) -BINARY_OPS = [ - "+", - "-", - "*", - "/", - "**", - "==", - "!=", - "<", - "<=", - ">", - ">=", - "min", - "max", -] +BINARY_OPS = ["+", "-", "*", "/", "**", "==", "!=", "<", "<=", ">", ">=", "min", "max"] BOOLEAN_OPS = ["&", "|", "^"] diff --git a/tutorials/sum_product_network.ipynb b/tutorials/sum_product_network.ipynb index fa2f417fe..65eed7901 100644 --- a/tutorials/sum_product_network.ipynb +++ b/tutorials/sum_product_network.ipynb @@ -4,9 +4,306 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Sum Product Network\n", + "# Sum Product Network" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "metadata": {}, + "outputs": [], + "source": [ + "from collections import OrderedDict\n", + "\n", + "import torch\n", "\n", - "(in preparation)" + "import funsor\n", + "import funsor.torch.distributions as dist\n", + "import funsor.ops as ops\n", + "\n", + "funsor.set_backend(\"torch\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### network" + ] + }, + { + "cell_type": "code", + "execution_count": 70, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Tensor(tensor([[[0.0341, 0.0371],\n", + " [0.0571, 0.0717]],\n", + "\n", + " [[0.1363, 0.1485],\n", + " [0.2285, 0.2867]]]), OrderedDict([('v0', Bint[2, ]), ('v1', Bint[2, ]), ('v2', Bint[2, ])]), 'real')" + ] + }, + "execution_count": 70, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# sum_op = +, prod_op = *\n", + "# alternatively, we can use rewrite_ops as in\n", + "# https://github.com/pyro-ppl/funsor/pull/456\n", + "# and switch to sum_op = logsumexp, prod_op = +\n", + "spn = 0.4 * (dist.Categorical(torch.tensor([0.2, 0.8]), value=\"v0\").exp() *\n", + " (0.3 * (dist.Categorical(torch.tensor([0.3, 0.7]), value=\"v1\").exp() *\n", + " dist.Categorical(torch.tensor([0.4, 0.6]), value=\"v2\").exp())\n", + " + 0.7 * (dist.Categorical(torch.tensor([0.5, 0.5]), value=\"v1\").exp() *\n", + " dist.Categorical(torch.tensor([0.6, 0.4]), value=\"v2\").exp()))) \\\n", + " + 0.6 * (dist.Categorical(torch.tensor([0.2, 0.8]), value=\"v0\").exp() *\n", + " dist.Categorical(torch.tensor([0.3, 0.7]), value=\"v1\").exp() *\n", + " dist.Categorical(torch.tensor([0.4, 0.6]), value=\"v2\").exp())\n", + "spn" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### marginalize" + ] + }, + { + "cell_type": "code", + "execution_count": 71, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tensor(tensor([[0.1704, 0.1856],\n", + " [0.2856, 0.3584]]), OrderedDict([('v1', Bint[2, ]), ('v2', Bint[2, ])]))\n" + ] + } + ], + "source": [ + "spn_marg = spn.reduce(ops.add, \"v0\")\n", + "print(spn_marg)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### likelihood" + ] + }, + { + "cell_type": "code", + "execution_count": 72, + "metadata": {}, + "outputs": [], + "source": [ + "test_data = {\"v0\": 1, \"v1\": 0, \"v2\": 1}" + ] + }, + { + "cell_type": "code", + "execution_count": 73, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(-1.9073) tensor(0.1485)\n" + ] + } + ], + "source": [ + "ll_exp = spn(**test_data)\n", + "print(ll_exp.log(), ll_exp)" + ] + }, + { + "cell_type": "code", + "execution_count": 74, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(-1.6842) tensor(0.1856)\n" + ] + } + ], + "source": [ + "llm_exp = spn_marg(**test_data)\n", + "print(llm_exp.log(), llm_exp)" + ] + }, + { + "cell_type": "code", + "execution_count": 76, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(-1.6842) tensor(0.1856)\n" + ] + } + ], + "source": [ + "test_data2 = {\"v1\": 0, \"v2\": 1}\n", + "llom_exp = spn(**test_data2).reduce(ops.add)\n", + "print(llom_exp.log(), llom_exp)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### sample" + ] + }, + { + "cell_type": "code", + "execution_count": 77, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Delta((('v0', (Tensor(tensor([0, 1, 0, 0, 1]), OrderedDict([('particle', Bint[5, ])]), 2), Number(0.0))),)) + Tensor(-0.8297846913337708, OrderedDict(), 'real').reduce(nullop, set())" + ] + }, + "execution_count": 77, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sample_inputs = OrderedDict(particle=funsor.Bint[5])\n", + "spn(v1=0, v2=0).sample(frozenset({\"v0\"}), sample_inputs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "what is `-0.8297846913337708`? a normalization factor?" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### train parameters" + ] + }, + { + "cell_type": "code", + "execution_count": 82, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(-2.0612e-09)" + ] + }, + "execution_count": 82, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "-torch.nn.functional.softplus(-torch.tensor(20.))" + ] + }, + { + "cell_type": "code", + "execution_count": 79, + "metadata": {}, + "outputs": [ + { + "ename": "SyntaxError", + "evalue": "invalid syntax (, line 1)", + "output_type": "error", + "traceback": [ + "\u001b[0;36m File \u001b[0;32m\"\"\u001b[0;36m, line \u001b[0;32m1\u001b[0m\n\u001b[0;31m parameter optimization\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m invalid syntax\n" + ] + } + ], + "source": [ + "parameter optimization" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### most probable explanation" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### multivariate leaf" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### cutset networks" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### expectations and moments" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "Integrate(q, x, q_vars)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### pareto" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(-0.5232)\n" + ] + } + ], + "source": [ + "spn = 0.3 * dist.Pareto(1., 2., value=\"v0\").exp() + 0.7 * dist.Pareto(1., 3., value=\"v0\").exp()\n", + "print(spn(v0=1.5).log())" ] } ], From e7dc9d1db7637ed29c9ea8341a987360be607ccd Mon Sep 17 00:00:00 2001 From: Du Phan Date: Sun, 21 Feb 2021 10:17:42 -0600 Subject: [PATCH 4/8] revert black changes --- funsor/affine.py | 6 +- funsor/cnf.py | 6 +- funsor/delta.py | 5 +- funsor/distribution.py | 7 +- funsor/gaussian.py | 7 +- funsor/instrument.py | 9 +- funsor/integrate.py | 4 +- funsor/interpreter.py | 3 +- funsor/jax/__init__.py | 8 +- funsor/jax/ops.py | 5 +- funsor/joint.py | 4 +- funsor/memoize.py | 4 +- funsor/montecarlo.py | 4 +- funsor/registry.py | 4 +- funsor/syntax.py | 23 +- funsor/terms.py | 2 +- funsor/testing.py | 6 +- test/examples/test_bart.py | 92 ++++- test/examples/test_sensor_fusion.py | 11 +- test/pyro/test_hmm.py | 145 ++++++- test/test_adjoint.py | 7 +- test/test_distribution.py | 5 +- test/test_distribution_generic.py | 8 +- test/test_domains.py | 10 +- test/test_factory.py | 22 +- test/test_gaussian.py | 15 +- test/test_memoize.py | 4 +- test/test_minipyro.py | 17 +- test/test_optimizer.py | 28 +- test/test_samplers.py | 111 +++++- test/test_sum_product.py | 573 ++++++++++++++++++++++++---- test/test_tensor.py | 182 ++++++++- test/test_terms.py | 29 +- 33 files changed, 1172 insertions(+), 194 deletions(-) diff --git a/funsor/affine.py b/funsor/affine.py index 1d6c9a2d1..c9e181911 100644 --- a/funsor/affine.py +++ b/funsor/affine.py @@ -159,4 +159,8 @@ def extract_affine(fn): return const, coeffs -__all__ = ["affine_inputs", "extract_affine", "is_affine"] +__all__ = [ + "affine_inputs", + "extract_affine", + "is_affine", +] diff --git a/funsor/cnf.py b/funsor/cnf.py index 0e570df50..2200a6e74 100644 --- a/funsor/cnf.py +++ b/funsor/cnf.py @@ -594,7 +594,11 @@ def unary_contract(op, arg): ) -BACKEND_TO_EINSUM_BACKEND = {"numpy": "numpy", "torch": "torch", "jax": "jax.numpy"} +BACKEND_TO_EINSUM_BACKEND = { + "numpy": "numpy", + "torch": "torch", + "jax": "jax.numpy", +} # NB: numpy_log, numpy_map is backend-agnostic so they also work for torch backend; # however, we might need to profile to make a switch BACKEND_TO_LOGSUMEXP_BACKEND = { diff --git a/funsor/delta.py b/funsor/delta.py index 341280050..f87359b8f 100644 --- a/funsor/delta.py +++ b/funsor/delta.py @@ -248,4 +248,7 @@ def eager_independent_delta(delta, reals_var, bint_var, diag_var): return None -__all__ = ["Delta", "solve"] +__all__ = [ + "Delta", + "solve", +] diff --git a/funsor/distribution.py b/funsor/distribution.py index f67ceb5ef..b179905a5 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -184,7 +184,7 @@ def eager_log_prob(cls, *params): params, value = params[:-1], params[-1] params = params + (Variable("value", value.output),) instance = reflect.interpret(cls, *params) - (raw_dist, value_name, value_output, dim_to_name) = instance._get_raw_dist() + raw_dist, value_name, value_output, dim_to_name = instance._get_raw_dist() assert value.output == value_output name_to_dim = {v: k for k, v in dim_to_name.items()} dim_to_name.update( @@ -379,7 +379,10 @@ def dist_init(self, **kwargs): dist_class = DistributionMeta( backend_dist_class.__name__.split("Wrapper_")[-1], (Distribution,), - {"dist_class": backend_dist_class, "__init__": dist_init}, + { + "dist_class": backend_dist_class, + "__init__": dist_init, + }, ) if generate_eager: diff --git a/funsor/gaussian.py b/funsor/gaussian.py index 13edad9d1..caa278c9c 100644 --- a/funsor/gaussian.py +++ b/funsor/gaussian.py @@ -779,4 +779,9 @@ def eager_neg(op, arg): return Gaussian(info_vec, precision, arg.inputs) -__all__ = ["BlockMatrix", "BlockVector", "Gaussian", "align_gaussian"] +__all__ = [ + "BlockMatrix", + "BlockVector", + "Gaussian", + "align_gaussian", +] diff --git a/funsor/instrument.py b/funsor/instrument.py index 569c771a4..71f797cf9 100644 --- a/funsor/instrument.py +++ b/funsor/instrument.py @@ -108,4 +108,11 @@ def print_counters(): print("-" * 80) -__all__ = ["DEBUG", "PROFILE", "STACK_SIZE", "debug_logged", "get_indent", "profile"] +__all__ = [ + "DEBUG", + "PROFILE", + "STACK_SIZE", + "debug_logged", + "get_indent", + "profile", +] diff --git a/funsor/integrate.py b/funsor/integrate.py index 7212628af..b75b6f50e 100644 --- a/funsor/integrate.py +++ b/funsor/integrate.py @@ -230,4 +230,6 @@ def eager_integrate(log_measure, integrand, reduced_vars): return None # defer to default implementation -__all__ = ["Integrate"] +__all__ = [ + "Integrate", +] diff --git a/funsor/interpreter.py b/funsor/interpreter.py index 5059672b4..759d75587 100644 --- a/funsor/interpreter.py +++ b/funsor/interpreter.py @@ -80,7 +80,8 @@ def interpret(cls, *args): def interpretation(new): warnings.warn( - "'with interpretation(x)' should be replaced by 'with x'", DeprecationWarning + "'with interpretation(x)' should be replaced by 'with x'", + DeprecationWarning, ) return new diff --git a/funsor/jax/__init__.py b/funsor/jax/__init__.py index dae7b7435..ffefedd58 100644 --- a/funsor/jax/__init__.py +++ b/funsor/jax/__init__.py @@ -18,7 +18,13 @@ @adjoint_ops.register( - Tensor, AssociativeOp, AssociativeOp, Funsor, (DeviceArray, Tracer), tuple, object + Tensor, + AssociativeOp, + AssociativeOp, + Funsor, + (DeviceArray, Tracer), + tuple, + object, ) def adjoint_tensor(adj_redop, adj_binop, out_adj, data, inputs, dtype): return {} diff --git a/funsor/jax/ops.py b/funsor/jax/ops.py index 0570b92b6..07e38c9d6 100644 --- a/funsor/jax/ops.py +++ b/funsor/jax/ops.py @@ -257,7 +257,10 @@ def _triangular_solve(x, y, upper=False, transpose=False): x_new_shape = batch_shape[:prepend_ndim] for (sy, sx) in zip(y.shape[:-2], batch_shape[prepend_ndim:]): x_new_shape += (sx // sy, sy) - x_new_shape += (n, m) + x_new_shape += ( + n, + m, + ) x = np.reshape(x, x_new_shape) # Permute y to make it have shape (..., 1, j, m, i, 1, n) batch_ndim = x.ndim - 2 diff --git a/funsor/joint.py b/funsor/joint.py index 8658248db..cdc2a092c 100644 --- a/funsor/joint.py +++ b/funsor/joint.py @@ -104,7 +104,9 @@ def moment_matching_contract_joint(red_op, bin_op, reduced_vars, discrete, gauss discrete += gaussian.log_normalizer new_discrete = discrete.reduce(ops.logaddexp, approx_vars & discrete.input_vars) num_elements = reduce( - ops.mul, [v.output.num_elements for v in approx_vars - discrete.input_vars], 1 + ops.mul, + [v.output.num_elements for v in approx_vars - discrete.input_vars], + 1, ) if num_elements != 1: new_discrete -= math.log(num_elements) diff --git a/funsor/memoize.py b/funsor/memoize.py index baf45a471..90fea683f 100644 --- a/funsor/memoize.py +++ b/funsor/memoize.py @@ -40,4 +40,6 @@ def interpret(self, cls, *args): return value -__all__ = ["memoize"] +__all__ = [ + "memoize", +] diff --git a/funsor/montecarlo.py b/funsor/montecarlo.py index 3ef205b4f..b533c0f1a 100644 --- a/funsor/montecarlo.py +++ b/funsor/montecarlo.py @@ -40,4 +40,6 @@ def monte_carlo_integrate(state, log_measure, integrand, reduced_vars): return Integrate(sample, integrand, reduced_vars) -__all__ = ["MonteCarlo"] +__all__ = [ + "MonteCarlo", +] diff --git a/funsor/registry.py b/funsor/registry.py index 353693861..07f9f5542 100644 --- a/funsor/registry.py +++ b/funsor/registry.py @@ -84,4 +84,6 @@ def dispatch(self, key, *args): return self[key].partial_call(*args) -__all__ = ["KeyedRegistry"] +__all__ = [ + "KeyedRegistry", +] diff --git a/funsor/syntax.py b/funsor/syntax.py index 7b4291548..ed2b7d6eb 100644 --- a/funsor/syntax.py +++ b/funsor/syntax.py @@ -59,7 +59,12 @@ def visit_UnaryOp(self, node): var = self.prefix.get(type(node.op)) if var is not None: node = ast.Call( - func=ast.Name(id=var, ctx=ast.Load()), args=[node.operand], keywords=[] + func=ast.Name( + id=var, + ctx=ast.Load(), + ), + args=[node.operand], + keywords=[], ) return node @@ -68,7 +73,10 @@ def visit_BinOp(self, node): var = self.infix.get(type(node.op)) if var is not None: node = ast.Call( - func=ast.Name(id=var, ctx=ast.Load()), + func=ast.Name( + id=var, + ctx=ast.Load(), + ), args=[node.left, node.right], keywords=[], ) @@ -90,7 +98,10 @@ def visit_Compare(self, node): var = self.infix.get(type(node_op)) if var is not None: node = ast.Call( - func=ast.Name(id=var, ctx=ast.Load()), + func=ast.Name( + id=var, + ctx=ast.Load(), + ), args=[node.left, node_right], keywords=[], ) @@ -161,4 +172,8 @@ def decorator(fn): return decorator -__all__ = ["INFIX_OPERATORS", "PREFIX_OPERATORS", "rewrite_ops"] +__all__ = [ + "INFIX_OPERATORS", + "PREFIX_OPERATORS", + "rewrite_ops", +] diff --git a/funsor/terms.py b/funsor/terms.py index efc1cc33c..3b13ced73 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -1520,7 +1520,7 @@ def eager_subs(self, subs): n -= size assert False elif isinstance(value, Slice): - start, stop, step = (value.slice.start, value.slice.stop, value.slice.step) + start, stop, step = value.slice.start, value.slice.stop, value.slice.step new_parts = [] pos = 0 for part in self.parts: diff --git a/funsor/testing.py b/funsor/testing.py index aefcb52d8..ef5dc7ef9 100644 --- a/funsor/testing.py +++ b/funsor/testing.py @@ -114,9 +114,9 @@ def assert_close(actual, expected, atol=1e-6, rtol=1e-6): n for n, p in expected.terms ) actual = actual.align(tuple(n for n, p in expected.terms)) - for ( - (actual_name, (actual_point, actual_log_density)), - (expected_name, (expected_point, expected_log_density)), + for (actual_name, (actual_point, actual_log_density)), ( + expected_name, + (expected_point, expected_log_density), ) in zip(actual.terms, expected.terms): assert actual_name == expected_name assert_close(actual_point, expected_point, atol=atol, rtol=rtol) diff --git a/test/examples/test_bart.py b/test/examples/test_bart.py index c6f1762f1..afcdb7a27 100644 --- a/test/examples/test_bart.py +++ b/test/examples/test_bart.py @@ -52,7 +52,10 @@ def unpack_gate_rate(gate_rate): @pytest.mark.parametrize( "analytic_kl", - [False, xfail_param(True, reason="missing pattern")], + [ + False, + xfail_param(True, reason="missing pattern"), + ], ids=["monte-carlo-kl", "analytic-kl"], ) def test_bart(analytic_kl): @@ -93,7 +96,16 @@ def test_bart(analytic_kl): ], dtype=torch.float32, ), # noqa - (("time_b4", Bint[2]), ("_event_1_b2", Bint[8])), + ( + ( + "time_b4", + Bint[2], + ), + ( + "_event_1_b2", + Bint[8], + ), + ), "real", ), Gaussian( @@ -148,9 +160,18 @@ def test_bart(analytic_kl): dtype=torch.float32, ), # noqa ( - ("time_b4", Bint[2]), - ("_event_1_b2", Bint[8]), - ("value_b1", Real), + ( + "time_b4", + Bint[2], + ), + ( + "_event_1_b2", + Bint[8], + ), + ( + "value_b1", + Real, + ), ), ), ), @@ -220,8 +241,14 @@ def test_bart(analytic_kl): dtype=torch.float32, ), # noqa ( - ("state_b7", Reals[2]), - ("state(time=1)_b8", Reals[2]), + ( + "state_b7", + Reals[2], + ), + ( + "state(time=1)_b8", + Reals[2], + ), ), ), Subs( @@ -281,7 +308,12 @@ def test_bart(analytic_kl): ], dtype=torch.float32, ), # noqa - (("time_b9", Bint[2]),), + ( + ( + "time_b9", + Bint[2], + ), + ), "real", ), Tensor( @@ -310,7 +342,12 @@ def test_bart(analytic_kl): ], dtype=torch.float32, ), # noqa - (("time_b9", Bint[2]),), + ( + ( + "time_b9", + Bint[2], + ), + ), "real", ), Variable("state(time=1)_b8", Reals[2]), @@ -352,7 +389,12 @@ def test_bart(analytic_kl): ), Variable("value_b5", Reals[2]), ), - (("value_b5", Variable("state_b10", Reals[2])),), + ( + ( + "value_b5", + Variable("state_b10", Reals[2]), + ), + ), ), ), ) @@ -449,9 +491,18 @@ def test_bart(analytic_kl): dtype=torch.float32, ), # noqa ( - ("time_b17", Bint[2]), - ("origin_b15", Bint[2]), - ("destin_b16", Bint[2]), + ( + "time_b17", + Bint[2], + ), + ( + "origin_b15", + Bint[2], + ), + ( + "destin_b16", + Bint[2], + ), ), "real", ), @@ -476,9 +527,18 @@ def test_bart(analytic_kl): dtype=torch.float32, ), # noqa ( - ("time_b17", Bint[2]), - ("origin_b15", Bint[2]), - ("destin_b16", Bint[2]), + ( + "time_b17", + Bint[2], + ), + ( + "origin_b15", + Bint[2], + ), + ( + "destin_b16", + Bint[2], + ), ), "real", ), diff --git a/test/examples/test_sensor_fusion.py b/test/examples/test_sensor_fusion.py index f52e483a0..f8fc8b77f 100644 --- a/test/examples/test_sensor_fusion.py +++ b/test/examples/test_sensor_fusion.py @@ -142,7 +142,16 @@ def test_affine_subs(): ], dtype=torch.float32, ), # noqa - (("state_1_b6", Reals[3]), ("obs_b2", Reals[2])), + ( + ( + "state_1_b6", + Reals[3], + ), + ( + "obs_b2", + Reals[2], + ), + ), ), ( ( diff --git a/test/pyro/test_hmm.py b/test/pyro/test_hmm.py index 9db9dcad7..69d645ca5 100644 --- a/test/pyro/test_hmm.py +++ b/test/pyro/test_hmm.py @@ -245,19 +245,134 @@ def test_gaussian_mrf_log_prob(init_shape, trans_shape, obs_shape, hidden_dim, o ] ) SLHMM_SHAPES = [ - ((2,), (), (1, 2), (1, 3, 3), (1,), (1, 3, 4), (1,)), - ((2,), (), (5, 1, 2), (1, 3, 3), (1,), (1, 3, 4), (1,)), - ((2,), (), (1, 2), (5, 1, 3, 3), (1,), (1, 3, 4), (1,)), - ((2,), (), (1, 2), (1, 3, 3), (5, 1), (1, 3, 4), (1,)), - ((2,), (), (1, 2), (1, 3, 3), (1,), (5, 1, 3, 4), (1,)), - ((2,), (), (1, 2), (1, 3, 3), (1,), (1, 3, 4), (5, 1)), - ((2,), (), (5, 1, 2), (5, 1, 3, 3), (5, 1), (5, 1, 3, 4), (5, 1)), - ((2,), (2,), (5, 2, 2), (5, 2, 3, 3), (5, 2), (5, 2, 3, 4), (5, 2)), - ((7, 2), (), (7, 5, 1, 2), (7, 5, 1, 3, 3), (7, 5, 1), (7, 5, 1, 3, 4), (7, 5, 1)), ( + (2,), + (), + ( + 1, + 2, + ), + (1, 3, 3), + (1,), + (1, 3, 4), + (1,), + ), + ( + (2,), + (), + ( + 5, + 1, + 2, + ), + (1, 3, 3), + (1,), + (1, 3, 4), + (1,), + ), + ( + (2,), + (), + ( + 1, + 2, + ), + (5, 1, 3, 3), + (1,), + (1, 3, 4), + (1,), + ), + ( + (2,), + (), + ( + 1, + 2, + ), + (1, 3, 3), + (5, 1), + (1, 3, 4), + (1,), + ), + ( + (2,), + (), + ( + 1, + 2, + ), + (1, 3, 3), + (1,), + (5, 1, 3, 4), + (1,), + ), + ( + (2,), + (), + ( + 1, + 2, + ), + (1, 3, 3), + (1,), + (1, 3, 4), + (5, 1), + ), + ( + (2,), + (), + ( + 5, + 1, + 2, + ), + (5, 1, 3, 3), + (5, 1), + (5, 1, 3, 4), + (5, 1), + ), + ( + (2,), + (2,), + ( + 5, + 2, + 2, + ), + (5, 2, 3, 3), + (5, 2), + (5, 2, 3, 4), + (5, 2), + ), + ( + ( + 7, + 2, + ), + (), + ( + 7, + 5, + 1, + 2, + ), + (7, 5, 1, 3, 3), + (7, 5, 1), + (7, 5, 1, 3, 4), + (7, 5, 1), + ), + ( + ( + 7, + 2, + ), (7, 2), - (7, 2), - (7, 5, 2, 2), + ( + 7, + 5, + 2, + 2, + ), (7, 5, 2, 3, 3), (7, 5, 2), (7, 5, 2, 3, 4), @@ -403,7 +518,13 @@ def test_switching_linear_hmm_log_prob_alternating(exact, num_steps, num_compone -1, num_components, -1, -1 ) - trans_mvn = random_mvn((num_steps, num_components), hidden_dim) + trans_mvn = random_mvn( + ( + num_steps, + num_components, + ), + hidden_dim, + ) hmm_obs_matrix = torch.randn(num_steps, hidden_dim, obs_dim) switching_obs_matrix = hmm_obs_matrix.unsqueeze(-3).expand( -1, num_components, -1, -1 diff --git a/test/test_adjoint.py b/test/test_adjoint.py index 4756479eb..b098c1c95 100644 --- a/test/test_adjoint.py +++ b/test/test_adjoint.py @@ -201,7 +201,12 @@ def test_optimized_plated_einsum_adjoint(equation, plates, backend): ids=lambda d: ",".join(d.keys()), ) @pytest.mark.parametrize( - "impl", [sequential_sum_product, naive_sequential_sum_product, MarkovProduct] + "impl", + [ + sequential_sum_product, + naive_sequential_sum_product, + MarkovProduct, + ], ) def test_sequential_sum_product_adjoint( impl, sum_op, prod_op, batch_inputs, state_domain, num_steps diff --git a/test/test_distribution.py b/test/test_distribution.py index c73da8567..d0794a1eb 100644 --- a/test/test_distribution.py +++ b/test/test_distribution.py @@ -1459,7 +1459,10 @@ def test_power_transform(shape): @pytest.mark.parametrize("shape", [(10,), (4, 3)], ids=str) @pytest.mark.parametrize( "to_event", - [True, xfail_param(False, reason="bug in to_funsor(TransformedDistribution)")], + [ + True, + xfail_param(False, reason="bug in to_funsor(TransformedDistribution)"), + ], ) def test_haar_transform(shape, to_event): try: diff --git a/test/test_distribution_generic.py b/test/test_distribution_generic.py index db1c4c263..5ffa99bf6 100644 --- a/test/test_distribution_generic.py +++ b/test/test_distribution_generic.py @@ -186,7 +186,9 @@ def __hash__(self): # Chi2 DistTestCase( - "dist.Chi2(df=case.df)", (("df", f"rand({batch_shape})"),), funsor.Real + "dist.Chi2(df=case.df)", + (("df", f"rand({batch_shape})"),), + funsor.Real, ) # ContinuousBernoulli @@ -368,7 +370,9 @@ def __hash__(self): # Poisson DistTestCase( - "dist.Poisson(rate=case.rate)", (("rate", f"rand({batch_shape})"),), funsor.Real + "dist.Poisson(rate=case.rate)", + (("rate", f"rand({batch_shape})"),), + funsor.Real, ) # RelaxedBernoulli diff --git a/test/test_domains.py b/test/test_domains.py index 29bfc5cdd..d721ee03e 100644 --- a/test/test_domains.py +++ b/test/test_domains.py @@ -9,7 +9,15 @@ from funsor.domains import Bint, Real, Reals # noqa F401 -@pytest.mark.parametrize("expr", ["Bint[2]", "Real", "Reals[4]", "Reals[3, 2]"]) +@pytest.mark.parametrize( + "expr", + [ + "Bint[2]", + "Real", + "Reals[4]", + "Reals[3, 2]", + ], +) def test_pickle(expr): x = eval(expr) f = io.BytesIO() diff --git a/test/test_factory.py b/test/test_factory.py index e23f50e20..a9fe8b78c 100644 --- a/test/test_factory.py +++ b/test/test_factory.py @@ -19,7 +19,9 @@ def test_lambda_lambda(): @make_funsor def LambdaLambda( - i: Bound, j: Bound, x: Funsor + i: Bound, + j: Bound, + x: Funsor, ) -> Fresh[lambda i, j, x: Array[x.dtype, (i.size, j.size) + x.shape]]: assert i in x.inputs assert j in x.inputs @@ -49,7 +51,10 @@ def GetitemGetitem( def test_flatten(): @make_funsor def Flatten21( - x: Funsor, i: Bound, j: Bound, ij: Fresh[lambda i, j: Bint[i.size * j.size]] + x: Funsor, + i: Bound, + j: Bound, + ij: Fresh[lambda i, j: Bint[i.size * j.size]], ) -> Fresh[lambda x: x.dtype]: m = to_funsor(i, x.inputs.get(i, None)).output.size n = to_funsor(j, x.inputs.get(j, None)).output.size @@ -115,7 +120,9 @@ def Cat2( def test_normal(): @make_funsor def Normal( - loc: Funsor, scale: Funsor, value: Fresh[lambda loc: loc] + loc: Funsor, + scale: Funsor, + value: Fresh[lambda loc: loc], ) -> Fresh[Real]: return None @@ -140,7 +147,11 @@ def _(loc, scale, value): def test_matmul(): @make_funsor - def MatMul(x: Funsor, y: Funsor, i: Bound) -> Fresh[lambda x: x]: + def MatMul( + x: Funsor, + y: Funsor, + i: Bound, + ) -> Fresh[lambda x: x]: return (x * y).reduce(ops.add, i) x = random_tensor(OrderedDict(a=Bint[3], b=Bint[4])) @@ -171,7 +182,8 @@ def Scatter1( def test_value_dependence(): @make_funsor def Sum( - x: Funsor, dim: Value[int] + x: Funsor, + dim: Value[int], ) -> Fresh[lambda x, dim: Array[x.dtype, x.shape[:dim] + x.shape[dim + 1 :]]]: return None diff --git a/test/test_gaussian.py b/test/test_gaussian.py index f3c5ac636..d9e66af02 100644 --- a/test/test_gaussian.py +++ b/test/test_gaussian.py @@ -576,10 +576,21 @@ def test_reduce_logsumexp(int_inputs, real_inputs): ) -@pytest.mark.parametrize("int_inputs", [{}, {"i": Bint[2]}], ids=id_from_inputs) +@pytest.mark.parametrize( + "int_inputs", + [ + {}, + {"i": Bint[2]}, + ], + ids=id_from_inputs, +) @pytest.mark.parametrize( "real_inputs", - [{"x": Real}, {"x": Reals[4]}, {"x": Reals[2, 3]}], + [ + {"x": Real}, + {"x": Reals[4]}, + {"x": Reals[2, 3]}, + ], ids=id_from_inputs, ) def test_integrate_variable(int_inputs, real_inputs): diff --git a/test/test_memoize.py b/test/test_memoize.py index e54b18cb2..14b11b3aa 100644 --- a/test/test_memoize.py +++ b/test/test_memoize.py @@ -169,10 +169,10 @@ def test_nested_einsum_complete_sharing( eqn1, eqn2, einsum_impl1, einsum_impl2, backend1, backend2 ): - (inputs1, outputs1, sizes1, operands1, funsor_operands1) = make_einsum_example( + inputs1, outputs1, sizes1, operands1, funsor_operands1 = make_einsum_example( eqn1, sizes=(3,) ) - (inputs2, outputs2, sizes2, operands2, funsor_operands2) = make_einsum_example( + inputs2, outputs2, sizes2, operands2, funsor_operands2 = make_einsum_example( eqn2, sizes=(3,) ) diff --git a/test/test_minipyro.py b/test/test_minipyro.py index b9c1eb937..5224ab25c 100644 --- a/test/test_minipyro.py +++ b/test/test_minipyro.py @@ -36,9 +36,8 @@ def Vindex(x): def _check_loss_and_grads(expected_loss, actual_loss, atol=1e-4, rtol=1e-4): # copied from pyro - expected_loss, actual_loss = ( - funsor.to_data(expected_loss), - funsor.to_data(actual_loss), + expected_loss, actual_loss = funsor.to_data(expected_loss), funsor.to_data( + actual_loss ) assert ops.allclose(actual_loss, expected_loss, atol=atol, rtol=rtol) names = pyro.get_param_store().keys() @@ -302,7 +301,11 @@ def guide(): @pytest.mark.parametrize( - "backend", ["pyro", xfail_param("funsor", reason="missing patterns")] + "backend", + [ + "pyro", + xfail_param("funsor", reason="missing patterns"), + ], ) def test_mean_field_ok(backend): def model(): @@ -320,7 +323,11 @@ def guide(): @pytest.mark.parametrize( - "backend", ["pyro", xfail_param("funsor", reason="missing patterns")] + "backend", + [ + "pyro", + xfail_param("funsor", reason="missing patterns"), + ], ) def test_mean_field_warn(backend): def model(): diff --git a/test/test_optimizer.py b/test/test_optimizer.py index 7c5399622..ee16c9d75 100644 --- a/test/test_optimizer.py +++ b/test/test_optimizer.py @@ -45,9 +45,19 @@ @pytest.mark.parametrize("equation", OPTIMIZED_EINSUM_EXAMPLES) @pytest.mark.parametrize( - "backend", ["pyro.ops.einsum.torch_log", "pyro.ops.einsum.torch_map"] + "backend", + [ + "pyro.ops.einsum.torch_log", + "pyro.ops.einsum.torch_map", + ], +) +@pytest.mark.parametrize( + "einsum_impl", + [ + naive_einsum, + naive_contract_einsum, + ], ) -@pytest.mark.parametrize("einsum_impl", [naive_einsum, naive_contract_einsum]) def test_optimized_einsum(equation, backend, einsum_impl): inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(equation) expected = pyro_einsum(equation, *operands, backend=backend)[0] @@ -69,7 +79,11 @@ def test_optimized_einsum(equation, backend, einsum_impl): @pytest.mark.parametrize( - "eqn1,eqn2", [("a,ab->b", "bc->"), ("ab,bc,cd->d", "de,ef,fg->")] + "eqn1,eqn2", + [ + ("a,ab->b", "bc->"), + ("ab,bc,cd->d", "de,ef,fg->"), + ], ) @pytest.mark.parametrize("optimize1", [False, True]) @pytest.mark.parametrize("optimize2", [False, True]) @@ -84,7 +98,7 @@ def test_nested_einsum( eqn1, eqn2, optimize1, optimize2, backend1, backend2, einsum_impl ): inputs1, outputs1, sizes1, operands1, _ = make_einsum_example(eqn1, sizes=(3,)) - (inputs2, outputs2, sizes2, operands2, funsor_operands2) = make_einsum_example( + inputs2, outputs2, sizes2, operands2, funsor_operands2 = make_einsum_example( eqn2, sizes=(3,) ) @@ -137,7 +151,11 @@ def test_nested_einsum( @pytest.mark.parametrize("equation,plates", PLATED_EINSUM_EXAMPLES) @pytest.mark.parametrize( - "backend", ["pyro.ops.einsum.torch_log", "pyro.ops.einsum.torch_map"] + "backend", + [ + "pyro.ops.einsum.torch_log", + "pyro.ops.einsum.torch_map", + ], ) def test_optimized_plated_einsum(equation, plates, backend): inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(equation) diff --git a/test/test_samplers.py b/test/test_samplers.py index d30aca38e..aafc29a21 100644 --- a/test/test_samplers.py +++ b/test/test_samplers.py @@ -37,17 +37,28 @@ @pytest.mark.parametrize( "sample_inputs", - [(), (("s", Bint[6]),), (("s", Bint[6]), ("t", Bint[7]))], + [ + (), + (("s", Bint[6]),), + (("s", Bint[6]), ("t", Bint[7])), + ], ids=id_from_inputs, ) @pytest.mark.parametrize( "batch_inputs", - [(), (("b", Bint[4]),), (("b", Bint[4]), ("c", Bint[5]))], + [ + (), + (("b", Bint[4]),), + (("b", Bint[4]), ("c", Bint[5])), + ], ids=id_from_inputs, ) @pytest.mark.parametrize( "event_inputs", - [(("e", Bint[2]),), (("e", Bint[2]), ("f", Bint[3]))], + [ + (("e", Bint[2]),), + (("e", Bint[2]), ("f", Bint[3])), + ], ids=id_from_inputs, ) def test_tensor_shape(sample_inputs, batch_inputs, event_inputs): @@ -81,16 +92,30 @@ def test_tensor_shape(sample_inputs, batch_inputs, event_inputs): @pytest.mark.parametrize( "sample_inputs", - [(), (("s", Bint[3]),), (("s", Bint[3]), ("t", Bint[4]))], + [ + (), + (("s", Bint[3]),), + (("s", Bint[3]), ("t", Bint[4])), + ], ids=id_from_inputs, ) @pytest.mark.parametrize( "batch_inputs", - [(), (("b", Bint[2]),), (("c", Real),), (("b", Bint[2]), ("c", Real))], + [ + (), + (("b", Bint[2]),), + (("c", Real),), + (("b", Bint[2]), ("c", Real)), + ], ids=id_from_inputs, ) @pytest.mark.parametrize( - "event_inputs", [(("e", Real),), (("e", Real), ("f", Reals[2]))], ids=id_from_inputs + "event_inputs", + [ + (("e", Real),), + (("e", Real), ("f", Reals[2])), + ], + ids=id_from_inputs, ) def test_gaussian_shape(sample_inputs, batch_inputs, event_inputs): be_inputs = OrderedDict(batch_inputs + event_inputs) @@ -130,16 +155,30 @@ def test_gaussian_shape(sample_inputs, batch_inputs, event_inputs): @pytest.mark.parametrize( "sample_inputs", - [(), (("s", Bint[3]),), (("s", Bint[3]), ("t", Bint[4]))], + [ + (), + (("s", Bint[3]),), + (("s", Bint[3]), ("t", Bint[4])), + ], ids=id_from_inputs, ) @pytest.mark.parametrize( "batch_inputs", - [(), (("b", Bint[2]),), (("c", Real),), (("b", Bint[2]), ("c", Real))], + [ + (), + (("b", Bint[2]),), + (("c", Real),), + (("b", Bint[2]), ("c", Real)), + ], ids=id_from_inputs, ) @pytest.mark.parametrize( - "event_inputs", [(("e", Real),), (("e", Real), ("f", Reals[2]))], ids=id_from_inputs + "event_inputs", + [ + (("e", Real),), + (("e", Real), ("f", Reals[2])), + ], + ids=id_from_inputs, ) def test_transformed_gaussian_shape(sample_inputs, batch_inputs, event_inputs): be_inputs = OrderedDict(batch_inputs + event_inputs) @@ -187,17 +226,28 @@ def test_transformed_gaussian_shape(sample_inputs, batch_inputs, event_inputs): @pytest.mark.parametrize( "sample_inputs", - [(), (("s", Bint[6]),), (("s", Bint[6]), ("t", Bint[7]))], + [ + (), + (("s", Bint[6]),), + (("s", Bint[6]), ("t", Bint[7])), + ], ids=id_from_inputs, ) @pytest.mark.parametrize( "int_event_inputs", - [(), (("d", Bint[2]),), (("d", Bint[2]), ("e", Bint[3]))], + [ + (), + (("d", Bint[2]),), + (("d", Bint[2]), ("e", Bint[3])), + ], ids=id_from_inputs, ) @pytest.mark.parametrize( "real_event_inputs", - [(("g", Real),), (("g", Real), ("h", Reals[4]))], + [ + (("g", Real),), + (("g", Real), ("h", Reals[4])), + ], ids=id_from_inputs, ) def test_joint_shape(sample_inputs, int_event_inputs, real_event_inputs): @@ -239,12 +289,19 @@ def test_joint_shape(sample_inputs, int_event_inputs, real_event_inputs): @pytest.mark.parametrize( "batch_inputs", - [(), (("b", Bint[4]),), (("b", Bint[2]), ("c", Bint[2]))], + [ + (), + (("b", Bint[4]),), + (("b", Bint[2]), ("c", Bint[2])), + ], ids=id_from_inputs, ) @pytest.mark.parametrize( "event_inputs", - [(("e", Bint[3]),), (("e", Bint[2]), ("f", Bint[2]))], + [ + (("e", Bint[3]),), + (("e", Bint[2]), ("f", Bint[2])), + ], ids=id_from_inputs, ) @pytest.mark.parametrize("test_grad", [False, True], ids=["value", "grad"]) @@ -267,7 +324,7 @@ def diff_fn(p_data): _, (p_data, mq_data) = align_tensors(p, mq) assert p_data.shape == mq_data.shape - return ((ops.exp(mq_data) * probe).sum() - (ops.exp(p_data) * probe).sum(), mq) + return (ops.exp(mq_data) * probe).sum() - (ops.exp(p_data) * probe).sum(), mq if test_grad: if get_backend() == "jax": @@ -290,11 +347,20 @@ def diff_fn(p_data): @pytest.mark.parametrize( "batch_inputs", - [(), (("b", Bint[3]),), (("b", Bint[3]), ("c", Bint[4]))], + [ + (), + (("b", Bint[3]),), + (("b", Bint[3]), ("c", Bint[4])), + ], ids=id_from_inputs, ) @pytest.mark.parametrize( - "event_inputs", [(("e", Real),), (("e", Real), ("f", Reals[2]))], ids=id_from_inputs + "event_inputs", + [ + (("e", Real),), + (("e", Real), ("f", Reals[2])), + ], + ids=id_from_inputs, ) def test_gaussian_distribution(event_inputs, batch_inputs): num_samples = 100000 @@ -330,12 +396,19 @@ def test_gaussian_distribution(event_inputs, batch_inputs): @pytest.mark.parametrize( "batch_inputs", - [(), (("b", Bint[3]),), (("b", Bint[3]), ("c", Bint[2]))], + [ + (), + (("b", Bint[3]),), + (("b", Bint[3]), ("c", Bint[2])), + ], ids=id_from_inputs, ) @pytest.mark.parametrize( "event_inputs", - [(("e", Real), ("f", Bint[3])), (("e", Reals[2]), ("f", Bint[2]))], + [ + (("e", Real), ("f", Bint[3])), + (("e", Reals[2]), ("f", Bint[2])), + ], ids=id_from_inputs, ) def test_gaussian_mixture_distribution(batch_inputs, event_inputs): diff --git a/test/test_sum_product.py b/test/test_sum_product.py index ff2fb5ee8..8977cc22c 100644 --- a/test/test_sum_product.py +++ b/test/test_sum_product.py @@ -100,7 +100,13 @@ def test_partition(inputs, dims, expected_num_components): ("abcij", ""), ], ) -@pytest.mark.parametrize("impl", [partial_sum_product, modified_partial_sum_product]) +@pytest.mark.parametrize( + "impl", + [ + partial_sum_product, + modified_partial_sum_product, + ], +) def test_partial_sum_product(impl, sum_op, prod_op, inputs, plates, vars1, vars2): inputs = inputs.split(",") factors = [random_tensor(OrderedDict((d, Bint[2]) for d in ds)) for ds in inputs] @@ -138,7 +144,14 @@ def test_partial_sum_product(impl, sum_op, prod_op, inputs, plates, vars1, vars2 (frozenset({"time", "x_0", "x_prev", "x_curr"}), frozenset()), ], ) -@pytest.mark.parametrize("x_dim,time", [(3, 1), (1, 5), (3, 5)]) +@pytest.mark.parametrize( + "x_dim,time", + [ + (3, 1), + (1, 5), + (3, 5), + ], +) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] ) @@ -146,10 +159,22 @@ def test_modified_partial_sum_product_0(sum_op, prod_op, vars1, vars2, x_dim, ti f1 = random_tensor(OrderedDict({})) - f2 = random_tensor(OrderedDict({"x_0": Bint[x_dim]})) + f2 = random_tensor( + OrderedDict( + { + "x_0": Bint[x_dim], + } + ) + ) f3 = random_tensor( - OrderedDict({"time": Bint[time], "x_prev": Bint[x_dim], "x_curr": Bint[x_dim]}) + OrderedDict( + { + "time": Bint[time], + "x_prev": Bint[x_dim], + "x_curr": Bint[x_dim], + } + ) ) factors = [f1, f2, f3] @@ -182,7 +207,13 @@ def test_modified_partial_sum_product_0(sum_op, prod_op, vars1, vars2, x_dim, ti ], ) @pytest.mark.parametrize( - "x_dim,y_dim,time", [(2, 3, 5), (1, 3, 5), (2, 1, 5), (2, 3, 1)] + "x_dim,y_dim,time", + [ + (2, 3, 5), + (1, 3, 5), + (2, 1, 5), + (2, 3, 1), + ], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -193,16 +224,41 @@ def test_modified_partial_sum_product_1( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor(OrderedDict({"x_0": Bint[x_dim]})) + f2 = random_tensor( + OrderedDict( + { + "x_0": Bint[x_dim], + } + ) + ) f3 = random_tensor( - OrderedDict({"time": Bint[time], "x_prev": Bint[x_dim], "x_curr": Bint[x_dim]}) + OrderedDict( + { + "time": Bint[time], + "x_prev": Bint[x_dim], + "x_curr": Bint[x_dim], + } + ) ) - f4 = random_tensor(OrderedDict({"x_0": Bint[x_dim], "y_0": Bint[y_dim]})) + f4 = random_tensor( + OrderedDict( + { + "x_0": Bint[x_dim], + "y_0": Bint[y_dim], + } + ) + ) f5 = random_tensor( - OrderedDict({"time": Bint[time], "x_curr": Bint[x_dim], "y_curr": Bint[y_dim]}) + OrderedDict( + { + "time": Bint[time], + "x_curr": Bint[x_dim], + "y_curr": Bint[y_dim], + } + ) ) factors = [f1, f2, f3, f4, f5] @@ -240,7 +296,13 @@ def test_modified_partial_sum_product_1( ], ) @pytest.mark.parametrize( - "x_dim,y_dim,time", [(2, 3, 5), (1, 3, 5), (2, 1, 5), (2, 3, 1)] + "x_dim,y_dim,time", + [ + (2, 3, 5), + (1, 3, 5), + (2, 1, 5), + (2, 3, 1), + ], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -251,16 +313,40 @@ def test_modified_partial_sum_product_2( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor(OrderedDict({"x_0": Bint[x_dim]})) + f2 = random_tensor( + OrderedDict( + { + "x_0": Bint[x_dim], + } + ) + ) f3 = random_tensor( - OrderedDict({"time": Bint[time], "x_prev": Bint[x_dim], "x_curr": Bint[x_dim]}) + OrderedDict( + { + "time": Bint[time], + "x_prev": Bint[x_dim], + "x_curr": Bint[x_dim], + } + ) ) - f4 = random_tensor(OrderedDict({"y_0": Bint[y_dim]})) + f4 = random_tensor( + OrderedDict( + { + "y_0": Bint[y_dim], + } + ) + ) f5 = random_tensor( - OrderedDict({"time": Bint[time], "y_prev": Bint[y_dim], "y_curr": Bint[y_dim]}) + OrderedDict( + { + "time": Bint[time], + "y_prev": Bint[y_dim], + "y_curr": Bint[y_dim], + } + ) ) factors = [f1, f2, f3, f4, f5] @@ -300,7 +386,13 @@ def test_modified_partial_sum_product_2( ], ) @pytest.mark.parametrize( - "x_dim,y_dim,time", [(2, 3, 5), (1, 3, 5), (2, 1, 5), (2, 3, 1)] + "x_dim,y_dim,time", + [ + (2, 3, 5), + (1, 3, 5), + (2, 1, 5), + (2, 3, 1), + ], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -311,13 +403,32 @@ def test_modified_partial_sum_product_3( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor(OrderedDict({"x_0": Bint[x_dim]})) + f2 = random_tensor( + OrderedDict( + { + "x_0": Bint[x_dim], + } + ) + ) f3 = random_tensor( - OrderedDict({"time": Bint[time], "x_prev": Bint[x_dim], "x_curr": Bint[x_dim]}) + OrderedDict( + { + "time": Bint[time], + "x_prev": Bint[x_dim], + "x_curr": Bint[x_dim], + } + ) ) - f4 = random_tensor(OrderedDict({"x_0": Bint[x_dim], "y_0": Bint[y_dim]})) + f4 = random_tensor( + OrderedDict( + { + "x_0": Bint[x_dim], + "y_0": Bint[y_dim], + } + ) + ) f5 = random_tensor( OrderedDict( @@ -398,7 +509,12 @@ def test_modified_partial_sum_product_3( ) @pytest.mark.parametrize( "x_dim,y_dim,sequences,time,tones", - [(2, 3, 2, 5, 4), (1, 3, 2, 5, 4), (2, 1, 2, 5, 4), (2, 3, 2, 1, 4)], + [ + (2, 3, 2, 5, 4), + (1, 3, 2, 5, 4), + (2, 1, 2, 5, 4), + (2, 3, 2, 1, 4), + ], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -409,7 +525,14 @@ def test_modified_partial_sum_product_4( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor(OrderedDict({"sequences": Bint[sequences], "x_0": Bint[x_dim]})) + f2 = random_tensor( + OrderedDict( + { + "sequences": Bint[sequences], + "x_0": Bint[x_dim], + } + ) + ) f3 = random_tensor( OrderedDict( @@ -424,7 +547,11 @@ def test_modified_partial_sum_product_4( f4 = random_tensor( OrderedDict( - {"sequences": Bint[sequences], "tones": Bint[tones], "y_0": Bint[y_dim]} + { + "sequences": Bint[sequences], + "tones": Bint[tones], + "y_0": Bint[y_dim], + } ) ) @@ -530,7 +657,12 @@ def test_modified_partial_sum_product_4( ) @pytest.mark.parametrize( "x_dim,y_dim,sequences,days,weeks,tones", - [(2, 3, 2, 5, 4, 3), (1, 3, 2, 5, 4, 3), (2, 1, 2, 5, 4, 3), (2, 3, 2, 1, 4, 3)], + [ + (2, 3, 2, 5, 4, 3), + (1, 3, 2, 5, 4, 3), + (2, 1, 2, 5, 4, 3), + (2, 3, 2, 1, 4, 3), + ], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -543,7 +675,11 @@ def test_modified_partial_sum_product_5( f2 = random_tensor( OrderedDict( - {"sequences": Bint[sequences], "tones": Bint[tones], "x_0": Bint[x_dim]} + { + "sequences": Bint[sequences], + "tones": Bint[tones], + "x_0": Bint[x_dim], + } ) ) @@ -559,7 +695,14 @@ def test_modified_partial_sum_product_5( ) ) - f4 = random_tensor(OrderedDict({"sequences": Bint[sequences], "y_0": Bint[y_dim]})) + f4 = random_tensor( + OrderedDict( + { + "sequences": Bint[sequences], + "y_0": Bint[y_dim], + } + ) + ) f5 = random_tensor( OrderedDict( @@ -643,7 +786,12 @@ def test_modified_partial_sum_product_5( ) @pytest.mark.parametrize( "x_dim,y_dim,sequences,time,tones", - [(2, 3, 2, 5, 4), (1, 3, 2, 5, 4), (2, 1, 2, 5, 4), (2, 3, 2, 1, 4)], + [ + (2, 3, 2, 5, 4), + (1, 3, 2, 5, 4), + (2, 1, 2, 5, 4), + (2, 3, 2, 1, 4), + ], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -654,7 +802,14 @@ def test_modified_partial_sum_product_6( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor(OrderedDict({"sequences": Bint[sequences], "x_0": Bint[x_dim]})) + f2 = random_tensor( + OrderedDict( + { + "sequences": Bint[sequences], + "x_0": Bint[x_dim], + } + ) + ) f3 = random_tensor( OrderedDict( @@ -760,7 +915,12 @@ def test_modified_partial_sum_product_6( ) @pytest.mark.parametrize( "x_dim,y_dim,sequences,time,tones", - [(2, 3, 2, 5, 4), (1, 3, 2, 5, 4), (2, 1, 2, 5, 4), (2, 3, 2, 1, 4)], + [ + (2, 3, 2, 5, 4), + (1, 3, 2, 5, 4), + (2, 1, 2, 5, 4), + (2, 3, 2, 1, 4), + ], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -771,7 +931,14 @@ def test_modified_partial_sum_product_7( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor(OrderedDict({"sequences": Bint[sequences], "x_0": Bint[x_dim]})) + f2 = random_tensor( + OrderedDict( + { + "sequences": Bint[sequences], + "x_0": Bint[x_dim], + } + ) + ) f3 = random_tensor( OrderedDict( @@ -811,7 +978,12 @@ def test_modified_partial_sum_product_7( factors = [f1, f2, f3, f4, f5] plate_to_step = { "sequences": {}, - "time": frozenset({("x_0", "x_prev", "x_curr"), ("y_0", "y_prev", "y_curr")}), + "time": frozenset( + { + ("x_0", "x_prev", "x_curr"), + ("y_0", "y_prev", "y_curr"), + } + ), "tones": {}, } @@ -900,7 +1072,12 @@ def test_modified_partial_sum_product_7( ) @pytest.mark.parametrize( "w_dim,x_dim,y_dim,sequences,time,tones", - [(3, 2, 3, 2, 5, 4), (3, 1, 3, 2, 5, 4), (3, 2, 1, 2, 5, 4), (3, 2, 3, 2, 1, 4)], + [ + (3, 2, 3, 2, 5, 4), + (3, 1, 3, 2, 5, 4), + (3, 2, 1, 2, 5, 4), + (3, 2, 3, 2, 1, 4), + ], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -911,7 +1088,14 @@ def test_modified_partial_sum_product_8( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor(OrderedDict({"sequences": Bint[sequences], "w_0": Bint[w_dim]})) + f2 = random_tensor( + OrderedDict( + { + "sequences": Bint[sequences], + "w_0": Bint[w_dim], + } + ) + ) f3 = random_tensor( OrderedDict( @@ -924,7 +1108,14 @@ def test_modified_partial_sum_product_8( ) ) - f4 = random_tensor(OrderedDict({"sequences": Bint[sequences], "x_0": Bint[x_dim]})) + f4 = random_tensor( + OrderedDict( + { + "sequences": Bint[sequences], + "x_0": Bint[x_dim], + } + ) + ) f5 = random_tensor( OrderedDict( @@ -965,7 +1156,12 @@ def test_modified_partial_sum_product_8( factors = [f1, f2, f3, f4, f5, f6, f7] plate_to_step = { "sequences": {}, - "time": frozenset({("x_0", "x_prev", "x_curr"), ("w_0", "w_prev", "w_curr")}), + "time": frozenset( + { + ("x_0", "x_prev", "x_curr"), + ("w_0", "w_prev", "w_curr"), + } + ), "tones": {}, } @@ -1063,7 +1259,12 @@ def test_modified_partial_sum_product_8( ) @pytest.mark.parametrize( "w_dim,x_dim,y_dim,sequences,time,tones", - [(3, 2, 3, 2, 5, 4), (3, 1, 3, 2, 5, 4), (3, 2, 1, 2, 5, 4), (3, 2, 3, 2, 1, 4)], + [ + (3, 2, 3, 2, 5, 4), + (3, 1, 3, 2, 5, 4), + (3, 2, 1, 2, 5, 4), + (3, 2, 3, 2, 1, 4), + ], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -1074,7 +1275,14 @@ def test_modified_partial_sum_product_9( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor(OrderedDict({"sequences": Bint[sequences], "w_0": Bint[w_dim]})) + f2 = random_tensor( + OrderedDict( + { + "sequences": Bint[sequences], + "w_0": Bint[w_dim], + } + ) + ) f3 = random_tensor( OrderedDict( @@ -1089,7 +1297,11 @@ def test_modified_partial_sum_product_9( f4 = random_tensor( OrderedDict( - {"sequences": Bint[sequences], "w_0": Bint[w_dim], "x_0": Bint[x_dim]} + { + "sequences": Bint[sequences], + "w_0": Bint[w_dim], + "x_0": Bint[x_dim], + } ) ) @@ -1133,7 +1345,12 @@ def test_modified_partial_sum_product_9( factors = [f1, f2, f3, f4, f5, f6, f7] plate_to_step = { "sequences": {}, - "time": frozenset({("x_0", "x_prev", "x_curr"), ("w_0", "w_prev", "w_curr")}), + "time": frozenset( + { + ("x_0", "x_prev", "x_curr"), + ("w_0", "w_prev", "w_curr"), + } + ), "tones": {}, } @@ -1220,7 +1437,12 @@ def test_modified_partial_sum_product_9( ) @pytest.mark.parametrize( "w_dim,x_dim,y_dim,sequences,time,tones", - [(3, 2, 3, 2, 5, 4), (3, 1, 3, 2, 5, 4), (3, 2, 1, 2, 5, 4), (3, 2, 3, 2, 1, 4)], + [ + (3, 2, 3, 2, 5, 4), + (3, 1, 3, 2, 5, 4), + (3, 2, 1, 2, 5, 4), + (3, 2, 3, 2, 1, 4), + ], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -1231,17 +1453,32 @@ def test_modified_partial_sum_product_10( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor(OrderedDict({"sequences": Bint[sequences], "w_0": Bint[w_dim]})) + f2 = random_tensor( + OrderedDict( + { + "sequences": Bint[sequences], + "w_0": Bint[w_dim], + } + ) + ) f3 = random_tensor( OrderedDict( - {"sequences": Bint[sequences], "time": Bint[time], "w_curr": Bint[w_dim]} + { + "sequences": Bint[sequences], + "time": Bint[time], + "w_curr": Bint[w_dim], + } ) ) f4 = random_tensor( OrderedDict( - {"sequences": Bint[sequences], "w_0": Bint[w_dim], "x_0": Bint[x_dim]} + { + "sequences": Bint[sequences], + "w_0": Bint[w_dim], + "x_0": Bint[x_dim], + } ) ) @@ -1417,13 +1654,30 @@ def test_modified_partial_sum_product_11( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor(OrderedDict({"a": Bint[a_dim]})) + f2 = random_tensor( + OrderedDict( + { + "a": Bint[a_dim], + } + ) + ) - f3 = random_tensor(OrderedDict({"sequences": Bint[sequences], "b": Bint[b_dim]})) + f3 = random_tensor( + OrderedDict( + { + "sequences": Bint[sequences], + "b": Bint[b_dim], + } + ) + ) f4 = random_tensor( OrderedDict( - {"a": Bint[a_dim], "sequences": Bint[sequences], "w_0": Bint[w_dim]} + { + "a": Bint[a_dim], + "sequences": Bint[sequences], + "w_0": Bint[w_dim], + } ) ) @@ -1575,7 +1829,12 @@ def test_modified_partial_sum_product_11( ) @pytest.mark.parametrize( "w_dim,x_dim,y_dim,sequences,time,tones", - [(3, 2, 3, 2, 5, 4), (3, 1, 3, 2, 5, 4), (3, 2, 1, 2, 5, 4), (3, 2, 3, 2, 1, 4)], + [ + (3, 2, 3, 2, 5, 4), + (3, 1, 3, 2, 5, 4), + (3, 2, 1, 2, 5, 4), + (3, 2, 3, 2, 1, 4), + ], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -1586,11 +1845,22 @@ def test_modified_partial_sum_product_12( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor(OrderedDict({"sequences": Bint[sequences], "w_0": Bint[w_dim]})) + f2 = random_tensor( + OrderedDict( + { + "sequences": Bint[sequences], + "w_0": Bint[w_dim], + } + ) + ) f3 = random_tensor( OrderedDict( - {"sequences": Bint[sequences], "time": Bint[time], "w_curr": Bint[w_dim]} + { + "sequences": Bint[sequences], + "time": Bint[time], + "w_curr": Bint[w_dim], + } ) ) @@ -1799,7 +2069,11 @@ def test_modified_partial_sum_product_13( f4 = random_tensor( OrderedDict( - {"w": Bint[w_dim], "sequences": Bint[sequences], "y_0": Bint[y_dim]} + { + "w": Bint[w_dim], + "sequences": Bint[sequences], + "y_0": Bint[y_dim], + } ) ) @@ -1920,7 +2194,12 @@ def test_modified_partial_sum_product_13( ) @pytest.mark.parametrize( "x_dim,y_dim,sequences,time,tones", - [(2, 3, 2, 3, 2), (1, 3, 2, 3, 2), (2, 1, 2, 3, 2), (2, 3, 2, 1, 2)], + [ + (2, 3, 2, 3, 2), + (1, 3, 2, 3, 2), + (2, 1, 2, 3, 2), + (2, 3, 2, 1, 2), + ], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -1931,7 +2210,14 @@ def test_modified_partial_sum_product_14( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor(OrderedDict({"sequences": Bint[sequences], "x_0": Bint[x_dim]})) + f2 = random_tensor( + OrderedDict( + { + "sequences": Bint[sequences], + "x_0": Bint[x_dim], + } + ) + ) f3 = random_tensor( OrderedDict( @@ -1946,7 +2232,11 @@ def test_modified_partial_sum_product_14( f4 = random_tensor( OrderedDict( - {"sequences": Bint[sequences], "x_0": Bint[x_dim], "y0_0": Bint[y_dim]} + { + "sequences": Bint[sequences], + "x_0": Bint[x_dim], + "y0_0": Bint[y_dim], + } ) ) @@ -1991,7 +2281,10 @@ def test_modified_partial_sum_product_14( "sequences": {}, "time": frozenset({("x_0", "x_prev", "x_curr")}), "tones": frozenset( - {("y0_0", "y0_prev", "y0_curr"), ("ycurr_0", "ycurr_prev", "ycurr_curr")} + { + ("y0_0", "y0_prev", "y0_curr"), + ("ycurr_0", "ycurr_prev", "ycurr_curr"), + } ), } @@ -2027,7 +2320,13 @@ def test_modified_partial_sum_product_14( ], ) @pytest.mark.parametrize( - "x_dim,y_dim,time", [(2, 3, 5), (1, 3, 5), (2, 1, 5), (2, 3, 1)] + "x_dim,y_dim,time", + [ + (2, 3, 5), + (1, 3, 5), + (2, 1, 5), + (2, 3, 1), + ], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -2038,21 +2337,50 @@ def test_modified_partial_sum_product_16( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor(OrderedDict({"x_0": Bint[x_dim]})) + f2 = random_tensor( + OrderedDict( + { + "x_0": Bint[x_dim], + } + ) + ) f3 = random_tensor( - OrderedDict({"time": Bint[time], "y_prev": Bint[y_dim], "x_curr": Bint[x_dim]}) + OrderedDict( + { + "time": Bint[time], + "y_prev": Bint[y_dim], + "x_curr": Bint[x_dim], + } + ) ) - f4 = random_tensor(OrderedDict({"y_0": Bint[y_dim]})) + f4 = random_tensor( + OrderedDict( + { + "y_0": Bint[y_dim], + } + ) + ) f5 = random_tensor( - OrderedDict({"time": Bint[time], "x_prev": Bint[x_dim], "y_curr": Bint[y_dim]}) + OrderedDict( + { + "time": Bint[time], + "x_prev": Bint[x_dim], + "y_curr": Bint[y_dim], + } + ) ) factors = [f1, f2, f3, f4, f5] plate_to_step = { - "time": frozenset({("x_0", "x_prev", "x_curr"), ("y_0", "y_prev", "y_curr")}) + "time": frozenset( + { + ("x_0", "x_prev", "x_curr"), + ("y_0", "y_prev", "y_curr"), + } + ), } factors1 = modified_partial_sum_product( @@ -2122,7 +2450,13 @@ def test_modified_partial_sum_product_16( ], ) @pytest.mark.parametrize( - "x_dim,y_dim,z_dim,time", [(2, 3, 2, 5), (1, 3, 2, 5), (2, 1, 2, 5), (2, 3, 2, 1)] + "x_dim,y_dim,z_dim,time", + [ + (2, 3, 2, 5), + (1, 3, 2, 5), + (2, 1, 2, 5), + (2, 3, 2, 1), + ], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -2133,10 +2467,22 @@ def test_modified_partial_sum_product_17( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor(OrderedDict({"x_0": Bint[x_dim]})) + f2 = random_tensor( + OrderedDict( + { + "x_0": Bint[x_dim], + } + ) + ) f3 = random_tensor( - OrderedDict({"time": Bint[time], "x_prev": Bint[x_dim], "x_curr": Bint[x_dim]}) + OrderedDict( + { + "time": Bint[time], + "x_prev": Bint[x_dim], + "x_curr": Bint[x_dim], + } + ) ) f4 = random_tensor( @@ -2186,7 +2532,13 @@ def test_modified_partial_sum_product_17( ) f8 = random_tensor( - OrderedDict({"x_0": Bint[x_dim], "y_0": Bint[y_dim], "z2_0": Bint[z_dim]}) + OrderedDict( + { + "x_0": Bint[x_dim], + "y_0": Bint[y_dim], + "z2_0": Bint[z_dim], + } + ) ) f9 = random_tensor( @@ -2201,7 +2553,9 @@ def test_modified_partial_sum_product_17( ) factors = [f1, f2, f3, f4, f5, f6, f7, f8, f9] - plate_to_step = {"time": frozenset({("x_0", "x_prev", "x_curr")})} + plate_to_step = { + "time": frozenset({("x_0", "x_prev", "x_curr")}), + } with (lazy if use_lazy else eager): factors1 = modified_partial_sum_product( @@ -2301,7 +2655,11 @@ def test_sequential_sum_product( ) @pytest.mark.parametrize( "x_domain,y_domain", - [(Bint[2], Bint[3]), (Real, Reals[2, 2]), (Bint[2], Reals[2])], + [ + (Bint[2], Bint[3]), + (Real, Reals[2, 2]), + (Bint[2], Reals[2]), + ], ids=str, ) @pytest.mark.parametrize( @@ -2370,15 +2728,29 @@ def test_sequential_sum_product_multi( @pytest.mark.parametrize("dim", [1, 2, 3]) def test_sequential_sum_product_bias_1(num_steps, dim): time = Variable("time", Bint[num_steps]) - bias_dist = random_gaussian(OrderedDict([("bias", Reals[dim])])) + bias_dist = random_gaussian( + OrderedDict( + [ + ("bias", Reals[dim]), + ] + ) + ) trans = random_gaussian( OrderedDict( - [("time", Bint[num_steps]), ("x_prev", Reals[dim]), ("x_curr", Reals[dim])] + [ + ("time", Bint[num_steps]), + ("x_prev", Reals[dim]), + ("x_curr", Reals[dim]), + ] ) ) obs = random_gaussian( OrderedDict( - [("time", Bint[num_steps]), ("x_curr", Reals[dim]), ("bias", Reals[dim])] + [ + ("time", Bint[num_steps]), + ("x_curr", Reals[dim]), + ("bias", Reals[dim]), + ] ) ) factor = trans + obs + bias_dist @@ -2397,15 +2769,29 @@ def test_sequential_sum_product_bias_1(num_steps, dim): def test_sequential_sum_product_bias_2(num_steps, num_sensors, dim): time = Variable("time", Bint[num_steps]) bias = Variable("bias", Reals[num_sensors, dim]) - bias_dist = random_gaussian(OrderedDict([("bias", Reals[num_sensors, dim])])) + bias_dist = random_gaussian( + OrderedDict( + [ + ("bias", Reals[num_sensors, dim]), + ] + ) + ) trans = random_gaussian( OrderedDict( - [("time", Bint[num_steps]), ("x_prev", Reals[dim]), ("x_curr", Reals[dim])] + [ + ("time", Bint[num_steps]), + ("x_prev", Reals[dim]), + ("x_curr", Reals[dim]), + ] ) ) obs = random_gaussian( OrderedDict( - [("time", Bint[num_steps]), ("x_curr", Reals[dim]), ("bias", Reals[dim])] + [ + ("time", Bint[num_steps]), + ("x_curr", Reals[dim]), + ("bias", Reals[dim]), + ] ) ) @@ -2451,9 +2837,18 @@ def _check_sarkka_bilmes(trans, expected_inputs, global_vars, num_periods=1): @pytest.mark.parametrize("duration", [2, 3, 4, 5, 6]) def test_sarkka_bilmes_example_0(duration): - trans = random_tensor(OrderedDict({"time": Bint[duration], "a": Bint[3]})) + trans = random_tensor( + OrderedDict( + { + "time": Bint[duration], + "a": Bint[3], + } + ) + ) - expected_inputs = {"a": Bint[3]} + expected_inputs = { + "a": Bint[3], + } _check_sarkka_bilmes(trans, expected_inputs, frozenset()) @@ -2463,11 +2858,20 @@ def test_sarkka_bilmes_example_1(duration): trans = random_tensor( OrderedDict( - {"time": Bint[duration], "a": Bint[3], "b": Bint[2], "_PREV_b": Bint[2]} + { + "time": Bint[duration], + "a": Bint[3], + "b": Bint[2], + "_PREV_b": Bint[2], + } ) ) - expected_inputs = {"a": Bint[3], "b": Bint[2], "_PREV_b": Bint[2]} + expected_inputs = { + "a": Bint[3], + "b": Bint[2], + "_PREV_b": Bint[2], + } _check_sarkka_bilmes(trans, expected_inputs, frozenset()) @@ -2553,11 +2957,20 @@ def test_sarkka_bilmes_example_5(duration): trans = random_tensor( OrderedDict( - {"time": Bint[duration], "a": Bint[3], "_PREV_a": Bint[3], "x": Bint[2]} + { + "time": Bint[duration], + "a": Bint[3], + "_PREV_a": Bint[3], + "x": Bint[2], + } ) ) - expected_inputs = {"a": Bint[3], "_PREV_a": Bint[3], "x": Bint[2]} + expected_inputs = { + "a": Bint[3], + "_PREV_a": Bint[3], + "x": Bint[2], + } global_vars = frozenset(["x"]) @@ -2593,7 +3006,13 @@ def test_sarkka_bilmes_example_6(duration): @pytest.mark.parametrize("time_input", [("time", Bint[t]) for t in range(6, 11)]) -@pytest.mark.parametrize("global_inputs", [(), (("x", Bint[2]),)]) +@pytest.mark.parametrize( + "global_inputs", + [ + (), + (("x", Bint[2]),), + ], +) @pytest.mark.parametrize( "local_inputs", [ diff --git a/test/test_tensor.py b/test/test_tensor.py index ff06ce771..2050fb105 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -126,7 +126,15 @@ def test_indexing(): def test_advanced_indexing_shape(): I, J, M, N = 4, 4, 2, 3 - x = Tensor(randn((I, J)), OrderedDict([("i", Bint[I]), ("j", Bint[J])])) + x = Tensor( + randn((I, J)), + OrderedDict( + [ + ("i", Bint[I]), + ("j", Bint[J]), + ] + ), + ) m = Tensor(numeric_array([2, 3]), OrderedDict([("m", Bint[M])]), I) n = Tensor(numeric_array([0, 1, 1]), OrderedDict([("n", Bint[N])]), J) assert x.data.shape == (I, J) @@ -223,17 +231,54 @@ def test_advanced_indexing_tensor(output_shape): # x output = Reals[output_shape] x = random_tensor( - OrderedDict([("i", Bint[2]), ("j", Bint[3]), ("k", Bint[4])]), output + OrderedDict( + [ + ("i", Bint[2]), + ("j", Bint[3]), + ("k", Bint[4]), + ] + ), + output, + ) + i = random_tensor( + OrderedDict( + [ + ("u", Bint[5]), + ] + ), + Bint[2], + ) + j = random_tensor( + OrderedDict( + [ + ("v", Bint[6]), + ("u", Bint[5]), + ] + ), + Bint[3], + ) + k = random_tensor( + OrderedDict( + [ + ("v", Bint[6]), + ] + ), + Bint[4], ) - i = random_tensor(OrderedDict([("u", Bint[5])]), Bint[2]) - j = random_tensor(OrderedDict([("v", Bint[6]), ("u", Bint[5])]), Bint[3]) - k = random_tensor(OrderedDict([("v", Bint[6])]), Bint[4]) expected_data = empty((5, 6) + output_shape) for u in range(5): for v in range(6): expected_data[u, v] = x.data[i.data[u], j.data[v, u], k.data[v]] - expected = Tensor(expected_data, OrderedDict([("u", Bint[5]), ("v", Bint[6])])) + expected = Tensor( + expected_data, + OrderedDict( + [ + ("u", Bint[5]), + ("v", Bint[6]), + ] + ), + ) assert_equiv(expected, x(i, j, k)) assert_equiv(expected, x(i=i, j=j, k=k)) @@ -258,7 +303,13 @@ def test_advanced_indexing_tensor(output_shape): def test_advanced_indexing_lazy(output_shape): x = Tensor( randn((2, 3, 4) + output_shape), - OrderedDict([("i", Bint[2]), ("j", Bint[3]), ("k", Bint[4])]), + OrderedDict( + [ + ("i", Bint[2]), + ("j", Bint[3]), + ("k", Bint[4]), + ] + ), ) u = Variable("u", Bint[2]) v = Variable("v", Bint[3]) @@ -274,7 +325,15 @@ def test_advanced_indexing_lazy(output_shape): for u in range(2): for v in range(3): expected_data[u, v] = x.data[i_data[u], j_data[v], k_data[u, v]] - expected = Tensor(expected_data, OrderedDict([("u", Bint[2]), ("v", Bint[3])])) + expected = Tensor( + expected_data, + OrderedDict( + [ + ("u", Bint[2]), + ("v", Bint[3]), + ] + ), + ) assert_equiv(expected, x(i, j, k)) assert_equiv(expected, x(i=i, j=j, k=k)) @@ -304,7 +363,18 @@ def unary_eval(symbol, x): @pytest.mark.parametrize("dims", [(), ("a",), ("a", "b")]) @pytest.mark.parametrize( "symbol", - ["~", "-", "abs", "atanh", "sqrt", "exp", "log", "log1p", "sigmoid", "tanh"], + [ + "~", + "-", + "abs", + "atanh", + "sqrt", + "exp", + "log", + "log1p", + "sigmoid", + "tanh", + ], ) def test_unary(symbol, dims): sizes = {"a": 3, "b": 4} @@ -837,7 +907,14 @@ def test_function_of_numeric_array(): def test_align(): x = Tensor( - randn((2, 3, 4)), OrderedDict([("i", Bint[2]), ("j", Bint[3]), ("k", Bint[4])]) + randn((2, 3, 4)), + OrderedDict( + [ + ("i", Bint[2]), + ("j", Bint[3]), + ("k", Bint[4]), + ] + ), ) y = x.align(("j", "k", "i")) assert isinstance(y, Tensor) @@ -950,13 +1027,41 @@ def test_tensor_stack(n, shape, dim): @pytest.mark.parametrize("output", [Bint[2], Real, Reals[4], Reals[2, 3]], ids=str) def test_funsor_stack(output): - x = random_tensor(OrderedDict([("i", Bint[2])]), output) - y = random_tensor(OrderedDict([("j", Bint[3])]), output) - z = random_tensor(OrderedDict([("i", Bint[2]), ("k", Bint[4])]), output) + x = random_tensor( + OrderedDict( + [ + ("i", Bint[2]), + ] + ), + output, + ) + y = random_tensor( + OrderedDict( + [ + ("j", Bint[3]), + ] + ), + output, + ) + z = random_tensor( + OrderedDict( + [ + ("i", Bint[2]), + ("k", Bint[4]), + ] + ), + output, + ) xy = Stack("t", (x, y)) assert isinstance(xy, Tensor) - assert xy.inputs == OrderedDict([("t", Bint[2]), ("i", Bint[2]), ("j", Bint[3])]) + assert xy.inputs == OrderedDict( + [ + ("t", Bint[2]), + ("i", Bint[2]), + ("j", Bint[3]), + ] + ) assert xy.output == output for j in range(3): assert_close(xy(t=0, j=j), x) @@ -966,7 +1071,12 @@ def test_funsor_stack(output): xyz = Stack("t", (x, y, z)) assert isinstance(xyz, Tensor) assert xyz.inputs == OrderedDict( - [("t", Bint[3]), ("i", Bint[2]), ("j", Bint[3]), ("k", Bint[4])] + [ + ("t", Bint[3]), + ("i", Bint[2]), + ("j", Bint[3]), + ("k", Bint[4]), + ] ) assert xy.output == output for j in range(3): @@ -981,9 +1091,32 @@ def test_funsor_stack(output): @pytest.mark.parametrize("output", [Bint[2], Real, Reals[4], Reals[2, 3]], ids=str) def test_cat_simple(output): - x = random_tensor(OrderedDict([("i", Bint[2])]), output) - y = random_tensor(OrderedDict([("i", Bint[3]), ("j", Bint[4])]), output) - z = random_tensor(OrderedDict([("i", Bint[5]), ("k", Bint[6])]), output) + x = random_tensor( + OrderedDict( + [ + ("i", Bint[2]), + ] + ), + output, + ) + y = random_tensor( + OrderedDict( + [ + ("i", Bint[3]), + ("j", Bint[4]), + ] + ), + output, + ) + z = random_tensor( + OrderedDict( + [ + ("i", Bint[5]), + ("k", Bint[6]), + ] + ), + output, + ) assert Cat("i", (x,)) is x assert Cat("i", (y,)) is y @@ -991,13 +1124,22 @@ def test_cat_simple(output): xy = Cat("i", (x, y)) assert isinstance(xy, Tensor) - assert xy.inputs == OrderedDict([("i", Bint[2 + 3]), ("j", Bint[4])]) + assert xy.inputs == OrderedDict( + [ + ("i", Bint[2 + 3]), + ("j", Bint[4]), + ] + ) assert xy.output == output xyz = Cat("i", (x, y, z)) assert isinstance(xyz, Tensor) assert xyz.inputs == OrderedDict( - [("i", Bint[2 + 3 + 5]), ("j", Bint[4]), ("k", Bint[6])] + [ + ("i", Bint[2 + 3 + 5]), + ("j", Bint[4]), + ("k", Bint[6]), + ] ) assert xy.output == output diff --git a/test/test_terms.py b/test/test_terms.py index ab33c2566..daa5f49a5 100644 --- a/test/test_terms.py +++ b/test/test_terms.py @@ -260,7 +260,18 @@ def unary_eval(symbol, x): @pytest.mark.parametrize("data", [0, 0.5, 1]) @pytest.mark.parametrize( "symbol", - ["~", "-", "atanh", "abs", "sqrt", "exp", "log", "log1p", "sigmoid", "tanh"], + [ + "~", + "-", + "atanh", + "abs", + "sqrt", + "exp", + "log", + "log1p", + "sigmoid", + "tanh", + ], ) def test_unary(symbol, data): dtype = "real" @@ -276,7 +287,21 @@ def test_unary(symbol, data): check_funsor(actual, {}, Array[dtype, ()], expected_data) -BINARY_OPS = ["+", "-", "*", "/", "**", "==", "!=", "<", "<=", ">", ">=", "min", "max"] +BINARY_OPS = [ + "+", + "-", + "*", + "/", + "**", + "==", + "!=", + "<", + "<=", + ">", + ">=", + "min", + "max", +] BOOLEAN_OPS = ["&", "|", "^"] From 503d0eabed92e0c1bf448ebe52a2756919ec68bc Mon Sep 17 00:00:00 2001 From: Du Phan Date: Sun, 21 Feb 2021 11:06:14 -0600 Subject: [PATCH 5/8] fix some typos --- .gitignore | 2 ++ docs/source/index.rst | 1 - examples/eeg_slds.py | 4 ++-- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index f57e109d9..096b0504a 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,8 @@ run_outputs* data .data results +docs/source/examples/ +docs/source/tutorials/ examples/*/processed examples/*/results examples/*/raw diff --git a/docs/source/index.rst b/docs/source/index.rst index 9401cc1c5..49a13b25e 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -38,7 +38,6 @@ Funsor is a tensor-like library for functions and distributions :name: tutorials-and-examples tutorials/sum_product_network - examples/mixed_hmm/* examples/discrete_hmm examples/eeg_slds examples/kalman_filter diff --git a/examples/eeg_slds.py b/examples/eeg_slds.py index fa31a5059..05f20b0c5 100644 --- a/examples/eeg_slds.py +++ b/examples/eeg_slds.py @@ -3,7 +3,7 @@ """ Example: Switching Linear Dynamical System EEG -================================================= +============================================== We use a switching linear dynamical system [1] to model a EEG time series dataset. For inference we use a moment-matching approximation enabled by @@ -155,7 +155,7 @@ def get_tensors_and_dists(self): self.observation_matrix, obs_mvn, event_dims, "x", "y" ) - return (trans_logits, trans_probs, trans_mvn, obs_mvn, x_trans_dist, y_dist) + return trans_logits, trans_probs, trans_mvn, obs_mvn, x_trans_dist, y_dist # compute the marginal log probability of the observed data using a moment-matching approximation @funsor.interpretation(funsor.terms.moment_matching) From 56a1b4a0f52cf57606fc58fb5a295751a428a7b6 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Sun, 21 Feb 2021 18:01:42 -0600 Subject: [PATCH 6/8] add a note in conf file --- docs/source/_templates/breadcrumbs.html | 2 +- docs/source/conf.py | 22 ++++++++++- docs/source/index.rst | 2 +- examples/README.rst | 5 ++- funsor/__init__.py | 3 ++ tutorials/README.md | 1 + tutorials/sum_product_network.ipynb | 51 ++++++++++--------------- 7 files changed, 52 insertions(+), 34 deletions(-) create mode 100644 tutorials/README.md diff --git a/docs/source/_templates/breadcrumbs.html b/docs/source/_templates/breadcrumbs.html index 5db8caf4d..c49d12ad3 100644 --- a/docs/source/_templates/breadcrumbs.html +++ b/docs/source/_templates/breadcrumbs.html @@ -25,4 +25,4 @@ {% endif %} {% endif %} -{% endblock %} \ No newline at end of file +{% endblock %} diff --git a/docs/source/conf.py b/docs/source/conf.py index d14e5865d..5c68f052d 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -34,8 +34,16 @@ # The short X.Y version version = u"0.0" + +if 'READTHEDOCS' not in os.environ: + # if developing locally, use funsor.__version__ as version + from funsor import __version__ # noqaE402 + version = __version__ + + html_context = {'github_version': 'master'} + # The full version, including alpha/beta/rc tags -release = u"0.0" +release = version # -- General configuration --------------------------------------------------- @@ -115,7 +123,9 @@ nbsphinx_prolog = r""" {% set docname = 'tutorials/' + env.doc2path(env.docname, base=None).split('/')[-1] %} :github_url: https://github.com/pyro-ppl/funsor/blob/master/{{ docname }} + .. raw:: html +
    Interactive online version: @@ -129,6 +139,16 @@ # -- Copy notebook files +# NB: tutorials and examples can be added to `index.rst` file using the paths +# tutorials/foo +# examples/foo +# without extensions .ipynb or .py +# TODO: find a solution for an example subfolder, e.g. examples/mixed_hmm folder +# +# To add thumbnail images for tutorials/examples in funsor docs, using +# .. nbgallery:: instead of .. toctree:: and add png thumnail images +# with corresponding names in _static/img/tutorials or _static/img/examples folders. +# For example, we can add minipyro.png to _static/img/examples/ folder. if not os.path.exists("tutorials"): os.makedirs("tutorials") diff --git a/docs/source/index.rst b/docs/source/index.rst index 49a13b25e..ce842d995 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -32,7 +32,7 @@ Funsor is a tensor-like library for functions and distributions minipyro einsum -.. nbgallery:: +.. toctree:: :maxdepth: 1 :caption: Tutorials and Examples :name: tutorials-and-examples diff --git a/examples/README.rst b/examples/README.rst index 5354631a5..ab2ec2739 100644 --- a/examples/README.rst +++ b/examples/README.rst @@ -1,2 +1,5 @@ Code Examples -============= \ No newline at end of file +============= + +Please check out `Sphinx-Gallery syntax `_ +for how to structure Python scripts to generate nicely rendered example pages. diff --git a/funsor/__init__.py b/funsor/__init__.py index 4372f1bad..ec70808cc 100644 --- a/funsor/__init__.py +++ b/funsor/__init__.py @@ -43,7 +43,10 @@ testing, ) +__version__ = "0.4.0" + __all__ = [ + "__version__", "Array", "Bint", "Cat", diff --git a/tutorials/README.md b/tutorials/README.md new file mode 100644 index 000000000..a415f9b39 --- /dev/null +++ b/tutorials/README.md @@ -0,0 +1 @@ +# Notebook tutorials diff --git a/tutorials/sum_product_network.ipynb b/tutorials/sum_product_network.ipynb index 65eed7901..b99aa47f0 100644 --- a/tutorials/sum_product_network.ipynb +++ b/tutorials/sum_product_network.ipynb @@ -4,12 +4,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Sum Product Network" + "# Sum Product Network\n", + "\n", + "Some text" ] }, { "cell_type": "code", - "execution_count": 51, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -33,7 +35,7 @@ }, { "cell_type": "code", - "execution_count": 70, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -46,7 +48,7 @@ " [0.2285, 0.2867]]]), OrderedDict([('v0', Bint[2, ]), ('v1', Bint[2, ]), ('v2', Bint[2, ])]), 'real')" ] }, - "execution_count": 70, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -76,7 +78,7 @@ }, { "cell_type": "code", - "execution_count": 71, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -102,7 +104,7 @@ }, { "cell_type": "code", - "execution_count": 72, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -111,7 +113,7 @@ }, { "cell_type": "code", - "execution_count": 73, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -129,7 +131,7 @@ }, { "cell_type": "code", - "execution_count": 74, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -147,7 +149,7 @@ }, { "cell_type": "code", - "execution_count": 76, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -173,16 +175,16 @@ }, { "cell_type": "code", - "execution_count": 77, + "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "Delta((('v0', (Tensor(tensor([0, 1, 0, 0, 1]), OrderedDict([('particle', Bint[5, ])]), 2), Number(0.0))),)) + Tensor(-0.8297846913337708, OrderedDict(), 'real').reduce(nullop, set())" + "Delta((('v0', (Tensor(tensor([1, 1, 1, 0, 1]), OrderedDict([('particle', Bint[5, ])]), 2), Number(0.0))),)) + Tensor(-0.8297846913337708, OrderedDict(), 'real').reduce(nullop, set())" ] }, - "execution_count": 77, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -208,7 +210,7 @@ }, { "cell_type": "code", - "execution_count": 82, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -217,7 +219,7 @@ "tensor(-2.0612e-09)" ] }, - "execution_count": 82, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -227,21 +229,10 @@ ] }, { - "cell_type": "code", - "execution_count": 79, + "cell_type": "markdown", "metadata": {}, - "outputs": [ - { - "ename": "SyntaxError", - "evalue": "invalid syntax (, line 1)", - "output_type": "error", - "traceback": [ - "\u001b[0;36m File \u001b[0;32m\"\"\u001b[0;36m, line \u001b[0;32m1\u001b[0m\n\u001b[0;31m parameter optimization\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m invalid syntax\n" - ] - } - ], "source": [ - "parameter optimization" + "### parameter optimization" ] }, { @@ -274,11 +265,11 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ - "Integrate(q, x, q_vars)" + "# Integrate(q, x, q_vars)" ] }, { @@ -290,7 +281,7 @@ }, { "cell_type": "code", - "execution_count": 50, + "execution_count": 11, "metadata": {}, "outputs": [ { From 33503dc3e2534ac61af12315a827a7dbac366dca Mon Sep 17 00:00:00 2001 From: Du Phan Date: Sun, 21 Feb 2021 18:02:57 -0600 Subject: [PATCH 7/8] remove WIP sum product network notebook --- docs/source/index.rst | 1 - tutorials/sum_product_network.ipynb | 322 ---------------------------- 2 files changed, 323 deletions(-) delete mode 100644 tutorials/sum_product_network.ipynb diff --git a/docs/source/index.rst b/docs/source/index.rst index ce842d995..230a68b5d 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -37,7 +37,6 @@ Funsor is a tensor-like library for functions and distributions :caption: Tutorials and Examples :name: tutorials-and-examples - tutorials/sum_product_network examples/discrete_hmm examples/eeg_slds examples/kalman_filter diff --git a/tutorials/sum_product_network.ipynb b/tutorials/sum_product_network.ipynb deleted file mode 100644 index b99aa47f0..000000000 --- a/tutorials/sum_product_network.ipynb +++ /dev/null @@ -1,322 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Sum Product Network\n", - "\n", - "Some text" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "from collections import OrderedDict\n", - "\n", - "import torch\n", - "\n", - "import funsor\n", - "import funsor.torch.distributions as dist\n", - "import funsor.ops as ops\n", - "\n", - "funsor.set_backend(\"torch\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### network" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Tensor(tensor([[[0.0341, 0.0371],\n", - " [0.0571, 0.0717]],\n", - "\n", - " [[0.1363, 0.1485],\n", - " [0.2285, 0.2867]]]), OrderedDict([('v0', Bint[2, ]), ('v1', Bint[2, ]), ('v2', Bint[2, ])]), 'real')" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# sum_op = +, prod_op = *\n", - "# alternatively, we can use rewrite_ops as in\n", - "# https://github.com/pyro-ppl/funsor/pull/456\n", - "# and switch to sum_op = logsumexp, prod_op = +\n", - "spn = 0.4 * (dist.Categorical(torch.tensor([0.2, 0.8]), value=\"v0\").exp() *\n", - " (0.3 * (dist.Categorical(torch.tensor([0.3, 0.7]), value=\"v1\").exp() *\n", - " dist.Categorical(torch.tensor([0.4, 0.6]), value=\"v2\").exp())\n", - " + 0.7 * (dist.Categorical(torch.tensor([0.5, 0.5]), value=\"v1\").exp() *\n", - " dist.Categorical(torch.tensor([0.6, 0.4]), value=\"v2\").exp()))) \\\n", - " + 0.6 * (dist.Categorical(torch.tensor([0.2, 0.8]), value=\"v0\").exp() *\n", - " dist.Categorical(torch.tensor([0.3, 0.7]), value=\"v1\").exp() *\n", - " dist.Categorical(torch.tensor([0.4, 0.6]), value=\"v2\").exp())\n", - "spn" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### marginalize" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tensor(tensor([[0.1704, 0.1856],\n", - " [0.2856, 0.3584]]), OrderedDict([('v1', Bint[2, ]), ('v2', Bint[2, ])]))\n" - ] - } - ], - "source": [ - "spn_marg = spn.reduce(ops.add, \"v0\")\n", - "print(spn_marg)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### likelihood" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "test_data = {\"v0\": 1, \"v1\": 0, \"v2\": 1}" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor(-1.9073) tensor(0.1485)\n" - ] - } - ], - "source": [ - "ll_exp = spn(**test_data)\n", - "print(ll_exp.log(), ll_exp)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor(-1.6842) tensor(0.1856)\n" - ] - } - ], - "source": [ - "llm_exp = spn_marg(**test_data)\n", - "print(llm_exp.log(), llm_exp)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor(-1.6842) tensor(0.1856)\n" - ] - } - ], - "source": [ - "test_data2 = {\"v1\": 0, \"v2\": 1}\n", - "llom_exp = spn(**test_data2).reduce(ops.add)\n", - "print(llom_exp.log(), llom_exp)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### sample" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Delta((('v0', (Tensor(tensor([1, 1, 1, 0, 1]), OrderedDict([('particle', Bint[5, ])]), 2), Number(0.0))),)) + Tensor(-0.8297846913337708, OrderedDict(), 'real').reduce(nullop, set())" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "sample_inputs = OrderedDict(particle=funsor.Bint[5])\n", - "spn(v1=0, v2=0).sample(frozenset({\"v0\"}), sample_inputs)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "what is `-0.8297846913337708`? a normalization factor?" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### train parameters" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor(-2.0612e-09)" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "-torch.nn.functional.softplus(-torch.tensor(20.))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### parameter optimization" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### most probable explanation" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### multivariate leaf" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### cutset networks" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### expectations and moments" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [], - "source": [ - "# Integrate(q, x, q_vars)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### pareto" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor(-0.5232)\n" - ] - } - ], - "source": [ - "spn = 0.3 * dist.Pareto(1., 2., value=\"v0\").exp() + 0.7 * dist.Pareto(1., 3., value=\"v0\").exp()\n", - "print(spn(v0=1.5).log())" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.5" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} From 65d2946893cf17529926c69181e43b3f685745ad Mon Sep 17 00:00:00 2001 From: Du Phan Date: Sun, 21 Feb 2021 18:05:32 -0600 Subject: [PATCH 8/8] black again --- docs/source/conf.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 5c68f052d..42bfdb1ed 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -35,12 +35,13 @@ # The short X.Y version version = u"0.0" -if 'READTHEDOCS' not in os.environ: +if "READTHEDOCS" not in os.environ: # if developing locally, use funsor.__version__ as version from funsor import __version__ # noqaE402 + version = __version__ - html_context = {'github_version': 'master'} + html_context = {"github_version": "master"} # The full version, including alpha/beta/rc tags release = version