!pip install bayeux-ml
import arviz as az
import bayeux as bx
import jax
import matplotlib.pyplot as plt
import numpy as np
import numpyro
dist = numpyro.distributions
Using bayeux with numpyro¤
bayeux
has a built-in function bx.Model.from_numpyro
that makes it easy to work with numpyro
models. More on NumPyro here
We implement a common hierarchical model of the eight schools dataset (Rubin 1981¹), whose details can be seen on the Stan documentation, PyMC documentation, TFP documentation, numpyro documentation, among others.
¹ Rubin, Donald B. 1981. “Estimation in Parallel Randomized Experiments.” Journal of Educational and Behavioral Statistics 6 (4): 377–401.
num_schools = 8
treatment_effects = np.array([28, 8, -3, 7, -1, 1, 18, 12], dtype=np.float32)
treatment_stddevs = np.array([15, 10, 16, 11, 9, 11, 10, 18], dtype=np.float32)
def numpyro_model():
avg_effect = numpyro.sample('avg_effect', dist.Normal(0., 10.))
avg_stddev = numpyro.sample('avg_stddev', dist.HalfNormal(10.))
with numpyro.plate('J', 8):
school_effects = numpyro.sample('school_effects', dist.Normal(0., 1.))
numpyro.sample('observed',
dist.Normal(avg_effect[..., None] + avg_stddev[..., None] * school_effects,
treatment_stddevs),
obs=treatment_effects)
bx_model = bx.Model.from_numpyro(numpyro_model)
idata = bx_model.mcmc.blackjax_nuts(seed=jax.random.key(0))
az.summary(idata)
opt_results = bx_model.optimize.optax_adam(seed=jax.random.PRNGKey(0))
fig, ax = plt.subplots(figsize=(12, 2))
ax.plot(opt_results.loss.T)
opt_results.params
surrogate_posterior, losses = bx_model.vi.tfp_factored_surrogate_posterior(
seed=jax.random.PRNGKey(0))
fig, ax = plt.subplots(figsize=(12, 2))
ax.plot(losses.T)
draws = surrogate_posterior.sample(100, seed=jax.random.PRNGKey(1))
jax.tree.map(lambda x: np.mean(x, axis=(0, 1)), draws)