Skip to content

clwgg/jaxGPs

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

jaxGPs

jaxGPs is a light-weight framework for performing Gaussian Process Regression with the help of JAX.

It’s main purpose is to provide a thin but useful abstraction that makes it comfortable to use JAX-based inference, while still being easy to customize and play around with.

To achieve this, two of the major design goals were to not depend on another library besides JAX (such as objax), while still providing an object-based interface to the Gaussian Process Regression components. As such, it makes heavy use of JAX pytrees, as well as closures. It also utilizes python Descriptors for parameter tracking.

Usage

from pprint import pprint
import jax.numpy as jnp

from jaxGPs import GPR
from jaxGPs import ExponentialQuadratic

x, y = get_toy_data()

gpr = GPR(ExponentialQuadratic())
gpr.update_data(x, y)
gpr.fit_scipy()
pprint(gpr.parameters())

atx = jnp.linspace(x.min(), x.max(), 200)[:, jnp.newaxis]
mu, cov = gpr.predict_f(atx)

Setup

Using a python virtual environment, and assuming the jaxGPs repository has been cloned and cd‘ed into:

python -m venv venv
source venv/bin/activate
pip install .

Examples

Classes provided by jaxGPs

Kernels

  • GaussMarkov
    • also known as White Noise, Brownian or Wiener
  • Exponential
    • also known as Matern1/2, but with slightly different scaling
  • ExponentialQuadratic
    • also known as Gaussian, RBF, SquaredExponential,…
  • HavExponential
  • HavExponentialQuad

Means

  • ConstantMean

Models

  • GPR

About

Gaussian Process Regression with JAX

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages