不止于ZeRO:BMTrain技術(shù)原理淺析

與現(xiàn)有的大模型訓(xùn)練使用百余張顯卡相比,我們發(fā)起的CPM-Live?開源大模型直播訓(xùn)練實(shí)現(xiàn)了?8 張 A100 顯卡?訓(xùn)練百億大模型。這優(yōu)異效果的背后基于的是?大模型高效訓(xùn)練工具 BMTrain?和?模型倉庫?ModelCenter。與現(xiàn)有框架相比,BMTrain 能夠?qū)崿F(xiàn)大模型的低資源、高效訓(xùn)練,并且簡單易用,便于開發(fā)者上手。
支撐起 BMTrain 優(yōu)異性能表現(xiàn)的是其采用的多項(xiàng)分布式訓(xùn)練優(yōu)化技術(shù),它們共同解決了大模型訓(xùn)練過程中的?顯存占用?問題。為了深刻理解這一關(guān)鍵問題,我們不妨分析一下模型訓(xùn)練過程中的顯存占用情況。
模型訓(xùn)練中的顯存占用主要包括:模型參數(shù)、模型梯度、優(yōu)化器狀態(tài)、運(yùn)算中間變量。以下圖為例,訓(xùn)練過程中的顯存占用包括一份模型參數(shù)以及對(duì)應(yīng)的一份梯度,比較常用的 Adam 會(huì)保留兩倍參數(shù)量的優(yōu)化器參數(shù),除此之外還有一些運(yùn)算的中間變量。

根據(jù)上述分析,對(duì)于一個(gè)百億參數(shù)大模型,模型參數(shù)約 20G,訓(xùn)練過程中需要占用的顯存就會(huì)超過 80G,在每一張顯卡中都完整地維護(hù)這些內(nèi)容,顯存是遠(yuǎn)遠(yuǎn)不夠的。這就需要我們采用相關(guān)分布式訓(xùn)練技術(shù),進(jìn)行模型訓(xùn)練的顯存優(yōu)化。
為解決這一關(guān)鍵問題,在 BMTrain 中,我們通過?數(shù)據(jù)并行?降低運(yùn)算中間變量顯存占比、增大吞吐量,通過?ZeRO?降低模型參數(shù)、模型梯度、優(yōu)化器狀態(tài)的顯存占比,通過?Optimizer Offload?將優(yōu)化器狀態(tài)卸載到內(nèi)存上,通過?Checkpointing?和?算子融合?避免儲(chǔ)存運(yùn)算的中間變量,最后使用?通信計(jì)算重疊?進(jìn)一步降低整套系統(tǒng)時(shí)間花費(fèi)。
綜合使用這些技術(shù),BMTrain?可以實(shí)現(xiàn)?單張消費(fèi)級(jí)顯卡全參數(shù)微調(diào) BERT-Large,8 臺(tái) A100 小集群訓(xùn)練 GPT-3,在超大規(guī)模模型訓(xùn)練場景下與 DeepSpeed 等框架相比最多可節(jié)省?90%?的算力成本。想了解這些技術(shù)具體的細(xì)節(jié)嗎?本文來帶你一探究竟!

01?背景知識(shí)
分布式訓(xùn)練的核心精神是切割,將數(shù)據(jù)、參數(shù)等諸多要素切割到不同計(jì)算節(jié)點(diǎn)上進(jìn)行運(yùn)算。有切割就有合并,不同節(jié)點(diǎn)之間會(huì)頻繁通信以同步及匯總計(jì)算結(jié)果。
這里我們簡單介紹 5 個(gè)基本通信算子,這是分布式訓(xùn)練框架的重要基礎(chǔ)(以四張顯卡為例,由 rank0?到 rank3 表示):
? ?01??Broadcast
張量位于某張顯卡中,廣播后,每張顯卡都會(huì)獲得一個(gè)同樣的張量。

? ?02??Reduce
每張顯卡中存有一個(gè)張量,將這些張量進(jìn)行如求和、取max等計(jì)算后,其結(jié)果被置于指定的某張顯卡上。

? ?03??All Reduce
每張顯卡中存有一個(gè)張量,使用它們進(jìn)行相關(guān)計(jì)算后的結(jié)果被置于所有的顯卡上,各張顯卡上得到的結(jié)果相同。

? ?04 ?Reduce Scatter
每張顯卡中存有一個(gè)大小為 4d 的張量,張量之間進(jìn)行計(jì)算后的結(jié)果被平均切分為 4 份,每份的大小為 d,分別置于 4 張顯卡上。

? ?05??All Gather
每張顯卡中存有一個(gè)大小為 d 的張量,收集后,張量拼接的結(jié)果 (大小為 4d) 被置于所有的顯卡上,各張顯卡上得到的結(jié)果相同。


02 分布式訓(xùn)練
一種典型的分布式訓(xùn)練方法是使用數(shù)據(jù)并行,然而對(duì)于大模型來說,僅通過數(shù)據(jù)并行進(jìn)行顯存優(yōu)化是遠(yuǎn)遠(yuǎn)不夠的,我們需要更進(jìn)一步地進(jìn)行切割。進(jìn)一步優(yōu)化的技術(shù)主要來自兩大技術(shù)路線:在算子層面進(jìn)行切割的?模型并行、流水線并行技術(shù)?以及在顯存上進(jìn)行切割的?ZeRO技術(shù)。在BMTrain中,我們采用了?數(shù)據(jù)并行?和?ZeRO技術(shù)?來進(jìn)行模型的分布式訓(xùn)練,并將陸續(xù)支持模型并行與流水線并行。
? ?數(shù)據(jù)并行
數(shù)據(jù)并行通過減小每張顯卡上需要處理的 batch 大小來減少模型的運(yùn)行中間變量。具體來說,假設(shè)有 ?張顯卡,那么每張顯卡可以只去處理?
?的數(shù)據(jù),最后將各張顯卡計(jì)算得到的梯度進(jìn)行求和 ( all-reduce ) 即可。在這種方式中,每張顯卡都會(huì)獲得完整的梯度信息,最后每一張顯卡上分別執(zhí)行優(yōu)化器的 step。

???模型并行
模型并行技術(shù)嘗試將模型計(jì)算進(jìn)行切割。以全連接層為例,對(duì)于計(jì)算??

通過將參數(shù)矩陣分解為n個(gè)小矩陣

?每張顯卡上計(jì)算?

?然后通過 all-gather 通信即可獲得完整的結(jié)果??。在這種方法中,各張顯卡均處理同一批次的數(shù)據(jù),在計(jì)算時(shí)進(jìn)行合作。

與模型并行類似的一種解決思路是流水線并行,也是嘗試對(duì)訓(xùn)練計(jì)算進(jìn)行切分。相比于模型并行中對(duì) transformer 模型進(jìn)行縱向的計(jì)算切分,流水線并行則將不同層的 transformer block 計(jì)算劃分到不同的顯卡上。
? ?ZeRO
在實(shí)際訓(xùn)練中,優(yōu)化器 ( 如 Adam ) 狀態(tài)占用的顯存要比參數(shù)和梯度二者加起來還要多,因此 ZeRO(Zero Redundancy Optimizer,零冗余優(yōu)化器)技術(shù)首次提出對(duì)優(yōu)化器狀態(tài)進(jìn)行切分,每張顯卡上只負(fù)責(zé)優(yōu)化器狀態(tài)對(duì)應(yīng)的部分參數(shù)的更新。訓(xùn)練策略上,ZeRO 基于數(shù)據(jù)并行,不同的數(shù)據(jù)被劃分到不同的顯卡上進(jìn)行計(jì)算。根據(jù)對(duì)優(yōu)化器狀態(tài)、梯度、參數(shù)劃分程度的不同,ZeRO 技術(shù)包含 ZeRO-1/2/3 三個(gè)層次。
? ?ZeRO-1
因?yàn)?ZeRO 基于數(shù)據(jù)并行,首先需要通過 all-gather 操作獲取完整的模型參數(shù)更新結(jié)果,隨后每張顯卡根據(jù)自己的數(shù)據(jù)和模型參數(shù)完成對(duì)應(yīng)的前向傳播和反向傳播。在整個(gè)過程中,梯度和參數(shù)均完整地保留在每張卡上,隨后對(duì)梯度進(jìn)行 reduce-scatter,每張卡根據(jù)自己所劃分的優(yōu)化器狀態(tài)和梯度來計(jì)算對(duì)應(yīng)部分的模型參數(shù)。

? ?ZeRO-2
ZeRO-2 在 ZeRO-1 的基礎(chǔ)上進(jìn)一步對(duì)梯度進(jìn)行劃分。注意,由于在反傳的過程中,不需要始終保留完整的梯度,在計(jì)算當(dāng)前層梯度時(shí),只需要后一層輸入的梯度。因此在反傳的過程中,對(duì)于不參與后續(xù)反傳計(jì)算的梯度,可以立即 reduce-scatter 劃分到多塊卡上,這樣在訓(xùn)練過程中,梯度在每塊卡上的顯存占用,就變?yōu)樵鹊??了。反傳結(jié)束后,每塊卡再根據(jù)部分的梯度和優(yōu)化器狀態(tài),計(jì)算得到更新后的模型參數(shù),最后再將更新后的參數(shù)使用 all-gather 同步到其他的顯卡上。

? ?ZeRO-3
而 ZeRO-3 技術(shù),則是更進(jìn)一步將模型參數(shù)部分進(jìn)行切分。由于每張顯卡只有一部分的優(yōu)化器狀態(tài),只更新一部分的參數(shù),一個(gè)很直觀的思路就是每張顯卡上只維護(hù)優(yōu)化器需要更新的那一部分參數(shù)。然而,在模型的計(jì)算過程中,還是需要完整的模型參數(shù)。因而在 ZeRO-3 中,模型中的每個(gè)模塊在計(jì)算之前,都需要通過一次 all-gather 操作將參數(shù)恢復(fù)完整,并在前向計(jì)算結(jié)束后再將模型參數(shù)釋放掉。進(jìn)行反傳時(shí),再重新使用 all-gather 獲取參數(shù)計(jì)算梯度并使用 reduce-scatter 劃分梯度,如下圖。

通過使用 ZeRO-3 優(yōu)化,訓(xùn)練相關(guān)的所有信息均被切碎分散到不同的顯卡上,讓每張顯卡上的顯存占用都被降低到極致,使得每張顯卡上可以容下更大的 batch_size,更充分地利用計(jì)算核心,帶來更大的模型吞吐,同時(shí)將訓(xùn)練模型所需的顯卡數(shù)量降至最低。

不過在 ZeRO 的原論文中指出, ZeRO-3 增加了額外的一次參數(shù)通信時(shí)間(即反向傳播時(shí)的 all-gather ),因此會(huì)引入額外的通信開銷,在部分場景下性能不及 ZeRO-2 和模型并行。為了減少額外通信量帶來的效率損失,我們還額外引入了通信計(jì)算重疊的策略,這將在后面被介紹到。根據(jù)我們的實(shí)現(xiàn),實(shí)驗(yàn)結(jié)果表明 ZeRO-3 在 NVLink+IB 的環(huán)境下訓(xùn)練超大規(guī)模模型較聯(lián)合使用 ZeRO-2 和模型并行的方案會(huì)帶來更大的計(jì)算吞吐量提升。

03 顯存優(yōu)化
除了上述分布式訓(xùn)練方法外,BMTrain還通過 Optimizer Offload 和 Checkpointing 技術(shù)進(jìn)一步減少冗余的顯存占用,并以犧牲最少的通信代價(jià)為前提,做到了在極致顯存優(yōu)化下仍然能高效率地訓(xùn)練。
? ?Optimizer Offload
Optimizer Offload 是指將優(yōu)化器狀態(tài)從 GPU 卸載到 CPU 上,從而進(jìn)一步節(jié)省顯存。我們以 Adam 優(yōu)化器為例介紹為什么需要將優(yōu)化器的參數(shù)卸載。
在 Adam 中,優(yōu)化器需要維護(hù)梯度的移動(dòng)平均以及梯度平方的移動(dòng)平均:

正如前文所示,與模型參數(shù)相比, Adam 優(yōu)化器需要至少兩份的顯存占用量,這在混合精度訓(xùn)練中是一筆非常大的開銷。通過使用 ZeRO-3 的梯度切分,每張計(jì)算卡上的需要處理的梯度信息大幅減少,將這一部分 GPU 計(jì)算卸載至 CPU 上產(chǎn)生的通信需求較小,同時(shí) CPU 處理這樣切分后的梯度也不會(huì)特別吃力。據(jù)此,我們付出了極小量的額外開銷就將顯存開銷降低至原本的一半左右。

? ?Checkpointing
Checkpointing 技術(shù)是一項(xiàng)很早就被提出,用于優(yōu)化神經(jīng)網(wǎng)絡(luò)模型訓(xùn)練時(shí)計(jì)算圖開銷的方法。這種方法在 Transformers 等結(jié)構(gòu)的模型訓(xùn)練中,能夠起到非常明顯的作用。目前主流的 Transformers 模型由大量的全連接層組成,我們以全連接層為例進(jìn)行計(jì)算圖的顯存分析。

為了能夠在反向傳播中計(jì)算梯度,需要在正向傳播時(shí)記錄下參數(shù)矩陣??與輸入?
,這兩部分參數(shù)隨著正向傳播逐層累積,消耗了非常多的顯存。
因此,我們使用 Checkpointing 技術(shù)(也稱為亞線性內(nèi)存優(yōu)化),其核心方式是通過時(shí)間換空間,我們在模型各層之間設(shè)置檢查點(diǎn),只記錄每一層模型的輸入向量。在反向傳播時(shí),根據(jù)最近的 checkpoint 重新計(jì)算該層的局部計(jì)算圖。


04 框架實(shí)現(xiàn)的優(yōu)化
除了上述顯存優(yōu)化技術(shù)外,BMTrain 還在具體實(shí)現(xiàn)上進(jìn)行優(yōu)化,以期得到更好的加速效果。
? ?混合精度
傳統(tǒng)模型使用單精度參數(shù)進(jìn)行訓(xùn)練,在大模型訓(xùn)練中,我們可以通過使用半精度參數(shù)來降低參數(shù)量并節(jié)省運(yùn)算時(shí)間。具體實(shí)現(xiàn)上,BMTrain 在正向傳播和反向傳播的過程中均使用半精度進(jìn)行計(jì)算,并在優(yōu)化器中維護(hù)單精度的模型參數(shù)和優(yōu)化器參數(shù)。
使用混合精度的另一個(gè)好處在于能夠更好地利用顯卡中的 tensor core。較新的顯卡在 CUDA core 之外,還設(shè)置了專門用于張量運(yùn)算的核心 tensor core,利用 tensor core 將為程序帶來進(jìn)一步的性能提升。使用混合精度訓(xùn)練能夠更好地利用 tensor core 特性,從而為訓(xùn)練過程進(jìn)一步加速。
? ?算子融合
為了進(jìn)一步提升性能,我們在 CPU 和 GPU 層面均進(jìn)行了算子層面的實(shí)現(xiàn)優(yōu)化。在 CPU 上,我們使用多線程 + SIMD(單指令流多數(shù)據(jù)流) 的 CPU 編程方式,對(duì) Offload 至 CPU 計(jì)算的 Adam 優(yōu)化器進(jìn)行 CPU 上的計(jì)算加速,使其不會(huì)成為系統(tǒng)的性能瓶頸。在 GPU 上,我們使用算子融合的方式,將 Softmax 與 NLLLoss 算子合二為一,減小了中間結(jié)果的顯存占用。
? ?通信計(jì)算重疊
上文中提到,ZeRO3 技術(shù)將引入額外的通信時(shí)間,我們采用通信計(jì)算策略來進(jìn)行通信時(shí)間的優(yōu)化。以反向傳播為例,由于使用了 ZeRO-3 技術(shù),需要將切碎至各個(gè)計(jì)算卡上的模型進(jìn)行臨時(shí)的重組裝(對(duì)應(yīng)圖中的 Gather );而在反向傳播 ( 對(duì)應(yīng)圖中的 Calculate ) 之后,我們還需要將得到的局部梯度重新切碎至不同的計(jì)算卡上(對(duì)應(yīng)圖中的 Scatter )。我們通過不同的 CUDA stream 區(qū)分不同的操作,讓運(yùn)算和通信得以同時(shí)運(yùn)行,通過大量的計(jì)算時(shí)間隱藏通信的時(shí)間開銷。


05?性能展示
綜合使用上述技術(shù),BMTrain 在大模型訓(xùn)練上效果出色,在不同規(guī)模的算力條件下均有較好的性能表現(xiàn)。
? ?在單卡 2080Ti 上,BMTrain 可以實(shí)現(xiàn) transformers 庫無法實(shí)現(xiàn)的?3 億參數(shù) BERT-Large?微調(diào)。
? ?在單卡 V100 上,BMTrain 訓(xùn)練 3 億參數(shù) BERT-Large 較 transformers 實(shí)現(xiàn)能夠提高約 20 倍 batch size,2.5 倍吞吐量。
? ?在單機(jī) 8 卡 A100 環(huán)境下,BMTrain 訓(xùn)練 130 億參數(shù)的 GPT 較 Deepspeed / veGiantModel 實(shí)現(xiàn)能夠提高約 4 倍 batch size,1.6 倍吞吐量。
? ?在多機(jī) 8 卡 A100 環(huán)境下,BMTrain 可以使用較少 GPU 訓(xùn)練 1750 億參數(shù)的 GPT-3,性能詳見下表:

使用 BMTrain,64 張 A100 跑完 GPT-3 的 300B token 大概需要 2 年,服務(wù)器與顯卡租金大約 900 萬人民幣左右。根據(jù)我們的實(shí)驗(yàn)估算,使用 128 張 A100 時(shí),單卡吞吐量可以提升 2.5 倍以上,6 個(gè)月可以跑完 GPT-3,服務(wù)器租金大約 500 萬人民幣左右。雖然訓(xùn)練出 GPT-3 的成本依然高昂,但與 GPT-3 的 1200 萬美元相比,成本仍然?節(jié)約了 90%?以上。

06?未來展望
這篇文章原載于我們的微信公眾號(hào)“OpenBMB開源社區(qū)”,主要介紹 BMTrain 中的基礎(chǔ)加速算法,BMTrain 將持續(xù)關(guān)注大模型的高效訓(xùn)練和性能優(yōu)化,不斷優(yōu)化與升級(jí)。我們誠摯歡迎感興趣的研究人員與開發(fā)者加入我們的開源社區(qū),參與相關(guān)的研究交流、技術(shù)研討與工具開發(fā),共同為大模型的落地與應(yīng)用添磚加瓦!

附錄?參考文獻(xiàn)
1.?Samyam Rajbhandari, Jeff Rasley, Olatunji Ruwase, Yuxiong He.?ZeRO: Memory Optimizations Toward Training Trillion Parameter Models.
2.?Zhengda Bian, Hongxin Liu, Boxiang Wang, et al.?Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training.
3.?Adam Paszke, Sam Gross, Francisco Massa, et al.?PyTorch: An Imperative Style, High-Performance Deep Learning Library.
4.?Zhengyan Zhang, Xu Han, Hao Zhou, et al.?CPM: A Large-scale Generative Chinese Pre-trained Language Model.
5.?Zhengyan Zhang, Yuxian Gu, Xu Han, et al.?CPM-2: Large-scale Cost-efficient Pre-trained Language Models.
6.?Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova.?BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding.
7.?Colin Raffel, Noam Shazeer, Adam Roberts, et al.?T5: Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer.
8.?Alec Radford, Jeffrey Wu, Rewon Child, et al.?GPT2: Language Models are Unsupervised Multitask Learners.
9.?Ben Wang and Aran Komatsuzaki, et al.?GPT-J from EleutherAI released in the repo mesh-transformer-jax.
10.?Diederik P. Kingma, Jimmy Ba.?Adam: A Method for Stochastic Optimization.
11.?Yang You, Jing Li, Sashank Reddi, et al.?Large Batch Optimization for Deep Learning: Training BERT in 76 minutes.
12.?Hanlin Tang, Shaoduo Gan, Ammar Ahmad Awan, et al.?1-bit Adam: Communication Efficient Large-Scale Training with Adam's Convergence Speed.
13.?NCCL:?https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html

關(guān)注我們
微信搜索關(guān)注 “OpenBMB開源社區(qū)”
加入社群或獲取更多大模型干貨知識(shí)和前沿資訊!?

??傳送門|相關(guān)鏈接
?? ?官方網(wǎng)站:https://www.openbmb.org
?? ?GitHub:https://github.com/OpenBMB
?? ?交流QQ群:735930538
?? ?啟智社區(qū):https://git.openi.org.cn/OpenBMB
?? ?微博:http://weibo.cn/OpenBMB
?? ?知乎:https://www.zhihu.com/people/OpenBMB
?? ?Twitter:https://twitter.com/OpenBMB
