Skip to content
!pip install -Uq bayeux-ml
from sklearn import datasets
iris = datasets.load_iris()
features, labels = iris['data'], iris['target']

num_features = features.shape[-1]
num_classes = len(iris.target_names)

import functools

import bayeux as bx
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import oryx.core.ppl as oryx_ppl
import tensorflow_probability.substrates.jax as tfp

tfd = tfp.distributions

Bayeux and Oryx¤

Oryx is a library for probabilistic programming and deep learning built on top of Jax, and is used by bayeux for automatically computing inverse transforms and log determinant Jacobians. But oryx also includes tools for defining joint log densities.

This notebook uses the example from the oryx documentation. The differences is that we additionally run an Adam optimizer, and use NumPyro's NUTS sampler (instead of TFP's HMC sampler).

We load the iris dataset which classifies iris species by four measurements on the flowers, and will train a BNN on this data.

def dense(dim_out, activation=jax.nn.relu):
  def forward(key, x):
    dim_in = x.shape[-1]
    w_key, b_key = jax.random.split(key)
    w = oryx_ppl.random_variable(
          tfd.Sample(tfd.Normal(0., 1.), sample_shape=(dim_out, dim_in)),
          name='w')(w_key)
    b = oryx_ppl.random_variable(
          tfd.Sample(tfd.Normal(0., 1.), sample_shape=(dim_out,)),
          name='b')(b_key)
    return activation(jnp.dot(w, x) + b)
  return forward


def mlp(hidden_sizes, num_classes):
  num_hidden = len(hidden_sizes)
  def forward(key, x):
    keys = jax.random.split(key, num_hidden + 1)
    for i, (subkey, hidden_size) in enumerate(zip(keys[:-1], hidden_sizes)):
      x = oryx_ppl.nest(dense(hidden_size), scope=f'layer_{i + 1}')(subkey, x)
    logits = oryx_ppl.nest(dense(num_classes, activation=lambda x: x),
                  scope=f'layer_{num_hidden + 1}')(keys[-1], x)
    return logits
  return forward

def predict(mlp):
  def forward(key, xs):
    mlp_key, label_key = jax.random.split(key)
    logits = jax.vmap(functools.partial(mlp, mlp_key))(xs)
    return oryx_ppl.random_variable(
        tfd.Independent(tfd.Categorical(logits=logits), 1), name='y')(label_key)
  return forward
(init_seed,
 bx_opt_seed,
 output_opt_seed,
 bx_mcmc_seed,
 output_mcmc_seed) = jax.random.split(jax.random.key(0), 5)

bnn = mlp([200, 200], num_classes)

weights = oryx_ppl.joint_sample(bnn)(init_seed, jnp.ones(num_features))

def target_log_prob(weights):
  return oryx_ppl.joint_log_prob(predict(bnn))(dict(weights, y=labels), features)

Converting to a bayeux model¤

We convert the log density to a bayeux model, and run model.debug with a higher verbosity than normal to get a feel for what is going on. This lets us see that the biggest part of the model is layer_2.w, which is 200 x 200, meaning we'll be running inference on around 40,000 parameters, and that every posterior sample will cost us about 40kB of memory. Optimization with 128 particles will only be a few MB.

model = bx.Model(log_density=jax.jit(target_log_prob), test_point=weights)
model.optimize.optax_adam.debug(seed=bx_opt_seed, verbosity=3,
                                kwargs={'num_particles': 128})
Checking test_point shape ✓ 
Test point has shape
{'layer_1': {'b': (200,), 'w': (200, 4)}, 'layer_2': {'b': (200,), 'w': (200, 200)}, 'layer_3': {'b': (3,), 'w': (3, 200)}}
✓✓✓✓✓✓✓✓✓✓

Computing test point log density ✓ 
Test point has log density
Array(-168474.19, dtype=float32)
✓✓✓✓✓✓✓✓✓✓

Loading keyword arguments... ✓ 
Keyword arguments are
{<function adam at 0x7ab5e9d4cd30>: {'b1': 0.9,
                                     'b2': 0.999,
                                     'eps': 1e-08,
                                     'eps_root': 0.0,
                                     'learning_rate': 0.1,
                                     'mu_dtype': None,
                                     'nesterov': False},
 'extra_parameters': {'apply_transform': True,
                      'chain_method': 'vectorized',
                      'num_iters': 1000,
                      'num_particles': 128}}
✓✓✓✓✓✓✓✓✓✓

Checking it is possible to compute an initial state ✓ 
Initial state has shape
{'layer_1': {'b': (128, 200), 'w': (128, 200, 4)}, 'layer_2': {'b': (128, 200), 'w': (128, 200, 200)}, 'layer_3': {'b': (128, 3), 'w': (128, 3, 200)}}
✓✓✓✓✓✓✓✓✓✓

Checking initial state is has no NaN ✓ 
No nans detected!
✓✓✓✓✓✓✓✓✓✓

Computing initial state log density ✓ 
Initial state log density has shape
(128,)
✓✓✓✓✓✓✓✓✓✓

Transforming model to R^n ✓ 
Transformed state has shape
{'layer_1': {'b': (128, 200), 'w': (128, 200, 4)}, 'layer_2': {'b': (128, 200), 'w': (128, 200, 200)}, 'layer_3': {'b': (128, 3), 'w': (128, 3, 200)}}
✓✓✓✓✓✓✓✓✓✓

Computing transformed state log density shape ✓ 
Transformed state log density has shape
(128,)
✓✓✓✓✓✓✓✓✓✓

Computing gradients of transformed log density ✓ 
Initial gradient has shape
{'layer_1': {'b': (128, 200), 'w': (128, 200, 4)}, 'layer_2': {'b': (128, 200), 'w': (128, 200, 200)}, 'layer_3': {'b': (128, 3), 'w': (128, 3, 200)}}
✓✓✓✓✓✓✓✓✓✓


True
%%time
params, state, loss = model.optimize.optax_adam(
    seed=bx_opt_seed, num_particles=128)
CPU times: user 4.38 s, sys: 1.52 s, total: 5.89 s
Wall time: 8.81 s

Interpreting the optimization results¤

An ensemble of optimization results is not a posterior, but with a BNN they'll mostly converge to different points, and it will be tempting to treat it as such. It is beyond the scope of this example to talk about how to capture uncertainty with an ensemble of point estimates, but we can still compute predictions in two different ways: taking an argmax, or computing a mean across particles.

output_logits = jax.vmap(lambda weights: jax.vmap(lambda x: oryx_ppl.intervene(bnn, **weights)(
    output_opt_seed, x))(features))(params)
output_probs = jax.nn.softmax(output_logits)

print('Average sample accuracy:', (
    output_probs.argmax(axis=-1) == labels[None]).mean())
print('BMA accuracy:', (
    output_probs.mean(axis=0).argmax(axis=-1) == labels[None]).mean())
Average sample accuracy: 0.9825
BMA accuracy: 0.9866667

Running MCMC¤

In order to make MCMC run in a reasonable amount of time, we: 1. Run on a GPU. 2. Limit the max_tree_depth to 8, since the No-U-Turn criterion is often not hit, especially during tuning, and this allows more frequent updating of tuning parameters. 3. Increase the num_warmup to 2,000, mostly to simulate burn-in. 4. Set return_pytree to True, since arviz is still not able to process nested dictionaries as of this being published.

Note that by default we will get 8 chains of 500 draws each (~200MB), plus the sampler uses a few times this much internally, so we'll expect memory usage to be a few GB.

%%time
params_mcmc = model.mcmc.numpyro_nuts(seed=bx_mcmc_seed,
                                      step_size=1e-3,
                                      max_tree_depth=8,
                                      num_warmup=2_000,
                                      return_pytree=True)
sample: 100%|██████████| 3000/3000 [20:23<00:00,  2.45it/s]

CPU times: user 20min 10s, sys: 3.89 s, total: 20min 14s
Wall time: 20min 34s

output_logits = jax.vmap(jax.vmap(lambda weights: jax.vmap(lambda x: oryx_ppl.intervene(bnn, **weights)(
    output_mcmc_seed, x))(features)))(params_mcmc)
output_probs = jax.nn.softmax(output_logits)

print('Average sample accuracy:', (
    output_probs.argmax(axis=-1) == labels[None]).mean())
print('BMA accuracy:', (
    output_probs.mean(axis=0).argmax(axis=-1) == labels[None]).mean())
Average sample accuracy: 0.9882125
BMA accuracy: 0.9919867

plt.plot(jax.vmap(jax.vmap(model.log_density))(params_mcmc).T);
No description has been provided for this image

Note:¤

I ran this with max_tree_depth=7, which makes the MCMC finish roughly twice as fast, and got the below plot of the 8 chains' log probabilities, which makes it more convincing that the chains were not quite mixing. I changed the tree depth to 8, and it did better, but I could also have increased the number of adaptation steps, probably. I have also fit this on TPUs, though you run into a matrix multiplication precision issue, and need to use config.update('jax_default_matmul_precision', 'float32') at the start of the colab.

image.png