PyTorch Tutorial 10 - Dataset Transfo...

教程Python代碼如下:
"""
epoch = 1 forward and backward pass of ALL training samples
batch_size = numberlof training samples in one forward & backward pass
number of iterations = number of passes,each pass using [batch_size] number of samples
e.g. 100 samples,batch_size=20 --> 100/20 = 5 iterations for 1 epoch
"""
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
import numpy as np
import math
#實(shí)現(xiàn)自己的自定義數(shù)據(jù)集
class WineDataset():
??def __init__(self, tranform = None):
????# data loading, 數(shù)據(jù)加載
????xy = np.loadtxt('./Data/wine.csv',delimiter=",", dtype=np.float32, skiprows=1) #delomiter分隔符,skiprows=1跳過(guò)第一行(第一行為標(biāo)題)
????self.n_samples = xy.shape[0]
????# 把數(shù)據(jù)集分成 x 和 y,note that we do not convert to tensor here
????self.x = xy[:,1:] #不要第一行
????self.y = xy[:, [0]] # n_samples, 1:只要第一列,這樣就有了樣品的大小數(shù)
????self.transform = tranform
??def __getitem__(self, index):
????sample = self.x[index], self.y[index]
????if self.transform:
??????sample = self.transform(sample)
????return sample
??def __len__(self):
????# len(dataset), 調(diào)用數(shù)據(jù)集的長(zhǎng)度
????return self.n_samples
# 類方法對(duì)類屬性進(jìn)行的處理是有記憶性的
class ToTensor:
??def __call__(self, sample):
????inputs, targets = sample
????return torch.from_numpy(inputs),torch.from_numpy(targets)
class MulTransform:
??def __init__(self, factor):
????self.factor = factor
??def __call__(self, sample):
????inputs, target = sample
????inputs *= self.factor
????return?inputs, target
dataset = WineDataset(tranform=ToTensor())
first_data = dataset[0]
feautres, labels = first_data
print(feautres)
print(type(feautres), type(labels))
dataset = WineDataset(tranform=None)
first_data = dataset[0]
feautres, labels = first_data
print(feautres)
print(type(feautres), type(labels))
print("\n" + "composed")
composed = torchvision.transforms.Compose([ToTensor(),MulTransform(4)])
dataset = WineDataset(tranform=composed)
first_data = dataset[0]
feautres, labels = first_data
print(feautres)
print(type(feautres), type(labels))