Skip to content
!pip install bayeux-ml
import arviz as az
import bayeux as bx
import jax
import matplotlib.pyplot as plt
import numpy as np
import pymc as pm

Using bayeux with pymc¤

bayeux has a built-in function bx.Model.from_pymc that makes it easy to work with pymc models. More on PyMC here.

We implement a common hierarchical model of the eight schools dataset (Rubin 1981¹), whose details can be seen on the Stan documentation, PyMC documentation, TFP documentation, numpyro documentation, among others.

¹ Rubin, Donald B. 1981. “Estimation in Parallel Randomized Experiments.” Journal of Educational and Behavioral Statistics 6 (4): 377–401.

treatment_effects = np.array([28, 8, -3, 7, -1, 1, 18, 12], dtype=np.float32)
treatment_stddevs = np.array(
    [15, 10, 16, 11, 9, 11, 10, 18], dtype=np.float32)

with pm.Model() as model:
  avg_effect = pm.Normal('avg_effect', 0., 10.)
  avg_stddev = pm.HalfNormal('avg_stddev', 10.)
  school_effects = pm.Normal('school_effects', shape=8)
  pm.Normal('observed',
            avg_effect + avg_stddev * school_effects,
            treatment_stddevs,
            observed=treatment_effects)

bx_model = bx.Model.from_pymc(model)
idata = bx_model.mcmc.blackjax_nuts(seed=jax.random.key(0))

az.summary(idata)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
avg_effect 6.392 4.207 -1.522 13.967 0.069 0.049 3717.0 2153.0 1.00
avg_stddev 4.681 3.728 0.001 11.391 0.075 0.053 1835.0 1255.0 1.00
school_effects[0] 0.344 0.973 -1.571 2.073 0.016 0.015 3784.0 2729.0 1.01
school_effects[1] 0.051 0.901 -1.784 1.671 0.014 0.014 4265.0 2982.0 1.00
school_effects[2] -0.126 0.943 -1.917 1.594 0.015 0.015 4037.0 2689.0 1.00
school_effects[3] 0.020 0.911 -1.710 1.733 0.014 0.015 4367.0 3061.0 1.00
school_effects[4] -0.245 0.899 -1.942 1.423 0.014 0.013 4306.0 2970.0 1.00
school_effects[5] -0.131 0.921 -1.916 1.554 0.015 0.014 4039.0 2874.0 1.00
school_effects[6] 0.362 0.907 -1.350 2.034 0.013 0.013 4549.0 3204.0 1.00
school_effects[7] 0.066 0.972 -1.678 1.942 0.015 0.016 4481.0 3004.0 1.00
opt_results = bx_model.optimize.optax_adam(seed=jax.random.PRNGKey(0))


fig, ax = plt.subplots(figsize=(12, 2))
ax.plot(opt_results.loss.T)
opt_results.params
{'avg_effect': Array([6.17051501, 6.17051501, 6.17051501, 6.17051501, 6.17051501,
        6.17051501, 6.1705141 , 6.17051501], dtype=float64),
 'avg_stddev': Array([10.6251654 , 10.6251654 , 10.6251654 , 10.6251654 , 10.6251654 ,
        10.6251654 , 10.62516158, 10.6251654 ], dtype=float64),
 'school_effects': Array([[ 0.68643359,  0.09130632, -0.26413604,  0.03768121, -0.39293559,
         -0.23488223,  0.59038842,  0.14177174],
        [ 0.68643359,  0.09130632, -0.26413604,  0.03768121, -0.39293559,
         -0.23488223,  0.59038842,  0.14177174],
        [ 0.68643359,  0.09130632, -0.26413604,  0.03768121, -0.39293559,
         -0.23488223,  0.59038842,  0.14177174],
        [ 0.68643359,  0.09130632, -0.26413604,  0.03768121, -0.39293559,
         -0.23488223,  0.59038842,  0.14177174],
        [ 0.68643359,  0.09130632, -0.26413604,  0.03768121, -0.39293559,
         -0.23488223,  0.59038842,  0.14177174],
        [ 0.68643359,  0.09130632, -0.26413604,  0.03768121, -0.39293559,
         -0.23488223,  0.59038842,  0.14177174],
        [ 0.68643307,  0.09130638, -0.26413599,  0.03768124, -0.39293562,
         -0.23494317,  0.59038848,  0.14177166],
        [ 0.68643359,  0.09130632, -0.26413604,  0.03768121, -0.39293559,
         -0.23488223,  0.59038842,  0.14177174]], dtype=float64)}
No description has been provided for this image
surrogate_posterior, losses = bx_model.vi.tfp_factored_surrogate_posterior(
    seed=jax.random.PRNGKey(0))


fig, ax = plt.subplots(figsize=(12, 2))
ax.plot(losses.T)

draws = surrogate_posterior.sample(100, seed=jax.random.PRNGKey(1))
jax.tree_map(lambda x: np.mean(x, axis=(0, 1)), draws)
{'avg_effect': Array(6.462984, dtype=float32),
 'avg_stddev': Array(3.947107, dtype=float32),
 'school_effects': Array([ 0.36255622,  0.05804485, -0.17273399, -0.01884761, -0.2191218 ,
        -0.16507229,  0.3044642 ,  0.06365623], dtype=float32)}
No description has been provided for this image