Part 4 of How To Scale Your Model (Part 3: Sharding | Part 5: Training)
Here we'll do a quick review of the Transformer architecture, specifically how to calculate FLOPs, bytes, and other quantities of interest.
Let’s start with vectors \(x\),\(y\) and matrices \(A\),\(B\) of the following shapes:
\[\def \red#1{\textcolor{red}{#1}} \def \green#1{\textcolor{green}{#1}} \def \blue#1{\textcolor{blue}{#1}} \def \purple#1{\textcolor{purple}{#1}} \def \orange#1{\textcolor{orange}{#1}} \def \gray#1{\textcolor{gray}{#1}} \begin{array}{cc} \textrm{array} & \textrm{shape} \\ \hline x & \textrm{[P]} \\ y & \textrm{[P]} \\ A & \textrm{[N P]} \\ B & \textrm{[P M]} \\ \hline \end {array}\]Make note of the fact that for a matrix-matrix multiply, the compute scales cubically \(O(N^3)\) while the data transfer only scales quadratically \(O(N^2)\) - this means that as we scale up our matmul size, it becomes easier to hit the compute-saturated limit. This is extremely unusual, and explains in large part why we use architectures dominated by matrix multiplication - they’re amenable to being scaled!
During training, we don’t particularly care about the result of a given matrix multiply; we really care about its derivative. That means we do significantly more FLOPs during backpropagation.
If we imagine B is just one matrix in a larger network and A are our input activations with C = A B, the derivative of the loss L with respect to B is given by the chain rule:
\[\frac{\partial L}{\partial B} = \frac{\partial L}{\partial C}\frac{\partial C}{\partial B} = A^T \left(\frac{\partial L}{\partial C}\right)\]which is an outer product and requires $2NPM$ FLOPs to compute (since it contracts over the $N$ dimension). Likewise, the derivative of the loss with respect to A is
\[\frac{\partial L}{\partial A} = \frac{\partial L}{\partial C}\frac{\partial C}{\partial A} = \left(\frac{\partial L}{\partial C}\right) B^T\]is again $2NPM$ FLOPs since dL/dC is a (co-)vector of size \([N, M]\). While this quantity isn’t the derivative wrt. a parameter, it’s used to compute derivatives for previous layers of the network (e.g. just as dL/dC is used to compute dL/dB above).
Adding these up, we see that during training, we have a total of 6NPM FLOPs, compared to 2NPM during inference: 2NPM in the forward pass, 4NPM in the backward pass. Since PM is the number of parameters in the matrix, this is the simplest form of the famous \(6 * \text{num parameters} * \text{num tokens}\) approximation of Transformer FLOPs during training: each token requires \(6 * \text{num parameters}\) FLOPs. We’ll show a more correct derivation below.
Transformers are the future. Well, they’re the present at least. Maybe a few years ago, they were one of many architectures. But today, it’s worth knowing pretty much every detail of the architecture. We won’t reintroduce the architecture but this blog and the original Transformer paper may be helpful references.
Here’s a basic diagram of the Transformer decoder architecture:
Note [gating einsum]: Not all LLM use “gating einsums”
Note 2 [MHA attention]: With self-attention, T and S are the same but for cross-attention they may be different. With vanilla Multi-Head Attention (MHA), N and K are the same while for Multi-Query Attention (MQA)
For the below we’re going to compute per-layer FLOPs to avoid having to stick factors of L everywhere.
The MLPs of a Transformer typically consist of 2 input matmuls that are element-wise combined and a single output matmul:
\[\begin{array}{ccc} \textrm{operation} & \textrm{train FLOPs} & \textrm{params} \\ \hline \\ A[B,T,\red{D}] \cdot W_{in1}[\red{D}, F] & 6BTDF & DF \\[10pt] A[B,T,\red{D}] \cdot W_{in2}[\red{D}, F] & 6BTDF & DF \\[10pt] \sigma\left(A_{in1}\right)[B,T, F] * A_{in2}[B,T, F] & \gray{O(BTF)} \\[10pt] A[B,T,\red{F}] \cdot W_{out}[\red{F}, D] & 6BTDF & DF \\[10pt] \hline \\ & \approx 18BTDF & 3DF \end{array}\]For multi-headed attention, let us assume equal head dimension H for Q,K,V projections, and estimate the cost of the QKVO matmuls:
\[\begin{array}{ccc} \textrm{operation} & \textrm{train FLOPs} & \textrm{params} \\ \hline \\ A[B,T,\red{D}] \cdot W_{Q}[\red{D}, N, H] & 6BTDNH & DNH \\[10pt] A[B,T,\red{D}] \cdot W_{K}[\red{D}, N, H] & 6BTDNH & DNH \\[10pt] A[B,T,\red{D}] \cdot W_{V}[\red{D}, N, H] & 6BTDNH & DNH \\[10pt] A[B,T,\red{N}, \red{H}] \cdot W_{O}[\red{N}, \red{H}, D] & 6BTDNH & DNH \\[10pt] \hline \\ & 24BTDNH & 4DNH \end{array}\]The dot-product attention operation is more subtle, effectively being a \(TH \cdot HS\) matmul batched over the \(B\), \(N\) dimensions, a softmax, and a \(TS \cdot SH\) matmul again batched over the \(B\), \(N\) dimensions. We highlight the batched dims in blue:
\[\begin{array}{cc} \textrm{operation} & \textrm{train FLOPs} \\ \hline \\[3pt] Q[\blue{B}, T, \blue{N}, \red{H}] \cdot K[\blue{B}, S, \blue{N}, \red{H}] & 6BTSNH \\[3pt] \textrm{softmax}_S \;\; L[\blue{B}, T, S, \blue{N}] & \gray{O(BTSN)} \\[3pt] S[\blue{B}, T, \red{S}, \blue{N}] \cdot V[\blue{B}, \red{S}, \blue{N}, H] & 6BTSNH \\[3pt] \hline \\ & \approx 12BTSNH = 12BT^2NH \\ \end{array}\]There are several other operations happening in a Transformer. Layernorms are comparatively cheap and can be ignored for first-order cost estimates. There is also the final enormous (though not per-layer) unembedding matrix multiply.
\[\begin{array}{ccc} \textsf{operation} & \textsf{train FLOPs} & \textsf{params} \\ \hline \\ \textrm{layernorm}_D \;\; A[B,T,\red{D}] & \gray{O\left(BTD\right)} & \gray{D} \\[10pt] A[B,T,\red{D}] \cdot W_{unembed}[\red{D}, V] & 6BTDV & DV \\ \end{array}\]If we neglect the cost of dot-product attention for shorter-context training, then the total FLOPs across all layers is
\[\begin{align*} 18BTDF + 24BTDNH = 6 *BT * (3DF + 4DNH) \\ = 6 * \textrm{num tokens} * \textrm{parameter count} \end{align*}\]Leading to a famous rule of thumb for estimating Transformer FLOP count, ignoring the attention FLOPs. (Unembedding is another simple matmul with $6BSEV$ FLOPs and $EV$ params, and follows the same rule of thumb.)
If we do account for dot-product attention above and assume \(F=4D\) and \(D=NH\) (as is typical):
\[\small{\frac{\textrm{attention FLOPs}}{\textrm{matmul FLOPs}} = \frac{12BT^2NH}{18BTDF + 24BTDNH} = \frac{12BT^2D}{4*18 BTD^2 + 24 BTD^2} = \frac{12BT^2D}{96 BTD^2} = \frac{T}{8D}}\]So the takeaway is that dot-product attention FLOPs only become dominant during training once T>8D. For D ~ 8k, this would be ~64K tokens. This makes some sense, since it means as the MLP size increases, the attention FLOPs become less critical. For large models, the quadratic cost of attention is not actually a huge obstacle to longer context training. However, for smaller models, even e.g. Gemma-27B, D=4608 which means attention becomes dominant around 32k sequence lengths. Flash Attention also helps alleviate the cost of long-context, which we discuss briefly in Appendix A.
We’d be remiss not to briefly discuss Mixture of Experts (MoE) models
Compared to a dense model, an MoE introduces new comms, primarily two AllToAlls (one before and one after the MoE block) that route tokens to the correct expert and bring them back to their home device.
Backpropagation as an algorithm trades memory for compute. Instead of a backward pass requiring \(O(n_\text{layers}^2)\) FLOPs, it requires \(O(n_\text{layers})\) memory, saving all intermediate activations generated during the forward pass. While this is better than quadratic compute, it’s incredibly expensive memory-wise: a model with \(B * T=4M\) (4M total tokens per batch), L=64, and D=8192 that avoids all unnecessary backward pass compute would have to save roughly \(2 * 20 * B * T * D * L = 84TB\) of activations in bfloat16. 20 comes from (roughly) counting every intermediate node in the Transformer diagram above, since e.g.
\[f(x) = \exp(g(x))\] \[\frac{df}{dx} = \exp(g(x)) \cdot \frac{dg}{dx}\]so to avoid recomputing we need to save \(g(x)\) and \(\exp(g(x))\) from the forward pass. To avoid saving this much memory, we can choose to only save some fraction of the intermediate activations. Here are a few strategies we use.
This by no means comprehensive. When using JAX, these are typically controlled by jax.remat
/jax.checkpoint
(you can read more here).
As we’ll see in Section 7, LLM inference has two key parts, prefill and generation.
Each KV cache is then effectively an array of size $[2, S, L, N, H]$ where the 2 accounts for the keys and values. This is quite large! The total size of the Key-Value cache in int8 is $2SLNH$. For a moderately-sized model with 8k context length, 64 layers, and $NH = D = 8192$, this is $2 \cdot 8192 \cdot 64 \cdot 8192 = 8\text{GiB}$. We often use fewer KV heads than query heads to reduce this size.
Component | Params per layer | Training FLOPs per layer |
---|---|---|
MLP | 3DF | 18BTDF |
Attention | 4DNH | 24BTDNH + 12BT2NH |
Other | D | BTD |
Vocab | DV (total, not per-layer) | 12BTDV |
Question 1: How many parameters does a model with $D=4096$, $F=4 \cdot D$, $V=32,000$, and $L=64$ have? What fraction of these are attention parameters? How large are our KV caches per token? You can assume $N\cdot H=D$ and multi-head attention with int8 KVs.
512kB / token
.Question 2: How many total FLOPs are required to perform A[BX, DY] *D W[DY, F] on {‘X': 4, ‘Y': 8, ‘Z': 4}
. How many FLOPs are performed by each TPU?
The total “theoretical” FLOPs of the operation is \(2 \cdot B \cdot D \cdot F\). However, because the computation isn’t sharded across the Z dimension, we’re actually doing Z extra FLOPs, meaning \(2 \cdot B \cdot D \cdot F \cdot Z\) total FLOPs. Since the computation is sharded across the other dimensions, the total per-device is roughly \(2 \cdot B \cdot D \cdot F / (X \cdot Y)\).
Question 3: How many FLOPs are involved in performing $A[I,J,K,L] * B[I,J,M,N,O] \rightarrow C[K,L,M,N,O]$?
Following the rule above, we have I and J as contracting dimensions and K, L, M, N, and O as non-contracting dimensions. We have no “batching dimensions”, so this is just \(2 \cdot I \cdot J \cdot K \cdot L \cdot M \cdot N \cdot O\), the sum of all the axes. If we had a shared axis, it would only be counted once.
Question 4: What is the arithmetic intensity of self-attention (ignoring the Q/K/V/O projections)? Give the answer as a function of the Q and KV lengths T and S. At what context length is attention FLOPs-bound? Given the HBM bandwidth of our TPUs, plot the effective relative cost of attention to the FFW block as the context length grows.
Self-attention requires loading the \(Q\), \(K\), and \(V\) activations, then computing \(\text{softmax}(Q \cdot K) \cdot V\), then writing the result back to HBM. This will be done with Flash Attention so there are some caveats to this math, but basically in bf16 self-attention performs
\[\text{Q[B,T,N,H]} \rightarrow_\text{reshape} \text{Q[B, T, K, G, H]} \cdot \text{K[B, S, K, H]} \rightarrow \text{O[B, T, S, K, G]}\] \[U=\text{softmax}_S(\text{O[B, T, S, K, G]})\] \[\text{U[B, T, S, K, G]} \cdot \text{V[B, S, K, H]} \rightarrow \text{X[B, T, K, G, H]}\]So our total bytes is \(2 * \text{sizeof}(Q) + 2 * \text{sizeof(K or V)} = 4BTNH + 4BSKH = 4BHK * (TG + S)\), total FLOPs is \(4BTSNH + O(BTSN)\) and the arithmetic intensity is \(4BTSKGH / (4BHK * (TG + S))\).
So basically, during prefill we have \(S=T\) so we have an arithmetic intensity of \(4BT^2KGH / 4BHKT \cdot (G+1) = TG/(G + 1) = O(T)\). During generation, \(T=1\) so we have \(4BSKGH / (4BHK \cdot (G + S)) = SG / (G + S) \rightarrow G\) assuming \(S\) is very large. Depending on how you interpret the question, during prefill or training self-attention is compute bound at S=240 assuming no sequence sharding. During generation, we are never compute bound because \(G\) is small. Nonetheless, however, you can see that increasing \(G\) leads to us being closer to compute bound.
Question 5: At what sequence length are self-attention FLOPs equal to the QKVO projection FLOPs?
This is purely a question of when \(24BTDNH == 12BT^2NH\). Simplifying we get \(2D = T\), so e.g. for \(D=4096\), this is \(8192\). This tells us that for most reasonable context lengths, matmul FLOPs are greater.
Question 6: Say we only save the output of each of the 7 main matmuls in a Transformer layer during our forward pass (Q, K, V, O + the three FFW matrices). How many extra FLOPs do we need to “rematerialize” during the backwards pass?
The traditional objection to scaling Transformers to very long context is that the attention FLOPs and memory usage scale quadratically with context length. While it’s true that the attention QK product has shape \([B, S, T, H]\) where B is the batch size, S and T are the Q and K sequence dims, and H is the number of heads, this claim comes with some serious caveats:
This second observation was first made by Rabe et al. 2021 and later in the Flash Attention paper (Dao et al. 2022). The basic idea is to compute the attention in chunks of K/V, where we compute the local softmax and some auxiliary statistics, then pass them onto the next chunk which combines them with its local chunk. Specifically, we compute
With these, we can compute the new max, the new running sum, and the new output with only a constant amount of memory. To give a sketchy description of how this works, attention is roughly this operation:
\[\text{Attn}(Q, K, V) = \sum_i \frac{\exp(Q \cdot K_i - \max_j Q \cdot K_j) V_i}{\sum_l \exp(Q \cdot K_l - \max_j Q \cdot K_j)}\]The max is subtracted for numerical stability and can be added without affecting the outcome since \(\sum_i \exp(a_i + b) = \exp(b) \sum \exp(a)\). Looking just at the denominator above, if we imagine having two contiguous chunks of key vectors, \(K^1\) and \(K^2\) and we compute the local softmax sums \(L^1\) and \(L^2\) for each
\[L^1 = \sum_i \exp(Q \cdot K_i^1 - \max_j Q \cdot K_j^1)\] \[L^2 = \sum_i \exp(Q \cdot K_i^2 - \max_j Q \cdot K_j^1)\]Then we can combine these into the full softmax sum for these two chunks together by using
\[L^\text{combined} = \exp(M^1 - \max(M^1, M^2)) \cdot L^1 + \exp(M^2 - \max(M^1, M^2)) \cdot L^2\]where
\[M^1 = \max_j Q \cdot K_j^1 \text{ and } M^2 = \max_j Q \cdot K_j^2\]This can be done for the full softmax as well, giving us a way of accumulating arbitrarily large softmax sums. Here’s the full algorithm from the Flash Attention paper.
From a hardware standpoint, this lets us fit our chunk of Q into VMEM (what the algorithm above calls on-chip SRAM) so we only have to load the KV chunks on each iteration, reducing the arithmetic intensity. We can also keep the running statistics in VMEM.
One last subtle point worth emphasizing is an attention softmax property that’s used to make the Flash VJP (reverse mode derivative) calculation practical for training. If we define an intermediate softmax array as:
\[S_{ij} = \frac{e^{\tau q_i \cdot k_j}}{\sum_k e^{\tau q_i \cdot k_j}}\]In attention, we obtain dS from reverse-mode dO and V arrays:
\[dS_{ij} = dO_{id} \cdot_d V_{jd} = \sum_d dO_{id} V_{jd}\]During the backpropagation of this gradient to Q and K
\[d(q_i \cdot k_j) = (dS_{ij} - S_{ij} \cdot_j dS_{ij}) S_{ij}\]We exploit an identity that allows us to exchange a contraction along the large key length dimension with a local contraction along the feature depth dimension.
\[\begin{align*} S_{ij} \cdot_j dS_{ij} &= \sum_j \frac{e^{\tau q_i \cdot k_j}}{\sum_k e^{\tau q_i \cdot k_k}} \sum_d dO_{id} V_{jd} \\ &= \sum_d dO_{id} \sum_j \frac{e^{\tau q_i \cdot k_j}}{\sum_k e^{\tau q_i \cdot k_k}} V_{jd} \\ &= \sum_d dO_{id} O_{id} \\ &= dO_{id} \cdot_d O_{id} \end{align*}\]This replacement is crucial for being able to implement a sequence-block local calculation for the VJP, and enables further clever sharding schemes like ring attention.