PyTorch深度學(xué)習(xí)快速入門教程(絕對通俗易懂!)【小土堆】

'''
# -*-coding:utf-8 -*-
# author:sakia
from torch.utils.data import Dataset
import CV2
from PIL import Image
import os
class MyData(Dataset):
??# 提供全局變量。,這里把label_dir改成相應(yīng)的類的絕對路徑label_dir
??# 比如 root_dir : "dataset/train" ,?label_dir = "ants_image"
??def __init__(self, root_dir, img_dir, label_dir):
????self.root_dir = root_dir
????self.img_dir = img_dir
????self.label_dir = label_dir
????self.path = os.path.join(self.root_dir, self.img_dir)
????self.lpath = os.path.join(self.root_dir, self.label_dir)
????# 獲取所有圖片
????self.img_path = os.listdir(self.path)
????self.label_path = os.listdir(self.lpath)
??# 獲取每一個圖片
??def __getitem__(self, idx):
????img_name = self.img_path[idx]
????img_item_path = os.path.join(self.root_dir, self.img_dir, img_name)
????img = Image.open(img_item_path)
????label_name = self.label_path[idx]
????label_item_path = os.path.join(self.root_dir, self.label_dir, label_name)
????with open(label_item_path, 'r') as f:
??????label = f.read().strip()
????return img, label
???
??# 列表長度
??def __len__(self):
????return len(self.img_path)
root_dir = "練手?jǐn)?shù)據(jù)集\\train"
ants_img_dir = "ants_image"
ants_label_dir = "ants_label"
bees_img_dir = "bees_image"
bees_label_dir = "bees_label"
#
ants_dataset = MyData(root_dir, ants_img_dir, ants_label_dir)
bees_dataset = MyData(root_dir, bees_img_dir, bees_label_dir)
# 手工創(chuàng)建數(shù)據(jù)集可以創(chuàng)建這個方法有用
train_dataset = ants_dataset + bees_dataset
img, label = train_dataset[1]
p7 標(biāo)簽和圖像分開的代碼
'''