!pip install -Uq bayeux-ml
from sklearn import datasets
iris = datasets.load_iris()
features, labels = iris['data'], iris['target']
num_features = features.shape[-1]
num_classes = len(iris.target_names)
import functools
import bayeux as bx
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import oryx.core.ppl as oryx_ppl
import tensorflow_probability.substrates.jax as tfp
tfd = tfp.distributions
Bayeux and Oryx¤
Oryx is a library for probabilistic programming and deep learning built on top of Jax, and is used by bayeux
for automatically computing inverse transforms and log determinant Jacobians. But oryx
also includes tools for defining joint log densities.
This notebook uses the example from the oryx documentation. The differences is that we additionally run an Adam optimizer, and use NumPyro's NUTS sampler (instead of TFP's HMC sampler).
We load the iris dataset which classifies iris species by four measurements on the flowers, and will train a BNN on this data.
def dense(dim_out, activation=jax.nn.relu):
def forward(key, x):
dim_in = x.shape[-1]
w_key, b_key = jax.random.split(key)
w = oryx_ppl.random_variable(
tfd.Sample(tfd.Normal(0., 1.), sample_shape=(dim_out, dim_in)),
name='w')(w_key)
b = oryx_ppl.random_variable(
tfd.Sample(tfd.Normal(0., 1.), sample_shape=(dim_out,)),
name='b')(b_key)
return activation(jnp.dot(w, x) + b)
return forward
def mlp(hidden_sizes, num_classes):
num_hidden = len(hidden_sizes)
def forward(key, x):
keys = jax.random.split(key, num_hidden + 1)
for i, (subkey, hidden_size) in enumerate(zip(keys[:-1], hidden_sizes)):
x = oryx_ppl.nest(dense(hidden_size), scope=f'layer_{i + 1}')(subkey, x)
logits = oryx_ppl.nest(dense(num_classes, activation=lambda x: x),
scope=f'layer_{num_hidden + 1}')(keys[-1], x)
return logits
return forward
def predict(mlp):
def forward(key, xs):
mlp_key, label_key = jax.random.split(key)
logits = jax.vmap(functools.partial(mlp, mlp_key))(xs)
return oryx_ppl.random_variable(
tfd.Independent(tfd.Categorical(logits=logits), 1), name='y')(label_key)
return forward
(init_seed,
bx_opt_seed,
output_opt_seed,
bx_mcmc_seed,
output_mcmc_seed) = jax.random.split(jax.random.key(0), 5)
bnn = mlp([200, 200], num_classes)
weights = oryx_ppl.joint_sample(bnn)(init_seed, jnp.ones(num_features))
def target_log_prob(weights):
return oryx_ppl.joint_log_prob(predict(bnn))(dict(weights, y=labels), features)
Converting to a bayeux model¤
We convert the log density to a bayeux model, and run model.debug
with a higher verbosity than normal to get a feel for what is going on. This lets us see that the biggest part of the model is layer_2.w
, which is 200 x 200, meaning we'll be running inference on around 40,000 parameters, and that every posterior sample will cost us about 40kB of memory. Optimization with 128 particles will only be a few MB.
model = bx.Model(log_density=jax.jit(target_log_prob), test_point=weights)
model.optimize.optax_adam.debug(seed=bx_opt_seed, verbosity=3,
kwargs={'num_particles': 128})
%%time
params, state, loss = model.optimize.optax_adam(
seed=bx_opt_seed, num_particles=128)
Interpreting the optimization results¤
An ensemble of optimization results is not a posterior, but with a BNN they'll mostly converge to different points, and it will be tempting to treat it as such. It is beyond the scope of this example to talk about how to capture uncertainty with an ensemble of point estimates, but we can still compute predictions in two different ways: taking an argmax, or computing a mean across particles.
output_logits = jax.vmap(lambda weights: jax.vmap(lambda x: oryx_ppl.intervene(bnn, **weights)(
output_opt_seed, x))(features))(params)
output_probs = jax.nn.softmax(output_logits)
print('Average sample accuracy:', (
output_probs.argmax(axis=-1) == labels[None]).mean())
print('BMA accuracy:', (
output_probs.mean(axis=0).argmax(axis=-1) == labels[None]).mean())
Running MCMC¤
In order to make MCMC run in a reasonable amount of time, we:
1. Run on a GPU.
2. Limit the max_tree_depth
to 8, since the No-U-Turn criterion is often not hit, especially during tuning, and this allows more frequent updating of tuning parameters.
3. Increase the num_warmup
to 2,000, mostly to simulate burn-in.
4. Set return_pytree
to True
, since arviz
is still not able to process nested dictionaries as of this being published.
Note that by default we will get 8 chains of 500 draws each (~200MB), plus the sampler uses a few times this much internally, so we'll expect memory usage to be a few GB.
%%time
params_mcmc = model.mcmc.numpyro_nuts(seed=bx_mcmc_seed,
step_size=1e-3,
max_tree_depth=8,
num_warmup=2_000,
return_pytree=True)
output_logits = jax.vmap(jax.vmap(lambda weights: jax.vmap(lambda x: oryx_ppl.intervene(bnn, **weights)(
output_mcmc_seed, x))(features)))(params_mcmc)
output_probs = jax.nn.softmax(output_logits)
print('Average sample accuracy:', (
output_probs.argmax(axis=-1) == labels[None]).mean())
print('BMA accuracy:', (
output_probs.mean(axis=0).argmax(axis=-1) == labels[None]).mean())
plt.plot(jax.vmap(jax.vmap(model.log_density))(params_mcmc).T);
Note:¤
I ran this with max_tree_depth=7
, which makes the MCMC finish roughly twice as fast, and got the below plot of the 8 chains' log probabilities, which makes it more convincing that the chains were not quite mixing. I changed the tree depth to 8, and it did better, but I could also have increased the number of adaptation steps, probably. I have also fit this on TPUs, though you run into a matrix multiplication precision issue, and need to use config.update('jax_default_matmul_precision', 'float32')
at the start of the colab.