斯坦福開源FlashAttention,大模型速度翻倍
一年時(shí)間,斯坦福大學(xué)提出的新型 Attention 算法 ——FlashAttention 完成了進(jìn)化。這次在算法、并行化和工作分區(qū)等方面都有了顯著改進(jìn),對(duì)大模型的適用性也更強(qiáng)了。
近來,幾種長(zhǎng)上下文語(yǔ)言模型陸續(xù)問世,包括 GPT-4(上下文長(zhǎng)度為 32k)、MosaicML 的 MPT(上下文長(zhǎng)度為 65k)Anthropic 的 Claude(上下文長(zhǎng)度為 100k)。長(zhǎng)文檔查詢和故事寫作等新興用例已經(jīng)表明擴(kuò)展語(yǔ)言模型上下文窗口是非常必要的。
然而,擴(kuò)大 Transformer 的上下文長(zhǎng)度是一個(gè)挑戰(zhàn),因?yàn)槠浜诵牡淖⒁饬釉跁r(shí)間復(fù)雜度和空間復(fù)雜度與輸入序列長(zhǎng)度的平方成正比。
一年前,來自斯坦福大學(xué)、紐約州立大學(xué)布法羅分校的研究者共同提出一種快速、內(nèi)存高效的注意力算法 ——FlashAttention。該算法無需任何近似即可加速注意力并減少內(nèi)存占用?,F(xiàn)在,已經(jīng)有許多機(jī)構(gòu)和研究實(shí)驗(yàn)室采用 FlashAttention 來加速訓(xùn)練和推理。
FlashAttention 示意圖
盡管 FlashAttention 的速度已經(jīng)是優(yōu)化基線的 2-4 倍,但它仍然有相當(dāng)大的改進(jìn)空間。FlashAttention 仍然不如優(yōu)化過的矩陣乘法 (GEMM) 運(yùn)算快,僅達(dá)到理論最大 FLOPs/s 的 25-40%。
現(xiàn)在,研究團(tuán)隊(duì)宣布推出 FlashAttention-2。FlashAttention-2 完全從頭開始重寫,使用 Nvidia 的 CUTLASS 3.x 及其核心庫(kù) CuTe 的原語(yǔ)(primitive)。
FlashAttention-2 開發(fā)者 Tri Dao。他是斯坦福大學(xué)博士生,還是 Together.AI 首席科學(xué)家,并將于 2024 年 9 月開始任職普林斯頓大學(xué)計(jì)算機(jī)科學(xué)助理教授。
FlashAttention-2 的速度是 FlashAttention 的 2 倍,在 A100 GPU 上達(dá)到 230 TFLOPs/s。在端到端訓(xùn)練 GPT 類語(yǔ)言模型時(shí),F(xiàn)lashAttention-2 可讓訓(xùn)練速度高達(dá) 225 TFLOPs/s(模型 FLOP 利用率為 72%)。
FlashAttention-2 將加速現(xiàn)有模型的訓(xùn)練、微調(diào)和推理。這意味著我們可以用相同成本訓(xùn)練 2 倍上下文長(zhǎng)度的語(yǔ)言模型。這將有助于語(yǔ)言模型理解長(zhǎng)篇書籍和報(bào)告、高分辨率圖像、音頻和視頻。
FlashAttention 是什么?
FlashAttention 是一種重新排序注意力計(jì)算的算法,它利用平鋪、重計(jì)算等經(jīng)典技術(shù)來顯著提升計(jì)算速度,并將序列長(zhǎng)度中的內(nèi)存使用實(shí)現(xiàn)從二次到線性減少。其中平鋪意味著將輸入塊從 HBM(GPU 內(nèi)存)加載到 SRAM(快速緩存),并對(duì)該塊執(zhí)行注意力操作,更新 HBM 中的輸出。
此外通過不將大型中間注意力矩陣寫入 HBM,內(nèi)存讀寫量減少,帶來了 2-4 倍的時(shí)鐘時(shí)間加速。
下圖為 FlashAttention 的前向傳遞圖:通過平鋪和 softmax 重新縮放,研究者按塊進(jìn)行操作,避免從 HBM 中讀取 / 寫入,同時(shí)獲得正確的輸出,無需近似操作。
然而,F(xiàn)lashAttention 仍然存在一些低效率問題,原因在于不同線程塊之間的工作分區(qū)不理想以及 GPU 上的 warp。這些導(dǎo)致低占用率或不必要的共享內(nèi)存讀寫。
FlashAttention-2
更好的算法、并行化和工作分區(qū)
更少的非矩陣乘法 Flops
研究者調(diào)整了 FlashAttention 的算法,從而減少了非矩陣乘法(non-matmul)的 Flops 數(shù)量。這點(diǎn)很重要,因?yàn)楝F(xiàn)代 GPU 具有專門的計(jì)算單元(例如 Nvidia GPU 上的張量核心),使得矩陣乘法速度更快。
舉例而言,A100 GPU 的 FP16/BF16 矩陣乘法的最大理論吞吐量為 312 TFLOPs/s,但非矩陣乘法 FP32 的理論吞吐量?jī)H為 19.5 TFLOPs/s。
換一種思考方式,每個(gè)非矩陣乘法 FLOP 比矩陣乘法 FLOP 的代價(jià)高 16 倍。為了保持高吞吐量,研究者希望在矩陣乘法 FLOP 上花費(fèi)盡可能多的時(shí)間。因此他們重寫了 FlashAttention 中使用的在線 softmax 技巧,以減少重新縮放操作、邊界檢查和因果掩碼操作的數(shù)量,而無需更改輸出。
更好的并行化
FlashAttention v1 在批大小和頭(head)數(shù)量上進(jìn)行并行化。研究者使用 1 個(gè)線程塊來處理一個(gè)注意力頭,總共有(批大小 * 頭數(shù)量)個(gè)線程塊。每個(gè)線程塊都計(jì)劃在流式多處理器(SM)上運(yùn)行,例如 A100 GPU 上有 108 個(gè)這樣的 SM。當(dāng)這個(gè)數(shù)字非常大(如 >= 80)時(shí),這種調(diào)度是有效的,這時(shí)可以高效地使用 GPU 上幾乎所有計(jì)算資源。
在長(zhǎng)序列的情況下(通常意味著小批量或少量頭),為了更好地利用 GPU 上的多處理器,現(xiàn)在研究者在序列長(zhǎng)度維數(shù)上額外地進(jìn)行并行化,使該機(jī)制顯著加速。
更好的工作分區(qū)
即使在每個(gè)線程塊內(nèi),研究者也必須決定如何在不同的 warp 之間劃分工作(一組 32 個(gè)線程一起工作)。通常情況下,每個(gè)線程塊使用 4 或 8 個(gè) warp,分區(qū)方案如下圖所述。
研究者改進(jìn)了 FlashAttention-2 中的這種分區(qū),減少不同 warp 之間的同步和通信量,進(jìn)而減少共享內(nèi)存讀寫。
對(duì)于每個(gè)塊,F(xiàn)lashAttention 將 K 和 V 分割到 4 個(gè) warp 上,同時(shí)保持 Q 可被所有 warp 訪問。這被稱為「sliced-K」方案。不過,這種方案是低效的,原因在于所有 warp 都需要將它們的中間結(jié)果寫入共享內(nèi)存,并同步,然后將中間結(jié)果相加。這些共享內(nèi)存讀寫會(huì)減慢 FlashAttention 中的前向傳遞速度。
在 FlashAttention-2 中,研究者將 Q 分割在 4 個(gè) warp 上,同時(shí)保持 K 和 V 可被所有的 warp 訪問。每個(gè) warp 執(zhí)行矩陣乘法以獲得 Q K^T 的切片,然后只需與 V 的共享切片相乘就能獲得相應(yīng)的輸出切片。warp 之間不需要通信。共享內(nèi)存讀寫的減少也可以提升速度。
新特性:頭維數(shù)高達(dá) 256、多查詢注意力
我們知道,F(xiàn)lashAttention 僅支持最高 128 的頭維數(shù),這適用于大多數(shù)模型,但有一些模型被遺漏了。
因此,F(xiàn)lashAttention-2 支持了高達(dá) 256 的頭維數(shù),這意味著 GPT-J、CodeGen 和 CodeGen2、StableDiffusion 1.x 等模型可以使用 FlashAttention-2 來獲得加速和節(jié)省內(nèi)存。
此外,F(xiàn)lashAttention-2 還支持了多查詢注意力(multi-query attention, MQA)以及分組查詢注意力(grouped-query attention, GQA)。它們是注意力的變體,其中多個(gè)查詢頭關(guān)注相同的鍵和值頭,以減少推理過程中 KV 緩存的大小,并可以顯著提高推理吞吐量。
注意力基準(zhǔn)結(jié)果
研究者在 A100 80GB SXM4 GPU 上,測(cè)量不同設(shè)置(無 / 有因果掩碼、頭維數(shù) 64 或 128)下不同注意力方法的運(yùn)行時(shí)。
結(jié)果發(fā)現(xiàn), FlashAttention-2 的速度是 FlashAttention(以及 xformers 庫(kù)和 Triton 中的其他實(shí)現(xiàn))的 2 倍。與 PyTorch 中的標(biāo)準(zhǔn)注意力實(shí)現(xiàn)相比,F(xiàn)lashAttention-2 的速度最高是它們的 9 倍。
A100 GPU 上的注意力前向 + 后向速度。
此外只需要在 H100 GPU 上 運(yùn)行相同的實(shí)現(xiàn)(不使用特殊指令來利用 TMA 和第四代 Tensor Core 等新硬件功能),研究者最高獲得了 335 TFLOPs/s。
H100 GPU 上的注意力前向 + 后向速度。
當(dāng)用于端到端 GPT 類模型訓(xùn)練時(shí),F(xiàn)lashAttention-2 有助于在 A100 GPU 上實(shí)現(xiàn)最高 225 TFLOPs/s(模型 FLOPs 利用率為 72%)。與優(yōu)化良好的 FlashAttention 模型相比,端到端實(shí)現(xiàn) 1.3 倍加速。
這里的基線是不使用 FlashAttention 的 Megatron-LM,它現(xiàn)在也可以選擇使用 FlashAttention 了。不久的將來,F(xiàn)lashAttention-2 也將集成到 Megatron-LM 中。
研究團(tuán)隊(duì)表示:下一步將針對(duì) H100 GPU 優(yōu)化 FlashAttention-2,以使用新的硬件功能。