6.8 RMSProp算法
Adagrad算法優(yōu)化了學習率的調(diào)整,但仍然存在一些問題。例如:它很多時候存在學習率過小的問題,收斂速度較慢。為此,科學家們又有針對性的提出了更多改進。RMSProp 算法就是其中比較典型的一種變體。本節(jié)咱們具體介紹它。
6.8.1?基本思想
RMSProp 算法是由 Geoffrey Hinton 提出的一種優(yōu)化算法,主要用于解決梯度下降中的學習率調(diào)整問題。
在梯度下降中,每個參數(shù)的學習率是固定的。但在實際應用中,每個參數(shù)的最優(yōu)學習率可能是不同的。如果學習率過大,則模型可能會跳出最優(yōu)值;如果學習率過小,則模型的收斂速度可能會變慢。
RMSProp 算法通過自動調(diào)整每個參數(shù)的學習率來解決這個問題。它在每次迭代中維護一個指數(shù)加權平均值,用于調(diào)整每個參數(shù)的學習率。如果某個參數(shù)的梯度較大,則RMSProp算法會自動減小它的學習率;如果梯度較小,則會增加學習率。這樣可以使得模型的收斂速度更快。
這種算法的計算公式如下:


6.8.2?算法優(yōu)缺點
任何事物都是有兩面性的,深度學習的每種算法都不例外。明確它們的優(yōu)缺點才能更好的理解它們。對RMSProp算法而言,盡管它在Adagrad算法的基礎上進行了改進,但依然優(yōu)缺點都很突出。
優(yōu)點方面,RMSProp算法能夠自動調(diào)整學習率,使得模型的收斂速度更快。它可以避免學習率過大或過小的問題,能夠更好地解決學習率調(diào)整問題。實現(xiàn)上看它較為簡單,適用于各種優(yōu)化問題。
缺點方面,它在處理稀疏特征時可能不夠優(yōu)秀。此外,它需要調(diào)整超參數(shù),如衰減率和學習率,這需要一定的經(jīng)驗。還有,收斂速度可能不如其他我們后面會介紹的優(yōu)化算法,例如 Adam算法。
不過,瑕不掩瑜。RMSProp算法還是一種優(yōu)化算法發(fā)展進程中非常優(yōu)秀的算法。
6.8.3?代碼示例
我們用一個簡單的線性回歸的例子來演示RMSProp算法的pytorch代碼實現(xiàn),方便你的理解。
import?os
os.environ["KMP_DUPLICATE_LIB_OK"]?=?"TRUE"
import?torch
import?matplotlib.pyplot?as?plt
#?假設我們有一個簡單的線性回歸模型
# y = w * x + b
#?其中?w?和?b?是需要學習的參數(shù)
#?定義超參數(shù)
learning_rate?=?0.01
num_epochs?=?100
#?隨機生成訓練數(shù)據(jù)
X?=?torch.randn(100,?1)
y?=?2?*?X?+?3?+?torch.randn(100,?1)
#?初始化參數(shù)
w?=?torch.zeros(1, requires_grad=True)
b?=?torch.zeros(1, requires_grad=True)
#?創(chuàng)建?RMSProp optimizer
optimizer?=?torch.optim.RMSprop([w, b], lr=learning_rate)
#?記錄每次迭代的?loss
losses?=?[]
#?訓練模型
for?epoch?in?range(num_epochs):
??#?計算預測值
??y_pred?=?w?*?X?+?b
??#?計算?loss
??loss?=?torch.mean((y_pred?-?y)?**?2)
??#?記錄?loss
??losses.append(loss.item())
??#?清空上一步的梯度
??optimizer.zero_grad()
??#?計算梯度
??loss.backward()
??#?更新參數(shù)
??optimizer.step()
#?可視化訓練過程
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.show()

RMSProp 算法在訓練過程中是通過自動調(diào)整學習率來優(yōu)化模型參數(shù)的。因此,它并沒有明顯的搜索過程。但是,我們可以通過觀察訓練過程中的損失,了解模型的訓練情況。如果損失在不斷降低,則說明模型的訓練效果較好;如果損失不再降低,則說明模型可能已經(jīng)達到了最優(yōu)解或者出現(xiàn)了過擬合。 這個過程中我們可以判斷模型的訓練效果,并適當調(diào)整超參數(shù),以提高模型的泛化能力。
梗直哥提示:如果簡單了解RMSProp算法的基本原理并不難,但這樣其實只掌握了其功力的30%。最精華的其實是為什么提出這種算法,人家怎么就能提出這種算法,當時是怎么想的等等一系列更加深入的問題。只有這樣才能不光知其然,還能知其所以然,充分提升自己深度學習的境界。歡迎來到哥的課堂,就這些進階問題,幫你武裝自己。更多了解加V:gengzhige99