Skip to content
try:
    import dynamax
except ModuleNotFoundError:
    print('installing dynamax')
    %pip install -qq dynamax
    import dynamax
import bayeux as bx
import jax
from itertools import count

from dynamax.linear_gaussian_ssm import LinearGaussianSSM
from dynamax.parameters import log_det_jac_constrain
from dynamax.parameters import to_unconstrained, from_unconstrained
from dynamax.utils.utils import ensure_array_has_batch_dim
from functools import partial

Using bayeux with Dynamax¤

dynamax is a library for probabilistic state space models written in JAX. dynamax builds a pure JAX likelihood function from a model, and hence is compatible with other libraries in the JAX ecosystem: we can estimate model parameters using other JAX libraries such as optax (via stochastic gradient descent) and blackjax (via sampling).

Here, we will provide minimal steps to recreate the inference stage for an example from the dynamax documentation for Bayesian parameter estimation for a linear Gaussian state space model using HMC. Writing inference loops in blackjax, especially for multiple chains, can be quite cumbersome. We will use bayeux to reduce some of the boilerplate code.

This example shows how we can take any model in a JAX library and use bayeux to perform inference.

state_dim = 2
emission_dim = 10
num_timesteps = 100

seed = jax.random.key(0)
init_key, sample_key, test_init_key, inference_key = jax.random.split(seed, 4)

# simulate synthetic data from true model
true_model = LinearGaussianSSM(state_dim, emission_dim)
true_params, _ = true_model.initialize(init_key)
true_states, emissions = true_model.sample(true_params, sample_key, num_timesteps)

test_model = LinearGaussianSSM(state_dim, emission_dim)
initial_params, param_props = test_model.initialize(test_init_key)
def dynamax_logdensity(
        model,
        props,
        emissions,
        inputs=None,
    ):
    """Convert dynamax model into log-desnity function."""
    # Make sure the emissions and inputs have batch dimensions
    batch_emissions = ensure_array_has_batch_dim(emissions, model.emission_shape)
    batch_inputs = ensure_array_has_batch_dim(inputs, model.inputs_shape)

    # log likelihood that the HMC samples from
    def _logprob(unc_params):
        params = from_unconstrained(unc_params, props)
        batch_lls = jax.vmap(partial(model.marginal_log_prob, params))(batch_emissions, batch_inputs)
        lp = model.log_prior(params) + batch_lls.sum()
        lp += log_det_jac_constrain(params, props)
        return lp

    return _logprob
log_density = dynamax_logdensity(test_model, param_props, emissions)
initial_unc_params = to_unconstrained(initial_params, param_props)

ssm_density = bx.Model(
  log_density=log_density,
  test_point=initial_unc_params
)

We can use debug mode to help check if the model is correctly implemented.

ssm_density.mcmc.blackjax_hmc.debug(seed=inference_key)
Checking test_point shape ✓ 
Computing test point log density ✓ 
Loading keyword arguments... ✓ 
Checking it is possible to compute an initial state ✓ 
Checking initial state is has no NaN ✓ 
Computing initial state log density ✓ 
Transforming model to R^n ✓ 
Computing transformed state log density shape ✓ 
Computing gradients of transformed log density ✓ 

True
samples = ssm_density.mcmc.blackjax_hmc(
    seed=seed,
    chain_method="vectorized",
    num_chains=2,
    num_draws=500,
    num_integration_steps=30,
    progress_bar=True,
    return_pytree=True,
)

constrained_samples = from_unconstrained(samples, param_props)
Running window adaptation

100.00% [500/500 00:00<?]


We are not just limited to blackjax, we can use another sampling backend like numpyro, or we can use gradient descent and other options, all with a simple interface thanks to bayeux.

print(ssm_density)
mcmc
    .tfp_hmc
    .tfp_nuts
    .tfp_snaper_hmc
    .blackjax_hmc
    .blackjax_chees_hmc
    .blackjax_meads_hmc
    .blackjax_nuts
    .blackjax_hmc_pathfinder
    .blackjax_nuts_pathfinder
    .flowmc_rqspline_hmc
    .flowmc_rqspline_mala
    .flowmc_realnvp_hmc
    .flowmc_realnvp_mala
    .numpyro_hmc
    .numpyro_nuts
optimize
    .jaxopt_bfgs
    .jaxopt_gradient_descent
    .jaxopt_lbfgs
    .jaxopt_nonlinear_cg
    .optimistix_bfgs
    .optimistix_chord
    .optimistix_dogleg
    .optimistix_gauss_newton
    .optimistix_indirect_levenberg_marquardt
    .optimistix_levenberg_marquardt
    .optimistix_nelder_mead
    .optimistix_newton
    .optimistix_nonlinear_cg
    .optax_adabelief
    .optax_adafactor
    .optax_adagrad
    .optax_adam
    .optax_adamw
    .optax_adamax
    .optax_amsgrad
    .optax_fromage
    .optax_lamb
    .optax_lion
    .optax_noisy_sgd
    .optax_novograd
    .optax_radam
    .optax_rmsprop
    .optax_sgd
    .optax_sm3
    .optax_yogi
vi
    .tfp_factored_surrogate_posterior