try:
import dynamax
except ModuleNotFoundError:
print('installing dynamax')
%pip install -qq dynamax
import dynamax
import bayeux as bx
import jax
from itertools import count
from dynamax.linear_gaussian_ssm import LinearGaussianSSM
from dynamax.parameters import log_det_jac_constrain
from dynamax.parameters import to_unconstrained, from_unconstrained
from dynamax.utils.utils import ensure_array_has_batch_dim
from functools import partial
Using bayeux with Dynamax¤
dynamax
is a library for probabilistic state space models written in JAX. dynamax
builds a pure JAX likelihood function from a model, and hence is compatible with other libraries in the JAX ecosystem: we can estimate model parameters using other JAX libraries such as optax
(via stochastic gradient descent) and blackjax
(via sampling).
Here, we will provide minimal steps to recreate the inference stage for an example from the dynamax
documentation for Bayesian parameter estimation for a linear Gaussian state space model using HMC. Writing inference loops in blackjax
, especially for multiple chains, can be quite cumbersome. We will use bayeux
to reduce some of the boilerplate code.
This example shows how we can take any model in a JAX library and use bayeux
to perform inference.
state_dim = 2
emission_dim = 10
num_timesteps = 100
seed = jax.random.key(0)
init_key, sample_key, test_init_key, inference_key = jax.random.split(seed, 4)
# simulate synthetic data from true model
true_model = LinearGaussianSSM(state_dim, emission_dim)
true_params, _ = true_model.initialize(init_key)
true_states, emissions = true_model.sample(true_params, sample_key, num_timesteps)
test_model = LinearGaussianSSM(state_dim, emission_dim)
initial_params, param_props = test_model.initialize(test_init_key)
def dynamax_logdensity(
model,
props,
emissions,
inputs=None,
):
"""Convert dynamax model into log-desnity function."""
# Make sure the emissions and inputs have batch dimensions
batch_emissions = ensure_array_has_batch_dim(emissions, model.emission_shape)
batch_inputs = ensure_array_has_batch_dim(inputs, model.inputs_shape)
# log likelihood that the HMC samples from
def _logprob(unc_params):
params = from_unconstrained(unc_params, props)
batch_lls = jax.vmap(partial(model.marginal_log_prob, params))(batch_emissions, batch_inputs)
lp = model.log_prior(params) + batch_lls.sum()
lp += log_det_jac_constrain(params, props)
return lp
return _logprob
log_density = dynamax_logdensity(test_model, param_props, emissions)
initial_unc_params = to_unconstrained(initial_params, param_props)
ssm_density = bx.Model(
log_density=log_density,
test_point=initial_unc_params
)
We can use debug
mode to help check if the model is correctly implemented.
ssm_density.mcmc.blackjax_hmc.debug(seed=inference_key)
samples = ssm_density.mcmc.blackjax_hmc(
seed=seed,
chain_method="vectorized",
num_chains=2,
num_draws=500,
num_integration_steps=30,
progress_bar=True,
return_pytree=True,
)
constrained_samples = from_unconstrained(samples, param_props)
We are not just limited to blackjax
, we can use another sampling backend like numpyro
, or we can use gradient descent and other options, all with a simple interface thanks to bayeux
.
print(ssm_density)