原文链接:https://jax-ml.github.io/scaling-book/tpus/


目录

什么是 TPU?

从根本上说,TPU 是一个专门用于矩阵乘法的计算核心(称为 TensorCore),连接着一堆高速内存(称为高带宽内存HBM)。

下图展示了 TPU 芯片的基本组件。TensorCore 是左侧的灰色框,包含矩阵乘法单元 (MXU)、向量单元 (VPU) 和向量内存 (VMEM)。

![[Pasted image 20251129220903.png]]

你可以把 TensorCore 看作是一台非常擅长矩阵乘法的机器,但它还有其他几个值得注意的功能。TensorCore 有三个关键单元:

  1. MXU (矩阵乘法单元):这是 TensorCore 的核心。在大多数 TPU 世代中,它利用脉动阵列(详见附录 B),每 8 个时钟周期执行一次 bfloat16[8,128] @ bf16[128,128] -> f32[8,128] 的矩阵乘法 1

    • 在 TPU v5e 上(1.5GHz),这大约相当于每个 MXU 每秒 5e13 次 bf16 浮点运算 (FLOPs)。
    • 大多数 TensorCore 有 2 或 4 个 MXU,因此例如 TPU v5e 的总 bf16 FLOPs/s 为 2e14
    • TPU 还支持更高吞吐量的低精度矩阵乘法(例如,每个 TPU v5e 芯片每秒可执行 4e14 次 int8 操作)。
  2. VPU (向量处理单元):执行通用的数学运算,如 ReLU 激活、逐点加法或向量间的乘法。规约(求和)也在这里执行。附录 A 提供了更多细节。

  3. VMEM (向量内存):是位于 TensorCore 内部、靠近计算单元的片上暂存器 (scratchpad)。它比 HBM 小得多(例如,TPU v5e 上为 128 MiB),但到 MXU 的带宽要高得多。VMEM 的运作方式有点像 CPU 上的 L1/L2 缓存,但它更庞大且完全由程序员控制。HBM 中的数据需要先复制到 VMEM 中,TensorCore 才能对其进行计算。


TPU 在矩阵乘法方面非常、非常快。这是它们的主要工作,而且干得很好。TPU v5p 是迄今为止最强大的 TPU 之一,每个核心每秒可进行 2.5e14 次 bf16 FLOPs,或者每个芯片每秒 5e14 次。一个包含 8960 个芯片的 Pod 每秒可以进行 4 exaflops (百亿亿次) 的运算。这是非常强大的算力。这是世界上最强大的超级计算机之一。而 Google 拥有很多这样的计算机。[^2]

上面的图还包括其他一些组件,如 SMEM标量单元 (Scalar Unit),用于处理控制流,在附录 A 中有简要讨论,但理解它们并不关键。另一方面,HBM 很重要且相当简单:

通常,所有 TPU 操作都是流水线化 (pipelined) 且重叠 (overlapped) 的。

为了执行矩阵乘法 $X \cdot A \rightarrow Y$,TPU 首先需要将矩阵 $A$ 和 $X$ 的块 (chunks) 从 HBM 复制到 VMEM,然后将它们加载到 MXU 中(MXU 会对 $X$ 的 8x128 块和 $A$ 的 128x128 块进行乘法),最后将结果逐块复制回 HBM。

为了高效地做到这一点,矩阵乘法是流水线化的,因此与 VMEM 的数据复制是与 MXU 的工作重叠进行的。这允许 MXU 继续工作而不是等待内存传输,从而保持矩阵乘法处于计算受限 (compute-bound) 状态,而不是内存受限。

这是一个如何从 HBM 执行逐元素 (elementwise) 乘积的示例:

![[pointwise-product 2.gif]]

图:动画显示在 TPU 上执行逐点乘积,字节从 HBM 加载。

注意字节是如何分块流出内存的,部分结果被流水线回传,而无需等待完整数组具体化。

矩阵乘法看起来几乎一样,只是它会加载到 MXU 而不是 VPU / 向量单元,并且加载和存储的顺序会不同,因为同一个权重块会被用于多个激活块。你可以看到数据块流在 VMEM 中,然后进入 VREGs (向量寄存器),然后进入向量单元,最后回到 VMEM 和 HBM。正如我们即将看到的,如果从 HBM 到 VMEM 的加载速度慢于向量单元(或 MXU)中的 FLOPs 速度,我们就变成了“带宽受限”,因为我们让 VPU 或 MXU 处于饥饿状态,没有工作可做。

关键要点:TPU 非常简单。它们将权重从 HBM 加载到 VMEM,然后从 VMEM 加载到一个每秒可以执行约 200 万亿次乘加运算的脉动阵列。HBM $\leftrightarrow$ VMEM 和 VMEM $\leftrightarrow$ 脉动阵列 的带宽设定了 TPU 能高效执行哪些计算的基本限制。

![[Pasted image 20251129221416.png]]

VMEM 和算术强度:VMEM 比 HBM 小得多,但它到 MXU 的带宽要高得多。正如我们在第一部分看到的,这意味着如果一个算法能将其所有输入/输出放入 VMEM,它遇到通信瓶颈的可能性就会小得多。

这在计算的算术强度较差时特别有用:VMEM 带宽大约是 HBM 带宽的 22 倍,这意味着一个从 VMEM 读取/写入的操作只需要 10-20 的算术强度即可达到峰值 FLOPs 利用率。这意味着如果我们能将权重放入 VMEM 而不是 HBM,我们的矩阵乘法可以在更小的 batch sizes 下达到 FLOPs 受限状态。这也意味着那些本质上算术强度较低的算法仍然可以是高效的。只是 VMEM 太小了,这通常是一个挑战。[^3]

![[Pasted image 20251129221344.png]]

![[Pasted image 20251129221335.png]]

PCIe 带宽是有限的:就像 HBM <-> VMEM 链路一样,CPU <-> HBM 的 PCIe 连接也有特定的带宽,限制了你从主机内存加载到 HBM 或反之亦然的速度。例如,TPU v4 的 PCIe 带宽每个方向为 16GB/s,比 HBM 慢近 100 倍。我们可以将数据加载/卸载到主机 (CPU) RAM 中,但这并不快。


TPU 互联

芯片在 Pod 中通过 ICI 网络相互连接。在老一代(TPU v2 和 TPU v3)、推理芯片(如 TPU v5e)和 Trillium (TPU v6e) 中,ICI (“芯片间互联”) 连接最近的 4 个邻居(带有边缘链路以形成 2D 环面)。TPU v4 和 TPU v5p 连接到最近的 6 个邻居(形成 3D 环面)。注意这些连接经过它们的主机,它们是芯片间的直接链路。

![[Pasted image 20251129222652.png]]

环面 (Toroidal) 结构将任意两个节点之间的最大距离从 $N$ 减少到 $N/2$,使通信更快。TPU 还有一个“扭曲环面 (twisted torus)”配置,将环面 包裹在类似莫比乌斯带的拓扑中,以进一步减少节点间的平均距离。

TPU Pods(通过 ICI 连接)可以变得非常大:最大 Pod 大小(称为 superpod)对于 TPU v4 是 16x16x16,对于 TPU v5p 是 16x20x28。这些大型 Pod 由 4x4x4 芯片的可重构立方体组成,通过光学环绕链路 [^5] 连接,我们可以重新配置它们以连接非常大的拓扑结构。

![[Pasted image 20251129222717.png]]

也可以请求较小的拓扑结构(例如 2x2x1, 2x2x2),尽管没有环绕链路。这是一个重要的注意事项,因为它通常会使大多数通信的时间加倍。任何完整立方体的倍数(例如 4x4x4 或 4x4x8)都将具有由光学开关提供的环绕链路。[^6]

![[Pasted image 20251129222850.png]]

TPU v5e 和 Trillium Pods 由单个 16x16 2D 环面组成,在任何大小为 16 的轴上都有环绕(意味着 8x16 在长轴上有环绕)。TPU v5e 和 v6e (Trillium) 不能扩展超过 16x16 环面,但 Pods 之间仍然可以通过标准数据中心网络 (DCN) 通信,DCN 连接着 TPU 主机。同样,如果请求维度 <16 的较小拓扑,则没有环绕链路。

![[Pasted image 20251129222901.png]]

这种最近邻连接是 TPU 和 GPU 之间的一个关键区别。GPU 使用分层交换机连接,近似于每个 GPU 之间的点对点连接,而不是像 TPU 那样使用本地连接。通常,一个节点内的 GPU(H100 为 8 个 GPU,B200 NVLink72 多达 72 个)是直接连接的,而更大的拓扑结构需要在每个 GPU 之间进行 $O(\log(N))$ 跳。一方面,这意味着 GPU 可以在少量跳数内发送任意数据。另一方面,TPU 便宜得多(因为 NVLink 交换机很贵),接线更简单,并且可以扩展到更大的拓扑结构,因为每个设备的链路数量和每个设备的带宽是恒定的。在此处阅读更多内容。

ICI 相对于 DCN 非常快,但仍慢于 HBM 带宽。例如,一个 TPU v5p 拥有:

这意味着当我们跨多个芯片分割模型时,我们需要小心避免用较慢的跨设备通信成为 MXU 的瓶颈。

多切片 (Multi-slice) 训练:一组通过 ICI 连接的 TPU 称为一个 slice (切片)。不同的 slice 可以通过 DCN 相互连接,例如连接不同 Pod 上的 slice。由于 DCN 是比 ICI 慢得多的连接,我们应该尽量限制计算等待来自 DCN 数据的时间。

DCN 是主机对主机的,所以要通过 DCN 将缓冲区从 TPU 传输到 TPU,我们需要先通过 PCIe 传输到主机,然后通过网络出口,再通过目标主机网络入口,最后通过 PCIe 进入 HBM。


TPU 芯片在 Pod 内部通过 ICI (Inter-Chip Interconnects) 网络互联。

TPU vs GPU 网络差异:

NVLink Switch 是什么?

带宽层级:

  1. HBM 带宽:~ 2.5 TB/s (最快)
  2. ICI 带宽:~ 90 GB/s (芯片间互联)
  3. PCIe 带宽:~ 16 GB/s (主机到芯片,很慢)
  4. DCN 带宽 (Data Center Network):~ 6 GB/s (跨主机/跨 Pod,最慢)

扩展策略:
当我们将模型切分到多个芯片时,必须小心避免让 MXU 等待慢速的跨设备通信。通信量必须与该通道的物理带宽成正比。

[[TPU Networking 通俗解释]]


关键要点

这句话是什么意思?
TPU 是吞吐量怪兽,但它有很高的“启动门槛”。如果你喂给它的数据块(Batch Size 或 Hidden Dimension)小于 128,TPU 就会被迫通过补零来“空转”。


TPU 规格参数

以下是 TPU 芯片的一些具体数字:

型号 Pod 大小 主机大小 HBM 容量/芯片 HBM 带宽/芯片 (bytes/s) FLOPs/s/芯片 (bf16) FLOPs/s/芯片 (int8)
TPU v3 32x32 4x2 32GB 9.0e11 1.4e14 1.4e14
TPU v4p 16x16x16 2x2x1 32GB 1.2e12 2.75e14 2.75e14
TPU v5p 16x20x28 2x2x1 96GB 2.8e12 4.59e14 9.18e14
TPU v5e 16x16 4x2 16GB 8.1e11 1.97e14 3.94e14
TPU v6e 16x16 4x2 32GB 1.6e12 9.20e14 1.84e15

截止现在 v7e 已经出了

主机大小指的是连接到单个主机的 TPU 拓扑(例如,TPU v5e 有一个连接到 8 个 TPU 的单个 CPU 主机,采用 4x2 拓扑)。

以下是互联数据:

型号 ICI 带宽/链路 (单向, bytes/s) ICI 带宽/链路 (双向, bytes/s)
TPU v3 1e11 2e11
TPU v4p 4.5e10 9e10
TPU v5p 9e10 1.8e11
TPU v5e 4.5e10 9e10
TPU v6e 9e10 1.8e11

我们包含单向带宽和双向 (bidi) 带宽,因为单向带宽更符合硬件实际,但双向带宽在涉及完整环的方程中更常出现。[9]

PCIe 带宽通常约为每个 TPU 1.6e10 bytes/s (TPU v6e 为 3.2e10),而 DCN 带宽通常约为每个 TPU 6.25e9 bytes/s (TPU v6e 为 12.5e9,TPU v5e 为 3.125e9)。


练习题 (Worked Problems)

这些数字有点枯燥,但它们让你能够对模型性能进行基本的 Roofline 估算。让我们做几个练习来解释为什么这很有用。你将在第三部分看到更多例子。

问题 1 [估算 LLM 延迟]:

假设你想从分布在 32 个 TPU v4p 上的 200B 参数模型(bf16)中进行采样。将所有参数从 HBM 加载到脉动阵列需要多长时间? 提示:使用上面的数字。

我理解 sizeof(bf16) * 200B / 32 / 1.2e12
我们在 32 个芯片上加载 sizeof(bf16) * 200e9 = 400e9 字节,意味着每个芯片 12.5e9 字节,每个芯片的 HBM 带宽为 1.23e12。所以加载大约需要 10ms。
这很酷,因为这是从模型采样的延迟的一个合理的下限。每个采样步骤都需要从 HBM 加载所有参数,所以它不可能少于 10 ms。实际上,在小 batch sizes 下,这接近于可实现的数值。

问题 2 [TPU 细节]:

考虑一个完整的 TPU v5e pod。总共有多少个 CPU 主机?有多少个 TPU TensorCores?整个 pod 的总 FLOPs/s 是多少?总 HBM 是多少?对 TPU v5p pod 做同样的练习。

TPU v5e pod 不是 TPU 的集合吗?为什么会涉及到 CPU 呢,这里的主机到底是什么概念
pod 难道是机器集群的意思吗?
对于 TPU v5e,每个 pod 是 16x16,每个主机是 4x2 slice,所以我们有 16*16 / 8 = 32 个主机。对于 TPU v5e,每个 TPU 只有一个核心,所以我们有 256 个 TensorCores。bfloat16 下的总 FLOPs/s 是 16*16*2e14 = 5.1e16。每个芯片有 16GB HBM,所以总共有 256 * 16 = 4TB 内存。

对于一个完整的 TPU v5p pod,我们有 16x20x28 个芯片,每个主机是 2x2x1,所以我们有 16*20*28 / 2*2 = 2,240 个主机。对于 TPU v5p,每个 TPU 有两个 TensorCores,所以我们有 8960 * 2 = 17,920 个核心。bfloat16 下的总 FLOPs/s 是 8960 * 4.5e14 = 4e18。每个芯片有 96GB HBM,所以总共有 8960 * 96 = 860TB 内存。

Pod 是 Google 数据中心里的一个物理部署单位。你可以把它想象成一整排(或几排)专用的机柜。
特点:在一个 Pod 里的所有 TPU 芯片,都通过 ICI(那个专用的高速互联网络)物理连接在一起了。
关键点:Pod 定义了“最大连通域”。同一个 Pod 里的 TPU 可以直接“聊天”(通过 ICI)。Pod A 的 TPU 想和 Pod B 的 TPU 聊天,必须走慢速的公网(DCN),经过 Host 的 CPU 和网卡转发。

对于 TPU v5e Pod:文案里说它是 16x16。意思是这个 Pod 里的 TPU 芯片排列成一个 $16 \times 16$ 的二维网格。总芯片数:$16 \times 16 = 256$ 颗 TPU 芯片。这 256 颗芯片被物理安装在若干台 Host 服务器里。

让我们来拆解这句让你困惑的话:

“对于 TPU v5e,每个 pod 是 16x16,每个主机是 4x2 slice,所以我们有 16*16 / 8 = 32 个主机。”

这其实是在算物理服务器的数量。

A. 单台 Host 的物理配置 (The Atomic Unit)

对于 TPU v5e 这一代硬件,Google 设计的主板结构是这样的:

B. 整个 Pod 的规模

C. 需要多少台服务器?

如果你要拼出一个 256 颗芯片的大网格,而你手里的“积木块”(Host)每块包含 8 颗芯片。
你需要多少块积木?

$$\text{Host Count} = \frac{\text{Total Chips in Pod}}{\text{Chips per Host}}$$
$$\text{Host Count} = \frac{256}{8} = 32 \text{ 台服务器}$$
物理画面感:

走进机房,你会看到 32 台服务器。

问题 3 [PCIe 运算强度]:

想象我们被迫在主机 DRAM 中存储一个大的权重矩阵 $A$ (类型 bfloat16[D,F]) 和一批激活值 $x$ (类型 bfloat16[B,D]),并想对它们进行矩阵乘法。这是在单个主机上运行的,我们使用连接到它的单个 TPU v6e 芯片。你可以假设 $B \ll D$,并且 $F=4D$(我们将在以后的章节中看到为什么这些是合理的假设)。为了保持在 PCIe 上 FLOPs 受限 (FLOPs bound),我们需要多大的最小 batch size $B$?假设 PCIe 带宽为 1.5e10 bytes/s。

权重矩阵 激活值 分别是什么?
什么叫做为了保持在 PCIe 上 FLOPs 受限 (FLOPs bound) 「这是计算受限的意思吗?」
计算量 (FLOPs):$2 \times B \times D \times (4D) = 8 B D^2$。
计算时间:$8 B D^2 / 9.2e14$ 秒。

![[Pasted image 20251130115024.png]]

数据传输量:主要由权重 $A$ 决定(因为 $B$ 很小),大小为 $2 \times D \times 4D = 8D^2$ bytes。
传输时间:$8D^2 / 1.5e10$ 秒。

平衡点:计算时间 > 传输时间 $$\frac{8 B D^2}{9.2 \times 10^{14}} &gt; \frac{8 D^2}{1.5 \times 10^{10}}$$$$\frac{B}{9.2 \times 10^{14}} > \frac{1}{1.5 \times 10^{10}}$$$$B > \frac{9.2 \times 10^{14}}{1.5 \times 10^{10}} \approx 61333$$
结论:这非常反直觉。通过 PCIe 传输权重极慢,你需要极大的 Batch Size 才能掩盖传输延迟。
==这就是为什么我们要把权重常驻 HBM==。

问题 4 [通用矩阵乘法延迟]:

假设我们想将一个 int8[16384, 4096] 的权重矩阵乘以一个大小为 int8[B, 4096] 的激活矩阵,其中 B 是某个未知的 batch size。
假设我们开始是在 1 个 TPUv5e 上。作为一个关于 B 的函数,这个乘法需要多长时间?
提示:计算从 HBM 加载数组需要多长时间以及乘法实际需要多长时间可能会有所帮助。哪个是瓶颈?
如果我们想在 VMEM 中运行此操作怎么办?作为一个关于 B 的函数,它需要多长时间?

(1) 我们需要执行的浮点运算次数是 $2 \cdot 16384 \cdot 4096 \cdot B \approx 1.34e8 \cdot B$。所以 $1.34e8 \cdot B / 3.94e14 \approx 3.4e^{-7} \cdot B$ 秒。
我们需要从 HBM 加载 $16384 \cdot 4096 + B \cdot 4096$ 字节到 VMEM,并从 VMEM 写回 $B \cdot 16384$ 字节到 HBM。这意味着 $(6.7e7 + 20480 \cdot B) / 8.1e11 \approx 8.2e^{-5} + 2.5e^{-8} \cdot B$ 秒。假设通信和计算尽可能多地重叠,整个乘法将花费大约

$$\max(3.4e^{-7} \cdot B, 8.2e^{-5} + 2.5e^{-8} \cdot B)$$

当 $3.4e^{-7} \cdot B &gt; 8.2e^{-5}$,或者等价地 $B &gt; 240$ 时,我们将是 FLOPs 受限的。这比我们在下面得出的 240 数字略大,因为我们考虑了 $x$ 和 $y$ 加载/存储的全部影响。

(2) 如果相反我们是从 VMEM 加载,让我们把 VMEM 到 MXU 的带宽视为 HBM $\leftrightarrow$ VMEM 带宽的 22 倍。这将我们的数据加载分母从 8.1e11 变为 1.78e13,我们得到 $B &gt; 11$。注意在实践中,我们不能将所有的 VMEM 带宽都用于加载 $W$,所以在实践中它会接近 20。

问题 5 [ICI 带宽]:

假设我们有一个 TPU v5e 4x4 slice。假设我们想把一个类型为 bfloat16[8, 128, 8192] 的数组从 TPU{0,0} 发送到 TPU{3, 3}。假设 TPU v5e 的每跳延迟是 $1\mu s$。
第一个字节多久会到达目的地?
总传输需要多长时间?

答案: 在 TPUv5e 中我们有 2D 连接。因为我们只有 4x4 slice(没有大小为 16 的轴),我们没有环绕连接。因此我们的目标芯片有两个端口可以接收数据,同样源芯片也有两个端口可以发送数据。

我们要传输的数据量是 2 * 8 * 128 * 8192 = 1.7e7 字节。我们可以同时从两个端口传输(即发送一半数组向右,一半向下),所以我们得到每秒传输 2 * 4.5e10 = 9e10 字节,这意味着大约需要 1.7e7 / 9e10 = 188us 来传输整个数组(假设我们是带宽受限的)。

在一个 4x4 slice 中,我们在芯片 (0,0) 和 (3,3) 之间有六跳,因为对于少于 16 个芯片的轴没有环绕链路。由于每跳的延迟大约是 $1\mu s$,第一个字节将在大约 $6\mu s$ 内到达,总传输将花费 188us。

![[Pasted image 20251130164537.png]]

问题 6 [综合运用] 「困难」:

想象你有一个大的矩阵 $A$: int8[128 * 1024, 128 * 1024] 均匀分片在一个 TPU v5e 4x4 slice 上,但卸载到了每个芯片的主机 DRAM 上。假设你想把整个数组复制到 TPU{0, 0} 并将其乘以一个向量 bf16[8, 128 * 1024]。这将需要多长时间? 提示:使用上面的数字。

什么叫做 「矩阵均匀分片」?
什么叫做 「卸载到了每个芯片的主机 DRAM 上」?

均匀分片 = 数据被打散存储在 16 个节点上。
卸载到 DRAM = 数据在冷存储(Host Memory)里,不在热存储(HBM)里。


附录 A: 更多关于 TPU 内部细节

在这里我们将更深入地研究 TPU 的内部操作。除非另有说明,我们将提供 TPU v5p 的规格。

VPU (向量处理单元)

VPU 是 TPU 的向量算术核心。VPU 由一个二维 SIMD 向量机(即 VPU)和一组称为 VREGs 的向量寄存器组成,VPU 执行逐元素算术操作,如 vadd(向量加法)或 vmax(逐元素最大值),VREGs 保存 VPU 和 MXU 的数据。

在 x86/ARM 中,你熟悉的 SIMD 寄存器(如 AVX-512 的 zmm0)是一个 512-bit 的一维数组(存 16 个 float)但在 TPU v5p 中,一个 VREG(向量寄存器)不是一维的,而是二维的。
物理形状:8 (Sublane) x 128 (Lane)
单个寄存器大小:$8 \times 128 \times 4 \text{ bytes (f32)} = 4 \text{ KB}$。
总量:每个核心有 64 个这样的寄存器。总大小 $64 \times 4 \text{ KB} = 256 \text{ KB}$。
把 VREG 想象成一个芯片内的 L0 Cache。它非常大,而且带宽极高(每周期 3 读 1 写)。这意味着 VPU 指令(如 vadd)一次性操作的不是 1 个数,也不是 16 个数,而是 1024 个数(8x128)。这体现了 TPU 用空间换时间的设计哲学。

这里提到的 (8, 128) 维度非常关键,这决定了编程模型和通信成本。Lane (128 维):这是主并行轴。你可以把它想象成 128 个独立的“工位”。Sublane (8 维):这是次级并行轴。每个工位上坐着 8 个“工人”。
ALU 密度:在每个 (Lane, Sublane) 点上,实际上有 4 个独立的 ALU。
指令级并行 (ILP): 文中提到 “一个周期内处理 1 个 vadd 和 1 个 vsub”。 这意味着 VPU 采用了类似 VLIW(超长指令字) 的设计。 在一个时钟周期内,编译器可以打包多条指令,同时利用那 4 个 ALU。 比如:
ALU 0: 计算 vREG1 + vREG2
ALU 1: 计算 vREG3 - vREG4
ALU 2 & 3: 空闲

所有 lanes 和 sublanes 每个周期都以纯 SIMD 方式执行相同的程序,但每个 ALU 可以执行不同的操作。所以我们可以例如在一个周期内处理 1 个 vadd 和 1 个 vsub,每一个都操作两个完整的 VREGs 并将输出写入第三个。

SIMD 单指令多数据

突击测验 [计算 VPU 吞吐量]

使用上述信息,计算一个 TPU v5p 可以执行多少向量 FLOPs/s。TPU v5p 的时钟速度约为 1.75GHz。

已知条件:

计算过程:
$$\text{Total FLOPs/s} = (\text{Cores}) \times (\text{Grid Size}) \times (\text{ALUs/Grid}) \times (\text{Frequency})$$$$\text{Total FLOPs/s} = 2 \times (8 \times 128) \times 4 \times 1.75 \times 10^9$$$$\text{Total FLOPs/s} = 2 \times 1024 \times 4 \times 1.75 \times 10^9$$$$\text{Total FLOPs/s} = 8192 \times 1.75 \times 10^9 \approx 14,336 \times 10^9 \approx 14.3 \text{ TFLOPS}$$
对比:MXU(矩阵单元)的算力高达 459 TFLOPS。这里会发现 VPU 的算力(14T)只有 MXU 的 3% 左右。结论:TPU 严重偏科。它极其擅长矩阵乘法,但如果你用它做大量复杂的通用向量运算(在 VPU 上跑),性能会暴跌。

规约 (Reductions): 通常,跨 sublane 维度的通信或规约比跨 lane 维度更容易。

例如,VPU 支持一种 intra-lane shuffle 操作,可以在大约一个周期内在大小为 8 的轴上滚动。这可以用来沿 sublane 维度执行高效的规约(只需 shuffle 4, 2, 和 1 并做 3 对逐元素求和)。

规约是什么?
可以类比函数式编程中的 Reduce 操作
底层实现的痛苦: 对于 SIMD 硬件,加法(Map) 是极其快乐的: c[i] = a[i] + b[i]。每个 i 互不干扰,完全并行。
但 规约(Reduce) 是痛苦的: sum = a[0] + a[1] + … + a[N]。 这包含数据依赖。你必须拿到前一步的结果才能加下一个。为了并行,通常使用树状归约(Tree Reduction)(两两相加,然后再两两相加),但这需要数据在不同的 ALU 之间移动。

跨 lane 规约要困难得多,涉及一个单独的硬件单元称为 XLU 或“跨 lane 单元”,它很慢且相当昂贵。

为什么跨纬度规约这么难呢?

场景:你要把 Sublane 维度(大小为 8)的 8 个数加起来。 物理含义:这 8 个数都在同一个“工位”上,只是属于不同的“工人”。他们坐得很近。 实现:Shuffle(洗牌)操作
Shift 4: 前 4 个工人把数据递给后 4 个。相加。 (现在剩下 4 个有效数据)
Shift 2: 再递给后 2 个。相加。 (剩下 2 个)
Shift 1: 再递给最后 1 个。相加。 (剩下 1 个结果) 耗时:仅需约 1-3 个时钟周期。因为连线短,不需要复杂的路由。

场景:你要把 Lane 维度(大小为 128)的 128 个数加起来。 物理含义:这 128 个数分布在 128 个不同的工位上,物理距离可能跨越了半个芯片核心。 瓶颈:工位 0 不能直接把数据给工位 127。必须通过 XLU (Cross Lane Unit)。这是一个专门的硬件路由器,类似于芯片内部的“总线”或“交换机”。耗时:非常慢,且昂贵。

程序员的启示(Optimization Tip): 如果你在写 TPU 的 Kernel(比如用 JAX/Triton):
尽可能沿着 Sublane (8) 维度做 Sum/Max。
极力避免沿着 Lane (128) 维度做 Sum/Max。 如果非要做,编译器可能会通过转置(Transpose)把数据换个方向,但这也很慢。

与 GPU 的比较: 对于那些熟悉 NVIDIA GPU 的人来说,VPU 中的每个 ALU 类似于一个 CUDA 核心,单个 VPU lane 类似于一个“Warp Scheduler”,即通常执行 SIMD 算术的 32 个 CUDA 核心的集合。Lane 内的规约相当容易,但如果我们需要跨 lane,我们需要通过至少 VMEM/XLU/SMEM,这要慢得多。

可惜我不熟悉 GPU

更多细节请参阅 [GPU 章节]。

标量核心 (Scalar Core)

标量核心是 TPU 的控制单元。它获取并分发所有指令并执行从 HBM 到 VMEM 的传输,并且可以编程做标量元数据工作。因为标量核心是单线程的,这的一个副作用是 TPU 的每个核心每个周期只能创建一个 DMA 请求。
为了把这放在上下文中,单个标量核心控制一个 VPU(由 4096 个 ALU 组成)、4 个 MXU、2 个 XLU 和多个 DMA 引擎。这种单位计算控制的高度倾斜是硬件效率的一个来源,但也限制了以任何有趣的方式做数据依赖的向量化的能力。


附录 B: 脉动阵列是如何工作的?

TPU MXU 的核心是一个 128x128 脉动阵列(TPU v6e 上为 256x256)。当完全饱和时,脉动阵列可以每 8 个时钟周期执行一次 bfloat16[8,128] @ bf16[128x128] -> f32[8,128] [10] 乘法。

在其核心,脉动阵列是一个 128x128 (=16,384) 的 ALU 网格,每个都能执行乘法和加法操作。
权重 (W, 128x128 输入) 从上方传递下来(称为 RHS),而输入 (X, 8x128 输入) 从左侧传递进来(称为 LHS)。

这是一个权重集(蓝色)与激活集(绿色)相乘的简化动画。你会注意到权重 (RHS) 首先被部分加载,对角线地,然后激活被送入,也是对角线地。在下面的每一帧中,我们将所有重叠的绿色和蓝色单元相乘,将结果与从上方传入的任何残差求和,然后将结果依次向下传递一个单元。

![[systolic-array.gif]]

这是一个更通用的动画版本,显示输出正从计算中流出:

![[systolic-array2.gif]]

这是一张图表,展示了如何在多个 RHS 和 LHS 数组之间进行流水线操作:

![[Pasted image 20251130104857.png]]

随着权重 (RHS) 和激活 (LHS) 的加载,会有一个初始的流水线气泡 (bubble)。在那个初始气泡之后,新的输入和权重可以被加载进来而没有额外的气泡。

这是一个 bf16[2, 3] x bf16[3, 3] 矩阵乘法的糟糕动画,你可以把它想象成一个 2x3 权重矩阵与一个 batch 为 1、大小为 3 的输入激活的 matmul。这与之前的幻灯片相比旋转了,输入流向右侧而不是下方,但你可以大致看到结构。

[动画:bf16[2, 3] x bf16[3, 3] 矩阵乘法]

我们可以有效地对此进行流水线处理以倍增大型矩阵,而不会产生太大的流水线气泡。话虽如此,重要的是我们的矩阵形状要大于 MXU 的边维度,通常是 128x128。一些 TPU(自 TPU v3 起)有多个 MXU,TPU v3 为 2 个,TPU v4/5 为 4 个,所以我们需要确保切片维度大于 128 * MXU 数量。这里有一个很好的动画展示了这一点。

Trillium (TPU v6e) 有一个 256x256 脉动阵列,这意味着它可以每周期执行 4 倍多的 FLOPs。这也意味着你的张量维度需要大两倍才能完全利用 MXU。

这篇博客文章 有另一个关于固定权重矩阵的脉动阵列乘法的精彩动画。

随着人工智能推理需求的爆炸式增长,许多硬件初创公司正在设计特定领域架构(DSA)。
我们将从 Transformer 工作负载出发,反向推导,找出最优设计方案、前景广阔的硬件,并预测推理技术的未来发展方向。


脚注

[1] TPU v6e (Trillium) 有一个 256x256 MXU,而所有前几代使用 128x128。
[2] TPU,特别是它们的脉动阵列,之所以是如此强大的硬件加速器,是因为矩阵乘法是为数不多的使用 $O(n^3)$ 计算对应 $O(n^2)$ 字节的算法之一。这使得普通 ALU 很容易受限于计算而不是内存带宽。
[3] 我们有时谈论 VMEM 预取,指的是提前在 VMEM 中加载权重,这样我们可以掩盖 matmul 的加载成本。例如,在正常的 Transformer 中,我们有时可以在 attention 期间将大的前馈权重加载到 VMEM 中,如果我们是内存带宽受限的,这可以隐藏权重加载的成本。这要求我们的权重足够小或分片足够多,以适应单个层进入 VMEM 且有剩余空间。
[4] 在 Cloud TPU VMs 上,每个托盘作为单独 VM 的一部分暴露,因此再次可见 4 个核心。
[5] 光学开关只是一个具有相同 ICI 带宽的可重构连接。它只是让我们连接立方体同时保留环绕链路。
[6] 注意 2x2x4 将没有任何环绕,因为它们由光学开关提供,而光学开关仅在完整立方体上可用。然而,TPU v5e 8x16 将在长轴上有环绕,因为它不使用可重构光学网络。
[7] 上面的页面列出了 100 GB/s 的带宽,这与这里列出的略有不同。TPU ICI 链路具有略微不同的带宽,取决于正在执行的操作。你通常可以放心地使用本文档中的数字。
[8] TPU v6e 有 12.5e9 bytes/s,v5e 有 3.125e9 bytes/s。
[9] 通过双向 (bidi) 带宽,我们指的是可以在单个链路上双向发送的总字节数,或者同样地,从单个 TPU 沿特定轴发出的总字节数,假设我们可以有效地使用两个链路。当我们有一个功能正常的环时,即我们在特定轴上有一个环绕连接时,这是真的。这发生在推理芯片上有完整的 16 轴时,或在训练芯片 (v*p) 上有 4 的倍数的轴时。我们更喜欢使用双向带宽,因为它经常出现在涉及双向通信的计算中。
[10] 如果你不熟悉这个符号,它的意思是:将一个带有 bfloat16 元素的 8x128 矩阵与一个带有 bfloat16 元素的 128x128 矩阵相乘,并将结果存储在一个带有 float32 元素的 8x128 矩阵中。

![[Pasted image 20251130095355.png]]

Footnotes