Part 10 of How To Scale Your Model (Part 9: Profiling | Part 11: Conclusions)
How to use JAX to program TPUs efficiently! Much of this section is taken from here.
JAX supports two schools of thought for multi-device programming:
Correspondingly, JAX provides two APIs for each of these schools: jit (jax.jit
) and shard_map (jax.experimental.shard_map.shard_map
lets you specify the sharding of the inputs and outputs to a program (via in_shardings
and out_shardings
) and infers the rest using the GSPMD compiler. While it isn’t perfect, it usually does a decent job at automatically scaling your program to any number of chips.jax.experimental.shard_map.shard_map
is the more explicit counterpart. You get a device-local view of the program and have to write any communication you want explicitly. Have a sharded array and want the whole thing on each device? Add a jax.lax.all_gather
. Want to sum an array across your devices? Add a jax.lax.psum
(an AllReduce). Programming is harder but far less likely to do something you don’t want.jax.jit plays two roles inside JAX. As the name suggests, it “just-in-time” compiles a function from Python into bytecode (via XLA/HLO/LLO) so it runs faster. But if the input is sharded or the user specifies an in_sharding
or out_sharding
, it also lets XLA distribute the computation across multiple devices and add communication as needed. For example, here’s how you could write a sharded matmul using jax.jit:
import jax
import jax.numpy as jnp
import jax.sharding as shd
# Running on an TPU v5e 2x2. This assigns names to the two physical axes of the hardware.
mesh = jax.make_mesh(axis_shapes=(2, 2), axis_names=('X', 'Y'))
def P(*args):
return shd.NamedSharding(mesh, shd.PartitionSpec(*args))
# We create a matrix W and input activations In sharded across our devices.
In = jnp.zeros((8, 2048), dtype=jnp.bfloat16, device=P('X', 'Y'))
W = jnp.zeros((2048, 8192), dtype=jnp.bfloat16, device=P('Y', None))
def matmul_square(In, W):
return jnp.einsum('bd,df->bf', jnp.square(In), W)
# We can explicitly compile the sharded matmul function here. This adds all the
# necessary comms (e.g. an AllReduce after the matmul).
jit_matmul = jax.jit(matmul_square, out_shardings=P('X', None)).lower(In, W).compile()
out = jit_matmul(In, W)
This will run automatically with any sharding and partition the computation across our devices. But what’s actually happening at the hardware level?
would simply square the input and perform a simple matmul. But because we specify the out_shardings
as P(‘X', None)
, our output will be sharded along the batch but replicated across the model dimension and will require an AllReduce to compute.Using our notation from previous sections, this will likely do something like
will add this for us automatically! We can actually print the HLO with jit_matmul.as_text()
and see the following HLO (abbreviated dramatically):
# This fusion is the actual matmul of the sharded inputs and matrix
%fusion = bf16[4,8192]{1,0:T(4,128)(2,1)S(1)} fusion(bf16[4,1024]{1,0:T(4,128)(2,1)} %param, bf16[8192,1024]{1,0:T(8,128)(2,1)S(1)} %copy-done)
# We reduce the partially summed results across devices
ROOT %AllReduce = bf16[4,8192]{1,0:T(4,128)(2,1)} AllReduce(bf16[4,8192]{1,0:T(4,128)(2,1)S(1)} %fusion)
We can see the matmul (the fusion) and the AllReduce above. Pay particular attention to the shapes. bf16[4, 1024]
is a local view of the activations, since our batch_size=8
is split across 2 devices and our d_model=2048
is likewise split 2 ways.
This is pretty magical! No matter how complicated our program is, GSPMD and jit will attempt to find shardings for all the intermediate activations and add communication as needed. With that said, GSPMD has its flaws. It can make mistakes. Sometimes you’ll look at a profile and notice something has gone wrong. A giant AllGather takes up 80% of the profile, where it doesn’t need to. When this happens, we can try to correct the compiler by explicitly annotating intermediate tensors with jax.lax.with_sharding_constraint
. For instance, with two matmuls I can force the intermediate activations to be sharded along the y
dimension (not that this is a good idea) with the following:
import jax
import jax.numpy as jnp
def matmul(x, Win, Wout):
hidden = jnp.einsum('bd,df->bf', x, Win)
hidden = jax.lax.with_sharding_constraint(hidden, P('x', 'y'))
return jnp.einsum('bf,df->bd', hidden, Wout)
This makes up like 60% of JAX parallel programming in the jit world, since it’s our only way of intervening with the compiler. It’s worth playing around with with_sharding_constraint
in a Colab and getting a sense for how it works. When we write LLMs using jax.jit
, 90% of what we do to control shardings is changing the input and output shardings (via in_shardings
and out_shardings
) and annotating intermediate tensors with with_sharding_constraint
to ensure the correct comms are happening. For more jax.jit examples, this is a great doc to read.
While GSPMD is the “compiler take the wheel” mode, jax shard_map puts everything in your hands. You specify the sharding of the inputs, like in jax.jit, but then you write all communication explicitly. Whereas jax.jit
leaves you with a global cross-device view of the program, shard_map
gives you a local per-device view.
Here’s an example. Try to reason about what this function does:
import jax
import jax.numpy as jnp
import jax.lax
import jax.sharding as shd
from jax.experimental.shard_map import shard_map as shmap
P = shd.PartitionSpec
mesh = jax.make_mesh(axis_shapes=(2,4), axis_names=('x','y'))
x = jnp.arange(0, 512, dtype=jnp.int32, device=jax.NamedSharding(mesh, P(('x', 'y'))))
# This function will operate on 1/8th of the array.
def slice_and_average(x):
assert x.shape == (512 // 8,)
return jax.lax.pmean(x[:4], axis_name=('x', 'y'))
out = shmap(slice_and_average, mesh, in_specs=P(('x', 'y')), out_specs=P(None,))(x)
assert out.shape == (4,)
What does this do? slice_and_average
is run on each TPU with 1/8th of the array, from which we slice the first 4 elements and average them across the full mesh. This means we’re effectively doing mean(x[:4], x[64:68], x[128:132], …)
. This is pretty cool, because that’s not an easy operation to express in JAX otherwise.
Why do this instead of jax.jit? If we’d used jax.jit
, slice_and_average
would have seen a global view of the array (the full [512,]
array). We’d have had to slice out this non-uniform slice and then perform an average which XLA would have had to interpret correctly. XLA might have added the wrong communication or gotten confused. Here we see the local view and write only the communication we need.
Example [Collective Matmul]: To take a more realistic example, say we to implement model parallelism where the activations are initially model sharded, i.e. A[BX, DY] * W[D, FY] -> Out[BX, FY]. Naively, we would do this by AllGathering A first followed by a local matrix multiplication:
Sadly, this is bad because it doesn’t allow us to overlap the communication with the computation. Overlapping them can be done with a “collective matmul”, as described in Wang et al. 2023. The algorithm is basically as follows:
[B / X, F / Y]
. Simultaneously, permute A so you get the next chunk locally, perform the matmul, and sum the result.We can implement that quite easily with shard_map:
import functools
import jax
import jax.numpy as jnp
import jax.sharding as shd
import numpy as np
from jax.experimental.shard_map import shard_map
mesh = jax.make_mesh(axis_shapes=(2, 4), axis_names=('X', 'Y'))
def P(*args):
return shd.NamedSharding(mesh, shd.PartitionSpec(*args))
B, D, F = 1024, 2048, 8192
A = jnp.arange(, D))).reshape((B, D))
W = jnp.arange(, F))).reshape((D, F))
A = jax.device_put(A, P('X', 'Y'))
W = jax.device_put(W, P(None, 'Y'))
@functools.partial(jax.jit, out_shardings=P('X', 'Y'))
def matmul(lhs, rhs):
return lhs @ rhs
def collective_matmul_allgather_lhs_contracting(lhs, rhs):
# lhs is the looped operand; rhs is the local operand
axis_size = jax.lax.psum(1, axis_name='Y') # axis_size = 4 for this example
idx = jax.lax.axis_index('Y')
chunk_size = lhs.shape[1]
assert rhs.shape[0] % chunk_size == 0
def f(i, carrys):
accum, lhs = carrys
rhs_chunk = jax.lax.dynamic_slice_in_dim(rhs, (idx + i) % axis_size * chunk_size, chunk_size)
# Matmul for a chunk
update = lhs @ rhs_chunk
# Circular shift to the left
lhs = jax.lax.ppermute(
perm=[(j, (j - 1) % axis_size) for j in range(axis_size)]
return accum + update, lhs
accum = jnp.zeros((lhs.shape[0], rhs.shape[1]), dtype=lhs.dtype)
accum, lhs = jax.lax.fori_loop(0, axis_size - 1, f, (accum, lhs), unroll=True)
# Compute the last chunk after the final permute to leave lhs in the state we found it
i = axis_size - 1
rhs_chunk = jax.lax.dynamic_slice_in_dim(rhs, (idx + i) % axis_size * chunk_size, chunk_size)
update = lhs @ rhs_chunk
return accum + update
jit_sharded_f = jax.jit(shard_map(
collective_matmul_allgather_lhs_contracting, mesh,
in_specs=(shd.PartitionSpec('X', 'Y'), shd.PartitionSpec(None, 'Y')), out_specs=shd.PartitionSpec('X', 'Y')))
shmapped_out = jit_sharded_f(A, W)
expected_out = matmul(A, W)
np.testing.assert_array_equal(shmapped_out, expected_out)
This is pretty neat! We can benchmark this and see that it’s also a lot faster! Here’s the profile with the default jit matmul which takes 311us with a big blocking AllGather at the beginning:
And here’s the version above that takes 244 us. You can see the profile doesn’t have the AllGather. It’s all useful work! Our FLOPs utilization is also a lot higher.
It’s also worth noting that the matmul time with no sharding on the contracting dimension is 224us, so we’re remarkably close to the unsharded baseline here. This is a good example of the kind of performance engineering you might end up doing to improve TPU utilization. For more shard_map
examples, this note is great.
Now here are a couple of useful worked problems to try and implement using jax.jit
or shard_map
Here are some random JAX-related problems. I’ll add some more later. For all of these, you’ll need some number of TPUs in a Colab. You can use a public Colab with TPUv2-8. From now on, we’ll assume you have N devices available.
Problem 1: For the next several parts, we’ll let A be an array of activations of shape float32[SX, DY] with X * Y = N
. Do the following:
Write a function in JAX that computes the average over each X
shard, i.e. it returns an array of size [X, DY] where arr[i]
is the average over shard i
. Do this with both jax.jit
and shard_map
. Profile each and see how long they took. Was there any communication added? Hint: there shouldn’t be, but sometimes XLA adds it anyway. Here’s the answer.
Write a function in JAX that returns roll(x, shift) - x for some shift within each shard X. I’m not enough of a masochist to make you do this in jax.jit, so just do this with shard_map
Problem 2: Here we’ll make a basic “mixture of experts” model together. Let W: float32[EX, D, FY] be a set of E “expert” matrices. Let A be as above (our activations) and let B be a set of “routing assignments” where B[i] is an integer in the range [0, E)
telling us which matrix we want to process that activation. We want to write a function in JAX that returns Out[i] = W[B[i]] @ A[i]
Let’s start by ignoring sharding altogether. Make all of these tensors small enough so they fit in one device. Write a local implementation of this function. Make sure you don’t materialize an array of shape [S, D, F]
! Hint: try sorting the tokens into a new buffer of shape [E, S, D]
with some attention to masking (why do we need the second dimension to have size S?).
If you just jax.jit
the above method, something will happen. Profile this and see what communication it decided to do. How long does it take?
One problem you’ll notice with the above is that it likely gathers the full set of activations A locally, i.e. AllGatherX([SX, DY]), Not only is this expensive communication-wise, it’s also incredibly expensive memory-wise if we can’t fit the full set of activations locally. Implement the above using shard_map
and explicit communication.
For a first pass, it might be easiest to use a jax.lax.all_gather
and reorder as in (a).
For a second pass, try to avoid materializing any array of size [E, S, D]
, i.e. try to perform the computation in a ragged fashion using a jax.lax.all_to_all
inside a jax.lax.while_loop
. This way, you can avoid materializing the full activations and wasting compute on padding. How much faster is this than your original implementation?
Most MoEs route to multiple (k) experts and then average the result. Refactor the above to implement this. Let B: int32[S, k] in this case for the k experts to route to.
Problem 3: The collective matmul example above is actually super relevant for real LLMs. Let’s tweak the example to do the full Transformer stack.
As an exercise, let’s start by implementing an AllReduce collective matmul, i.e. A[BX, DY] *D W[DY, F] -> Out[BX, F]. Note that the output isn’t replicated. The naive algorithm is discussed above, basically just a local matmul followed by an AllReduce. Try to make a comms overlapped “collective” version of this operation. Hint: tile over the output dimension and feel free to use jax.lax.psum
(aka AllReduce). Note: due to the way XLA handles this, it may not actually be faster than the baseline.
The complement to the AllReduce collective matmul above is a ReduceScatter collective matmul, as in Tmp[BX, FY] *F W2[FY, D] -> Out[BX, DY]. This occurs in the down-projection matrix in a Transformer. Implement a collective, overlapped version of this in JAX. Be careful about passing only the minimal amount of data you need. Hint: try permuting the result as you accumulate it.
Put these two together into an end-to-end Transformer block that performs In[BX, DY] *D Win[D, FY] *F Wout[FY, D] -> Out[BX, DY] with overlapped communication.jax.jit
Problem 4: All of the collective matmuls implemented above are unidirectional: they only permute in one direction. Rewrite the collective AllReduce matmul and the collective ReduceScatter matmuls to use bidirectional communication. How much faster are these?