【TF/Guide筆記】 03. Automatic differentiation
? ? 由于我是半路出家,沒(méi)有學(xué)過(guò)基礎(chǔ)課程,結(jié)果在梯度的這個(gè)問(wèn)題上繞了好久才想明白。
? ? 給理解帶來(lái)最大障礙的反而是我們以前實(shí)現(xiàn)自動(dòng)求導(dǎo)的方式,從前我們的每一個(gè)Variable內(nèi)部包含了兩個(gè)Tensor,一個(gè)數(shù)據(jù)data一個(gè)梯度diff,在一個(gè)op實(shí)現(xiàn)了從上游data轉(zhuǎn)換為下游data的同時(shí),附加實(shí)現(xiàn)了下游diff轉(zhuǎn)為上游diff的操作,這樣用戶創(chuàng)建的一個(gè)op在計(jì)算圖實(shí)際上被展開(kāi)為了兩組計(jì)算節(jié)點(diǎn)和輸入輸出,一個(gè)正向一個(gè)反向的操作直到loss那里才連成了通路。
? ? 正是受這個(gè)先入為主的概念的影響,在看tf是怎么實(shí)現(xiàn)的時(shí)候一下就繞不明白了,對(duì)于y=x^2來(lái)說(shuō),我從數(shù)學(xué)上可以理解求導(dǎo)是2x,也可以理解當(dāng)x=3時(shí)梯度是6,但這個(gè)梯度跟以前的diff值怎么都對(duì)不上。
? ? 繞了半天才想明白,tf里求解的梯度就是真正的梯度,也就是函數(shù)在某一點(diǎn)的斜率,數(shù)學(xué)上應(yīng)該是dy/dx,而我們以前實(shí)現(xiàn)的梯度,其實(shí)得到的是dx,它是一個(gè)跟loss有關(guān)的具體數(shù)值,把loss當(dāng)成最終的y的話,我們實(shí)際上隱式的假設(shè)了loss的目標(biāo)必須是0,所以dy其實(shí)就是loss-0,然后使用這個(gè)dy一步一步的往回推到出dx,這個(gè)dx就是我算出來(lái)的梯度下降的值,已經(jīng)并非梯度本身了。
? ? 困惑了我半天的是,tf求出來(lái)的loss是怎么作用到梯度上的,答案就是他根本沒(méi)作用,一個(gè)標(biāo)準(zhǔn)的流程應(yīng)該是先求出loss對(duì)于每個(gè)變量的梯度dy/dx,然后設(shè)定dy(也就是梯度下降的步長(zhǎng))再計(jì)算得出dx,tf求出來(lái)的就是純粹的梯度,至于你要怎么用,那你再用代碼去寫(xiě),我們以前的框架給訓(xùn)練預(yù)設(shè)了太多前提。
? ? 基本原理理順了,做法一下就清晰了。
? ? 對(duì)于一個(gè)op y=f(x),dy/dx直接對(duì)函數(shù)f求導(dǎo)就行了,這個(gè)結(jié)果只與x的值有關(guān),跟下游變量是無(wú)關(guān)的,每一個(gè)op都可以針對(duì)輸入求出一個(gè)dy/dx,然后根據(jù)求導(dǎo)的鏈?zhǔn)椒▌t把一連串op的結(jié)果乘起來(lái),就能得到d loss/dx了。
? ? 聲明tape則是為了記錄每個(gè)op的梯度計(jì)算式,又或許tape里直接存放了梯度值本身,畢竟你都用tape了不可能最后不求導(dǎo)吧,這樣tape里就存放了一張圖,節(jié)點(diǎn)上是梯度,邊則是變量,tape.gradient(y, x)就是由用戶指定了鏈?zhǔn)椒▌t的起點(diǎn)和終點(diǎn),這條鏈路上相乘就得到了dy/dx。
? ? 使用tape的好處是,你顯示的告訴了框架那些操作是需要算梯度而哪些不要的,這樣diff tensor就始終只存在于tape的體系里,Variable內(nèi)部并不需要兩個(gè)tensor,這樣的確非常合理,因?yàn)楫?dāng)我們假定了Variable里必須包含data和diff的時(shí)候,很多很多op其實(shí)壓根兒沒(méi)有算diff的意義(比如格式轉(zhuǎn)換),但是為了統(tǒng)一性還必須要寫(xiě)這段空邏輯,造成了大量的代碼冗余和難以理解。
? ? 這種方式可能對(duì)精度也比較友好,雖然理論上我們的算法和tf的算法得出的結(jié)果是一樣的,但很常見(jiàn)的一種情況是由于初始參數(shù)不好,頭一輪的loss是非常爆炸的,用這個(gè)超大的dy去一步一步往前推導(dǎo)dx,可以想象它的精度會(huì)在中間損失多少,而如果先算清楚dy/dx的話,由于這里涉及的計(jì)算使用的都是輸入,噪點(diǎn)肯定是在早期處理掉的,所以剃度的精度會(huì)算的很高,最后再把乘完alpha的loss帶進(jìn)去,會(huì)讓結(jié)果有較大的保障。也難怪我們以前小數(shù)據(jù)去跟tf對(duì)拍出來(lái)都是對(duì)的,但經(jīng)常用著用著就覺(jué)得自家的精度莫名其妙的沒(méi)了。
? ? 通常來(lái)講,需要算梯度的部分和其他操作是可以比較容易的劃分開(kāi)的,但要是真有不太好劃分出來(lái)的Variable,就可以用到之前提過(guò)的trainable參數(shù),所以trainable只是標(biāo)記了這個(gè)變量是否要進(jìn)tape,跟他本身有沒(méi)有diff是沒(méi)關(guān)系的。
? ? Variable內(nèi)部只有一個(gè)tensor,又支持所有tensor的操作,所以Variable和Tensor在本質(zhì)上的區(qū)別就是是否自動(dòng)被tape記錄。不過(guò)tf支持了手動(dòng)記錄的接口,tape.watch(tensor),這樣看來(lái)tape內(nèi)部記錄和操作的還是tensor,Variable并沒(méi)有做太多封裝。
? ? tape同時(shí)支持關(guān)閉自動(dòng)記錄變量,改為手動(dòng)指定想要計(jì)算梯度的Variable,不過(guò)只要某個(gè)x處于watch狀態(tài),顯然由它計(jì)算得到的那些變量就都會(huì)進(jìn)入tape內(nèi)部的鏈?zhǔn)綀D,相當(dāng)于下游的中間變量是一定被watch的。
? ? 根據(jù)文檔說(shuō)的,tape如果不開(kāi)persistent=True的話,他就只能調(diào)用一次,這其實(shí)間接說(shuō)明了tape內(nèi)部記錄的是每個(gè)op里dy/dx的結(jié)果,而不是只記錄了一個(gè)function用的時(shí)候現(xiàn)算,如果是用時(shí)現(xiàn)算的話是沒(méi)有什么臨時(shí)變量值得釋放的。這樣設(shè)計(jì)的目的大概在于,在forward這個(gè)過(guò)程里執(zhí)行盡可能多的計(jì)算,這樣就可以實(shí)現(xiàn)更多的優(yōu)化,比如計(jì)算y=f(x)之后緊接著計(jì)算y=f'(x),那么x的數(shù)據(jù)就可以只進(jìn)一次高級(jí)緩存,應(yīng)該會(huì)有明顯的速度差異。
? ? 雖然后面提了很多不同target的情況下梯度的計(jì)算結(jié)果,但只要按鏈?zhǔn)椒▌t去理解應(yīng)該就行了。不過(guò)鏈?zhǔn)降闹虚g肯定要涉及tensor的升降維,所以也不是完全連乘那么簡(jiǎn)單。
? ? 同樣的,利用鏈?zhǔn)綀D就可以很容易理解,為什么控制流里的變量無(wú)法求導(dǎo),為什么必須所有操作都用tf操作,為什么重新賦值會(huì)導(dǎo)致無(wú)法求導(dǎo),為什么帶狀態(tài)的變量無(wú)法參與求導(dǎo),因?yàn)檫@些操作都會(huì)讓鏈?zhǔn)綀D斷裂或者無(wú)法定義。
? ? 不得不感慨,即便強(qiáng)大如tf,也會(huì)對(duì)用戶級(jí)的代碼有諸多限制,我們以前沒(méi)幾個(gè)人力還想做各種兼容真是太不自量力了。