MATLAB手寫(xiě)數(shù)字識(shí)別(MNIST)

制作數(shù)據(jù)集
手寫(xiě)體數(shù)字(MNIST)的基本信息在上一篇專欄(Pytorch 手寫(xiě)數(shù)字識(shí)別MNIST)里介紹過(guò),這里只做簡(jiǎn)要說(shuō)明
官網(wǎng):? ?yann.lecun.com/exdb/mnist/

該數(shù)據(jù)集下載下來(lái)的二進(jìn)制格式文件無(wú)法直接打開(kāi)預(yù)覽

這里主要介紹數(shù)據(jù)集的下載、解壓和保存為標(biāo)準(zhǔn)的mat文件格式。
下載地址
url1 = "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz" ;? %training set images
url2 = "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz" ;? %training set labels
url3 = "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz"? ;? %test set images
url4 = "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz"? ;? %test set labels
保存文件到本地
filepath1 = websave("train-images-idx3-ubyte.gz", url1);
filepath2 = websave("train-labels-idx1-ubyte.gz", url2);
filepath3 = websave("t10k-images-idx3-ubyte.gz", url3);
filepath4 = websave("t10k-labels-idx1-ubyte.gz",url4);
解壓所有壓縮文件
files =? gunzip('*.gz'); %解壓gz文件

為了方便后面的文件讀寫(xiě),需要按字節(jié)(1Byte = 8bit)轉(zhuǎn)化為10進(jìn)制數(shù)

function y = Byte2Dec(data)
? ? bin8 = dec2bin(data,8); ?%按字節(jié)
? ? byte = [bin8(1,:),bin8(2,:),bin8(3,:),bin8(4,:)];
? ? y = bin2dec(byte);
end
%%上面是轉(zhuǎn)化函數(shù)
制作測(cè)試集
fid1 = fopen(files{1});
m1 = fread(fid1,4);
n1 = fread(fid1,4);
r1 = fread(fid1,4);
c1 = fread(fid1,4);
m1 = Byte2Dec(m1);
n1 = Byte2Dec(n1);
r1 = Byte2Dec(r1);
c1 = Byte2Dec(c1);
test_imgs = cell(n1,1);
for i = 1:n1
? ?temp = fread(fid1,r1*c1);
? ?temp = reshape(temp,[r1,c1]);
? ?test_imgs{i} = temp';
end
fclose(fid1);
fid2 = ?fopen(files{2}) ;
m2 = fread(fid2,4); ? ? ?
n2 = fread(fid2,4);
m2 = Byte2Dec(m2);
n2 = Byte2Dec(n2);
test_labels = zeros(n2,1);
for i = 1:n2
? ?test_labels(i) = fread(fid2,1);
end
fclose(fid2);
for index = 1:10000
? ?img = test_imgs{index};
? ?label = num2str(test_labels(index));
? ?path = fullfile('D:\mnist','testdata',label,['img',label,num2str(index),'.png']);
? ?imwrite(img,path);
end
制作訓(xùn)練集
fid3?= fopen(files{3});
m1 = fread(fid3,4);
n1 = fread(fid3,4);
r1 = fread(fid3,4);
c1 = fread(fid3,4);
m1 = Byte2Dec(m1);
n1 = Byte2Dec(n1);
r1 = Byte2Dec(r1);
c1 = Byte2Dec(c1);
train_imgs = cell(n1,1);
for i = 1:n1
? ?temp = fread(fid3,r1*c1);
? ?temp = reshape(temp,[r1,c1]);
? ?train_imgs{i} = temp';
end
fclose(fid3);
fid4?= ?fopen(files{4}) ;
m2 = fread(fid4,4); ? ? ?
n2 = fread(fid4,4);
m2 = Byte2Dec(m2);
n2 = Byte2Dec(n2);
train_labels = zeros(n2,1);
for i = 1:n2
? ?train_labels(i) = fread(fid4,1);
end
fclose(fid4);
for index = 1:60000
? ?img = train_imgs{index};
? ?label = num2str(train_labels(index));
? ?path = fullfile('D:\mnist','traindata',label,['img',label,num2str(index),'.png']);
? ?imwrite(img,path);
end
保存標(biāo)準(zhǔn)mat文件(變量)
train_labels = categorical(train_labels);
test_labels = categorical(test_labels);
save minist.mat train_imgs train_labels test_imgs test_labels?
以上就是MATLAB下導(dǎo)入一般的數(shù)據(jù)集文件,可能不是一般意義上做實(shí)驗(yàn)、數(shù)據(jù)標(biāo)定來(lái)制作數(shù)據(jù)集。
注:以上代碼運(yùn)行可能會(huì)在imwrite函數(shù)下報(bào)錯(cuò),可以手動(dòng)提前準(zhǔn)備文件目錄,也可以使用絕對(duì)路徑并添加到MATLAB的搜索路徑下

導(dǎo)入數(shù)據(jù)集(實(shí)際上面已經(jīng)有了,這里假設(shè)剛開(kāi)始只有mat文件)
load minist.mat
traindata = table(train_imgs,train_labels);
testdata = table(test_imgs,test_labels);
搭建網(wǎng)絡(luò),訓(xùn)練模型,進(jìn)行預(yù)測(cè)
layers = [
? ? imageInputLayer([28,28,1])
? ? convolution2dLayer(3,16,'Padding','same')
? ? batchNormalizationLayer
? ? reluLayer
? ? maxPooling2dLayer(2,'Stride',2)
? ? fullyConnectedLayer(10)
? ? softmaxLayer
? ? classificationLayer
];
options = trainingOptions('adam',...
? ? 'ExecutionEnvironment', 'gpu', ...
? ? 'InitialLearnRate',0.01,...
? ? 'MiniBatchSize',100,...
? ? 'MaxEpochs',2,...
? ? 'Shuffle','every-epoch',...
? ? 'ValidationData',testdata,...
? ? 'ValidationFrequency',50,...
? ? 'Verbose',false,...
? ? 'Plots','training-progress');
net = trainNetwork(traindata,layers,options);

pred_labels = classify(net,table(test_imgs));
accuracy = sum(pred_labels == test_labels)/length(test_labels)

plotconfusion(test_labels,pred_labels)? %不推薦使用在categories類(lèi)的標(biāo)簽的分類(lèi)問(wèn)題上

figure
cm = confusionchart(test_labels,pred_labels);? %推薦使用
cm.ColumnSummary = 'column-normalized';
cm.RowSummary = 'row-normalized';
cm.Title = 'MNIST Confusion Matrix';

由上圖可知,這樣一個(gè)簡(jiǎn)單的神經(jīng)網(wǎng)絡(luò)的預(yù)測(cè)分類(lèi)準(zhǔn)確率高達(dá)97%,不算太高,一些典型神經(jīng)網(wǎng)絡(luò)在圖像分類(lèi)問(wèn)題準(zhǔn)確率高達(dá)99%以上。上一篇Pytorch訓(xùn)練的CNN就是這樣的。