一種關(guān)于Neural ODE的新架構(gòu)的思考
Neural ODE的基本形式為:
dX/dt = f(X(t), t, W)
其中f代表任意一個(gè)神經(jīng)網(wǎng)絡(luò),W為網(wǎng)絡(luò)f的權(quán)重。
在上一篇我們討論過(guò),Neural ODE既是將離散網(wǎng)絡(luò)連續(xù)化又是一種ResNet。
現(xiàn)在的問(wèn)題在于,上面形式的Neural ODE的參數(shù)W,并不會(huì)隨其層數(shù)t而變化,這就使Neural ODE更像是一種連續(xù)化的RNN,在不同層數(shù)t上共享的參數(shù)可能會(huì)使其擬合效果并不一定更好。
為了解決這個(gè)可能存在的問(wèn)題,從Transformer網(wǎng)絡(luò)架構(gòu)中得到靈感,我們可以讓參數(shù)或者權(quán)重W,隨t而變化,即W(t)。
那么如何組織W(t)的形式使它變得可訓(xùn)練就成了一個(gè)問(wèn)題。一種想法是泰勒展開或傅里葉展開,如:
W(t) = a0+a1*t+a2*t^2+...
然后訓(xùn)練aj (j = 0, 1, 2, ...)。但是我們還有更好的想法。當(dāng)已知t與W之間存在函數(shù)關(guān)系,但是并不知道具體的函數(shù)關(guān)系時(shí),我們充分利用統(tǒng)計(jì)學(xué),讓神經(jīng)網(wǎng)絡(luò)代勞,即:
W(t) = net(t, θ)
net(t, θ)是一個(gè)輸入為t,參數(shù)為θ,輸出為W的神經(jīng)網(wǎng)絡(luò),是一個(gè)升維的過(guò)程,將標(biāo)量t映射為向量/矩陣W。
這樣,我們可以得到一種關(guān)于Neural ODE的新架構(gòu):
dX/dt = f1(X(t), W(t))
f1為任意神經(jīng)網(wǎng)絡(luò),W為權(quán)重,Neural ODE作為一個(gè)“連續(xù)的”神經(jīng)網(wǎng)絡(luò)有t層,而f1作為一個(gè)“離散的”神經(jīng)網(wǎng)絡(luò)有層數(shù)i, i∈N。神經(jīng)網(wǎng)絡(luò)f1每層的權(quán)重Wi也分別用神經(jīng)網(wǎng)絡(luò)f2i表示:
Wi(t)?=?f2i(t,?θi)
這樣,Neural ODE依托于網(wǎng)絡(luò)f1,而網(wǎng)絡(luò)f1依托于i個(gè)網(wǎng)絡(luò)f2i,訓(xùn)練Neural ODE就相當(dāng)于訓(xùn)練參數(shù)θi, i∈N。相當(dāng)于Neural ODE每層(t取不同值)的權(quán)重W都不一樣,隨t變化。
另外,由于f1是任意的神經(jīng)網(wǎng)絡(luò),所以f1可以取Transformer,也許效果會(huì)更佳。