How to Think About TPUs

Part 2 of How To Scale Your Model (Part 1: Rooflines | Part 3: Sharding)

All about how TPUs work, how they're networked together to enable multi-chip training and inference, and how they limit the performance of our favorite algorithms. While this may seem a little dry, it's super important for actually making models efficient.

What Is a TPU?

A TPU is basically a compute core that specializes in matrix multiplication (called a TensorCore) attached to a stack of fast memory (called high-bandwidth memory or HBM). Here’s a diagram:

Figure: the basic components of a TPU chip. The TensorCore is the gray left-hand box, containing the matrix-multiply unit (MXU), vector unit (VPU), and vector memory (VMEM).

You can think of the TensorCore as basically just being a really good matrix multiplication machine, but it has a few other functions worth noting. The TensorCore has three key units:

TPUs are very, very fast at matrix multiplication. It’s mainly what they do and they do it well. TPU v5p, one of the most powerful TPUs to date, can do 2.5e14 bf16 FLOPs / second / core or 5e14 bf16 FLOPs / sec / chip. A single pod of 8960 chips can do 4 exaflops / second. That’s a lot. That’s one of the most powerful supercomputers in the world. And Google has a lot of them.TPUs, and their systolic arrays in particular, are such powerful hardware accelerators because matrix multiplication is one of the few algorithms that uses $O(n^3)$ compute for $O(n^2)$ bytes. That makes it very easy for an ordinary ALU to be bottlenecked by compute and not by memory bandwidth.

The diagram above also includes a few other components like SMEM and the scalar unit, which are used for control flow handling and are discussed briefly in Appendix C, but aren’t crucial to understand. On the other hand, HBM is important and fairly simple:

Generally, all TPU operations are pipelined and overlapped. To perform a matmul $X \cdot A \to Y$, a TPU would first need to copy chunks of matrices $A$ and $X$ from HBM into VMEM, then load them into the MXU which multiplies chunks of 8x128 (for $X$) and 128x128 (for $A$), then copy the result chunk by chunk back to HBM. To do this efficiently, the matmul is pipelined so the copies to/from VMEM are overlapped with the MXU work. This allows the MXU to continue working instead of waiting on memory transfers, keeping matmuls compute-bound, not memory-bound.

Here’s an example of how you might perform an elementwise product from HBM:

Figure: an animation showing a pointwise product performed on TPU, with bytes loaded from HBM. Note how bytes are streamed out of memory in chunks and partial results are pipelined back without waiting for the full array to be materialized.

A matmul would look nearly identical except it would load into the MXU instead of the VPU/Vector unit, and the loads and stores would occur in a different order, since the same weight chunk is used for multiple chunks of activations. You can see chunks of data streaming into VMEM, then into the VREGs (vector registers), then into the Vector Unit, then back into VMEM and HBM. As we’re about to see, if the load from HBM to VMEM is slower than the FLOPs in the Vector Unit (or MXU), we become “bandwidth bound” since we’re starving the VPU or MXU of work.

Key takeaway: TPUs are very simple. They load weights from HBM into VMEM, then from VMEM into a systolic array which can perform around 200 trillion multiply-adds per second. The HBM $\leftrightarrow$ VMEM and VMEM $\leftrightarrow$ systolic array bandwidths set fundamental limits on what computations TPUs can do efficiently.

VMEM and arithmetic intensity: VMEM is much smaller than HBM but it has a much higher bandwidth to the MXU. As we saw in Section 1, this means if an algorithm can fit all its inputs/outputs in VMEM, it’s much less likely to hit communication bottlenecks. This is particularly helpful when a computation has poor arithmetic intensity: VMEM bandwidth is around 22x higher than HBM bandwidth which means an MXU operation reading from/writing to VMEM requires an arithmetic intensity of only 10-20 to achieve peak FLOPs utilization. That means if we can fit our weights into VMEM instead of HBM, our matrix multiplications can be FLOPs bound at much smaller batch sizes. And it means algorithms that fundamentally have a lower arithmetic intensity can still be efficient. VMEM is just so small this is often a challenge.We sometimes talk about VMEM prefetching, which refers to loading weights ahead of time in VMEM so we can mask the cost of loading for our matmuls. For instance, in a normal Transformer we can sometimes load our big feed-forward weights into VMEM during attention, which can hide the cost of the weight load if we're memory bandwidth bound. This requires our weights to be small enough or sharded enough to fit a single layer into VMEM with space to spare.

A TPU chip typically (but not always) consists of two TPU cores which share memory and can be thought of as one large accelerator with twice the FLOPs. This has been true since TPU v4 (known as “megacore”). On older TPU chips they have separate memory and are regarded as two separate accelerators (TPU v3 and older). Inference-optimized chips like the TPU v5e only have one TPU core per chip.

Chips are arranged in sets of 4 on a ‘tray’ connected to a CPU host via PCIe network. This is the format most readers will be familiar with, 4 chips (8 cores, though usually treated as 4 logical megacores) exposed through Colab or a single TPU-VM. For inference chips like the TPU v5e, we have 2 trays per host, instead of 1, but also only 1 core per chip, giving us 8 chips = 8 cores.On Cloud TPU VMs, each tray is exposed as part of a separate VM, so there are once again 4 cores visible.

PCIe bandwidth is limited: Like the HBM $\leftrightarrow$ VMEM link, the CPU $\leftrightarrow$ HBM PCIe connection has a specific bandwidth that limits how quickly you can load from host memory to HBM or vice-versa. PCIe bandwidth for TPU v4 is 16GB / second each way, for example, so close to 100x slower than HBM. We can load/offload data into the host (CPU) RAM, but not very quickly.

TPU Networking

Chips are connected to each other through the ICI network in a Pod. In older generations (TPU v2 and TPU v3), inference chips (e.g., TPU v5e), and Trilium (TPU v6e), ICI (“inter-chip interconnects”) connects the 4 nearest neighbors (with edge links to form a 2D torus). TPU v4 and TPU v5p are connected to the nearest 6 neighbors (forming a 3D torus). Note these connections do not go through their hosts, they are direct links between chips.

The toroidal structure reduces the maximum distance between any two nodes from $N$ to $N / 2$, making communication much faster. TPUs also have a “twisted torus” configuration that wraps the torus in a Mobius-strip like topology to further reduce the average distance between nodes.

TPU pods (connected by ICI) can get really big: the maximum pod size (called a superpod) is 16x16x16 for TPU v4 and 16x20x28 for TPU v5p. These large pods are composed of reconfigurable cubes of 4x4x4 chips connected by optical wraparound linksThe optical switch is simply a reconfigurable connection with the same ICI bandwidth. It just lets us connect cubes while retaining a wraparound link. that we can reconfigure to connect very large topologies.

Smaller topologies (e.g. 2x2x1, 2x2x2) can also be requested, albeit with no wraparounds. This is an important caveat, since it typically doubles the time of most communication. Any multiple of a full cube (e.g. 4x4x4 or 4x4x8) will have wraparounds provided by the optical switches.Note that a `2x2x4` won't have any wraparounds since they are provided by the optical switches which are only available on a full cube. A TPU v5e 8x16 _will_ have a wraparound on the longer axis, however, since it doesn't use reconfigurable optical networking.

For TPU v5e and Trillium we have pods which consist of a 16x16 2D torus. TPUs v5e and v6e (Trillium) cannot expand beyond a 16x16 torus but pods can still communicate with each other over standard Data Center Networking (DCN). Again, smaller topologies can be requested without wraps on dims $<16$.

This nearest-neighbor connectivity is a key difference between TPUs and GPUs. GPUs connect up to 256 H100s in an all-to-all configuration (called a node), rather than using local connections. On the one hand, that means GPUs can send arbitrary data within a node in a single low-latency hop. On the other hand, TPUs are dramatically cheaper and simpler to wire together, and can scale to much larger topologies because the number of links per device is constant.

ICI is very fast relative to DCN, but is still slower than HBM bandwidth. For instance, a TPU v5p has 2.5e12 bytes/s (2.5 TB/s) of HBM bandwidth and 1e11 bytes/s (100 GB/s) of ICI bandwidth per axis, which is about 25x lower. This means that when we split models across multiple chips, we need to be careful to avoid bottlenecking the MXU with slower cross-device communication.

Multi-slice training: A set of ICI-connected TPUs is called a slice. Different slices can be connected between each other using DCN, for instance to link slices on different pods. Since DCN is a much slower connection than ICI, one should try to limit how much our computation has to wait for data from DCN.

Key Takeaways

Model Pod size Host size HBM capacity/chip HBM BW/chip (bytes/s) FLOPs/s/chip (bf16) FLOPs/s/chip (int8)
TPU v3 32x32 4x2 32GB 9.0e11 1.4e14 1.4e14
TPU v4p 16x16x16 2x2x1 32GB 1.2e12 2.75e14 2.75e14
TPU v5p 16x20x28 2x2x1 96GB 2.8e12 4.59e14 9.18e14
TPU v5e 16x16 4x2 16GB 8.1e11 1.97e14 3.94e14
TPU v6e 16x16 4x2 32GB 1.6e12 9.20e14 1.84e15

Host size refers to the topology of TPUs connected to a single host (e.g. TPU v5e has a single CPU host connected to 8 TPUs in a 4x2 topology). Here are interconnect figures:

Model ICI BW/link (one-way, bytes/s) ICI BW/link (bidi, bytes/s)
TPU v3 1e11 2e11
TPU v4p 4.5e10 9e10
TPU v5p 1e11 2e11
TPU v5e 5e10 1e11
TPU v6e 1e11 2e11

We include both one-way (unidirectional) bandwidth and bidi (bidirectional) bandwidth since unidirectional bandwidth is more true to the hardware but bidirectional bandwidth occurs more often in equations involving a full ring.By bidi (bidirectional) bandwidth we mean the total bytes that can be sent along a single link in both directions, or equally, the total number of outgoing bytes from a single TPU along a particular axis, assuming we can use both links efficiently. This is true when we have a functioning ring, AKA when we have a wraparound connection on the particular axis. This occurs on inference chips when we have a full 16 axis, or on training chips (v*p) when we have an axis which is a multiple of 4. We prefer to use the bidirectional bandwidth because it appears frequently in calculations involving bidirectional comms.

PCIe bandwidth is typically around 1.5e10 bytes / second per chipTrillium (TPU v6e) has 32GB/s, about 2x higher than v5., while DCN bandwidth is typically around 2.5e10 bytes / second per host. We include both unidirectional and bidirectional bandwidth for completeness. Typically bidirectional bandwidth is the more useful number when we have access to a full wraparound ring, while one-way bandwidth is more true to the hardware.

Worked Problems

These numbers are a little dry, but they let you make basic roofline estimates for model performance. Let’s work a few problems to explain why this is useful. You’ll see more examples in Part 3.

Question 1 [bounding LLM latency]: Say you want to sample from a 200B parameter model in bf16 that’s split across 32 TPU v4p. How long would it take to load all the parameters from HBM into the systolic array? Hint: use the numbers above.

Click here for the answer.

Answer: We’re loading sizeof(bf16) * 200e9 = 400e9 bytes on 32 chips, meaning 12.5e9 bytes / chip, each with an HBM bandwidth of 1.23e12. So the load takes around 10ms.

That’s pretty cool, because that’s a reasonable lower bound on the latency of sampling from the model. Each sampling step needs to load all parameters from HBM, so it cannot take less than 10 ms. In practice, at small batch sizes, this is close to being achievable.

Question 2 [TPU details]: Consider a full TPU v5e pod. How many total CPU hosts are there? How many TPU TensorCores? What is the total FLOPs/s for the whole pod? What is the total HBM? Do the same exercise for TPU v5p pod.

Click here for the answer.

Answer: For TPU v5e, each pod is 16x16 and each host is a 4x2 slice, so we have 16*16 / 8 = 32 hosts. For TPU v5e, each TPU has only one core, so we have 256 TensorCores. The total FLOPs/s is 16*16*2e14 = 5.1e16 in bfloat16. Each chip has 16GB of HBM, so that’s 256 * 16 = 4TB of memory.

For a full TPU v5p pod, we have 16x20x28 chips and each host is 2x2x1, so we have 16*20*28 / 2*2 = 2,240 hosts. For TPU v5p, each TPU has two TensorCores, so we have 8960 * 2 = 17,920 cores. The total FLOPs/s is 8960 * 4.5e14 = 4e18 in bfloat16. Each chip has 96GB of HBM, so that’s 8960 * 96 = 860TB of memory.

Question 3 [PCIe operational intensity]: Imagine we’re forced to store a big weight matrix $A$ of type $\text{bfloat16}[D, F]$, and a batch of activations $x$ of type $\text{bfloat16}[B, D]$ in host DRAM and want to do a matrix multiplication on them. This is running on a single host, and we’re using a single TPU v6e chip attached to it. You can assume $B \ll D$, and $F = 4D$ (we’ll see in future chapters why these are reasonable assumptions). What is the smallest batch size $B$ we need to remain FLOPs bound over PCIe? Assume PCIe bandwidth of 1.5e10 bytes / second.

Click here for the answer.

Answer: We have to perform $2BDF$ floating point operations, and each chip can perform 9.2e14 floating point operations per second. This then requires $2BDF / 9.2e14$ seconds to perform. We have to load $2DF + 2BD$ bytes from DRAM, and write $2BF$ bytes back to it. We are bottlenecked by PCIe transfer speeds, so we need $2 \cdot (BD + DF + BF) / 1.5e10$ seconds to transfer data to and from the TPU. Since we want computation to take longer than weight loading, assuming we can overlap all weight loading with computation, we want $2BDF / 9.2e14 > 2 \cdot (BD + DF + BF) / 1.5e10$. We can simplify this using our assumptions that $B \ll D$, and $F = 4D$, to get

\[\frac{8BD^2}{9.2e14} > \frac{8D^2}{1.5e10}\]

or

\[B > \frac{9.2e14}{1.5e10} \simeq 61,000\]

Question 4 [general matmul latency]: Let’s say we want to multiply a weight matrix int8[16384, 4096] by an activation matrix of size int8[B, 4096] where B is some unknown batch size. Let’s say we’re on 1 TPUv5e to start.

  1. How long will this multiplication take as a function of B? Hint: it may help to calculate how long it will take to load the arrays from HBM and how long the multiplication will actually take. Which is bottlenecking you?
  2. What if we wanted to run this operation out of VMEM? How long would it take as a function of B?
Click here for the answer.

Answer: (1) The number of floating point operations we need to perform is $2 \cdot 4096 \cdot 16384 \cdot B = 1.3e8 \cdot B$. So $T_{\text{math}} = (1.3e8 \cdot B) / 3.94e14$ seconds. We need to load $16384 \cdot 4096 + 4096 \cdot B$ bytes from HBM to VMEM, and write back $16384 \cdot B$ bytes from VMEM to HBM. This means $T_{\text{comms}} = (6.7e7 + 2e4\cdot B) / 8.1e11$ seconds. Assuming as much overlap of communication and computation as possible, the whole multiplication will take approximately

\[\max\{T_{\text{math}}, T_{\text{comms}}\} = \max\left\{\frac{6.7e7 + 2e4\cdot B}{8.1e11}, \frac{1.3e8 \cdot B}{3.94e14}\right\}\]

We’ll be FLOPs-bound when $\frac{6.7e7 + 2e4\cdot B}{8.1e11} < \frac{1.3e8 \cdot B}{3.94e14}$, or equivalently, $B > 271$. This is slightly larger than the 240 number we derive below because we factor in the full impact of \(D\) and \(F\).

(2) If instead we are loading from VMEM, let’s consider VMEM bandwidth to the MXU as 22 times the HBM $\leftrightarrow$ VMEM bandwidth. This turns our data loading denominator from 8.1e11 to 1.78e13, and we get $B > 11$. _Note that in practice, we cannot dedicate all of our VMEM bandwidth to loading $W$, so in practice it will be closer to 20.

Question 5 [ICI bandwidth]: Let’s say we have a TPU v5e 4x4 slice. Let’s say we want to send an array of type bfloat16[8, 128, 8192] from TPU{0,0} to TPU{3, 3}. Let’s say the per-hop latency for TPU v5e is $1\mu s$.

  1. How soon will the first byte arrive at its destination?
  2. How long will the total transfer take?
Click here for the answer.

Answer: In a TPUv4e we have 2D connectivity. Thus there are two ports from which our target chip can receive data, and likewise two ports from which our source chip can output data. The amount of data we have to transfer is 2 * 8 * 128 * 8192 = 1.7e7 bytes. We can transfer from both ports simultaneously (i.e. send half the array right and half down), so we get 2 * 5e10 = 1e11 bytes transferred per second, which means it’ll take 1.7e7 / 1e11 = 170us to do one hop. In a 4x4 slice, we have six hops between chips $(0, 0)$ and $(3, 3)$, since there are no wraparound links for axes with fewer than 16 chips. Since the latency of each hop is about $1\mu s$, the first byte will arrive in about $6\mu s$ and the total transfer will take $170\mu s \cdot 6 = 1ms$.

Question 6 [pulling it all together, hard]: Imagine you have a big matrix A: int8[128 * 1024, 128 * 1024] sharded evenly across a TPU v5e 4x4 slice but offloaded to host DRAM on each chip. Let’s say you want to copy the entire array to TPU{0, 0} and multiply it by a vector bf16[8, 128 * 1024]. How long will this take? Hint: use the numbers above.

Click here for the answer.

Answer: Let’s start by outlining the operations we have to perform. Our array is about 16GB. From the table above, a TPU v5e host has a 4x2 topology, so a 4x4 has 2 hosts, Thus, since our array is evenly sharded, each host effectively contains a chunk of 1/2 of the array, or 8GB. We need to copy these chunks all to TPU{0,0}, which gives us two options:

  1. We can copy over DCN and then load the entire unsharded array over PCIe into HBM.
  2. We can load our sharded arrays onto their corresponding TPUs, then perform a gather over ICI, then perform the matmul on TPU{0,0}.

It should be clear that option (2) is better. DCN is slow compared to ICI and we’d much prefer to load a big array over many PCIe links rather than just a few (the 8 on host 0). Here’s a diagram of part of the system. As described above, note that TPUs are connected to their neighbors by ICI (even across hosts), all TPUs are connected to their host CPU (via PCIe), and hosts are connected by DCN.

Now let’s work through how long each piece will take:

  1. PCIe load: we’re loading chunks of 16GB / 2 = 8GB over 8 PCIe links, each of which has 1.5e10 bytes/second bandwidth. Thus this will take about 66ms.

  2. ICI copy: each TPU now has 16GB / 16 = 1GB of our array. Our ICI bandwidth is 10e10 bytes/second per link bidirectional, and you’ll notice from the above diagram that only 2 of the 4 ICI links on the TPU v5e are in use in this topology. Since TPU{0,0} needs to receive a total of 15GB along 2 axes at 5e10 bytes/s/link, we can lower bound the time by 15e9 / (5e10 * 2) = 150ms. In practice this probably isn’t achievable because the load is very uneven, but it’s probably within a factor of 2. As you’ll see in Section 2, performing a full AllGather would also take roughly 16e9 / (5e10 * 2), so this is close to optimal.

  3. HBM $\rightarrow$ MXU load: to perform our final matmul, we need to load these 16e9 bytes plus the bf16[8, 128 * 1024] array (another 2MB, so negligible) over HBM bandwidth into the MXU, which will take 16e9 / 8.1e11 = 19ms.

  4. FLOPs: we’re performing a total of \(2 \cdot 8 \cdot 128 \cdot 1024 \cdot 128 \cdot 1024 = 2.7e11\) FLOPs, and since we can perform 1.97e14 bf16 FLOPs/s, we get 1.3ms.

An upper bound for the total time is the sum of all of these times, but since the TPU can typically overlap these operations, we can think of this as a pipelining problem that’s bottlenecked by the slowest piece. Assuming that’s true, then the answer is about 150-200ms.

That’s it for Part 2! For Part 3, covering partitioning and cross-TPU communication, click here.

Appendix

Appendix A: Let’s talk briefly about GPUs

Compared to TPUs, GPUs have a simpler communication model and a more complicated programming model.

Appendix B: How does a systolic array work?

At the core of the TPU MXU is a 128x128 systolic array (256x256 on TPU v6e). When fully saturated the systolic array can perform one bfloat16[8,128] @ bf16[128x128] -> f32[8,128]If you are not familiar with this notation, it means: multiplying a `8x128` matrix with bfloat16 elements by a `128x128` matrix with bfloat16 elements and storing the results in a `8x128` matrix with float32 elements. multiplication per 8 clock cycles.

Here is a simplified animation of multiplying a set of weights (blue) with a set of activations (green). You’ll notice that the weights (RHS) are partially loaded first, diagonally, and then the activations are fed in, also diagonally. In each frame above, we multiply all the overlapped green and blue units, sum the result with any residual passed in from above, and then pass the result in turn down one unit.

Here’s a more general version of this animation showing the output being streamed out of computation:

Here’s a diagram showing how this can be pipelined across multiple RHS and LHS arrays:

There is an initial pipeline bubble as the weights (RHS) and activations (LHS) are loaded. After that initial bubble, new inputs and weights can be loaded in without an additional bubble.

Here’s a bad animation of a bf16[2, 3] x bf16[3, 3] matrix multiplication, which you could imagine as a matmul of a 2x3 weight matrix with an input activation of batch 1 and size 3. This is rotated compared to the previous slides and inputs flow out to the right instead of down, but you can roughly see the structure.

We can efficiently pipeline this to multiply large matrices without too large a pipeline bubble. With that said, it’s important that our matrices have shapes larger than the side dimension of the MXU, which is generally 128x128. Some TPUs (since TPU v3) have multiple MXUs, either 2 for TPU v3 and 4 for TPU v4/5, so we need to ensure tiling dimensions are larger than 128 * number of MXUs. Here’s a good animation for this.

Trillium (TPU v6e) has a 256x256 systolic array, which means it can perform 4x more FLOPs / cycle. This also means the dimensions of your tensors needs to be twice as large to utilize the MXU fully.

Appendix C: TPU internals

Scalar Core

The TPU scalar core processes all of the instructions and executes all of the transfers from HBM into vector memory (VMEM). The scalar core is also responsible for fetching instructions for the VPU, MXU and XLU components of the chip. One side-effect of this is that each core of the TPU is only capable of creating one DMA request per cycle.

To put this in context, a single 4 scalar core controls a VPU consisting of 2048 ALUs, 4 MXUs, 2 XLUs, and multiple DMA engines. The highly skewed nature of control per unit compute is a source of hardware efficiency, but also limits the ability to do data dependent vectorization in any interesting way.

VPU

The TPU vector core consists of a two dimensional vector machine (the VPU) that performs vector operations like vadd (vector addition) or vmax (elementwise max) and a set of vector registers called VREGs that hold data for the VPU and MXU. The VPU is effectively a 2D vector arithmetic unit of shape (8, 128) where the 128 dimension is referred to as a lane and the dimension of 8 is referred to as a sublane. Each (lane, sublane) pair on v4 contains 2 standard floating-point and integer ALUs. From a software point-of-view, this creates the appearance of a 8x128 vector unit with a total of 2048 floating point adders in v4. TPU v4 has 32 VREGs of size (8, 128) which the VPU loads from and writes to.

The VPU executes most arithmetic instructions in one cycle in each of its ALUs (like vadd or vector add) with a latency of 2 cycles, so e.g. in v5 you can add 4 pairs of f32 values together from VREGs in each cycle. A typical VPU instruction might look like {v2 = vadd.8x128.f32 v0, v1} where v0 and v1 are input VREGs and v2 is an output VREG.

All lanes and sublanes execute the same program every cycle in a pure SIMD manner, but each ALU can perform a different operation. So we can e.g. process 1 vadd and 1 vsub in a single cycle, each of which operates on two full VREGs and writes the output to a third.

Miscellaneous

*Work done at Google DeepMind, now at MatX.