Part 3 of How To Scale Your Model (Part 2: TPUs | Part 4: Transformer Math)
Here we'll explain how sharding works, how TPUs communicate with each other (emphasizing 4 core communication primitives) and how communication is performed by our hardware.
When we train an LLM on ten thousand TPUs, we’re still doing abstractly the same computation as when we’re training on one. The difference is that our arrays don’t fit in the HBM of a single TPU, so we have to split them up.
Here’s an example 2D array A sharded across 4 TPUs:
Note how the sharded array still has the same global or logical shape as unsharded array, say (4, 128)
, but it also has a device local shape, like (2, 64)
, which gives us the actual size in bytes that each TPU is holding (in the figure above, each TPU holds ¼ of the total array). Now we’ll generalize this to arbitrary arrays.
We use a variant of named-axis notation to describe how the tensor is sharded in blocks across the devices: we assume the existence of a 2D or 3D grid of devices called the device mesh where each axis has been given mesh axis names e.g. X, Y, and Z. We can then specify how the matrix data is laid out across the device mesh by describing how each named dimension of the array is partitioned across the physical mesh axes. We call this assignment a sharding.
Example (the diagram above): For the above diagram, we have:
Mesh(devices=((0, 1), (2, 3)), axis_names=(‘X', ‘Y'))
, which tells us we have 4 TPUs in a 2x2 grid, with axis names Taken together, we know that the local shape of the array (the size of the shard that an individual device holds) is
Example (2D sharding across 1 axis):
Visualizing these shardings: Let’s try to visualize these shardings by looking at a 2D array of data split over 4 devices:
We write the fully-replicated form of the matrix simply as
When we wish to indicate that one of these dimensions has been partitioned across a mesh axis, then we indicate so using a mesh-axis subscript. For instance
We illustrate the other possibilities in the figure below:
Here
Lastly, note that we cannot have multiple named axes sharded along the same mesh dimension. e.g.
Pop Quiz: Let A be an array with shape int8[128, 2048]
, sharding Mesh({‘X': 2, ‘Y': 8, ‘Z': 2})
(so 32 devices total). How much memory does A use per device? How much total memory does A use across all devices?
Answer: Our array A is sharded over X and and Y and replicated over Z, so per device it has shape int8[128 / (2 * 8), 2048] = int8[8, 2048]
, with size 8 * 2048 = 16,384
bytes. Because it’s replicated over Z, while within a Z-plane it’s fully sharded over X and Y, there’s one copy of it per Z-plane, and 2 such planes, so the total size (across all devices) is 128 * 2048 * 2 = 512kiB
total.
JAX uses a named sharding syntax that very closely matches the abstract syntax we describe above. We’ll talk more about this in Section 10, but here’s a quick preview. You can play with this in a Google Colab here and profile the result to see how JAX handles different shardings. This snippet does 3 things:
import jax
import jax.numpy as jnp
import jax.sharding as shd
# Create our mesh! We're running on a TPU v2-8 4x2 slice with names 'X' and 'Y'.
assert len(jax.devices()) == 8
mesh = jax.make_mesh(axis_shapes=(4, 2), axis_names=('X', 'Y'))
# A little utility function to help define our sharding. A PartitionSpec is our
# sharding (a mapping from axes to names).
def P(*args):
return shd.NamedSharding(mesh, shd.PartitionSpec(*args))
# We shard both A and B over the non-contracting dimension and A over the contracting dim.
A = jnp.zeros((8, 2048), dtype=jnp.bfloat16, device=P('X', 'Y'))
B = jnp.zeros((2048, 8192), dtype=jnp.bfloat16, device=P(None, 'Y'))
# We can perform a matmul on these sharded arrays! out_shardings tells us how we want
# the output to be sharded. JAX/XLA handles the rest of the sharding for us.
compiled = jax.jit(lambda A, B: jnp.einsum('BD,DF->BF', A, B), out_shardings=P('X', 'Y')).lower(A, B).compile()
y = compiled(A, B)
The cool thing about JAX is that these arrays behave as if they’re unsharded! B.shape
will tell us the global or logical shape (2048, 8192). We have to actually look at B.addressable_shards
to see how it’s locally sharded. We can perform operations on these arrays and JAX will attempt to figure out how to broadcast or reshape them to perform the operations. For instance, in the above example, the local shape of A is [2, 1024]
and for B is [2048, 4096]
. JAX/XLA will automatically add communication across these arrays as necessary to perform the final multiplication.
If you have an array of data that’s distributed across many devices and wish to perform mathematical operations on it, what are the overheads associated with sharding both the data and the computation?
Obviously, this depends on the computation involved.
The rest of this section will deal with how to multiply sharded matrices. To a first approximation, this involves moving chunks of a matrix around so you can fully multiply or sum each chunk. Each sharding will involve different communication. For example,
First let’s recall the concept of a “block matrix”, or a nested matrix of matrices:
Matrix multiplication has the nice property that when the matrix multiplicands are written in terms of blocks, the product can be written in terms of block matmuls following the standard rule:
What this means is that implementing distributed matrix multiplications reduces down to moving these sharded blocks over the network, performing local matrix multiplications on the blocks, and summing their results. The question then is what communication to add, and how expensive it is.
Conveniently, we can boil down all possible shardings into roughly 4 cases we need to consider, each of which has a rule for what communication we need to add
You can think of these as rules that simply need to be followed, but it’s also valuable to understand why these rules hold and how expensive they are. We’ll go through each one of these in detail now.
Lemma: when multiplying partitioned tensors, the computation is valid and the output follows the sharding of the inputs unless the contracting dimension is sharded or both tensors have a non-contracting dimension sharded along the same axis. For example, this works fine
with no communication whatsoever, and results in a tensor sharded across both the X and Y hardware dimensions. Try to think about why this is. Basically, the computation is independent of the sharding, since each batch entry has some local chunk of the axis being contracted that it can multiply and reduce. Any of these cases work fine and follow this rule:
Because neither A nor B has a sharded contracting dimension J, we can simply perform the local block matrix multiplies of the inputs and the results will already be sharded according to the desired output shardings. When both multiplicands have non-contracting dimensions sharded along the same axis, this is no longer true (see the invalid shardings section for details).
Let us consider the simple case of the distributed matrix multiply of A sharded in the contracting J dimension against a fully replicated B:
We cannot simply perform local matrix multiplies of the local A, B blocks against one another as we’re missing the full data from the contracting axis of A. Typically, we first “AllGather” the shards of A together locally, and only then multiply against B:
AllGathers remove sharding along an axis and reassembles the shards spread across devices onto each device along that axis. Using the notation above, an AllGather removes a subscript from a set of axes, e.g.
We also don’t have to remove all subscripts for a given dimension, e.g.
Note that we may also wish to use an AllGather to remove non-contracting dimension sharding, for instance the matrix multiply:
We would similarly AllGather along X to remove the output sharding, however in this case we have the freedom of doing so before or after the matrix multiply, unlike in the case of AllGathering the contracting dimension, where we are forced to do so before performing the matrix multiply.
How is an AllGather actually performed? To perform an AllGather along a single axis, we need to pass all the shards around the axis until every device has a copy. Figure 1 shows an example. Each of the 8 devices starts with 1 / 8th of the array and ends up with all copies. One efficient way to do this is to have each device pass its shard around the sharding dimension ring, either in one direction or both directions. If we do one direction, it takes
How long does this take? Let’s take the bidirectional AllGather and calculate how long it takes. Let
where
Note that this doesn’t depend on
Takeaway: when performing an AllGather (or a ReduceScatter or AllReduce) in a throughput-bound regime, the actual communication time depends only on the size of the array and the available bandwidth, not the number of devices over which our array is sharded!
A note on ICI latency: Each hop over an ICI link has some intrinsic overhead regardless of the data volume. This is typically around 1us. This means when our array
Let
since we perform 4.5e10
unidirectional ICI bandwidth, sending any buffer under 4.5e10 * 1e-6 = 45kB
will be latency bound.
What happens when we AllGather over multiple axes? When we gather over multiple axes, we have multiple dimensions of ICI over which to perform the gather. For instance, AllGatherXY([B, DXY]) operates over two hardware mesh axes. This increases the available bandwidth by a factor of
In general we have
where
Pop Quiz 2 [AllGather time]: Using the numbers from Part 2, how long does it take to perform the AllGatherY([EY, F]) → [E, F] on a TPUv5e with a 2D mesh {'X': 8, 'Y': 4}
,
Answer: Let’s start by calculating some basic quantities:
1) TPU v5e has 4.5e10 bytes/s of unidirectional ICI bandwidth for each of its 2 axes.
2) In bfloat16 for (a), we have
For part (1), we can use the formula above. Since we’re performing the AllGather over one axis, we have 64 * 256 * 2 = 32kB. 32e3 / 4.5e10 = 0.7us
, so we’re latency bound. Since we have 3 hops, this will take roughly 3 * 1us = 3us. In practice, it’s closer to 8us.
The third fundamental case is when both multiplicands are sharded on their contracting dimensions, along the same mesh axis:
In this case the local sharded block matrix multiplies are at least possible to perform, since they will share the same sets of contracting indices. But each product will only represent a partial sum of the full desired product, and each device along the X dimension will be left with different partial sums of this final desired product. This is so common that we extend our notation to explicitly mark this condition:
The notation { UX } reads “unreduced along X mesh axis” and refers to this status of the operation being “incomplete” in a sense, in that it will only be finished pending a final sum. The
This can be seen as the following result about matrix multiplications and outer products:
where ⊗ is the outer product. Thus, if TPU i on axis X has the ith column of A, and the ith row of B, we can do a local matrix multiplication to obtain
We can perform this summation using a full AllReduce across the X axis to remedy this:
AllReduce removes partial sums, resulting in each device along the axis having the same fully-summed value. AllReduce is the second of several key communications we’ll discuss in this section, the first being the AllGather, and the others being ReduceScatter and AllToAll. An AllReduce takes an array with an unreduced (partially summed) axis and performs the sum by passing those shards around the unreduced axis and accumulating the result. The signature is
This means it simply removes the
How expensive is an AllReduce? One mental model for how an AllReduce is performed is that every device sends its shard to its neighbors, and sums up all the shards that it receives. Clearly, this is more expensive than an AllGather because each “shard” has the same shape as the full array. Generally, an AllReduce is twice as expensive as an AllGather. One way to see this is to note that an AllReduce can be expressed as a composition of two other primitives: a ReduceScatter and an AllGather. Like an AllReduce, a ReduceScatter resolves partial sums on an array but results in an output ‘scattered’ or partitioned along a given dimension. AllGather collects all those pieces and ‘unpartitions/unshards/replicates’ the logical axis along that physical axis.
What about a ReduceScatter? Just as the AllReduce removes a subscript (
The communication time for each hop is simply the per-shard bytes
where
Each mesh dimension can appear at most once when sharding a tensor. Performing the above rules can sometimes lead to a situation where this rule is violated, such as:
This is invalid because a given shard, say i, along dimension X, would have the (i, i)th shard of C, that is, a diagonal entry. There is not enough information among all shards, then, to recover anything but the diagonal entries of the result, so we cannot allow this sharding.
The way to resolve this is to AllGather some of the dimensions. Here we have two choices:
or
In either case, the result will only mention X once in its shape. Which one we pick will be based on what sharding the following operations need.
The previous 4 cases have introduced several “core communication primitives” used to perform sharded matrix multiplications:
There’s one more core communication primitive to mention that arises in the case of Mixture of Experts (MoE) models and other computations: the AllToAll.
A final fundamental collective which does not occur naturally when considering sharded matrix multiplies, but which comes up constantly in practice, is the AllToAll collective, or more precisely the special case of a sharded transposition or resharding operation. e.g.
AllToAlls are typically required to rearrange sharded layouts between different regions of a sharded computation that don’t have compatible layout schemes. They arise naturally when considering sharded mixture-of-experts models. You can think of an AllToAll as moving a subscript from one axis to another. Because an all to all doesn’t need to replicate all of the data of each shard across the ring, it’s actually cheaper than an allgather (by a factor of ¼).
ReduceScatter is a more fundamental operation than it first appears, as it is actually the derivative of an AllGather, and vice versa. i.e. if in the forward pass we have:
Then we ReduceScatter the reverse-mode derivatives A’ (which will in general be different on each shard) to derive the sharded A’:
Likewise,
Turning an AllReduce into an AllGather and ReduceScatter also has the convenient property that we can defer the final AllGather until some later moment. Very commonly we’d rather not pay the cost of reassembling the full matrix product replicated across the devices. Rather we’d like to preserve a sharded state even in this case of combining two multiplicands with sharded contracting dimensions:
In this case, we can also perform a ReduceScatter instead of an AllReduce, and then optionally perform the AllGather at some later time, i.e.
Note that ReduceScatter introduces a sharded dimension, and so has a natural freedom to shard along either the I or K named dimensions in this case. We generally need to choose which named dimension to introduce a new sharding to when using a ReduceScatter (though the choice is usually forced by the larger modeling context). This is why we use the syntax ReduceScatterX,K to specify the axis to shard.
Arithmetic with sharded arrays works exactly like with unsharded arrays unless you perform a contraction along a sharded axis. In that case, we have to introduce some communication. We consider four cases:
Operation | Description | Syntax | Runtime |
---|---|---|---|
AllGather | Gathers all the shards of a sharded array along an axis, removing a subscript. | bytes / (bidirectional ICI bandwidth * num_axes) | |
ReduceScatter | Sums a partially summed array along an axis and shards it along another axis (adding a subscript). | Same as AllGather | |
AllReduce | Sums a partially summed array along an axis. Removes a { Ux }. Combines an AllGather and ReduceScatter. | 2 * AllGather | |
AllToAll | Gathers (replicates) an axis and shards a different dimension along the same axis. | AllGather / 4 for a bidirectional ring |
Here are some instructive problems based on content in this section. We won’t include all answers at the moment but we’ll write up more answers as we can.
Question 1 [replicated sharding]: An array is sharded Mesh({'X': 4, 'Y': 8, 'Z': 2})
. What is the ratio of the total number of bytes taken up by
Our array is only sharded along X, which has size 4, so effectively each shard has size
Question 2 [AllGather latency]: How long should Mesh({'X': 4, 'Y': 4, 'Z': 4})
if
We have a wraparound link on all axes because we have a full 4x4x4
cube, so we have 9e10 bidirectional bandwidth to work with.
Because we’re just gathering over one axis and the other is sharded, we’re effectively gathering
We have twice the bandwidth as before but we’re AllGathering the full array, so T = 2BD / (2 * W) = 2*1024*4096 / (2 * 9e10) = 46us
. This is far from the latency bound of 4us (1us per hop), so we’re fine.
The cost of an AllReduce is twice that of an AllGather. Each shard has size 4 * 1024 * 4096 / (16 * 9e10) = 11.6us
.
Question 3 [latency-bound AllGather]: Let’s say we’re performing an Mesh({'X': 4, 'Y': 4, 'Z': 4})
in bfloat16? Hint: you’re probably latency bound.
Our array in bfloat16 uses only 256 bytes total, and only 64 per device. Since we have an axis of size 4 on a TPU v4p, we have a wraparound link, so we can send the array in both directions. With 4.5e10
of unidirectional bandwidth, each hop would take roughly 64 / 4.5e10 ~ 0
, so we’re definitely latency bound. Counting the number of hops, we can do the full gather in only 2 hops, so roughly 2us a good estimate.
Question 4 [matmul strategies]: To perform
Let’s start with our baseline (Strategy 1). As we’ve shown, the cost of the AllGather is
By comparison, the new strategy (Strategy 2) does an AllReduce over
The question is: which of these is bigger? Strategy (2) is compute bound when
So if
which is true when
This is true when
Why don’t we always do this? Well, in practice we may do this sometimes, but it’s typically rare to have the contracting dimension of one of the inputs to a matmul sharded along a axis that the other input isn’t sharded over. For instance, if we’re doing FSDP (explained in Section 5), we’ll shard our parameters over the data dimension but our activations will also be sharded along data. So in this sense this doesn’t show up much.
Question 5 [minimum latency]: Let’s say I want to do a matmul
Question 6: Let’s say we want to perform
Question 7: A typical Transformer block has two matrices
Question 8 [challenge]: Using the short code snippet above as a template, allocate a sharded array and benchmark each of the 4 main communication primitives (AllGather, AllReduce, ReduceScatter, and AllToAll) using pmap or shard_map. You will want to use jax.lax.all_gather
, jax.lax.psum
, jax.lax.psum_scatter
, and jax.lax.all_to_all
. Do you understand the semantics of these functions? How long do they take?
Question 9 [another strategy for sharded matmuls?]: Above we claimed that when only one input to a matmul is sharded along its contracting dimension, we should AllGather the sharded matrix and perform the resulting contracting locally. Another strategy you might think of is to perform the sharded matmul and then AllReduce the result (as if both inputs were sharded along the contracting dimension), i.e.
Answer the following:
M/K
.Question 10: Fun with AllToAll: In the table above, it was noted that the time to perform an AllToAll is a factor of 4 lower than the time to perform an AllGather or ReduceScatter (in the regime where we are throughput-bound). In this problem we will see where that factor of 4 comes from, and also see how this factor would change if we only had single-direction ICI links, rather than bidirectional ICI links.
(1) Solution: The process is simple: in each step of the algorithm, each device will send a single-shard “strip” of the matrix (totalling
Answer:
(2) Solution: The key difference between an AllToAll and an AllGather, from the perspective of communications, is that in an AllToAll, the entirety of the shard that lives on a particular device does not need to be communicated to every other device. Imagine the shard stored on a particular device (call it device 0) is
Answer:
(3) Solution: The factor is simply
(4) Solution: The total number of scalars that any one link has to carry now reduces by a factor of 2, since in a bidirectional ring, each “sharded strip” can be sent two ways simultaneously.
(5) Solution: In this case, we win a factor of 4 compared to the unidirectional case. This is easiest to see by considering the fate of each of the size-(N2/D2) blocks in a single sharded strip, say the one which originates on device 0. Instead of (as in the unidirectional case) sending one of these blocks a distance of D-1, another block a distance D - 2, etc. all the way to 1, we now divide the strip into blocks which move right or left, moving a maximum distance of ceil(D/2). So the corresponding sum now becomes
(6) Solution: In a unidirectional ring, we saw that the AllToAll time was already twice as fast as the all-gather time; this comes from the fact that we don’t need to send our full strip to every single device. Then, when we added bidirectionality, we saw that it was a 4x win for AllToAll, and only a 2x win for all-gathers. Putting these ratios together, we get our sought after factor of 4.