JAX-Triton documentation
JAX-Triton is a repository containing containing integrations between JAX and Triton.
JAX is a Python library for accelerated numerical computing and Triton is a Python library and compiler for writing custom GPU kernels. When we put the two together, we get JAX-Triton, which enables writing custom GPU kernels using Triton that can be embedded inside of JAX programs.
Getting started
Installing JAX-Triton
You can install JAX-Triton with pip
. This will also install a compatible JAX and Triton.
JAX-Triton only works with JAX on GPU, so you'll need to make sure you have a CUDA-compatible jaxlib
installed.
For example you could run:
$ pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Installation at HEAD
JAX-Triton and Pallas are developed at JAX and Jaxlib HEAD and close to Triton HEAD. To get a bleeding edge installation of JAX-Triton, run:
This should install compatible versions of JAX and Triton.JAX-Triton does depend on Jaxlib but it's usually a more stable dependency. You might be able to get away with using a recent jaxlib release:
$ pip install jaxlib[cuda]
$ # or
$ pip install jaxlib[cuda11_pip]
$ # or
$ pip install jaxlib[cuda12_pip]
If you find there are issues with the latest Jaxlib release, you can try using a Jaxlib nightly. To install a new jaxlib, you can find a link to a CUDA 11 nightly or CUDA 12 nightly. Then install it via:
or to install CUDA via pip automatically, you can do:$ pip install 'jaxlib[cuda11_pip] @ <link to nightly>'
$ # or
$ pip install 'jaxlib[cuda12_pip] @ <link to nightly>'
Quickstart
The main function of interest is jax_triton.triton_call
for applying Triton
functions to JAX arrays, including inside jax.jit
-compiled functions. For
example, we can define a kernel from the Triton
tutorial:
import triton
import triton.language as tl
@triton.jit
def add_kernel(
x_ptr,
y_ptr,
output_ptr,
block_size: tl.constexpr,
):
"""Adds two vectors."""
pid = tl.program_id(axis=0)
block_start = pid * block_size
offsets = block_start + tl.arange(0, block_size)
mask = offsets < 8
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
tl.store(output_ptr + offsets, output, mask=mask)
Then we can apply it to JAX arrays using jax_triton.triton_call
:
import jax
import jax.numpy as jnp
import jax_triton as jt
def add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
out_shape = jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype)
block_size = 8
return jt.triton_call(
x,
y,
kernel=add_kernel,
out_shape=out_shape,
grid=(x.size // block_size,),
block_size=block_size)
x_val = jnp.arange(8)
y_val = jnp.arange(8, 16)
print(add(x_val, y_val))
print(jax.jit(add)(x_val, y_val))
See the examples directory, especially fused_attention.py and the fused attention ipynb.