Calling Triton kernels from JAX
The primary way of using JAX Triton is using jax_triton.triton_call
to call handwritten Triton kernels
from inside JIT-ted JAX programs.
jax_triton.triton_call
Calls a Triton kernel with jax.Array
arguments.
Example usage:
First we define a simple kernel that adds two vectors.
import triton
import triton.language as tl
@triton.jit
def add_kernel(
x_ptr,
y_ptr,
output_ptr,
block_size: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * block_size
offsets = block_start + tl.arange(0, block_size)
mask = offsets < 8
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
tl.store(output_ptr + offsets, output, mask=mask)
Then we use triton_call
to call it from JAX.
import jax
import jax.numpy as jnp
import jax_triton as jt
def add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
out_shape = jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype)
block_size = 8
return jt.triton_call(
x,
y,
kernel=add_kernel,
out_shape=out_shape,
grid=(x.size // block_size,),
block_size=block_size)
x_val = jnp.arange(8)
y_val = jnp.arange(8, 16)
print(add(x_val, y_val))
print(jax.jit(add)(x_val, y_val))
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args
|
Array | bool | int | float | float32
|
Inputs for the Triton kernel. |
()
|
kernel
|
JITFunction
|
A Triton kernel (e.g. a function decorated with |
required |
out_shape
|
ShapeDtype | Sequence[ShapeDtype]
|
A |
required |
grid
|
GridOrLambda
|
An integer, tuple of up to 3 integers, or a function that returns a
tuple of up to 3 integers. When |
required |
input_output_aliases
|
dict[int, int] | None
|
A dictionary mapping input argument indices to output indices. Providing a mapping will alias the corresponding buffers. |
None
|
zeroed_outputs
|
Sequence[int] | Callable[[dict[str, Any]], Sequence[int]]
|
A sequence of indices, or a function returning a sequence of indices, for outputs that should be zeroed before the kernel is launched. |
()
|
num_warps
|
int | None
|
The number of warps used to execute the Triton kernel. |
None
|
num_stages
|
int | None
|
The number of stages emitted by the Triton compiler. |
None
|
num_ctas
|
int
|
The size of thread blocks per cluster to be used on GPUs with compute capabilities >= 9.0. It must be less or equal to 8. |
1
|
debug
|
bool
|
Prints out intermediate IRs if True for debugging purposes. |
False
|
serialized_metadata
|
bytes
|
Arbitrary metadata that will be added into the serialized kernel call. |
b''
|
**metaparams
|
Any
|
Additional keyword arguments that will be provided to a |
{}
|
Returns:
Type | Description |
---|---|
Any
|
Outputs from the Triton kernel. |