Calling Triton kernels from JAX
The primary way of using JAX Triton is using jax_triton.triton_call
to call handwritten Triton kernels
from inside JIT-ted JAX programs.
jax_triton.triton_call
Calls a Triton kernel with jax.Array
arguments.
Example usage:
First we define a simple kernel that adds two vectors.
import triton
import triton.language as tl
@triton.jit
def add_kernel(
x_ptr,
y_ptr,
output_ptr,
block_size: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * block_size
offsets = block_start + tl.arange(0, block_size)
mask = offsets < 8
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
tl.store(output_ptr + offsets, output, mask=mask)
Then we use triton_call
to call it from JAX.
import jax
import jax.numpy as jnp
import jax_triton as jt
def add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
out_shape = jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype)
block_size = 8
return jt.triton_call(
x,
y,
kernel=add_kernel,
out_shape=out_shape,
grid=(x.size // block_size,),
block_size=block_size)
x_val = jnp.arange(8)
y_val = jnp.arange(8, 16)
print(add(x_val, y_val))
print(jax.jit(add)(x_val, y_val))
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Union[jax.Array, bool, int, float]
|
Inputs for the Triton kernel. |
()
|
kernel |
triton.JITFunction
|
A Triton kernel (e.g. a function decorated with |
required |
out_shape |
Union[ShapeDtype, Sequence[ShapeDtype]]
|
A |
required |
grid |
GridOrLambda
|
An integer, tuple of up to 3 integers, or a function that returns a
tuple of up to 3 integers. When |
required |
input_output_aliases |
Optional[Dict[int, int]]
|
A dictionary mapping input argument indices to output indices. Providing a mapping will alias the corresponding buffers. |
None
|
zeroed_outputs |
Union[Sequence[int], Callable[[Dict[str, Any]], Sequence[int]]]
|
A sequence of indices, or a function returning a sequence of indices, for outputs that should be zeroed before the kernel is launched. |
()
|
num_warps |
int
|
The number of warps used to execute the Triton kernel. |
4
|
num_stages |
int
|
The number of stages emitted by the Triton compiler. |
2
|
debug |
bool
|
Prints out intermediate IRs if True for debugging purposes. |
False
|
serialized_metadata |
bytes
|
Arbitrary metadata that will be added into the serialized kernel call. |
b''
|
**metaparams |
Any
|
Additional keyword arguments that will be provided to a |
{}
|
Returns:
Type | Description |
---|---|
Any
|
Outputs from the Triton kernel. |