import arviz as az
import bayeux as bx
import jax
import matplotlib.pyplot as plt
import numpy as np
import tensorflow_probability.substrates.jax as tfp
tfd = tfp.distributions
Using bayeux with TFP-on-JAX¤
bayeux
has a built-in function for working with TFP models, bx.Model.from_tfp
. More on TensorFlow Probability here.
We implement a common hierarchical model of the eight schools dataset (Rubin 1981¹), whose details can be seen on the Stan documentation, PyMC documentation, TFP documentation, numpyro documentation, among others.
¹ Rubin, Donald B. 1981. “Estimation in Parallel Randomized Experiments.” Journal of Educational and Behavioral Statistics 6 (4): 377–401.
num_schools = 8
treatment_effects = np.array([28, 8, -3, 7, -1, 1, 18, 12], dtype=np.float32)
treatment_stddevs = np.array([15, 10, 16, 11, 9, 11, 10, 18], dtype=np.float32)
@tfd.JointDistributionCoroutineAutoBatched
def tfp_model():
avg_effect = yield tfd.Normal(0., 10., name='avg_effect')
avg_stddev = yield tfd.HalfNormal(10., name='avg_stddev')
school_effects = yield tfd.Sample(
tfd.Normal(0., 1.), sample_shape=8, name='school_effects')
yield tfd.Normal(loc=avg_effect + avg_stddev * school_effects,
scale=treatment_stddevs, name='observed')
bx_model = bx.Model.from_tfp(
tfp_model.experimental_pin(observed=treatment_effects))
idata = bx_model.mcmc.blackjax_nuts(seed=jax.random.key(0))
az.summary(idata)
opt_results = bx_model.optimize.optax_adam(seed=jax.random.key(0))
fig, ax = plt.subplots(figsize=(12, 2))
ax.plot(opt_results.loss.T)
opt_results.params
fit_key, draw_key = jax.random.split(jax.random.key(0))
surrogate_posterior, losses = bx_model.vi.tfp_factored_surrogate_posterior(
seed=fit_key)
fig, ax = plt.subplots(figsize=(12, 2))
ax.plot(losses.T)
draws = surrogate_posterior.sample(100, seed=draw_key)
jax.tree.map(lambda x: np.mean(x, axis=(0, 1)), draws)