Skip to content

Calling Triton kernels from JAX

The primary way of using JAX Triton is using jax_triton.triton_call to call handwritten Triton kernels from inside JIT-ted JAX programs.

jax_triton.triton_call

Calls a Triton kernel with jax.Array arguments.

Example usage:

First we define a simple kernel that adds two vectors.

import triton
import triton.language as tl

@triton.jit
def add_kernel(
    x_ptr,
    y_ptr,
    output_ptr,
    block_size: tl.constexpr,
):
  pid = tl.program_id(axis=0)
  block_start = pid * block_size
  offsets = block_start + tl.arange(0, block_size)
  mask = offsets < 8
  x = tl.load(x_ptr + offsets, mask=mask)
  y = tl.load(y_ptr + offsets, mask=mask)
  output = x + y
  tl.store(output_ptr + offsets, output, mask=mask)

Then we use triton_call to call it from JAX.

import jax
import jax.numpy as jnp
import jax_triton as jt

def add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
  out_shape = jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype)
  block_size = 8
  return jt.triton_call(
      x,
      y,
      kernel=add_kernel,
      out_shape=out_shape,
      grid=(x.size // block_size,),
      block_size=block_size)

x_val = jnp.arange(8)
y_val = jnp.arange(8, 16)
print(add(x_val, y_val))
print(jax.jit(add)(x_val, y_val))

Parameters:

Name Type Description Default
*args Union[jax.Array, bool, int, float]

Inputs for the Triton kernel.

()
kernel triton.JITFunction

A Triton kernel (e.g. a function decorated with triton.jit). All static values should be annotated with triton.language.constexpr.

required
out_shape Union[ShapeDtype, Sequence[ShapeDtype]]

A jax.ShapeDtypeStruct (or something that has .shape and .dtype attributes) or a sequence thereof that specify the output(s) of the kernel. Pointers for each of the jax.ShapeDtypeStructs in out_shape will be passed into kernel following the input parameters.

required
grid GridOrLambda

An integer, tuple of up to 3 integers, or a function that returns a tuple of up to 3 integers. When grid is an integer, kernel is invocated in grid-many parallel executions. When grid is a sequence of integers, kernel is launched in a prod(grid)-many parallel execution. When grid is a function, it is passed **metaparams and should return a tuple of up to 3 integers.

required
input_output_aliases Optional[Dict[int, int]]

A dictionary mapping input argument indices to output indices. Providing a mapping will alias the corresponding buffers.

None
zeroed_outputs Union[Sequence[int], Callable[[Dict[str, Any]], Sequence[int]]]

A sequence of indices, or a function returning a sequence of indices, for outputs that should be zeroed before the kernel is launched.

()
num_warps int

The number of warps used to execute the Triton kernel.

4
num_stages int

The number of stages emitted by the Triton compiler.

2
debug bool

Prints out intermediate IRs if True for debugging purposes.

False
serialized_metadata bytes

Arbitrary metadata that will be added into the serialized kernel call.

b''
**metaparams Any

Additional keyword arguments that will be provided to a grid (if it is a function) and to the Triton kernel as constexpr arguments.

{}

Returns:

Type Description
Any

Outputs from the Triton kernel.