最美情侣中文字幕电影,在线麻豆精品传媒,在线网站高清黄,久久黄色视频

歡迎光臨散文網(wǎng) 會員登陸 & 注冊

Pytorch將數(shù)據(jù)集劃分為訓練集、驗證集和測試集

2020-08-26 16:46 作者:肆十二-  | 我要投稿

Pytorch將數(shù)據(jù)集劃分為訓練集、驗證集和測試集

我們可以借助Pytorch從文件夾中讀取數(shù)據(jù)集,十分方便,但是Pytorch中沒有提供數(shù)據(jù)集劃分的操作,需要手動將原始的數(shù)據(jù)集劃分為訓練集、驗證集和測試集,廢話不多說,這里我寫了一個工具類,幫助大家將數(shù)據(jù)集自動劃分為訓練集、驗證集和測試集,還可以指定比例,代碼如下。

# 工具類
import os
import random
import shutil
from shutil import copy2


def data_set_split(src_data_folder, target_data_folder, train_scale=0.8, val_scale=0.1, test_scale=0.1):
? ?'''
? ?讀取源數(shù)據(jù)文件夾,生成劃分好的文件夾,分為trian、val、test三個文件夾進行
? ?:param src_data_folder: 源文件夾 E:/biye/gogogo/note_book/torch_note/data/utils_test/data_split/src_data
? ?:param target_data_folder: 目標文件夾 E:/biye/gogogo/note_book/torch_note/data/utils_test/data_split/target_data
? ?:param train_scale: 訓練集比例
? ?:param val_scale: 驗證集比例
? ?:param test_scale: 測試集比例
? ?:return:
? ?'''
? ?print("開始數(shù)據(jù)集劃分")
? ?class_names = os.listdir(src_data_folder)
? ?# 在目標目錄下創(chuàng)建文件夾
? ?split_names = ['train', 'val', 'test']
? ?for split_name in split_names:
? ? ? ?split_path = os.path.join(target_data_folder, split_name)
? ? ? ?if os.path.isdir(split_path):
? ? ? ? ? ?pass
? ? ? ?else:
? ? ? ? ? ?os.mkdir(split_path)
? ? ? ?# 然后在split_path的目錄下創(chuàng)建類別文件夾
? ? ? ?for class_name in class_names:
? ? ? ? ? ?class_split_path = os.path.join(split_path, class_name)
? ? ? ? ? ?if os.path.isdir(class_split_path):
? ? ? ? ? ? ? ?pass
? ? ? ? ? ?else:
? ? ? ? ? ? ? ?os.mkdir(class_split_path)

? ?# 按照比例劃分數(shù)據(jù)集,并進行數(shù)據(jù)圖片的復制
? ?# 首先進行分類遍歷
? ?for class_name in class_names:
? ? ? ?current_class_data_path = os.path.join(src_data_folder, class_name)
? ? ? ?current_all_data = os.listdir(current_class_data_path)
? ? ? ?current_data_length = len(current_all_data)
? ? ? ?current_data_index_list = list(range(current_data_length))
? ? ? ?random.shuffle(current_data_index_list)

? ? ? ?train_folder = os.path.join(os.path.join(target_data_folder, 'train'), class_name)
? ? ? ?val_folder = os.path.join(os.path.join(target_data_folder, 'val'), class_name)
? ? ? ?test_folder = os.path.join(os.path.join(target_data_folder, 'test'), class_name)
? ? ? ?train_stop_flag = current_data_length * train_scale
? ? ? ?val_stop_flag = current_data_length * (train_scale + val_scale)
? ? ? ?current_idx = 0
? ? ? ?train_num = 0
? ? ? ?val_num = 0
? ? ? ?test_num = 0
? ? ? ?for i in current_data_index_list:
? ? ? ? ? ?src_img_path = os.path.join(current_class_data_path, current_all_data[i])
? ? ? ? ? ?if current_idx <= train_stop_flag:
? ? ? ? ? ? ? ?copy2(src_img_path, train_folder)
? ? ? ? ? ? ? ?# print("{}復制到了{}".format(src_img_path, train_folder))
? ? ? ? ? ? ? ?train_num = train_num + 1
? ? ? ? ? ?elif (current_idx > train_stop_flag) and (current_idx <= val_stop_flag):
? ? ? ? ? ? ? ?copy2(src_img_path, val_folder)
? ? ? ? ? ? ? ?# print("{}復制到了{}".format(src_img_path, val_folder))
? ? ? ? ? ? ? ?val_num = val_num + 1
? ? ? ? ? ?else:
? ? ? ? ? ? ? ?copy2(src_img_path, test_folder)
? ? ? ? ? ? ? ?# print("{}復制到了{}".format(src_img_path, test_folder))
? ? ? ? ? ? ? ?test_num = test_num + 1

? ? ? ? ? ?current_idx = current_idx + 1

? ? ? ?print("*********************************{}*************************************".format(class_name))
? ? ? ?print(
? ? ? ? ? ?"{}類按照{(diào)}:{}:{}的比例劃分完成,一共{}張圖片".format(class_name, train_scale, val_scale, test_scale, current_data_length))
? ? ? ?print("訓練集{}:{}張".format(train_folder, train_num))
? ? ? ?print("驗證集{}:{}張".format(val_folder, val_num))
? ? ? ?print("測試集{}:{}張".format(test_folder, test_num))


if __name__ == '__main__':
? ?src_data_folder = "E:/biye/gogogo/note_book/torch_note/data/utils_test/data_split/src_data"
? ?target_data_folder = "E:/biye/gogogo/note_book/torch_note/data/utils_test/data_split/target_data"
? ?data_set_split(src_data_folder, target_data_folder)

** 注意 **

劃分前你得文件夾結(jié)構(gòu)應該是這樣的

image-20200826160553697

劃分結(jié)果

data_split

tensorflow2.3 加載數(shù)據(jù)集的方式

from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential


def load_data_from_folder(batch_size, target_img_height, target_img_width, data_dir="F:/datas/massmass/fer2013+/ccc/"):
? ?train_datagen = ImageDataGenerator(
? ? ? ?rescale=1. / 255, ?# 重放縮因子,數(shù)值乘以1.0/255(歸一化)
? ? ? ?shear_range=0.2, ?# 剪切強度(逆時針方向的剪切變換角度)
? ? ? ?zoom_range=0.2, ?# 隨機縮放的幅度
? ? ? ?# 進行隨機水平翻轉(zhuǎn)
? ? ? ?horizontal_flip=True)
? ?val_datagen = ImageDataGenerator(
? ? ? ?rescale=1. / 255)

? ?train_generator = train_datagen.flow_from_directory(
? ? ? ?data_dir + '/train', ?# dictory參數(shù),該路徑下的所有子文件夾的圖片都會被生成器使用,無限產(chǎn)生batch數(shù)據(jù)
? ? ? ?target_size=(target_img_height, target_img_width), ?# 圖片將被resize成該尺寸
? ? ? ?color_mode='grayscale', ?# 顏色模式,graycsale或rgb(默認rgb)
? ? ? ?batch_size=batch_size, ?# batch數(shù)據(jù)的大小,默認為32
? ? ? ?class_mode='sparse') ?# 返回的標簽形式,默認為‘category’,返回2D的獨熱碼標簽
? ?val_generator = val_datagen.flow_from_directory(
? ? ? ?data_dir + '/val', ?# 同上
? ? ? ?target_size=(target_img_height, target_img_width),
? ? ? ?color_mode='grayscale',
? ? ? ?batch_size=batch_size,
? ? ? ?class_mode='sparse')
? ?num_class = train_generator.num_classes
? ?return train_generator, val_generator, num_class

tensorflow2.0 加載數(shù)據(jù)集的方式

from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential

def load_data_from_folder(batch_size, target_img_height, target_img_width, data_dir="data/"):
? ?emotion_classification_train_datagen = ImageDataGenerator(
? ? ? ?rescale=1. / 255, ?# 重放縮因子,數(shù)值乘以1.0/255(歸一化)
? ? ? ?shear_range=0.2, ?# 剪切強度(逆時針方向的剪切變換角度)
? ? ? ?zoom_range=0.2, ?# 隨機縮放的幅度
? ? ? ?# 進行隨機水平翻轉(zhuǎn)
? ? ? ?horizontal_flip=True)
? ?emotion_classification_val_datagen = ImageDataGenerator(
? ? ? ?rescale=1. / 255)

? ?emotion_classification_train_generator = emotion_classification_train_datagen.flow_from_directory(
? ? ? ?data_dir + '/train', ?# dictory參數(shù),該路徑下的所有子文件夾的圖片都會被生成器使用,無限產(chǎn)生batch數(shù)據(jù)
? ? ? ?target_size=(target_img_height, target_img_width), ?# 圖片將被resize成該尺寸
? ? ? ?color_mode='grayscale', ?# 顏色模式,graycsale或rgb(默認rgb)
? ? ? ?batch_size=batch_size, ?# batch數(shù)據(jù)的大小,默認為32
? ? ? ?class_mode='sparse') ?# 返回的標簽形式,默認為‘category’,返回2D的獨熱碼標簽
? ?emotion_classification_val_generator = emotion_classification_val_datagen.flow_from_directory(
? ? ? ?data_dir + '/val', ?# 同上
? ? ? ?target_size=(target_img_height, target_img_width),
? ? ? ?color_mode='grayscale',
? ? ? ?batch_size=batch_size,
? ? ? ?class_mode='sparse')
? ?num_class = emotion_classification_train_generator.num_classes
? ?return emotion_classification_train_generator, emotion_classification_val_generator, num_class


pytorch加載數(shù)據(jù)集的方式

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy

# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
? ?'train': transforms.Compose([
? ? ? ?transforms.RandomResizedCrop(224),
? ? ? ?transforms.RandomHorizontalFlip(),
? ? ? ?transforms.ToTensor(),
? ? ? ?transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
? ?]),
? ?'val': transforms.Compose([
? ? ? ?transforms.Resize(256),
? ? ? ?transforms.CenterCrop(224),
? ? ? ?transforms.ToTensor(),
? ? ? ?transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
? ?]),
}

data_dir = 'data/hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?data_transforms[x])
? ? ? ? ? ? ? ? ?for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? shuffle=True, num_workers=4)
? ? ? ? ? ? ?for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


最后附上github地址

https://github.com/cmFighting/mnist_demo_torch1.6



Pytorch將數(shù)據(jù)集劃分為訓練集、驗證集和測試集的評論 (共 條)

分享到微博請遵守國家法律
石狮市| 邯郸市| 遵义县| 玉山县| 潞城市| 太仓市| 双桥区| 隆尧县| 寻乌县| 南丰县| 科技| 太谷县| 岑巩县| 洱源县| 固阳县| 高安市| 横山县| 朝阳市| 新干县| 门头沟区| 嘉荫县| 积石山| 道孚县| 临夏市| 绥宁县| 蕉岭县| 秭归县| 商水县| 方正县| 蒙自县| 昭平县| 海阳市| 陆河县| 利津县| 米脂县| 灵台县| 土默特右旗| 瓦房店市| 贡山| 增城市| 磐安县|