6.10 Adam算法
話說“天下大勢(shì)合久必分,分久必合”,深度學(xué)習(xí)優(yōu)化算法的比拼歷程就精彩的上演著同樣的故事。前面幾節(jié),我們從梯度下降開始,介紹了它的兩種變體,隨機(jī)梯度下降、小批量隨機(jī)梯度下降,然后有介紹了改進(jìn)梯度計(jì)算的動(dòng)量法,以及改善學(xué)習(xí)率的Adagrad算法、RMSProp算法和Adadelta算法。有沒有人琢磨著把它們的優(yōu)點(diǎn)結(jié)合起來搞個(gè)大雜燴呢?必然的,這就是由Diederik Kingma和Jimmy Ba在2014年提出的Adam算法了。本節(jié)咱們就來詳細(xì)介紹。
6.10.1?基本原理
Adam算法是在RMSProp算法的基礎(chǔ)上提出的,并且使用了指數(shù)加權(quán)平均數(shù)來調(diào)整學(xué)習(xí)率。Adam算法被廣泛用于神經(jīng)網(wǎng)絡(luò)的訓(xùn)練過程中,因?yàn)樗軌蜃赃m應(yīng)學(xué)習(xí)率,使得訓(xùn)練過程更加順暢。Adam算法在傳統(tǒng)梯度下降算法的基礎(chǔ)上具體是怎么改進(jìn)的呢,咱們來看它的數(shù)學(xué)公式:

6.10.2?算法流程
Adam算法訓(xùn)練神經(jīng)網(wǎng)絡(luò)時(shí),我們需要設(shè)置一些參數(shù),包括學(xué)習(xí)率,指數(shù)加權(quán)平均數(shù)的衰減率
和
,以及一個(gè)很小的正數(shù)
。我們可以使用如下的步驟來訓(xùn)練一個(gè)神經(jīng)網(wǎng)絡(luò):
初始化網(wǎng)絡(luò)的權(quán)重和偏置,并定義損失函數(shù)和Adam優(yōu)化器。
在訓(xùn)練數(shù)據(jù)上進(jìn)行前向傳播。
計(jì)算損失。
進(jìn)行反向傳播。
使用Adam優(yōu)化器更新權(quán)重和偏置。
重復(fù)步驟?2-5,直到達(dá)到規(guī)定的訓(xùn)練步數(shù)
6.10.3?代碼示例
我們用一個(gè)例子演示在?PyTorch 中實(shí)現(xiàn) Adam 算法并可視化訓(xùn)練過程:
import?matplotlib.pyplot?as?plt
import?numpy?as?np
import?torch
#?首先,我們定義一個(gè)隨機(jī)訓(xùn)練數(shù)據(jù)
np.random.seed(0)
x?=?np.random.uniform(0,?2,?100)
y?=?x?*?3?+?1?+?np.random.normal(0,?0.5,?100)
#?將訓(xùn)練數(shù)據(jù)轉(zhuǎn)換為?PyTorch Tensor
x?=?torch.from_numpy(x).float().view(-1,?1)
y?=?torch.from_numpy(y).float().view(-1,?1)
#?然后,我們定義一個(gè)線性模型和損失函數(shù)
model?=?torch.nn.Linear(1,?1)
loss_fn?=?torch.nn.MSELoss()
#?接下來,我們使用?Adam?優(yōu)化器來訓(xùn)練模型
optimizer?=?torch.optim.Adam(model.parameters(), lr=0.1)
#?初始化用于可視化訓(xùn)練過程的列表
losses?=?[]
#?開始訓(xùn)練循環(huán)
for?i?in?range(100):
????#?進(jìn)行前向傳遞,計(jì)算損失
????y_pred?=?model(x)
????loss?=?loss_fn(y_pred, y)
????#?將損失存儲(chǔ)到列表中,以便我們可視化
????losses.append(loss.item())
????#?進(jìn)行反向傳遞,更新參數(shù)
????optimizer.zero_grad()
????loss.backward()
????optimizer.step()
#?可視化訓(xùn)練過程
plt.plot(losses)
plt.ylim((0,?15))
plt.show()

梗直哥提示:adam算法在使用過程中還是有不少細(xì)節(jié)需要注意的,也需要調(diào)參實(shí)戰(zhàn)經(jīng)驗(yàn)的總結(jié)。多動(dòng)手多體會(huì)。當(dāng)然,也可以歡迎選修進(jìn)階課程幫你加快這個(gè)過程《梗直哥的深度學(xué)習(xí)必修課:python實(shí)戰(zhàn)》
