Confident Adaptive(下)

本文首發(fā)于網(wǎng)站?機(jī)器翻譯學(xué)堂
轉(zhuǎn)載事宜請后臺詢問哦
作者 | 蒙龍
單位 | 東北大學(xué)自然語言處理實(shí)驗(yàn)室

保形校準(zhǔn)
保形預(yù)測是基于假設(shè)檢驗(yàn),我們首先來復(fù)習(xí)一下假設(shè)檢驗(yàn)的一些基礎(chǔ)概念,假設(shè)檢驗(yàn)他是一個統(tǒng)計(jì)決策的過程,對于一個統(tǒng)計(jì)模型,我們提出一個假設(shè),根據(jù)抽取到的樣本,來作出是接受還是拒絕這個假設(shè)。它的一個基本的流程是提出一個統(tǒng)計(jì)假設(shè);選取一個合適的檢驗(yàn)統(tǒng)計(jì)量;利用零假設(shè)成立時檢驗(yàn)統(tǒng)計(jì)量的分布構(gòu)造出一個小概率事件;代入樣本觀察值,如果使得這個小概率事件發(fā)生,就否定零假設(shè)而去接受對立假設(shè),否則說明樣本沒有提供否定零假設(shè)的顯著性證據(jù),因此應(yīng)該接受零假設(shè)。
而在這里,文章認(rèn)為零假設(shè)是第層是不一致的,自然對立假設(shè)是第
層是一致的,
這里的我們選擇??來作為我們的檢驗(yàn)統(tǒng)計(jì)量, 因?yàn)?
?是用來訓(xùn)練預(yù)測?
?的,?
?越大,在一定程度上代表該層是一致的概率越大。因此它的拒絕域就應(yīng)該是?
?偏大,?
?, 功效函數(shù)是否定零假設(shè)的概率, 要求我們犯第一類錯誤的概率不超過?
?,
所謂的犯第一類錯誤的概率指的拒真的概率,也就是當(dāng)零假設(shè)成立時, 我們拒絕零假設(shè)的概率, 這里就是第??層實(shí)際是不一致層, 但是我們認(rèn)為它一致。
一般的檢驗(yàn)統(tǒng)計(jì)量我們要求它是當(dāng)零假設(shè)成立的時候, 它的分布是完全已知的, 但是現(xiàn)在??的分布是末知的。那怎么辦呢? 我們可以用經(jīng)驗(yàn)分布來近似地代替。可以知道經(jīng)驗(yàn)分布函數(shù)依概率,甚至幾平處處收斂到真實(shí)的分布函數(shù)。
經(jīng)驗(yàn)分布函數(shù)實(shí)際上是用樣本觀測值小于??的頻率去估計(jì)概率。我們現(xiàn)在需要的是零假設(shè)成立的時候,?
?的分布, 所以樣本的觀測值應(yīng)該是不一致時?
?的值, 我們把校準(zhǔn)集中的數(shù)據(jù)輸入到模型里面, 直接去取就可以了。
這樣我們就可以通過分位點(diǎn)把??求出來, 也就是我們需要校準(zhǔn)的?
到此??的保形預(yù)測?
?我們就構(gòu)建出來了,
最后還有一個小問題就是??的取值, 因?yàn)槲覀円还灿?
?個層, 我們需要做?
?次獨(dú)立的假設(shè)檢驗(yàn),?
?次假設(shè)檢驗(yàn)中, 假設(shè)?
?和?
?, 假設(shè)檢驗(yàn)之間相互獨(dú)立, 不犯錯誤的概率為?(1?0.01)100=36.6% , 而至少犯一次錯誤的概率高達(dá)P=1?(1?0.01)100=1?0.366=63.4?%, 因此為了讓最終犯錯誤的概率不超過???, 我們需要對?
?進(jìn)行校正。這里我們使用的是 Bonferroni Correction, 也就是讓所有的?
?加起來等于??
最后我們整體來看一下它的一個流程。首先我們需要給定一個原始模型??, 構(gòu)建能夠早退的模型?
?; 訓(xùn)練中間分類器?
?以及給退出信號的?
?; 通過校準(zhǔn)得到?
?, 最后就可以用模型?
?來做推理了, 當(dāng)?
?的值大于?
?時進(jìn)行早退。
CATs實(shí)驗(yàn)
CATs在三個分類任務(wù)以及一個回歸任務(wù)上進(jìn)行了實(shí)驗(yàn),分別是情感分析任務(wù)(IMDB)、事實(shí)驗(yàn)證任務(wù)(VitaminC)、新聞主題分類任務(wù)(AG)以及語義文本相似度回歸(STS-B)。Baseline模型使用的是24層的Albert-xlarge。分類任務(wù)的評價指標(biāo)使用的是正確率,回歸任務(wù)的使用的是Pearson-correlation。分類實(shí)驗(yàn)的結(jié)果如下所示

其中Static為固定在模型的某一層退出。SM指的是使用Softmax的值作為退出信號,Meta指的是使用中間分類器。Thres.指的是非CP的方法,對于SM使用的Temperature Scaling進(jìn)行校準(zhǔn),對于Meta則是當(dāng)它超過1-??時退出。與Indep.和Shared使用了CP的方法進(jìn)行對照,分別是獨(dú)立校準(zhǔn)和共享校準(zhǔn)。括號中的值是保證的理論性能的下界,也就是原始模型性能的(1-?)。
雖然在某些水平下非CP的方法也表現(xiàn)出很強(qiáng)的競爭力量,但是它缺乏正式的性能保證。對于比較嚴(yán)格的一致性水平下,Shared的方法在設(shè)定的范圍內(nèi)獲得了顯著的性能提升。Meta的方法比SM的方法在分類標(biāo)簽比較多的時候性能更好,因?yàn)榇藭rSM的和一致性的相關(guān)程度會下降?;貧w任務(wù)的結(jié)果如下,

在這里, Meta 的一個優(yōu)勢是它可以適用于多種輸出類型的任務(wù)。需要注意的是, 不管??的輸出是什么形式,?
?的事件空間始終存在。也就是說這種方法可以 很好地適應(yīng)分類之外的任務(wù), 例如 CALM。
CALM需要回答的問題
CALM在CATs的基礎(chǔ)上對問題做了進(jìn)一步的擴(kuò)展和討論,在原本的CATs中,文章主要集中在分類以及回歸任務(wù)上,CALM拓展到了翻譯,新聞總結(jié)和QA等文本生成任務(wù)上。我們可以試想一下,如果直接把CATs的校準(zhǔn)的流程直接用在生成任務(wù)上會有什么問題。我認(rèn)為主要有兩個,一個Auto-Regressive的生成方式,另外一個是如何判斷一致。
當(dāng)我們給定一個輸入(可能是一個待翻譯的句子,或者是一個需要總結(jié)的一段話),此時輸出的token是一個一個產(chǎn)生的,后一個token的生成依賴于前面的token,它和分類任務(wù)中那種主要看[CLS]token并行的結(jié)構(gòu)(Encoder)不同。當(dāng)前面一個token提前退出之后,它后面層的隱藏狀態(tài)沒有了,那么后面的token在計(jì)算attention的時候如果直接進(jìn)行隱藏層拷貝的時候會不會有問題?當(dāng)前面的token由于提前退出而出現(xiàn)擾動的時候,會不會影響整個句子的生成。
我們應(yīng)該怎么去判斷當(dāng)前的輸出和最后的一層輸出是否一致?如果我們直接還是用之前那種比較尖銳的是否相等的這種方法是否恰當(dāng)?當(dāng)我們是分類任務(wù)的時候,標(biāo)簽的數(shù)目比較少的,但是生成任務(wù)詞表的維度要比這個大得多。并且,我們應(yīng)該是讓早退輸出的每一個token都和完整模型的輸出對應(yīng)相等(local consistent)是不是太嚴(yán)格了,我們?nèi)绾巫寖蓚€模型的輸出保持一個全局上的一致性。
CALM模型結(jié)構(gòu)
CALM 的模型結(jié)構(gòu)和 CATs 基本一樣, 當(dāng)模型的信心足夠的時候, 可以用模型中間層的隱藏狀態(tài)??進(jìn)行輸出?
,
?代表在模型當(dāng)前的局部的信心的分?jǐn)?shù), 分?jǐn)?shù)的值越高, 越傾向提前退出。
?代表某個局部提前退出的閾值。
隱層拷貝
對于隱層拷貝的問題,文章做了實(shí)驗(yàn)來進(jìn)行探索。通過控制輸出的token一樣來檢驗(yàn)隱藏拷貝的問題,像是控制變量的過程。文章用了這樣的一個oracle,當(dāng)中間層預(yù)測的結(jié)果和最后一層一樣的時候退出,這樣,唯一可能導(dǎo)致生成差異的因素是skipped layers隱層拷貝問題。完整的模型的ROUGE-L的分?jǐn)?shù)是38.32,而這個oracle的分?jǐn)?shù)是38.24,平均退出層數(shù)只要有1.53,完整的模型是8層。文章還嘗試了一個一直用第一層的隱藏層的oracle,達(dá)到了38.31的分?jǐn)?shù)。
各位讀者可以看到兩個oracle的分?jǐn)?shù)都很高。筆者的理解是這里的實(shí)驗(yàn)過程和實(shí)際的推理過程不同,這里的目的是為了探索實(shí)驗(yàn),因此我們可以不計(jì)代價地先將最后一層的token算出來之后,再從前面的層,去選和最后一層一樣的層。
這個實(shí)驗(yàn)說明了兩個事情,一個是模型對于隱層拷貝的魯棒性很高,另外一個是模型的早退潛力很大。文章還做了一個實(shí)驗(yàn)是,如果在計(jì)算attention時用skipped layer的K和V,那么模型的分?jǐn)?shù)會下降很大。
局部錯誤敏感性
文章檢驗(yàn)了局部單個token的擾動對整個句子生成的影響。文章實(shí)驗(yàn)了兩種擾動的方法,一個Sample-Based,在解碼時間步t時,在采樣時我們用分?jǐn)?shù)排在第10的token。另外一個是Layer-Based,在解碼時間步t時使用第一層的預(yù)測的token。結(jié)果如下圖所示,越早的擾動會讓整個輸出的句子分?jǐn)?shù)更低,因?yàn)橛懈嗪罄m(xù)的token受到了影響,并且Layer-Based的影響要比Sample-Based的影響更小,因?yàn)樵趯?shí)際解碼時,很多早退的預(yù)測是準(zhǔn)確的。

根據(jù)上述觀察,文章提出了一個隨著時間步衰減的閾值,隨著解碼的不斷進(jìn)行,它會更加容易早退。閾值建模成一個具有用戶定義溫度系數(shù)的一個函數(shù),函數(shù)圖像如下

本質(zhì)上,這個函數(shù)提供了一個折中方案,即我們不想簡單地對所有的解碼時間步使用同樣的閾值, 也不需要在每個時間步這么巨大空間上搜索。
一致性判斷
對于怎么去判斷是否是一致的問題, 文章提出了兩種方法, Textually Consistent 和 Risk Consistent。給定一個無標(biāo)簽的校驗(yàn)集?,
和
分別是完整的模型的輸出和早退的輸出(相當(dāng)與 CATs 中的原始模型?
?以及早退的模型?
?)。
Textually Consistent 通過一個文本距離函數(shù)??來衡量, 當(dāng)?
和
距離小于用戶設(shè)定的一個容忍度?
?的時候我們認(rèn)為它是一致的,
然而, 對于某些任務(wù), 強(qiáng)制要求與
文本一致可能是不必要的, 特別是在可以接受多個的 Reference 的情況下。在這種情況下,我們可以在校驗(yàn)集
中加入 Target Reference, 這個 Reference 可以是不止一個,?
。 通過一個風(fēng)險函數(shù)
來計(jì)算輸出與 Reference 的分?jǐn)?shù), 當(dāng)兩者的分?jǐn)?shù)小于容忍度
時, 我們論為它是一致的
在實(shí)際操作中這個和
可以用 F 1 或者 ROUGE 等指標(biāo)。
Confident度量
在怎么給出模型應(yīng)該早退的這個信號上, CATs 是額外訓(xùn)練了一個分類器, 而 CALM里面還用了三種不同的度量方法。
第一個是 Softmax Response, 取前兩個 token 的之間的差值。在詞表較大的時候, 會導(dǎo)致比較大的運(yùn)算, 不過, 下一層可以并行啟動計(jì)算, 從而避免額外的運(yùn)行時間;
第二個是 Hidden-State Saturation 連續(xù)兩層隱藏狀態(tài)的余弦距離, 它是一種簡單的無參數(shù)和快速的方法。根據(jù)定義, 最早的可以退出的層是第二層 (除非), 這個度量試圖找到隱藏狀態(tài)的飽和狀態(tài);
第三個是 Early Exit Classifier, 和 CATs 一樣, 訓(xùn)練一個專門分類器來預(yù)測在當(dāng)前隱藏狀態(tài)下以局部一致性退出的可能性。
筆者認(rèn)為, 這里我們需要的 Confident 度量方法, 其實(shí)是找到一個信號, 這個信號可以比較好地捕捉到當(dāng)前層和最后一層一致的信息。因此我們并不一定需要是去訓(xùn)練一個分類器, 這種早退的信息也可能會隱藏在其他的信號里面。并且其實(shí)分類器的隱含的一個前提假設(shè)是這個分類器給出的置信度很大程度上代表了概率, 但是實(shí)際上分類器的預(yù)測的值是高度不可靠的, 雖然它隱含了一些信息, 但是它和真實(shí)分布之間可能還是有很大的差距。
CALM 校準(zhǔn)
CALM 的校準(zhǔn)過程和 CATs 類似, 同樣是把調(diào)參問題建模成一個多重假設(shè)檢驗(yàn)問題,使 用了一種叫 Learn then Test (LTT)?[4]?的方法, 并且 CALM 在不同層之間使用同樣的一個閾值。 具體的做法如下, 首先我們構(gòu)造閾值的一個集合?可能滿足約束也可能不滿足, 它可以盡量包含
可能的范圍, 例如它是一個 0-1 之間的等差數(shù)列。LTT 可以找到一個子集
, 滿足
對里面所有的值進(jìn)行如下的一個假設(shè)檢驗(yàn),零假設(shè)是模型是不一致的, 當(dāng)拒絕零假設(shè) 時, 我們就相當(dāng)于是找到了我們需要的一個閾值, 就可以把它放入到
里面
and
are not constant.
通過 Holding's inequality 我們可以直接把這個假設(shè)檢驗(yàn)?pp?值求出來,
其中是隨機(jī)變量
在檢驗(yàn)集上的算術(shù)平均。對于不同的一致性 判斷方法,?
的計(jì)算方式不同
固定序列測試
雖然早退LLM的性能與之間的確切依賴關(guān)系是未知的,但在實(shí)踐中,我們發(fā)現(xiàn)它往往是相當(dāng)平滑的和大致單調(diào)的。也就是說,相近的閾值傾向于執(zhí)行類似的性能,而更大的閾值傾向于更一致。
這樣我們就不需要去檢驗(yàn)中的每一個值,我們可以將這個數(shù)列從大到小排列,然后按順序去做檢驗(yàn),當(dāng)p值小于?時我們拒絕零假設(shè)。直到我們第一次接受零假設(shè)我們就找到了最小的一個滿足約束的閾值了。
CALM實(shí)驗(yàn)設(shè)置
CALM在三個文本生成任務(wù)上對模型進(jìn)行了評估,一個是新聞總結(jié)任務(wù)(CNN/DM)、英法翻譯(WMT15 EN-FR)以及QA(Open-book SQUAD 1.1)。Baseline使用的是8層的T5 1.1,用T5X框架來實(shí)現(xiàn),不同層之間共享同一個分類器參數(shù)和輸出Embedding。評價指標(biāo)在新聞總結(jié)任務(wù)上用的是ROUGE,在WMT上用BLUE,在QA上用的F1。除了BLEU之外,都使用相同的度量來計(jì)算文本距離和風(fēng)險,因?yàn)锽LUE是一個預(yù)料級的評價指標(biāo),作者這里認(rèn)為不太合適,用了BLUERT來代替。模型的效率評估用的是平均退出層數(shù)、FLOS以及加速比,其中加速比是直接在TPU v3上用兩百的樣本跑出來的。
性能和效率的權(quán)衡
文章首先在不同Confident度量方法上考察性能和效率的權(quán)衡。對于每個任務(wù)和度量方法,在驗(yàn)證集上評估,設(shè)置步長為0.05。

Softmax Response的度量方法效果比較好,只用了較小的性能損失,就在三個任務(wù)中都減少了超過一半的層。使用額外分類器的方法也不錯,并且FLOS比較小。但是狀態(tài)飽和度的這種方法的性能相對差一點(diǎn)。但是三者相比oracle還是有很大的差距,筆者認(rèn)為這里的oracle相當(dāng)于是一個上帝視角,可以代表性能效率權(quán)衡的上界。這說明了CALM潛力還是很大的,還有很大的一個優(yōu)化空間。
校準(zhǔn)性能
CALM檢驗(yàn)了校準(zhǔn)過程的結(jié)果。因?yàn)槟P偷男阅鼙槐WC是有效的(即至少95%的樣本≤?),因此主要考察不同選擇的
的效率收益。

總體來說Softmax Response的平均退出層數(shù)的加速比是最好的,額外分類器的方法相比于Softmax Response有時候雖然用的層數(shù)更多但是實(shí)際的加速比更好。因?yàn)楫?dāng)詞表很大的時候,Softmax的計(jì)算量可能會很大。如果我們放松約束?,或收緊置信區(qū)間(使用更大的校準(zhǔn)集),可以進(jìn)一步提高效率收益。
總結(jié)
能夠在不過度降低性能的情況下加快預(yù)測,對工業(yè)級的機(jī)器學(xué)習(xí)系統(tǒng)至關(guān)重要。并且,能夠量化預(yù)測中的不確定性,決定什么時候需要更深的計(jì)算量,對任何智能系統(tǒng)來說都是一個很重要的問題。CATs和CALM分別在分類任務(wù)和文本生成式任務(wù)上使用了Confident Adaptive的方法,認(rèn)為當(dāng)早期分類器的預(yù)測已經(jīng)與完整模型的預(yù)測一致時進(jìn)行退出,在早退信號上嘗試了專用分類器以及隱藏層特征等方法,設(shè)計(jì)實(shí)驗(yàn)對隱藏層拷貝,局部擾動對整體生成的影響等問題進(jìn)行了探索和分析,并且運(yùn)用了統(tǒng)計(jì)理論的方法來校準(zhǔn)模型,構(gòu)建了一個易于實(shí)現(xiàn)的、快速的、具有統(tǒng)計(jì)性能保證的模型結(jié)構(gòu)。

hi,這里是小牛翻譯~
想要看到更多我們的文章,可以關(guān)注下
機(jī)器翻譯學(xué)堂(公號或網(wǎng)站)
筆芯~

往期精彩文章


