!pip install bayeux-ml
import arviz as az
import bayeux as bx
import jax
import matplotlib.pyplot as plt
import numpy as np
import pymc as pm
Using bayeux with pymc¤
bayeux
has a built-in function bx.Model.from_pymc
that makes it easy to work with pymc
models. More on PyMC here.
We implement a common hierarchical model of the eight schools dataset (Rubin 1981¹), whose details can be seen on the Stan documentation, PyMC documentation, TFP documentation, numpyro documentation, among others.
¹ Rubin, Donald B. 1981. “Estimation in Parallel Randomized Experiments.” Journal of Educational and Behavioral Statistics 6 (4): 377–401.
treatment_effects = np.array([28, 8, -3, 7, -1, 1, 18, 12], dtype=np.float32)
treatment_stddevs = np.array(
[15, 10, 16, 11, 9, 11, 10, 18], dtype=np.float32)
with pm.Model() as model:
avg_effect = pm.Normal('avg_effect', 0., 10.)
avg_stddev = pm.HalfNormal('avg_stddev', 10.)
school_effects = pm.Normal('school_effects', shape=8)
pm.Normal('observed',
avg_effect + avg_stddev * school_effects,
treatment_stddevs,
observed=treatment_effects)
bx_model = bx.Model.from_pymc(model)
idata = bx_model.mcmc.blackjax_nuts(seed=jax.random.key(0))
az.summary(idata)
opt_results = bx_model.optimize.optax_adam(seed=jax.random.PRNGKey(0))
fig, ax = plt.subplots(figsize=(12, 2))
ax.plot(opt_results.loss.T)
opt_results.params
surrogate_posterior, losses = bx_model.vi.tfp_factored_surrogate_posterior(
seed=jax.random.PRNGKey(0))
fig, ax = plt.subplots(figsize=(12, 2))
ax.plot(losses.T)
draws = surrogate_posterior.sample(100, seed=jax.random.PRNGKey(1))
jax.tree.map(lambda x: np.mean(x, axis=(0, 1)), draws)