6.5 小批量梯度下降法
前面幾節(jié)介紹了批量梯度下降法和隨機(jī)梯度下降法,各有問題。前者計(jì)算量大速度慢,后者容易出現(xiàn)搜索中的震蕩情況。于是科學(xué)家們很自然的就想找個(gè)折中方案,這就是小批量隨機(jī)梯度下降法(Mini-Batch Stochastic Gradient Descent)。本節(jié)咱們就來具體看看它的基本原理和各種特點(diǎn)。
6.5.1 基本思想
小批量隨機(jī)梯度下降法的基本思路是在每一次迭代中,使用一小部分的隨機(jī)樣本來計(jì)算梯度,然后根據(jù)梯度來更新參數(shù)的值。它的基本流程如下:

6.5.2 優(yōu)缺點(diǎn)

6.5.3 代碼比較

我們用一個(gè)例子來演示三種梯度下降法在代碼實(shí)現(xiàn)上的異同。在下面這個(gè)例子中,我們使用?PyTorch 中的 torch.optim.SGD 函數(shù)來創(chuàng)建優(yōu)化器。通過設(shè)置 momentum 參數(shù)為?
?來使用隨機(jī)梯度下降法,設(shè)置為非?
?值則使用常規(guī)的梯度下降法。批量大小設(shè)置為?
,這意味著每次更新模型參數(shù)時(shí)使用了?
?個(gè)樣本。當(dāng)然可以根據(jù)需要調(diào)整批量的大小,以獲得最優(yōu)的訓(xùn)練效果。特別注意的是,下面這個(gè)例子重點(diǎn)是為了演示三種算法在代碼實(shí)現(xiàn)上的區(qū)別,簡(jiǎn)便起見,我們使用了隨機(jī)生成的數(shù)據(jù)來訓(xùn)練模型。損失值的變化趨勢(shì)的結(jié)果可能會(huì)因?yàn)橛?xùn)練數(shù)據(jù)的不同而有所差異。
import?torch
import?torch.nn?as?nn
import?numpy?as?np
import?matplotlib.pyplot?as?plt
from?tqdm?import?*
#?定義模型和損失函數(shù)
class?Model(nn.Module):
????def?__init__(self):
????????super().__init__()
????????self.hidden1?=?nn.Linear(1,?32)
????????self.hidden2?=?nn.Linear(32,?32)
????????self.output?=?nn.Linear(32,?1)
????def?forward(self, x):
????????x?=?torch.relu(self.hidden1(x))
????????x?=?torch.relu(self.hidden2(x))
????????return?self.output(x)
loss_fn?=?nn.MSELoss()
#?生成隨機(jī)數(shù)據(jù)
np.random.seed(0)
n_samples?=?1000
x?=?np.linspace(-5,?5, n_samples)
y?=?0.3?*?(x?**?2)?+?np.random.randn(n_samples)
#?轉(zhuǎn)換為Tensor
x?=?torch.unsqueeze(torch.from_numpy(x).float(),?1)
y?=?torch.unsqueeze(torch.from_numpy(y).float(),?1)
#?將數(shù)據(jù)封裝為數(shù)據(jù)集
dataset?=?torch.utils.data.TensorDataset(x, y)
names?=?["Batch",?"Stochastic",?"Minibatch"]?#?批量梯度下降法、隨機(jī)梯度下降法、小批量梯度下降法
batch_size?=?[n_samples,?1,?128]
momentum=?[1,0,1]
losses?=?[[], [], []]
#?超參數(shù)
learning_rate?=?0.0001
n_epochs?=?1000
#?分別訓(xùn)練
for?i?in?range(3):
????model?=?Model()
????optimizer?=?torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum[i])
????dataloader?=?torch.utils.data.DataLoader(dataset, batch_size=batch_size[i], shuffle=True)
????for?epoch?in?tqdm(range(n_epochs), desc=names[i], leave=True, unit=' epoch'):
????????x, y?=?next(iter(dataloader))
????????optimizer.zero_grad()
????????out?=?model(x)
????????loss?=?loss_fn(out, y)
????????loss.backward()
????????optimizer.step()
????????losses[i].append(loss.item())
#?使用?Matplotlib?繪制損失值的變化趨勢(shì)
for?i, loss_list?in?enumerate(losses):
????plt.figure(figsize=(12,?4))
????plt.plot(loss_list)
????plt.ylim((0,?15))
????plt.xlabel('Epoch')
????plt.ylabel('Loss')
????plt.title(names[i])
????plt.show()
Batch: 100%|██████████| 1000/1000 [00:07<00:00, 129.91 epoch/s]
Stochastic: 100%|██████████| 1000/1000 [00:00<00:00, 2397.32 epoch/s]
Minibatch: 100%|██████████| 1000/1000 [00:01<00:00, 780.15 epoch/s]



梗直哥提示:你可以試著修改不同方法的學(xué)習(xí)率,小批量梯度下降法的批量大小等參數(shù),以及換用真實(shí)的數(shù)據(jù)集訓(xùn)練來觀察結(jié)果的不同,從而對(duì)三種方法的優(yōu)缺點(diǎn)有更加深刻的認(rèn)識(shí)。也歡迎你入群學(xué)習(xí),參與討論。微信:gengzhige99。
深度學(xué)習(xí)必修課首期名額僅剩30個(gè),有需要的同學(xué)抓緊時(shí)間訂閱。
