Pytorch參數(shù)更新筆記
雖然寫過很多Pytorch有關(guān)的代碼,但是之前一直沒有怎么細(xì)究過一些細(xì)節(jié),這一次就pytorch中的參數(shù)更新過程進(jìn)行研究。
在有數(shù)據(jù)的情況下,準(zhǔn)備好criterion和optimizer以及network之后,開始訓(xùn)練。
更新參數(shù)總共可以分為四步
前向傳播:將數(shù)據(jù)扔進(jìn)網(wǎng)絡(luò)里,根據(jù)輸入得到輸出。
前向傳播是之后反向傳播的基礎(chǔ),鏈?zhǔn)椒▌t嘛
計(jì)算Loss:使用criterion,傳入predicted結(jié)果和label,計(jì)算得到Loss。
在這一步中需要闡明的是,pytorch中的Loss function返回的不只是數(shù)值,還有一個grad_fn,用于計(jì)算梯度。
調(diào)用Loss.backward(),計(jì)算參數(shù)梯度方向
調(diào)用optimizer.step(),更新模型參數(shù)
這里有一個疑問是,Loss和optimizer好像在表面上的代碼沒有任何交集(如圖1),但是卻存在交互并更新了參數(shù),經(jīng)過查看文檔發(fā)現(xiàn)是這樣的:
loss.backward()獲得所有parameter的gradient
optimizer存了這些parameter的指針,step()根據(jù)這些parameter的gradient對parameter的值進(jìn)行更新

# 上述中compute_loss函數(shù)最后是返回了nn.Lossfunction的結(jié)果,和criterion(input,label)一樣
在這四步后,網(wǎng)絡(luò)的參數(shù)就進(jìn)行更新了。
最后再注明一個點(diǎn),就是在每一次傳入數(shù)據(jù)進(jìn)行訓(xùn)練前,記得使用optimizer.zero_grad()清空梯度置0,防止每一次計(jì)算的梯度都疊加了。(當(dāng)然這部分也可以故意有一些操作,一般不涉及)
https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.step.html