ICML 2023 | 如何通過(guò)主動(dòng)干預(yù)實(shí)現(xiàn)魯棒性?
導(dǎo) 讀
本文是對(duì)發(fā)表于機(jī)器學(xué)習(xí)領(lǐng)域頂級(jí)會(huì)議 ICML 2023 的論文 Which Invariance Should We Transfer? A Causal Minimax Approach 的解讀。該論文由北京大學(xué)王亦洲課題組與復(fù)旦大學(xué)孫鑫偉助理教授合作完成,第一作者為北京大學(xué)計(jì)算機(jī)學(xué)院博士生劉鳴洲。
本文提出了一種基于主動(dòng)干預(yù)(Intervention)的分布外泛化算法。該算法具有完備的最優(yōu)性保證,并且可以通過(guò)等價(jià)類搜索的方式大幅降低計(jì)算復(fù)雜度。在阿爾茨海默疾病診斷中,該算法的泛化性能超越已有方法15%,計(jì)算代價(jià)降低99%以上,顯示出強(qiáng)大的威力。
論文鏈接:https://proceedings.mlr.press/v202/liu23bc/liu23bc.pdf
項(xiàng)目代碼:https://github.com/lmz123321/which_invariance
視頻介紹:https://youtu.be/5hmyl1hP-6k
01
方法概覽
當(dāng)前的機(jī)器學(xué)習(xí)系統(tǒng)普遍依賴于獨(dú)立同分布假設(shè)(Independent Identical Distribution, IID)。當(dāng)訓(xùn)練環(huán)境與部署環(huán)境的分布出現(xiàn)偏移時(shí),這些模型的預(yù)測(cè)將不再可靠,從而可能導(dǎo)致嚴(yán)重后果。為了解決這一問(wèn)題,研究人員普遍認(rèn)為可泛化的機(jī)器學(xué)習(xí)系統(tǒng)應(yīng)具備兩個(gè)特征,即穩(wěn)定性(Stability)和魯棒性(Robustness)。前者是指預(yù)測(cè)行為對(duì)分布偏移的不敏感性,而后者則是描述泛化誤差的可控性。
為了實(shí)現(xiàn)這一目標(biāo),已有工作提出挖掘并遷移數(shù)據(jù)中的不變性(Invariance),如 Peters Jones 等人提出的 ICP 算法利用目標(biāo)變量的穩(wěn)定父節(jié)點(diǎn)進(jìn)行預(yù)測(cè)。然而,這些方法只能被動(dòng)利用觀測(cè)數(shù)據(jù)中不變性,無(wú)法對(duì)變化的環(huán)境進(jìn)行主動(dòng)適應(yīng),限制了它們的應(yīng)用潛力。
為此,本文提出一種基于主動(dòng)干預(yù)的分布外泛化方法。該方法首先通過(guò)因果發(fā)現(xiàn)自動(dòng)識(shí)別系統(tǒng)中的不穩(wěn)定成分。進(jìn)而,通過(guò)對(duì)這些不穩(wěn)定成分進(jìn)行干預(yù),得到了一個(gè)穩(wěn)定的干預(yù)分布。與被動(dòng)觀測(cè)的條件分布相比,該干預(yù)分布具有更優(yōu)的不變性質(zhì)。最后,本文在干預(yù)分布上定義了一族穩(wěn)定的預(yù)測(cè)模型,并通過(guò)等價(jià)類搜索的方式識(shí)別出其中最魯棒者,從而實(shí)現(xiàn)了穩(wěn)定性和魯棒性的兼優(yōu)。
02
背景介紹
干預(yù)
干預(yù)是對(duì)人類主動(dòng)行為(Action)的數(shù)學(xué)抽象。這一概念最早由 Judea Pearl 等人[1]在結(jié)構(gòu)因果模型框架中提出。具體來(lái)說(shuō),對(duì)某個(gè)變量X的干預(yù),通常被記作do(X=x),是指將X從它原有的因果機(jī)制中抽離出來(lái),并強(qiáng)行賦予其新的狀態(tài)x。從因果圖上看,對(duì)X的干預(yù)就是刪除所有指向X 的邊。
以開(kāi)關(guān)燈為例,do(開(kāi)關(guān)=開(kāi))就表示不管開(kāi)關(guān)的原有狀態(tài)如何、受何種因素(聲音、觸摸)的影響,強(qiáng)行打開(kāi)開(kāi)關(guān)。
分布偏移的因果解釋
根據(jù) Scholkopf 等人[2]提出的稀疏機(jī)制偏移理論(Sparse Mechanism Shift Hypothesis),數(shù)據(jù)分布中的分布偏移(e 表示不同的環(huán)境),是由于部分變量的因果機(jī)制發(fā)生變化導(dǎo)致的。
這也就是說(shuō),只有部分變量的因果機(jī)制會(huì)隨著環(huán)境的改變而發(fā)生變化,而其余變量的機(jī)制則保持穩(wěn)定。相應(yīng)的,我們將前者稱為不穩(wěn)定變量(Mutable Variables,),它們是數(shù)據(jù)分布出現(xiàn)偏移的根本原因;后者則稱為穩(wěn)定變量(Stable Variables,
)。
03
通過(guò)干預(yù)實(shí)現(xiàn)不變性
如前文所指出的,不穩(wěn)定變量的因果機(jī)制是數(shù)據(jù)分布發(fā)生偏移的根本原因。因此,對(duì)不穩(wěn)定變量進(jìn)行干預(yù),刪除它們隨環(huán)境變化的因果機(jī)制,就能去除系統(tǒng)中隨環(huán)境發(fā)生偏移的成分,從而得到對(duì)不同環(huán)境具有不變性的穩(wěn)定分布。
具體來(lái)說(shuō),本文有以下結(jié)論:
【命題-1】干預(yù)分布

對(duì)于不同環(huán)境e保持不變。
值得注意的是,數(shù)據(jù)中的不穩(wěn)定變量可以由因果發(fā)現(xiàn)算法[3]自動(dòng)識(shí)別。因此,上述命題中給出的干預(yù)分布是可計(jì)算的。
基于命題-1,我們推導(dǎo)出一族具備穩(wěn)定性的預(yù)測(cè)器,該族中的每一個(gè)成員對(duì)應(yīng)穩(wěn)定變量集合S的一個(gè)各個(gè)子集S':

04
最優(yōu)性理論
針對(duì)前文介紹的穩(wěn)定預(yù)測(cè)器族,一個(gè)自然的問(wèn)題是:該族中的哪一個(gè)成員是最魯棒的?換言之,在所有成員中哪一個(gè)預(yù)測(cè)器的泛化誤差最小?
在本文中,我們用最差情況風(fēng)險(xiǎn)(Worst-case Risk)- 即所有部署環(huán)境中的最差預(yù)測(cè)誤差 - 來(lái)衡量預(yù)測(cè)器的魯棒性。因此,最魯棒的預(yù)測(cè)器f*應(yīng)該具有以下的極大極小最優(yōu)性(Minimax Optimum):

為了識(shí)別上述最魯棒預(yù)測(cè)器,我們提出利用訓(xùn)練環(huán)境估計(jì)每個(gè)預(yù)測(cè)器的最差情況風(fēng)險(xiǎn),從而通過(guò)比較選出最優(yōu)者。具體來(lái)說(shuō),我們?cè)O(shè)計(jì)了一個(gè)仿真分布族,

其中h是從不穩(wěn)定變量的父節(jié)點(diǎn)到不穩(wěn)定變量
的一個(gè)映射函數(shù)。這一分布族保持了原有分布的穩(wěn)定成分,同時(shí)允許
基于它們的父節(jié)點(diǎn)任意變化,從而可以模擬潛在部署環(huán)境中的分布偏移行為。
理論分析表明,從該仿真分布族中測(cè)得的最差情況風(fēng)險(xiǎn)與實(shí)際部署中的最差情況風(fēng)險(xiǎn)完全相同:
【定理-1】令

為仿真分布族上測(cè)得的最壞情況風(fēng)險(xiǎn),令
為實(shí)際部署中的最差情況風(fēng)險(xiǎn),則對(duì)于任何
,均有
。
05
圖等價(jià)類
根據(jù)定理-1,我們需要逐個(gè)估計(jì)F中各個(gè)預(yù)測(cè)器的最差情況風(fēng)險(xiǎn),其計(jì)算復(fù)雜度與穩(wěn)定變量個(gè)數(shù)成指數(shù)關(guān)系。
為了降低這一復(fù)雜度,我們提出了圖等價(jià)類的概念。具體來(lái)說(shuō),我們發(fā)現(xiàn)F中的存在許多相互等價(jià)的預(yù)測(cè)器,由相互等價(jià)的預(yù)測(cè)器所構(gòu)成的集合就是一個(gè)等價(jià)類。進(jìn)而,搜索范圍可以由F中所有的預(yù)測(cè)器,減少到F中所有的等價(jià)類。同時(shí),我們發(fā)現(xiàn)F中所有的等價(jià)類均可從其因果圖中識(shí)別出來(lái),這就為圖等價(jià)類搜索提供了實(shí)現(xiàn)算法。
理論分析表明,圖等價(jià)類可以將搜索復(fù)雜度由指數(shù)級(jí)降低為多項(xiàng)式級(jí)。
此外,值得注意的是,上述圖等價(jià)類搜索算法適用于任何因果圖模型(如有向無(wú)環(huán)圖 Directed Acyclic Graph DAG,極大祖先圖 Maximal Ancestral Graph MAG),因此有廣泛的應(yīng)用價(jià)值。
06
實(shí)驗(yàn)結(jié)論
為了驗(yàn)證本文理論的有效性,我們?cè)诎柶澓D膊。ˋlzheimer's Disease, AD)診斷任務(wù)上進(jìn)行了實(shí)驗(yàn)。
實(shí)驗(yàn)的預(yù)測(cè)目標(biāo)Y是患者的活動(dòng)功能得分(Functional Activity Questionnaire, FAQ),該得分是對(duì)患者患病程度的常見(jiàn)度量指標(biāo)。預(yù)測(cè)變量X是25個(gè)主要腦區(qū)的體積,這些體積是從結(jié)構(gòu)核磁共振(sMRI)中測(cè)量得到的。實(shí)驗(yàn)數(shù)據(jù)來(lái)源于 ADNI 數(shù)據(jù)集。我們根據(jù)患者的年齡劃分了7個(gè)環(huán)境 (<60, 60-65, 65-70, 70-75, 75-80, 80-85, >85),各個(gè)環(huán)境中分別包含27, 59, 90, 240, 182, 117, 42個(gè)樣本。我們重復(fù)了多個(gè)隨機(jī)種子,每次隨機(jī)選取4個(gè)環(huán)境作為訓(xùn)練環(huán)境,剩余3個(gè)環(huán)境作為測(cè)試。
在 AD 中識(shí)別的因果圖如圖一所示。如圖所示,AD 導(dǎo)致的腦萎縮首先出現(xiàn)在海馬區(qū)(HP)和顳葉中回(TML),進(jìn)而傳播到其他腦區(qū)。這一發(fā)現(xiàn)與臨床研究中發(fā)現(xiàn)的海馬區(qū)、顳葉區(qū)是早萎縮腦區(qū)這一現(xiàn)象高度吻合,從一個(gè)側(cè)面驗(yàn)證了所識(shí)別因果圖的可靠性。此外,我們還發(fā)現(xiàn),尾狀核(CAU)、蒼白球(PAL)和海馬區(qū)(HP)是不穩(wěn)定腦區(qū),即={CAU, PAL, HP}。

圖等價(jià)類識(shí)別的結(jié)果表明,圖1中的等價(jià)類的個(gè)數(shù)為25307個(gè),是全部穩(wěn)定預(yù)測(cè)器個(gè)數(shù)2^{22}的0.075%。這一結(jié)果說(shuō)明,對(duì)圖等價(jià)類的搜索能大幅降低復(fù)雜度。
圖2展示了不同方法泛化性能的對(duì)比。可以發(fā)現(xiàn),本文方法較已有方法的提升達(dá)到15%以上,這充分驗(yàn)證了本文方法的有效性。

07
總 結(jié)
本文提出了一種基于主動(dòng)干預(yù)實(shí)現(xiàn)泛化性的理論框架。該框架具有完備的最優(yōu)性保證、高效的計(jì)算算法和較強(qiáng)的可拓展性,在智能醫(yī)學(xué)、具身智能等關(guān)鍵領(lǐng)域有很好的應(yīng)用潛力。
相關(guān)問(wèn)題歡迎聯(lián)系作者:
liumingzhou@stu.pku.edu.cn sunxinwei@fudan.edu.cn
參考文獻(xiàn)
[1] Pearl, J. Causality. Cambridge University Press, 2009.
[2] Scholkopf, B., Locatello, F., Bauer, S., Ke, N. R., Kalchbrenner,N., Goyal, A., and Bengio, Y. Toward causal representation learning. Proceedings of the IEEE, 109(5): 612–634, 2021.
[3] Huang, B., Zhang, K., Zhang, J., Ramsey, J., Sanchez-Romero, R., Glymour, C., and Sch¨olkopf, B. Causal discovery from heterogeneous/nonstationary data. Journal of Machine Learning Research, 21(89):1–53, 2020.

Computer Vision and Digital Art (CVDA)