Confident Adaptive(上)

本文首發(fā)于網(wǎng)站 機器翻譯學(xué)堂
轉(zhuǎn)載事宜請后臺詢問哦
作者 | 蒙龍
單位 | 東北大學(xué)自然語言處理實驗室

論文概況
大模型在許多任務(wù)中獲得了顯著的性能提升,這些收益往往伴隨著模型規(guī)模的急劇增加,導(dǎo)致模型在推理時的緩慢和昂貴。條件計算是一種動態(tài)模型推理加速方法,根據(jù)輸入的難易程度不同來分配不同的計算量。這次的分享通過兩篇論文來為大家分享條件計算中的Confident Adaptive方法。這兩篇論文分別是Consistent Accelerated via Confident Adaptive Transformers[1]以及Confident Adaptive Language Modeling[2]。這兩者一脈相承,都用了Confident Adaptive的方法,后者在前者的基礎(chǔ)上,對問題進(jìn)行進(jìn)一步的泛化和討論。第一篇論文收錄在2021的EMNLP,第二篇是在2022年的NIPS上。


條件計算
在正式介紹這兩篇文章之前,筆者先給大家介紹一下什么是條件計算以及條件計算中的一個基本范式是什么。條件計算是一種動態(tài)的模型推理加速方法,它的一個基本的假設(shè)是輸入的難易程度不同,它所需要的計算量是不一樣的,這樣我們就可以根據(jù)難易程度的不同來動態(tài)分配不同的計算量。條件計算中比較常用的兩個方式是早退以及MoE,它們分別是在模型的深度和寬度上來調(diào)整計算量。
自適應(yīng)早期退出
對于Transformers這種多層架構(gòu),一種流行的方法是自適應(yīng)早期退出,也就是早退。舉一個簡單的例子來說,給定一個比較深的三十二層的一個網(wǎng)絡(luò),我們可以在中間的某一層例如第20層退出,這樣就不需要完整地計算三十二層這么深的一個網(wǎng)絡(luò)。但是它和簡單地把模型砍到為20層不同,它依然保留了在深層輸出的一個能力,讓模型自己去為不同的輸入選擇最適合它的那一層進(jìn)行輸出。因此,從這里我們可以看出來,早退里面有兩個比較基本的問題。一個是模型如何再中間層輸出結(jié)果。
第一個問題相對比較直接,目前比較主流的一種方式就是隨時結(jié)構(gòu)化預(yù)測,也就是我們在每一個中間層的后面都接上一個輸出的分類器,這個分類器和最后一層的分類器一樣,輸入中間層的隱藏狀態(tài)表示,輸出一個分類類別維度的分布。然后使用對齊訓(xùn)練的方式同時優(yōu)化這幾個分類器。
第二個問題是給定一個輸入,我們應(yīng)該在第幾層退出,怎么去找到這一個最適合它的層。對于第二個問題相對比較開放,百家爭鳴。這個問題目前沒有一個說百分百ground truth的標(biāo)簽,不同的文章提出不同的假設(shè),也就是不同的oracle,再根據(jù)不同的oracle構(gòu)建出不同的偽標(biāo)簽。例如有的人認(rèn)為我們應(yīng)該在分?jǐn)?shù)最高的那一層退出,有的人認(rèn)為我們應(yīng)該正確token數(shù)最多的那一層退出,有的人認(rèn)為應(yīng)該用互信息表示來衡量,有的人認(rèn)為用語言模型的重建損失進(jìn)行衡量等等。
Confident Adaptive
我們今天要介紹的Confident Adaptive的方式也是這樣的模式,文章的作者認(rèn)為,我們應(yīng)該在哪里退出呢,當(dāng)中間某一層的結(jié)果和最后一層結(jié)果一致的時候,我們就可以進(jìn)行早退了。我們?nèi)绾稳ダ斫馑@個中間層的結(jié)果和最后一層一致呢。

我們不妨來看這個圖,這是一個Vitamin C的數(shù)據(jù)集,它給一個Claim和Evidence,然后模型需要判斷這個Evidence有沒有支持這個Claim,他一共是有三個標(biāo)簽分別是Support,Refuse和Not Enough Info,橫坐標(biāo)是模型的層數(shù),對于圖片中例子2,我們可以看到最后一層的輸出結(jié)果是Refuse,所謂的一致層就是和最后一層結(jié)果一樣的層,例如這里的第10層和第17層,其余的就是不一致層。當(dāng)我們在進(jìn)行早退時,只需要找到最早的那一個一致層就可以了。它這種做法其實還是比較直觀比較好理解的。然后就是我們要用什么辦法去找到這一個層。
模型一致性
給定一個固定的、深層的原始模型, 我們創(chuàng)建了一個可以早退的模型
,
里面包括早退的中間分類器
。然后, 我們以任意高的概率 (如 95%的樣本) 保證
與原始模型
一致。
怎么去理解這一個公式呢, 簡單來說, 給定個樣本,?
?如果誤差頻率
不超過
?, 那么我們就認(rèn)為這個模型
是
的。通過這樣的設(shè)計, 確保了
至少保留了
的
原始性能, 就可以保證模型的性能的一個穩(wěn)定性。在這些約束條件下, 剩下的問題是如何使
相對高效。例如, 一個肯定一致的, 但沒有實際加速的做法, 就是讓
恒等
。
這里有一個比較重要的點需要注意一下,目前在早退中比較重要的一個問題是模型的效果不穩(wěn)定。筆者現(xiàn)在做的一些實驗里面也會有這種問題,簡單隨意地決定什么時候進(jìn)行早退, 可能會導(dǎo)致模型精度的不可預(yù)測的下降。因此如何去量化模型預(yù)測中的這種不穩(wěn)定, 這對于在不過度犧牲性能的情況下, 同時能夠加快預(yù)測是至關(guān)重要的。
CATs 模型結(jié)構(gòu)
我們首先來看 Confident Adaptive Transformers (CATs) 模型結(jié)構(gòu)的一個形式化表示, 具體來說, 給定一個模型, 在預(yù)測
之前,?
將輸入
映射到一系列的特征表示,?
在這里就是一個
層的 Transformer。CATs 做的是分類和回歸任務(wù)。一個基本的模式就是, 對于下游任務(wù), 我們假設(shè)輸入中包含一個[CLS]token, 專門表示用于預(yù)測。產(chǎn)生一系列[CLS]token 的隱藏狀態(tài)表示, 每一個對應(yīng)一層的隱藏層表示
在每一層的后面我們接上一個分類器,對于分類任務(wù)我們使用的分類器如下,
最后一層的分類器和原始模型
的最后一層分類器保持一致, 額外的產(chǎn)生的參數(shù)一共是
, 在原來的訓(xùn)練數(shù)據(jù)上可以比較快速的微調(diào)。
為了找到一個高效的, 我們需要一個可靠的信號來告訴模型當(dāng)前的預(yù)測是否有已經(jīng)是和最后一層的預(yù)測一致
。這里和之前的很多工作一樣, 使用了一個額外的比較小的一個專用分類器
?。
然后我們在另一個無標(biāo)簽的數(shù)據(jù)集上來訓(xùn)練這個,當(dāng)前的 “早期” 的隱藏狀態(tài)以及其他幾個已處理過的特征作為輸入,
用交叉熵來訓(xùn)練,目標(biāo)函數(shù)是當(dāng)前層輸出和原始模型輸出一致的示性函數(shù)
有了中間分類器和給出早退信號的
這兩個零件之后, 我們就可以將
完整的表示出來
其中,?是置信度閾值。關(guān)鍵的挑戰(zhàn)是如何校準(zhǔn)
, 使
保證是???-consistent 的。
校準(zhǔn)預(yù)熱
一個比較簡單的校準(zhǔn)的做法是在校驗集上優(yōu)化,但是需要滿足如下的經(jīng)驗一致性約束,

其中 exit(.) 指的是模型在第幾層退出,指的是在校驗集上的算術(shù)平均, 但是這種校準(zhǔn)的方法效率較低。因此文章使用了一種叫 Conformal Prediction 保形預(yù)測的方法用來校準(zhǔn)
。
保形預(yù)測
保形預(yù)測是由Vovk,Gammerman,Shafer(2005)[3]提出的。并且它統(tǒng)計的理論由Lei, Robins and Wasserman (2013), Lei and Wasserman (2014), Lei, G’Sell, Rinaldo, Tibshirani and Wasserman (2017), Sadinle, Lei and Wasserman (2018)等人不斷發(fā)展。

Conformal Prediction(CP)將區(qū)間估計的思想用在預(yù)測問題上。在進(jìn)行點估計時,我們給位置參數(shù)只給出一個點的估計值,而區(qū)間估計是給出一段區(qū)間,這時我們就有更大的把握讓未知參數(shù)落在這個區(qū)間里面。對預(yù)測也有同樣的概念,相比于只給一個點的預(yù)測,我們可以給出一個預(yù)測的集合。
CP?的一個基本的模式是, 給定個數(shù)據(jù)輸入和標(biāo)簽的數(shù)據(jù)對,?
,CP?根據(jù)這?
個數(shù)據(jù)構(gòu)造一個集值函數(shù)?
?, 這個集值函數(shù)
需要滿足, 再來一個
時,?
落在我們估計的區(qū)間 (也就是
的輸出) 的概率要大于
它具體是怎么使用的呢, 大家不要忘了, 我們校準(zhǔn)的目的是為了找個一個高效的, 也就說我們需要給定一個輸入后, 我們要找到最早的那一個一致層。
我們假設(shè)集合是與原始模型最后一層預(yù)測不一致的層的索引。為了保證???-consistent, 我們應(yīng)該盡量避免在這些層退出,
同樣, 假設(shè)我們現(xiàn)在從訓(xùn)練數(shù)據(jù)里面拿了個樣本?
出來, 我們現(xiàn)在把這?
?個樣本輸入到模型
中, 我們就可以得到這些樣本各自的一個?
, 如就是相當(dāng)于?
?, 我們 將?
與保形程序配對, 通校準(zhǔn)的閾值?
, 得到了?
的保形預(yù)測,?
?, 使得
現(xiàn)在先不看為什么保形預(yù)測會是這種形式, 然后我們對??取補集?
?, 因為?
?是不一致層 的集合, 我們?nèi)⊙a集之后就得等到了一致層的集合, 然后我們?nèi)≌覀€補集中最小的值就作為?
?選擇退出的層, 就可以保證模型
?是???-consistent。
我們現(xiàn)在回過頭來看為什么不一致層的保形預(yù)測??會是這樣的一種形式。
參考文獻(xiàn):
[1] Schuster T, Fisch A, Jaakkola T, et al. Consistent accelerated inference via confident adaptive transformers[J]. arXiv preprint arXiv:2104.08803, 2021.
[2] Schuster T, Fisch A, Gupta J, et al. Confident adaptive language modeling[J]. Advances in Neural Information Processing Systems, 2022, 35: 17456-17472.
[3] Vovk V, Gammerman A, Shafer G. Algorithmic learning in a random world[M]. New York: Springer, 2005.
[4] Angelopoulos A N, Bates S, Candès E J, et al. Learn then test: Calibrating predictive algorithms to achieve risk control[J]. arXiv preprint arXiv:2110.01052, 2021.

hi,這里是小牛翻譯~
想要看到更多我們的文章,可以關(guān)注下
機器翻譯學(xué)堂(公號或網(wǎng)站)
筆芯~

往期精彩文章


