Part 10 of How To Scale Your Model (Part 9: Profiling | Part 11: Conclusions)
How to use JAX to program TPUs efficiently! Much of this section is taken from here. You can run the code examples in this section with free TPUs on Google Colab.
JAX supports three schools of thought for multi-device programming:
Mode | View? | Explicit sharding? | Explicit Collectives? |
---|---|---|---|
Auto | Global | ❌ | ❌ |
Explicit | Global | ✅ | ❌ |
Manual | Per-device | ✅ | ✅ |
Correspondingly, JAX provides APIs for each of these modes:
jax.jit
(with Auto
mesh axes) lets you take any existing JAX function and call it with sharded inputs. JAX then uses XLA’s Shardy compiler which automatically parallelizes the program. XLA will add communication for you (AllGathers, ReduceScatters, AllReduces, etc.) when needed to facilitate existing operations. While it isn’t perfect, it usually does a decent job at automatically scaling your program to any number of chips without code changes.jax.jit
with Explicit
mesh axes looks similar to (1), but lets JAX handle the sharding propagation instead of XLA. That means the sharding of an array is actually part of the JAX type system, and JAX can error out when it detects ambiguous communication and lets the user resolve it.jax.shard_map
is the more manual counterpart. You get a device-local view of the program and have to write any communication you want explicitly. Have a sharded array and want the whole thing on each device? Add a jax.lax.all_gather
. Want to sum an array across your devices? Add a jax.lax.psum
(an AllReduce). Programming is harder but far less likely to do something you don’t want.jax.jit plays two roles inside JAX. As the name suggests, it “just-in-time” compiles a function from Python into bytecode (via XLA/HLO/LLO) so it runs faster. But if the input is sharded or the user specifies an in_sharding
or out_sharding
, it also lets XLA distribute the computation across multiple devices and add communication as needed. For example, here’s how you could write a sharded matmul using jax.jit:
import jax
import jax.numpy as jnp
# Running on an TPU v5e 4x2. This assigns names to the two physical axes of the hardware.
mesh = jax.make_mesh(axis_shapes=(4, 2), axis_names=('X', 'Y'))
# This tells JAX to use this mesh for all operations, so you can just specify the PartitionSpec P.
jax.set_mesh(mesh)
# We create a matrix W and input activations In sharded across our devices.
In = jnp.zeros((8, 2048), dtype=jnp.bfloat16, device=jax.NamedSharding(mesh, jax.P('X', 'Y')))
W = jnp.zeros((2048, 8192), dtype=jnp.bfloat16, device=jax.NamedSharding(mesh, jax.P('Y', None)))
def matmul_square(In, W):
return jnp.einsum('bd,df->bf', jnp.square(In), W)
# We can explicitly compile the sharded matmul function here. This adds all the
# necessary comms (e.g. an AllReduce after the matmul).
jit_matmul = jax.jit(matmul_square, out_shardings=jax.P('X', None)).lower(In, W).compile()
out = jit_matmul(In, W)
This will run automatically with any sharding and partition the computation across our devices. But what’s actually happening at the hardware level?
matmul_square
would simply square the input and perform a simple matmul. But because we specify the out_shardings
as P('X', None)
, the output will be sharded along the batch but replicated across the model dimension and will require an AllReduce to compute.Using our notation from previous sections, this will likely do something like
jax.jit
will add this for us automatically! We can actually print the HLO with jit_matmul.as_text()
and see the following HLO (abbreviated dramatically):
# This fusion is the actual matmul of the sharded inputs and matrix
%fusion = bf16[2,8192]{1,0:T(4,128)(2,1)S(1)} fusion(bf16[2,1024]{1,0:T(4,128)(2,1)} %param, bf16[8192,1024]{1,0:T(8,128)(2,1)S(1)} %copy-done)
# We reduce the partially summed results across devices
ROOT %AllReduce = bf16[2,8192]{1,0:T(4,128)(2,1)} AllReduce(bf16[2,8192]{1,0:T(4,128)(2,1)S(1)} %fusion)
We can see the matmul (the fusion) and the AllReduce above. Pay particular attention to the shapes. bf16[2, 1024]
is a local view of the activations, since our batch_size=8
is split across 4 devices and our d_model=2048
is likewise split 2 ways.
This is pretty magical! No matter how complicated our program is, Shardy and jit will attempt to find shardings for all the intermediate activations and add communication as needed. With that said, Shardy has its flaws. It can make mistakes. Sometimes you’ll look at a profile and notice something has gone wrong. A giant AllGather takes up 80% of the profile, where it doesn’t need to. When this happens, we can try to correct the compiler by explicitly annotating intermediate tensors with jax.lax.with_sharding_constraint
. For instance, with two matmuls I can force the intermediate activations to be sharded along the y
dimension (not that this is a good idea) with the following:
import jax
import jax.numpy as jnp
mesh = jax.make_mesh((4, 2), ('X', 'Y'))
def matmul(x, Win, Wout):
hidden = jnp.einsum('bd,df->bf', x, Win)
hidden = jax.lax.with_sharding_constraint(hidden, jax.P('x', 'y'))
return jnp.einsum('bf,df->bd', hidden, Wout)
This makes up like 60% of JAX parallel programming in the automatic partitioning world where you control the intermediate shardings via jax.lax.with_sharding_constraint
. But “compiler tickling” is famously not a fun programming model. You could annotate every intermediate variable and still not know if you’ll get the right outcome. Instead, what if JAX itself could handle and control sharding propagation?
Explicit sharding (or “sharding in types”) looks a lot like automatic sharding, but sharding propagation happens at the JAX level! Each JAX operation has a sharding rule that takes the shardings of the op’s arguments and produces a sharding for the op’s result. You can see the resulting sharding using jax.typeof
:
import jax
import jax.numpy as jnp
import jax.sharding as shd
# Running on an TPU v5e 2x2. This assigns names to the two physical axes of the hardware.
mesh = jax.make_mesh(axis_shapes=(2, 2), axis_names=('X', 'Y'),
axis_types=(shd.AxisType.Explicit, shd.AxisType.Explicit))
# This tells JAX to use this mesh for all operations, so you can just specify the PartitionSpec P.
jax.set_mesh(mesh)
x = jax.device_put(np.arange(16).reshape(8, 2), P('X', 'Y'))
@jax.jit
def f(x):
print(jax.typeof(x)) # bfloat16[8@X,2@Y]
out = x * 2
print(jax.typeof(out)) # bfloat16[8@X,2@Y]
return out
f(x)
As you can see, JAX propagated the sharding from input (x
) to output (x
) which are inspectable at trace-time via jax.typeof
. For most operations these rules are simple and obvious because there’s only one reasonable choice (e.g. elementwise ops retain the same sharding). But for some operations it’s ambiguous how to shard the result in which case JAX throws a trace-time error and we ask the programmer to provide an out_sharding
argument explicitly (e.g. jnp.einsum, jnp.reshape, etc). Let’s see another example where you have conflicts:
# We create a matrix W and input activations In sharded across our devices.
In = jnp.zeros((8, 2048), dtype=jnp.bfloat16, out_sharding=jax.P('X', 'Y'))
W = jnp.zeros((2048, 8192), dtype=jnp.bfloat16, out_sharding=jax.P('Y', None))
@jax.jit
def matmul_square(In, W):
print(jax.typeof(In)) # bfloat16[8@X, 2048@Y]
print(jax.typeof(W)) # bfloat16[2048@Y, 8192]
return jnp.einsum('bd,df->bf', jnp.square(In), W)
matmul_square(In, W) # This will error
This code errors with Contracting dimensions are sharded and it is ambiguous how the output should be sharded. Please specify the output sharding via the
out_sharding parameter. Got lhs_contracting_spec=('Y',) and rhs_contracting_spec=('Y',)
This is awesome because how the output of einsum should be sharded is ambiguous. The output sharding can be:
Unlike Auto mode, explicit mode errors out when it detects ambiguous communication and requires the users to resolve it. So here you can do:
@jax.jit
def matmul_square(In, W):
return jnp.einsum('bd,df->bf', jnp.square(In), W, out_sharding=P('X', 'Y'))
out = matmul_square(In, W)
print(jax.typeof(out)) # bfloat16[8@X,8192@Y]
Auto mode and Explicit mode can be composed via jax.sharding.auto_axes
and jax.sharding.explicit_axes
APIs. This is a great doc to read for more information.
While Shardy is the “compiler take the wheel” mode, jax shard_map puts everything in your hands. You specify the sharding of the inputs, like in jax.jit, but then you write all communication explicitly. Whereas jax.jit
leaves you with a global cross-device view of the program, shard_map
gives you a local per-device view.
Here’s an example. Try to reason about what this function does:
import jax
import jax.numpy as jnp
import jax.sharding as shd
mesh = jax.make_mesh((2, 4), ('x', 'y'), (shd.AxisType.Explicit, shd.AxisType.Explicit))
jax.set_mesh(mesh)
x = jnp.arange(0, 512, dtype=jnp.int32, out_sharding=P(('x', 'y')))
# This function will operate on 1/8th of the array.
@jax.shard_map(in_specs=P(('x', 'y')), out_specs=P())
def slice_and_average(x):
assert x.shape == (512 // 8,)
return jax.lax.pmean(x[:4], axis_name=('x', 'y'))
out = slice_and_average(x)
assert out.shape == (4,)
What does this do? slice_and_average
is run on each TPU with 1/8th of the array, from which we slice the first 4 elements and average them across the full mesh. This means we’re effectively doing mean(x[:4], x[64:68], x[128:132], …)
. This is pretty cool, because that’s not an easy operation to express in JAX otherwise.
Why do this instead of jax.jit? If we’d used jax.jit
, slice_and_average
would have seen a global view of the array (the full [512,]
array). We’d have had to slice out this non-uniform slice and then perform an average which XLA would have had to interpret correctly. XLA might have added the wrong communication or gotten confused. Here we see the local view and write only the communication we need.
Example [Collective Matmul]: To take a more realistic example, say we to implement model parallelism where the activations are initially model sharded, i.e. A[BX, DY] * W[D, FY] -> Out[BX, FY]. Naively, we would do this by AllGathering A first followed by a local matrix multiplication:
Sadly, this is bad because it doesn’t allow us to overlap the communication with the computation. Overlapping them can be done with a “collective matmul”, as described in Wang et al. 2023. The algorithm is basically as follows:
[B / X, F / Y]
. Simultaneously, permute A so you get the next chunk locally, perform the matmul, and sum the result.We can implement that quite easily with shard_map:
import functools
import jax
import jax.numpy as jnp
import jax.sharding as shd
import numpy as np
mesh = jax.make_mesh(axis_shapes=(2, 4), axis_names=('X', 'Y'),
axis_types=(shd.AxisType.Explicit, shd.AxisType.Explicit))
jax.set_mesh(mesh)
B, D, F = 1024, 2048, 8192
A = jnp.arange(np.prod((B, D))).reshape((B, D))
W = jnp.arange(np.prod((D, F))).reshape((D, F))
A = jax.device_put(A, jax.P('X', 'Y'))
W = jax.device_put(W, jax.P(None, 'Y'))
@functools.partial(jax.jit, out_shardings=jax.P('X', 'Y'))
def matmul(lhs, rhs):
return lhs @ rhs
def collective_matmul_allgather_lhs_contracting(lhs, rhs):
# lhs is the looped operand; rhs is the local operand
axis_size = jax.lax.axis_size('Y') # axis_size = 4 for this example
idx = jax.lax.axis_index('Y')
chunk_size = lhs.shape[1]
assert rhs.shape[0] % chunk_size == 0
def f(i, carrys):
accum, lhs = carrys
rhs_chunk = jax.lax.dynamic_slice_in_dim(rhs, (idx + i) % axis_size * chunk_size, chunk_size)
# Matmul for a chunk
update = lhs @ rhs_chunk
# Circular shift to the left
lhs = jax.lax.ppermute(
lhs,
axis_name='Y',
perm=[(j, (j - 1) % axis_size) for j in range(axis_size)]
)
return accum + update, lhs
accum = jnp.zeros((lhs.shape[0], rhs.shape[1]), dtype=lhs.dtype)
accum = jax.lax.pvary(accum, ('X', 'Y'))
accum, lhs = jax.lax.fori_loop(0, axis_size - 1, f, (accum, lhs), unroll=True)
# Compute the last chunk after the final permute to leave lhs in the state we found it
i = axis_size - 1
rhs_chunk = jax.lax.dynamic_slice_in_dim(rhs, (idx + i) % axis_size * chunk_size, chunk_size)
update = lhs @ rhs_chunk
return accum + update
jit_sharded_f = jax.jit(jax.shard_map(
collective_matmul_allgather_lhs_contracting,
in_specs=(jax.P('X', 'Y'), jax.P(None, 'Y')), out_specs=jax.P('X', 'Y')))
shmapped_out = jit_sharded_f(A, W)
expected_out = matmul(A, W)
np.testing.assert_array_equal(shmapped_out, expected_out)
This is pretty neat! We can benchmark this and see that it’s also a lot faster! Here’s the profile with the default jit matmul which takes 311us with a big blocking AllGather at the beginning:
And here’s the version above that takes 244 us. You can see the profile doesn’t have the AllGather. It’s all useful work! Our FLOPs utilization is also a lot higher.
It’s also worth noting that the matmul time with no sharding on the contracting dimension is 224us, so we’re remarkably close to the unsharded baseline here. This is a good example of the kind of performance engineering you might end up doing to improve TPU utilization. For more shard_map
examples, this note is great.
Now here are a couple of useful worked problems to try and implement using jax.jit
or shard_map
!
Here are some random JAX-related problems. I’ll add some more later. For all of these, you’ll need some number of TPUs in a Colab. You can use a public Colab with TPUv2-8. From now on, we’ll assume you have N devices available.
Problem 1: Let A be an array of activations of shape float32[SX, DY] with X * Y = N
. Do the following:
Write a function in JAX that computes the average within each (X, Y)
shard, i.e. it returns an array of size [X, Y] where arr[i, j]
is the average over shard (i, j)
. Do this with both jax.jit
and shard_map
. Profile each and see how long they took. Was there any communication added? Hint: there shouldn’t be, but sometimes XLA adds it anyway.
Write a function in JAX that returns roll(x, shift, axis=0) - x for some shift within each shard X. I’m not enough of a masochist to make you do this in jax.jit, so just do this with shard_map
.
Part 1: Here is a solution to part 1. Note the fairly complex reshapes we have to do for the jax.jit
solution.
import numpy as np
import jax
import jax.numpy as jnp
P = jax.sharding.PartitionSpec
mesh = jax.make_mesh((4, 2), ('X','Y'))
average_shmap = jax.shard_map(
lambda x: x.mean(keepdims=True),
mesh=mesh,
in_specs=P('X','Y'), out_specs=P('X','Y')
)
def average(x):
X, Y = mesh.axis_sizes
return x.reshape(X, x.shape[0] // X, Y, x.shape[1] // Y).mean(axis=(1, 3))
average_jit = jax.jit(average, out_shardings=jax.NamedSharding(mesh, P('X','Y')))
x = jnp.arange(8 * 64 * 8, dtype=jnp.int32).reshape(8 * 64, 8)
x = jax.device_put(x, jax.NamedSharding(mesh, P('X','Y')))
y1 = average_shmap(x)
y2 = average_jit(x)
np.testing.assert_array_equal(y1, y2)
Part 2: Here is a similar solution to Part 2.
import numpy as np
import jax
import jax.numpy as jnp
import functools
P = jax.sharding.PartitionSpec
mesh = jax.make_mesh((4, 2), ('X','Y'))
def shift_shmap(x, shift: int):
shmapped = jax.shard_map(
lambda x: jnp.roll(x, shift, axis=0),
mesh=mesh,
in_specs=P('X','Y'), out_specs=P('X','Y')
)
return shmapped(x)
@functools.partial(jax.jit, static_argnames=['shift'], out_shardings=jax.NamedSharding(mesh, P('X','Y')))
def shift_jit(x, shift: int):
X, Y = mesh.axis_sizes
reshaped = x.reshape(X, x.shape[0] // X, -1)
return jnp.roll(reshaped, shift, axis=1).reshape(x.shape[0], x.shape[1])
x = jnp.arange(8 * 64 * 8, dtype=jnp.int32).reshape(8 * 64, 8)
x = jax.device_put(x, jax.NamedSharding(mesh, P('X','Y')))
y1 = shift_shmap(x, 5)
y2 = shift_jit(x, 5)
np.testing.assert_array_equal(y1, y2)
Problem 2: Here we’ll make a basic “mixture of experts” model together. Let W: float32[EX, D, FY] be a set of E “expert” matrices. Let A: float32[SX, DY] (our activations) and let B be a set of “routing assignments” where B[i] is an integer in the range [0, E)
telling us which matrix we want to process that activation. We want to write a function in JAX that returns Out[i] = W[B[i]] @ A[i]
.
Let’s start by ignoring sharding altogether. Make all of these tensors small enough so they fit in one device. Write a local implementation of this function. Make sure you don’t materialize an array of shape [S, D, F]
! Hint: try sorting the tokens into a new buffer of shape [E, S, D]
with some attention to masking (why do we need the second dimension to have size S?).
If you just jax.jit
the above method, something will happen. Profile this and see what communication it decided to do. How long does it take?
One problem you’ll notice with the above is that it likely gathers the full set of activations A locally, i.e. AllGatherX([SX, DY]), Not only is this expensive communication-wise, it’s also incredibly expensive memory-wise if we can’t fit the full set of activations locally. Implement the above using shard_map
and explicit communication.
For a first pass, it might be easiest to use a jax.lax.all_gather
and reorder as in (a).
For a second pass, try to avoid materializing any array of size [E, S, D]
, i.e. try to perform the computation in a ragged fashion using a jax.lax.all_to_all
inside a jax.lax.while_loop
. This way, you can avoid materializing the full activations and wasting compute on padding. How much faster is this than your original implementation?
Most MoEs route to multiple (k) experts and then average the result. Refactor the above to implement this. Let B: int32[S, k] in this case for the k experts to route to.
Problem 3: The collective matmul example above is actually super relevant for real LLMs. Let’s tweak the example to do the full Transformer stack.
As an exercise, let’s start by implementing an AllReduce collective matmul, i.e. A[BX, DY] *D W[DY, F] -> Out[BX, F]. Note that the output isn’t replicated. The naive algorithm is discussed above, basically just a local matmul followed by an AllReduce. Try to make a comms overlapped “collective” version of this operation. Hint: tile over the output dimension and feel free to use jax.lax.psum
(aka AllReduce). Note: due to the way XLA handles this, it may not actually be faster than the baseline.
The complement to the AllReduce collective matmul above is a ReduceScatter collective matmul, as in Tmp[BX, FY] *F W2[FY, D] -> Out[BX, DY]. This occurs in the down-projection matrix in a Transformer. Implement a collective, overlapped version of this in JAX. Be careful about passing only the minimal amount of data you need. Hint: try permuting the result as you accumulate it.
Put these two together into an end-to-end Transformer block that performs In[BX, DY] *D Win[D, FY] *F Wout[FY, D] -> Out[BX, DY] with overlapped communication.jax.jit
implementation?
Problem 4: All of the collective matmuls implemented above are unidirectional: they only permute in one direction. Rewrite the collective AllReduce matmul and the collective ReduceScatter matmuls to use bidirectional communication. How much faster are these?