Skip to content

Latest commit

 

History

History
63 lines (49 loc) · 1.73 KB

README.org

File metadata and controls

63 lines (49 loc) · 1.73 KB

jaxGPs

jaxGPs is a light-weight framework for performing Gaussian Process Regression with the help of JAX.

It’s main purpose is to provide a thin but useful abstraction that makes it comfortable to use JAX-based inference, while still being easy to customize and play around with.

To achieve this, two of the major design goals were to not depend on another library besides JAX (such as objax), while still providing an object-based interface to the Gaussian Process Regression components. As such, it makes heavy use of JAX pytrees, as well as closures. It also utilizes python Descriptors for parameter tracking.

Usage

from pprint import pprint
import jax.numpy as jnp

from jaxGPs import GPR
from jaxGPs import ExponentialQuadratic

x, y = get_toy_data()

gpr = GPR(ExponentialQuadratic())
gpr.update_data(x, y)
gpr.fit_scipy()
pprint(gpr.parameters())

atx = jnp.linspace(x.min(), x.max(), 200)[:, jnp.newaxis]
mu, cov = gpr.predict_f(atx)

Setup

Using a python virtual environment, and assuming the jaxGPs repository has been cloned and cd‘ed into:

python -m venv venv
source venv/bin/activate
pip install .

Examples

Classes provided by jaxGPs

Kernels

  • GaussMarkov
    • also known as White Noise, Brownian or Wiener
  • Exponential
    • also known as Matern1/2, but with slightly different scaling
  • ExponentialQuadratic
    • also known as Gaussian, RBF, SquaredExponential,…
  • HavExponential
  • HavExponentialQuad

Means

  • ConstantMean

Models

  • GPR