K折交叉驗(yàn)證 | 深度學(xué)習(xí)入門必讀系列5
如果你是自學(xué)的深度學(xué)習(xí),沒有系統(tǒng)的梳理過整個(gè)學(xué)習(xí)步驟會感覺很混亂,銜接不上。學(xué)姐整理的“用pytorch構(gòu)建多種類型模型來幫助學(xué)習(xí)深度學(xué)習(xí)”系列教程就是為了讓大家打好基礎(chǔ),每一個(gè)知識點(diǎn)都能銜接上。(你們有什么想看的評論區(qū)告訴我?。。?/strong>
今天是系列教程的第五節(jié)《K折交叉驗(yàn)證》,大家要繼續(xù)保持積極性嗷!

深度學(xué)習(xí)入門必讀系列前4篇傳送門
01?K折交叉驗(yàn)證介紹
K fold Cross Validation(K折交叉驗(yàn)證)是一種用于以穩(wěn)健的方式評估機(jī)器學(xué)習(xí)或深度學(xué)習(xí)模型的性能的技術(shù)。
它將數(shù)據(jù)集分成大小大致相同的k個(gè)部分/折疊(parts/folds)。依次選擇每個(gè)folds進(jìn)行測試,其余parts進(jìn)行訓(xùn)練。
這個(gè)過程重復(fù)k次,然后將性能作為所有測試集的平均值進(jìn)行測量。
02 使用Pytorch和sklearn實(shí)現(xiàn)步驟
K折交叉驗(yàn)證用于評估CNN模型在MNIST數(shù)據(jù)集上的性能。該方法使用sklearn庫實(shí)現(xiàn),而模型使用Pytorch進(jìn)行訓(xùn)練。
導(dǎo)入庫和數(shù)據(jù)集
我們定義了具有2個(gè)卷積層和1個(gè)全連接層的卷積神經(jīng)網(wǎng)絡(luò)架構(gòu),以將圖像分類為十個(gè)類別之一。我們在模型中添加了兩個(gè)Dropout層,以限制過度擬合的風(fēng)險(xiǎn)。
為分類初始化交叉熵?fù)p失函數(shù),在代碼上使用GPU并設(shè)置一個(gè)固定的隨機(jī)數(shù)種子。
然后將訓(xùn)練集和測試集連接成一個(gè)應(yīng)用該ConcatDataset函數(shù)的唯一數(shù)據(jù)集。
我們使用該Kfold函數(shù)生成10折,其中我們有隨機(jī)拆分和可復(fù)制的結(jié)果random_state=42。因此,它將數(shù)據(jù)集分為9部分用于訓(xùn)練,其余部分用于測試。
在應(yīng)用K折交叉驗(yàn)證之前,定義了用于訓(xùn)練和評估模型的函數(shù)。特別是在訓(xùn)練函數(shù)時(shí),執(zhí)行前向傳遞和后向傳遞。
現(xiàn)在,通過迭代折疊來構(gòu)建k折疊驗(yàn)證過程。
在第一個(gè)for循環(huán)中,從train_idx和val_idx中采樣元素,然后將這些采樣器轉(zhuǎn)換為批大小等于128的DataLoader對象,初始化模型并將其傳遞給GPU,最后以0.002作為學(xué)習(xí)率來初始化Adam優(yōu)化器。
在第二個(gè)循環(huán)中,我們通過之前定義的函數(shù)訓(xùn)練和評估CNN模型,這些函數(shù)將返回所選訓(xùn)練集和測試集的損失和準(zhǔn)確度。
我們把所有的執(zhí)行都保存到命名為history的字典里。在模型的訓(xùn)練和評估結(jié)束后,特定折疊(進(jìn)入history字典)的所有分?jǐn)?shù)都存儲在字典中foldperf。
使用torch.save函數(shù)存儲模型

我們可以看到最后兩折的表現(xiàn),結(jié)果看起來相當(dāng)不錯(cuò)——在訓(xùn)練和測試集中都有99%的準(zhǔn)確率。此外,訓(xùn)練和測試精度之間沒有明顯差異,證明沒有過度擬合。
我們可以通過兩個(gè)步驟計(jì)算平均性能以獲得更全面的概述:
計(jì)算每個(gè)折疊的平均分?jǐn)?shù)。
獲得每個(gè)折疊的平均分?jǐn)?shù)后,計(jì)算所有折疊的平均分?jǐn)?shù)。

正如之前的一樣,這里沒有進(jìn)行任何平均計(jì)算就獲得了很好的結(jié)果。
為了進(jìn)一步確認(rèn),我們可以繪制CNN模型十個(gè)交叉驗(yàn)證折疊的平均損失/準(zhǔn)確度曲線。在前4個(gè)epoch中,準(zhǔn)確度增加非???,而損失函數(shù)達(dá)到非常低的值。兩條曲線在 10 個(gè) epoch 后匯合。


03 總結(jié)
當(dāng)你學(xué)完這篇教程,你可能就可以使用Pytorch定義的模型實(shí)現(xiàn)K折交叉驗(yàn)證。本篇教程的實(shí)現(xiàn)步驟還是相對簡單的,問題的類型,部分?jǐn)?shù)據(jù),也有一些變化。例如,其他一些方法是分層K折交叉驗(yàn)證和時(shí)間序列交叉驗(yàn)證。
本教程部分代碼:
https://github.com/eugeniaring/Pytorch-tutorial/blob/main/KCV_mnist.ipynb
原文鏈接:
https://medium.com/dataseries/k-fold-cross-validation-with-pytorch-and-sklearn-d094aa00105f
下期預(yù)告:
深入研究k折交叉驗(yàn)證(K fold Cross Validation)
關(guān)注公眾號提前看好文!
