GPU 集群训练优化 阅读笔记(二)

溴化锂 溴化锂 Views -- #知识

https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=tensor_parallelism

上一章讲到了普通的数据层面的并行,这一章来看看Tensor Parallelism和Context Parallelism。

Tensor Parallelism(张量并行)

用到了一点神奇的线性代数知识。

根据上面的说法,我们可以把在其中完整的张量拆分。分为行拆分和列拆分。

原始表达
原始表达
列拆分(我们可以把W放在不同的GPU上进行计算)
列拆分(我们可以把W放在不同的GPU上进行计算)
行拆分,甚至还能把X拆出去
行拆分,甚至还能把X拆出去

在Transformer中,同样是有两种内容:sequence和parameters,那么对于两种内容我们也可以做出拆分。sequence是输入值,对应上图的X,parameters则是上面linear module中的W和B。

Transformer块中parameters的Tensor Parallelism

好了,知识点补充完成,我们继续看这个parameters的并行。

  • 在注意力部分:Q,K,V采用列分片,每个GPU负责部分注意力头。输出投影可用行分片。最后合并到一起计算
    • 限制:TP并行度不超过头数。
  • 在前馈部分,“列分片 + 行分片”效率更高

这里,通信效率仍然可能是瓶颈。单节点(8卡)内TP通信较快,TP度数越高,单卡吞吐下降越多,但是可以支持更大的batch size和模型规模。

内存节省方面,显著降低,可以让大模型在有限GPU上训练,但是后续操作(Dropout,LayerNorm)仍需全量激活,还可以优化!

可以看到,模型的Parameters内存占用显著降低。但是随着Sequence的增加,内存占用还是在增加。
可以看到,模型的Parameters内存占用显著降低。但是随着Sequence的增加,内存占用还是在增加。

Transformer的序列并行

根据上面的内容我们知道,sequence的内存增加是平方级别的。(因为对于一个长度为L的序列,在注意力机制中需要L^2个位置来存放注意力分数(对于每个token我们需要评估他和其他所有token的相关性)所以提出了序列并行。

核心思想是把序列分段。主要需要解决的难点是如何在序列被切分的情况下,高效地完成全局注意力计算。

在注意力层中,考虑序列分段到各个GPU上的场景,数据科学家们提出的做法如下:

  • 计算局部的Q,K,V。
  • All-Gather来获取全局的K和V。
  • 局部的Q和全局的K,V计算得到局部序列注意力输出,然后把全局K,V drop掉。
  • Reduce-Scatter把局部序列注意力相加,然后再次分割发回GPU。

而其他层(Dropout、LayerNorm),需要解决TP留下的”历史遗留“:LayerNorm和Dropout没有需要分割的部分,并且计算冗余。

完美的特性——计算是局部的:对序列中的第 i 个 Token 进行 LayerNorm,只需要这个 Token 自身的 h 维向量。它完全不需要序列中第 j 个 Token 的任何信息。这意味着: 一旦输入张量沿着序列维度被切分,每张 GPU 就可以在自己的数据分片上独立、完整地执行 LayerNorm 操作,而完全不需要和其他 GPU 进行任何通信。

效果:

对比上一张图,内存占用也显著降低了。
对比上一张图,内存占用也显著降低了。

局限性和优化:通信瓶颈仍然存在,实现逻辑较为复杂,部分层不适用。

Context Parallelism(上下文并行)

当上下文长度超长(128K + )时,Sequence Parallelism也无法处理Activation Value等的增加。所以引入新策略——Ring Attention

Ring Attention是一种高效的通信方式:每个GPU异步发送自己的key/value到下一个GPU,同时计算本地部分的注意力分数,循环进行,最终完成全序列的注意力计算。

这种方式虽然高效,但在因果注意力(causal attention)下,计算负载可能不均衡,需要进一步优化(如Zig-Zag Ring Attention)。

Zig-Zag Ring Attention 不是简单地顺序分配token到各GPU,而是将早期和晚期token交错分配,使每个GPU都能处理不同位置的token,从而均衡了各GPU的计算量。

这种分配方式让注意力掩码(attention mask)下的计算任务在所有GPU间分布更均匀,避免了某些GPU计算负载过重、某些过轻的问题。

处理流程:

  • 序列切分:长度切分
  • 局部计算:每个GPU接收到自己的子序列后,独立地、并行地计算该子序列对应的查询(Query)、键(Key)和值(Value)。
  • 全局信息同步:为了让每个GPU都能计算完整的注意力(即每个Query都能注意到全部的Key),需要通信:all-gather,将自己计算出的局部KV Cache广播给所有其他GPU。执行完 all-gather 后,每个GPU上就都有了完整的、来自全部序列的KV Cache。
  • 注意力计算:局部Query与全局的Key/Value进行注意力计算。

Comments

0 comments
?