上篇描述了朴素数据并行(DP)和分布式数据并行(DDP)策略。两者的总通信量相同,然而 DP 的单 Server 机制导致了通信负载不均的问题(Server 通信量和总 GPU 数量成线性关系),导致 DP 适用于单机多卡场景;DDP 通过 Ring-AllReduce 这一 NCCL 操作,使得通信量均衡分布到所有 GPU 上(单个 GPU 通信量不受到总 GPU 数量的影响),利于多机多卡场景。

解决通信量负载不均的问题后,开始考虑显存开销:每个 DP 分片都复制了一份完整的模型参数、梯度和优化器状态,造成极大的显存负担,如何优化呢?

在优化显存前,需要先理清两个关键问题:

  1. 训练过程中 GPU 上存储了哪些内容?
  2. 分别在哪些阶段需要使用?

先看第一个问题。

显存开销计算

当训练一个神经网络时,需要在内存中存储以下内容:模型参数 W、模型梯度 G、优化器状态、(用于计算梯度的)激活值

以上数据以tensor的形式存储在内存中。在训练各阶段的内存占用率如下图:

  1. 前向传播时,激活值迅速增加;
  2. 反向传播时,梯度逐渐积累,且计算梯度的激活值会逐步被清除;
  3. 最后,执行优化步骤,此时需要所有的梯度,并更新优化器状态;然后才开始下一次的前向传播。

第一步和后续步骤明显不同的原因:激活值快速增加,再保持一段时间的平稳。在第一步中,torch 的缓存分配器进行大量准备工作,预先分配内存;后续步骤不再需要寻找空闲内存块,从而加速)

传统 fp32 训练

假设模型参数 W 的大小是\(\psi\),以byte为单位,内存开销如下:

  • 模型参数:\(4\psi\)
  • 梯度:\(4\psi\)
  • Adam 优化器的动量和方差:\((4+4)\psi\)

混合精度训练

下图采用混合精度训练

过程如下:

  1. 存储一份 fp32 的 parameter, momentum 和 variance(统称 model states);在 forward 开始之前,额外开辟一块存储空间,将 fp32 parameter 减半到 fp16 parameter;
    • fp32 parameter 由优化器维护,只做更新;fp16 parameter 参与 FWD 计算
  2. forward/backward:产生精度为 fp16 的 activations 和 gradients;
  3. update:使用 fp16 gradients 更新 fp32 parameter;

模型收敛后,fp32 parameter 是最终的参数输出。

显存开销如下:

  • 模型参数:\(4\psi+2\psi\)
  • 梯度:\(2\psi\)
  • Adam 优化器的动量和方差:\((4+4)\psi\)

传统 fp32 训练和混合精度训练的显存占用量均为\(16\psi\),不会节省整体内存;只是重新分配。那混合精度训练有哪些优点呢?

  1. GPU 上使用 fp16 的 parameter,加速前向/反向传播过程;
  2. 减少前向/反向传播中产生的激活值的显存占用量。

为了通用表示,记模型必存的数据大小为\(k\psi\)(包括 fp32 parameter, fp32 动量和方差);显存开销为\(2\psi+2\psi+K\psi\)

DataFlow

现在考虑第2个问题:数据分别在哪些阶段使用呢?

  1. 模型参数 W:FWD/BWD 阶段;
  2. 梯度 G:FWD/BWD 结束后的 AllReduce,和参数更新阶段;
  3. Adam 优化下的优化器状态:参数更新阶段。

一个直观的策略是:我只需要在使用到数据时加载,而非全程保存在所有 GPU 上,就可以省下一笔存储空间了。

ZeRO 系列

假设显卡数量为 \(N\),提出以下3种 ZeRO 算法: 显存占用量分析:

  1. ZeRO-1:只对优化器状态进行分片,每张卡保存\(\frac{1}{N}\)​的状态量。此时,每张卡所需显存是\(4\psi+\frac{12\psi}{N}\)字节。当\(N\)较大时,趋向于\(4\psi\),记为\(P_{os}\)
  2. ZeRO-2:对优化器状态和梯度进行分片,此时,每张卡所需显存是\(2\psi+\frac{2\psi+12\psi}{N}\)字节。当\(N\)较大时,趋向于\(2\psi\),记为\(P_{os+g}\)
  3. ZeRO-3:将模型参数、梯度、优化器状态三者都进行分片,此时,每张卡所需显存是\(\frac{16\psi}{N}\)字节。当\(N\)较大时,趋向于\(0\),记为\(P_{os+g+p}\).
    • ZeRO-3 对应 Pytorch FSDP

通信量分析:(这里不再换算为 byte)

朴素 DDP

聚合并下发梯度时,采用 Ring-AllReduce:

  1. Reduce-Scatter 阶段:通信量为\((N-1)\frac{\psi}{N}\)
  2. All-Gather 阶段:通信量为\((N-1)\frac{\psi}{N}\)

单卡通信量近似为\(2\psi\)

ZeRO-1:\(P_{os}\)

将优化器状态(fp32 parameter + fp32 momentum + fp32 variance)分片存储在各个 GPU 上:

训练流程如下:

  1. 每张卡上存储一份完整的模型参数 W(fp16 parameter),将一个 batch 切分为多个 micro-batch,在每张卡上分别完成 FWD/BWD 计算,得到梯度 G(fp16 gradients);

  2. 聚合梯度 Reduce-Scatter:相比于朴素 DDP 使用 AllReduce,这里优化为 Reduce-Scatter,单卡通信量为\(\psi\),拿到更新自身\(\frac{1}{N}\)参数所需的 gradients 即可;

  3. 参数更新:

    1. 本地更新:使用完整梯度 G 和它负责的那部分 fp32 parameter,执行优化器步骤,更新它自己负责的那一部分 fp32 parameter
    2. fp16._copy(fp32)将更新后的 fp32 parameter 原位拷贝回本卡上完整 fp16 parameter 中对应的部分
      • 例如:负责第0~25%参数的 GPU,用它更新后的 fp32 parameter 覆盖掉自己 fp16 parameter 中第0~25%的参数。
    3. All-Gather:由于每张卡只更新了模型的一部分,需要通过 All-Gather 通信操作(单卡通信量为\(\psi\);使得所有卡同步到完整的、更新后的 fp16 parameter,用于下一次迭代的 FWD/BWD。

假设完整的 fp32 parameter 大小为 1000 MB;

fp16:ZeRO 系列将原始的 fp32 复制一份并转为 fp16;这个大小为500 MB 的完整 fp16 副本会保存在每张卡上fp16 用于 FWD/BWD 计算

  • 注意:fp16 是从更新后的 fp32 cast 得到并用于下一轮 FWD/BWD 计算(不会直接使用 fp16 gradients 更新)

fp32:每张卡上保存 \(\frac{1}{N}\) 的完整 fp32 parameter;操作为:fp32 = partiton_fp16.copy().float().detach()

  1. partiton_fp16:从本卡的完整 fp16 parameter 中,取出自己负责的那一部分(例如第0张卡负责0~25%的参数);
  2. .copy().float():将这一部分 fp16 parameter 复制出来,并转换为 fp32,得到了分区后的 fp32 parameter;
  3. .detach():将 fp32 parameter 从当前计算图中分离出来,不参与 FWD/BWD 计算,只用于权重更新

步骤如下图:

与朴素 DP 相比,ZeRO-1 在通信量相同的情况下,将单卡上优化器状态的显存占用降低到了\(\frac{1}{N}\)

ZeRO-2:\(P_{os+g}\)

在 ZeRO-1 的基础上,将 fp16 gradients 分片存储在各个 GPU 上。

在 gradient 的 Reduce-Scatter 操作中,只需要保持每张卡上自己维护的梯度是聚合梯度。例如对于 GPU 1,负责维护 G1,其他 GPU 只需要将 G1 对应位置的梯度发给它即可。 在上图中,聚合梯度后白色块无用,可从显存中移除;绿色块是最终的聚合梯度。

与 ZeRO-1 的通信步骤和通信量相同,优化点是进一步将 fp16 parameter 分片。

ZeRO-3:\(P_{os+g+p}\)

在 ZeRO-2 的基础上,将 fp16 parameter 也分片存储。

  1. 每张卡上只保存模型的部分参数 W(fp16 parameter),将一个 batch 切分为多个 micro-batch 输入到各张卡上;
  2. FWD:依次经过每个 layer。先对 W 做一次 All-Gather,聚合分布在其他卡上的 W,得到一份完整的 W(单卡通信量\(\psi\);做完 FWD,立刻从显存中清除不由自己维护的 W。
  3. BWD:对 W 做一次 All-Gather 取回完整的 W(单卡通信量\(\psi\)); 做完 BWD,对计算得到的梯度 G 做一次 Reduce-Scatter,从其他卡上聚合自己维护的那部分梯度(单卡通信量\(\psi\)。聚合操作结束后,立刻从显存中清除不由自己维护的 G。
  4. 每张卡用本地的 fp2 parameter 和梯度 G 执行优化器步骤,更新它自己负责的那一部分 fp32 parameter

单卡总通信量为 \(3\psi\)。与朴素 DP 相比,用 1.5 倍的开销,换回了单卡总显存占用量降低到\(\frac{1}{N}\)

与 ZeRO-2 相比,需要额外执行\(2*numLayers-1\)次 All-Gather,每次操作带来一个小的基础延迟开销。

显存占用量的对比图如下:

ZeRO-R:优化 Actications 的存储

以上为针对必存部分 model states 的显存优化,现在来看 activations.

不像模型参数、梯度、优化器状态对于模型更新是必须的,激活值只是起到加速梯度计算的作用,因此其存储策略支持灵活设置。

ZeRO-R 的做法是:将所有的激活值分片存储,即只对activation checkpoints分片。

正常情况:FWD 过程中保存每一个激活值,用于 BWD 时计算梯度,计算完对应梯度节点后才能释放激活值;

  • 缺点:需要保存大量的中间激活值,显存占用量随着层数线性增长的。

优化一:丢弃所有的中间激活值, BWD 需要时重新计算;

  • 缺点:训练速度慢,每个 BWD 原本只需要计算一次,现在最多需要计算n次!

折中做法:选取一些前向节点作为 checkpoint;训练时,这些checkpoint 的激活值会一直保存在显存中;而其他节点的激活值会被丢弃。

  • 优点:计算反向梯度节点时,只需要从离它最近的checkpoint节点开始计算(不用把每个节点都重新计算一遍)。

ZeRO-Offload:显存不够,内存来凑

GPU 显存不够用,则:将一部分计算和存储下放到 CPU 和内存;并且不让 CP 和 GPU 之间的通信成为瓶颈,也不让 CPU 参与过多计算,避免 CPU 计算成为瓶颈。

将数据流图切分成 CPU 和 GPU 两部分。ZeRO-Offload 策略如下:将计算量高的 FWD 和 BWD 放在 GPU 上;计算量低的参数更新和 float2half 这两个计算操作放在CPU上。因此,优化器状态也放在内存中。

上述方法仅仅针对单卡场景。在多卡场景下,ZeRO-Offload 采取 ZeRO-2 策略:将分片后的\(\frac{1}{N}\)个​优化器状态和梯度都下放到内存,只在CPU上进行参数更新

更多内容参考:大模型并行训练技术(一)—— ZeRO系列

Pytorch-FSDP

FSDP Engine 默认采取 ZeRO-3 策略。详细用法参考:FullyShardedDataParallel — PyTorch 2.2 documentation

FSDP 在 DeepSpeed ZeRO-DP 的基础上进行了拓展,分片策略包括:NO_SGARD(对标朴素 DDP);SHARD_GRAD_OP(对标ZeRO2);FULL_SHARD(默认策略,对标 ZeRO3);HYBRID_SHARD(节点内 shard,节点间 replicate,对标ZeRO++ stage3),_HYBRID_SHARD_ZeRO2(节点内 ZeRO-2 shard,节点间replicate)。

训练流程如下图:

细节详见下一篇。

ZeRO vs TP

做一个小结:ZeRO 系列的关键思路是“通信换显存”。

看上去 ZeRO 对 parameter 做了拆分,为什么不算 TP 呢?

因为在 FWD/BWD 过程中,需要通过 All-Gather 聚合每张卡上维护的 model weights 才能计算,即本质上还是在用完整的 model weights 计算,所以依然是 DP;与之对比的 TP 在 FWD/BWD 中只需要使用自己维护的部分 model weights 进行计算。

参考

图解大模型训练之:数据并行下篇( DeepSpeed ZeRO,零冗余优化)