Part 1 of How To Scale Your Model (Part 0: Introduction | Part 2: TPUs)
When we run algorithms on hardware, we're bounded by three things: how fast it can do math (OPs/second), the bandwidth available for moving data around (bytes/second), and the total memory available to store data (bytes). These “roofline” constraints let us upper and lower bound the time of a given computation.
Let’s start with an extremely simple question: why does an algorithm take 50ms instead of 50s or 5ms? What is actually happening within the model that takes substantial time and how long should we expect it to take?
Computation: A deep learning model is effectively a bunch of matrix multiplications, each composed of floating-point multiplication and addition ‘operations’ (FLOPs). Our accelerator speed determines how long these take to compute:
\[\begin{equation} T_\text{math} = \frac{\text{Computation FLOPs}}{\text{Accelerator FLOPs/s}} \end{equation}\]Communication within a chip: Within an accelerator, tensors need to be transferred between on-chip memory (HBM) and the compute cores. You’ll see the bandwidth of this link referred to as ‘HBM bandwidth’.
Communication between chips: When we distribute a model across multiple accelerators, tensors frequently need to be transferred between them. There are often a few options for this on our hardware (ICI, DCN, and PCIe), each with different bandwidths.
Whether the communication is within a chip or between chips, we measure this in GB/s and estimate the total communication time with:
\[\begin{equation} T_\text{comms} = \frac{\text{Communication GB}}{\text{Network/Memory Bandwidth GB/s}} \end{equation}\]Typically (but not always), computation within a single chip can be overlapped with communication within a chip and between chips. This means we can lower-bound training and inference time by using the maximum of computation and communication time. We can also upper-bound with their sum. In practice, we optimize against the maximum as the algebra is simpler and we can usually come close to this bound by overlapping our communication and computation. If we optimize with the maximum in mind then the lower and upper bounds differ by at most a factor of 2 since $T_\text{math} + T_\text{comms} \leq 2 * \max(T_\text{math}, T_\text{comms})$. We then increase accuracy beyond this by modeling ‘overlap regions’ and overheads, which can be informed by profiling your specific model and target system.
\[\begin{equation} T_\text{lower}=\max(T_\text{math}, T_\text{comms}) \end{equation}\] \[\begin{equation} T_\text{upper} = T_\text{math} + T_\text{comms} \end{equation}\]If we assume we can perfectly overlap communication and computation, when $T_\text{math} > T_\text{comms}$, we see full utilization from our hardware. We call this being “compute-bound”. When $T_\text{comms} > T_\text{math}$, we tend to be “communication-bound” and at least some fraction of our accelerator FLOPs/s is wasted waiting for data to be passed around. One way to tell if an operation will be compute or communication-bound is to look at its “arithmetic intensity” or “operational intensity”.
Definition: the arithmetic intensity of an algorithm is given by the ratio of the total FLOPs it performs to the number of bytes it needs to communicate — either within a chip or between chips.
\[\begin{equation} \text{Arithmetic Intensity} = \frac{\text{Computation FLOPs}}{\text{Communication GB}} \end{equation}\]Arithmetic intensity measures the “FLOPs per byte” of a given operation. To a first order, when our arithmetic intensity is high, $T_\text{math}$ is large compared to $T_\text{comms}$ and we typically use most of the available FLOPs. When the opposite is true, we spent more time on comms and waste FLOPs. The point where this crossover happens is the “peak arithmetic intensity” of our hardware, the ratio of peak accelerator FLOPs/s to accelerator bandwidth.
\[\begin{align*} T_\text{math} > T_\text{comms} \Leftrightarrow \frac{\text{Algorithm FLOPs}} {\text{Accelerator FLOPs/s}} > \frac{\text{Communication GB}}{\text{Network Bandwidth GB/s}} & \\[0.5em] \Leftrightarrow \frac{\text{Computation FLOPs}}{\text{Communication GB}} > \frac{\text{Accelerator FLOPs/s}}{\text{Network Bandwidth GB/s}} & \\[0.5em] \Leftrightarrow \text{Intensity}(\text{Algorithm}) > \text{Intensity}(\text{Accelerator}) & \\ \end{align*}\]The quantity $\text{Intensity}(\text{Accelerator})$ is the arithmetic intensity at which our accelerator achieves its peak FLOPs/s. For the TPU v5e MXU, this is about 240 FLOPs/byte1.97e14
FLOPs/s and load 8.2e11
bytes/s from HBM. That means if an algorithm has a lower arithmetic intensity than 240
Example (dot product): to compute the dot product of two vectors in bfloat16 precisionx • y: bf16[N], bf16[N] → bf16[1]
, we need to load $x$ and $y$ from memory, each of which has $2 * N = 2N$ bytes, perform $N$ multiplications and $N-1$ additions, and write $2$ bytes back into HBM \(\begin{equation} \text{Intensity}(\text{dot product}) = \frac{\text{Total FLOPs}}{\text{Total Bytes}} = \frac{N + N - 1}{2N + 2N + 2} = \frac{2N - 1}{4N + 2} \rightarrow \frac{1}{2} \end{equation}\)
as $N\rightarrow\infty$. So the dot product has an arithmetic intensity of $\frac{1}{2}$ or, put another way, the dot product does 0.5 floating point operations per byte loaded. This means our arithmetic intensity is lower than that of our hardware and we will be communication-bound.
We can visualize the tradeoff between memory and compute using a roofline plot, which plots the peak achievable FLOPs/s (throughput) of an algorithm on our hardware (the y-axis) against the arithmetic intensity of that algorithm (the x-axis). Here’s a fake diagram:
Above, as the intensity increases (moving left to right), we initially see a linear increase in the performance of our algorithm (in FLOPs/s) until we hit the critical arithmetic intensity of the hardware, 240 in the case of the TPU v5e. Any algorithm with a lower intensity will be bandwidth (BW) bound and limited by the peak memory bandwidth (shown in red). Any algorithm to the right will fully utilize our FLOPs (shown in green). Here, Algo 1 is comms-bound and uses only a fraction of the total hardware FLOPs/s. Algo 2 is compute-bound. We can generally improve the performance of an algorithm either by increasing its arithmetic intensity or by increasing the memory bandwidth available (moving from BW1 to BW2).
Let’s look at our soon-to-be favorite algorithm: matrix multiplication (aka matmul). We write \(A * B \rightarrow C\) where \(A\) has shape $\text{bf16}[B, D]$, $B$ has shape $\text{bf16}[D, F]$, and $C$ has shape \(\text{bf16}[B, F]\). To do the matmul we need to load $2DF + 2BD$ bytes, perform $2BDF$ FLOPs, and write $2BF$ bytes back.
We can get a nice simplification if we assume our local “batch size” \(B\) is small relative to \(D\) and \(F\). Then we get
\[\begin{equation} \frac{BDF}{BD + DF + BF} \approxeq \frac{BDF}{DF} = B \end{equation}\] \[\begin{equation} \text{Intensity}(\text{matmul}) > \text{Intensity}(\text{TPU}) \implies B > \frac{1.97e14}{8.20e11} = 240 \end{equation}\]This is a reasonable assumption for Transformer matmuls since for most of our models we have our local batch size in tokens \(B < 1024\) but $D$ and $F > 8000$. Thus we become compute-bound when our local batch size is greater than 240 tokens, a very simple rule!
Takeaway: for a bfloat16 matmul to be compute-bound on most TPUs, we need our local batch size in tokens to be greater than 240.
This comes with a few notable caveats we’ll explore in the problems below, particularly with respect to quantization (e.g., if we quantize our activations but still do full-precision FLOPs), but it’s a good rule to remember. For GPUs, this number is slightly higher (closer to 500), but the same conclusion generally holds. We’ll discuss the lower-level GPU and TPU details in the next section.
All the rooflines we’ve discussed so far have been memory-bandwidth rooflines, all within a single chip. This shouldn’t be taken as a rule. In fact, most of the rooflines we’ll care about in this book involve communication between chips: usually matrix multiplications that involve matrices sharded across multiple TPUs.
To pick a somewhat contrived example, say we want to multiply two big matrices $X\sim \text{bfloat16[B, D]}$ and $Y \sim \text{bfloat16[D, F]}$ which are split evenly across 2 TPUs/GPUs (along the $D$ dimension). To do this multiplication (as we’ll see in Section 3), we can multiply half of each matrix on each TPU (e.g. X[:, :D // 2] @ Y[:D // 2, :]
) and then copy the resulting “partial sums” to the other TPU and add them together. Say we can copy 5e10
bytes in each direction and perform 1.97e14
FLOPs/s on each chip. What are $T_\text{math}$ and $T_\text{comms}$?
$T_\text{math}$ is clearly half of what it was before, since each TPU is doing half the work, i.e.
Now what about $T_\text{comms}$? This now refers to the communication time between chips! This is just the total bytes sent divided by the network bandwidth, i.e.
\[T_\text{comms} = \frac{2BF}{\text{Network Bandwidth}} = \frac{2BF}{5e10}\]Therefore we become compute-bound (now with respect to the inter-chip network) when \(\text{Intensity}(\text{matmul (2-chips)}) > \text{Intensity}(\text{TPU w.r.t. inter-chip network})\) or equivalently when $\frac{BDF}{2BF} = \frac{D}{2} > \frac{1.97e14}{5e10} = 3940$ or $D > 7880$. Note that, unlike before, the critical threshhold now depends on $D$ and not $B$! Try to think why that is. This is just one such example, but we highlight that this kind of roofline is critical to knowing when we can parallelize an operation across multiple TPUs.
Problem 1 [int8 matmul]: Say we want to do $\text{int8[B, D]} *_D \text{int8[D, F]} \rightarrow \text{int8[B, F]}$ (an int8 matmul with some “batch size” $B$).
Throughout you can assume our HBM bandwidth is 8.1e11
bytes/s and our int8 peak OPs/s is 3.94e14
.
3.94e14 / 8.1e11 = 486
, so the rule is $B > 486 / 2 = 243$. Note that this is basically unchanged!Problem 2 [int8 + bf16 matmul]: In practice we sometimes do different weight vs. activation quantization, so we might quantize weights in int8 but keep activations in bfloat16 (and consequently perform the matmul in bfloat16). At what batch size do we become compute bound? As above, assume 1.97e14
bfloat16 FLOPs/s.
Again assuming B is small, we have 2BDF bfloat16 FLOPs but only DF weights. This means we become compute-bound when \(2B > 240\) or \(B > 120\). This is a lot lower, meaning if we can do int8 weight quantization (which is fairly easy to do) but still do bfloat16 FLOPs, we get a meaningful win in efficiency (although int8 OPs would be better).
Problem 3: For the problem above, make a roofline plot of peak FLOPs vs. B for several values of D and F.
Problem 4: What if we wanted to perform $\text{int8[B, D]} *_D \text{int8[B, D, F]} \rightarrow \text{int8[B, F]}$ where we imagine having a different matrix for each batch element. What is the arithmetic intensity of this operation?
Let’s start by looking at the total FLOPs and comms.