Skip to content
/ aehmc Public
forked from aesara-devs/aehmc

An HMC/NUTS implementation in Aesara

License

Notifications You must be signed in to change notification settings

zoj613/aehmc

 
 

Repository files navigation

Aehmc

Pypi Gitter Discord Twitter

AeHMC provides implementations for the HMC and NUTS samplers in Aesara.

FeaturesGet StartedInstallGet helpContribute

Get started

import aesara
from aesara import tensor as at
from aesara.tensor.random.utils import RandomStream

from aeppl import joint_logprob

from aehmc import nuts

# A simple normal distribution
Y_rv = at.random.normal(0, 1)


def logprob_fn(y):
    return joint_logprob(realized={Y_rv: y})[0]


# Build the transition kernel
srng = RandomStream(seed=0)
kernel = nuts.new_kernel(srng, logprob_fn)

# Compile a function that updates the chain
y_vv = Y_rv.clone()
initial_state = nuts.new_state(y_vv, logprob_fn)

step_size = at.as_tensor(1e-2)
inverse_mass_matrix=at.as_tensor(1.0)
(
    next_state,
    potential_energy,
    potential_energy_grad,
    acceptance_prob,
    num_doublings,
    is_turning,
    is_diverging,
), updates = kernel(*initial_state, step_size, inverse_mass_matrix)

next_step_fn = aesara.function([y_vv], next_state, updates=updates)

print(next_step_fn(0))
# 0.14344008534533775

Install

The latest release of AeHMC can be installed from PyPI using pip:

pip install aehmc

Or via conda-forge:

conda install -c conda-forge aehmc

The current development branch of AeHMC can be installed from GitHub using pip:

pip install git+https://github.com/aesara-devs/aehmc

Get help

Report bugs by opening an issue. If you have a question regarding the usage of AeHMC, start a discussion. For real-time feedback or more general chat about AeHMC use our Discord server or Gitter room.

Contribute

AeHMC welcomes contributions. A good place to start contributing is by looking at the issues.

If you want to implement a new feature, open a discussion or come chat with us on Discord or Gitter.

About

An HMC/NUTS implementation in Aesara

Resources

License

Security policy

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 97.9%
  • Makefile 2.1%