最美情侣中文字幕电影,在线麻豆精品传媒,在线网站高清黄,久久黄色视频

歡迎光臨散文網(wǎng) 會員登陸 & 注冊

【模型+代碼/保姆級教程】使用Pytorch實(shí)現(xiàn)手寫漢字識別

2023-03-20 19:43 作者:蟈總  | 我要投稿

前言

參考文章:

最初參考的兩篇:

【Pytorch】基于CNN手寫漢字的識別:https://blog.csdn.net/weixin_44403922/article/details/104451698

「Pytorch」CNN實(shí)現(xiàn)手寫漢字識別(數(shù)據(jù)集制作,網(wǎng)絡(luò)搭建,訓(xùn)練驗證測試全部代碼):https://blog.csdn.net/qq_31417941/article/details/97915035

模型:

EfficientNetV2網(wǎng)絡(luò)詳解:https://github.com/WZMIAOMIAO/deep-learning-for-image-processing/tree/master/pytorch_classification/Test11_efficientnetV2

數(shù)據(jù)集(不必從這里下載,可以看一下它的介紹):

CASIA Online and Offline Chinese Handwriting Databases:http://www.nlpr.ia.ac.cn/databases/handwriting/Home.html

鑒于已經(jīng)3202年了,GPT4都出來了,網(wǎng)上還是缺乏漢字識別這種“底層”基礎(chǔ)神經(jīng)網(wǎng)絡(luò)的能讓新手直接上手跑通的手把手教程,我就斗膽自己寫一篇好了。

本文的主要特點(diǎn):

  1. 使用EfficientNetV2模型真正實(shí)現(xiàn)3755類漢字識別

  2. 項目開源

  3. 預(yù)訓(xùn)練模型公開

  4. 預(yù)制數(shù)據(jù)集,無需處理直接使用


數(shù)據(jù)集

使用中科院制作的手寫漢字?jǐn)?shù)據(jù)集,鏈接直達(dá)官網(wǎng),所以我這里不多介紹,只有滿腔敬意。

上面參考的博客可能要你自己下載之后按照它的辦法再預(yù)處理一下,但是在這個環(huán)節(jié)出現(xiàn)問題的朋友挺多,本著保姆級教程教程的原則,我把預(yù)處理的數(shù)據(jù)已經(jīng)傳到北航云盤(https://bhpan.buaa.edu.cn:443/link/C2E69919DF187EB23C26653A0483D34D)了,速度應(yīng)該比百度網(wǎng)盤快吧,大概…

預(yù)訓(xùn)練模型已經(jīng)上傳了(后面有鏈接),但是如果想自己訓(xùn)一下,就需要下載這個數(shù)據(jù)集,解壓到項目結(jié)構(gòu)里的data文件夾如下所示

data文件夾和log文件夾需要自己建。

項目結(jié)構(gòu)

完整源代碼:【項目源碼】https://github.com/Katock-Cricket/Chinese_Character_Rec

目錄結(jié)構(gòu)

重點(diǎn)注意data文件夾的結(jié)構(gòu),不要把數(shù)據(jù)集放錯位置了或者多嵌套了文件夾


├─Chinese_Character_Rec

│? ?├─asserts

│? ?│? ?├─*.png

│? ?├─char_dict

│? ?├─Data.py

│? ?├─EfficientNetV2

│? ?│? ?├─demo.py

│? ?│? ?├─EffNetV2.py

│? ?│? ?├─Evaluate.py

│? ?│? ?├─model.py

│? ?│? ?└─Train.py

│? ?├─Utils.py

│? ?├─VGG19

│? ?│? ?├─demo.py

│? ?│? ?├─Evaluate.py

│? ?│? ?├─model.py

│? ?│? ?├─Train.py

│? ?│? ?└─VGG19.py

|? ?└─README.md

├─data

│? ?├─test

│? ?│? ?├─00000

│? ?│? ?├─00001

│? ?│? ?├─00002

│? ?│? ?├─00003

│? ? |? ? └─...

│? ?├─test.txt

│? ?├─train

│? ?│? ?├─00000

│? ?│? ?├─00001

│? ?│? ?├─00002

│? ?│? ?├─00003

|? ? |? ?└─ ...

│? ?└─train.txt

├─log

│? ?├─log1.pth

│? ?└─…



神經(jīng)網(wǎng)絡(luò)模型

預(yù)訓(xùn)練模型參數(shù)鏈接(包含vgg19和efficientnetv2)https://bhpan.buaa.edu.cn:443/link/719865B23D5DA304FC491A0A65FE24A3

請將.pth文件重命名為log+數(shù)字.pth的格式,例如log1.pth,放入log文件夾。方便識別和retrain。

VGG19

這里先后用了兩種神經(jīng)網(wǎng)絡(luò),我先用VGG19試了一下,分類前1000種漢字。訓(xùn)得有點(diǎn)慢,主要還是這模型有點(diǎn)老了,參數(shù)量也不小。而且要改到3755類的話還用原參數(shù)的話就很難收斂,也不知道該怎么調(diào)參數(shù)了,估計調(diào)好了也會規(guī)模很大,所以這里VGG19模型的版本只能分類1000種,就是數(shù)據(jù)集的前1000種(準(zhǔn)確率>92%)。

EfficientNetV2

這個模型很不錯,主要是卷積層的部分非常有效,參數(shù)量也很少。直接用small版本去分類3755個漢字,半小時就收斂得差不多了。所以本文用來實(shí)現(xiàn)3755類漢字的模型就是EfficientNetV2(準(zhǔn)確率>89%),后面的教程都是基于這個,VGG19就不管了,在源碼里感興趣的自己看吧。


以下代碼不用自己寫,前面已經(jīng)給出完整源代碼了,下面的教程是結(jié)合源碼的講解而已。

運(yùn)行環(huán)境

顯存>=4G(與batchSize有關(guān),batchSize=512時顯存占用4.8G;如果是256或者128,應(yīng)該會低于4G,雖然會導(dǎo)致訓(xùn)得慢一點(diǎn))

內(nèi)存>=16G(訓(xùn)練時不太占內(nèi)存,但是剛開始加載的時候會突然占一下,如果小于16G還是怕爆)

如果你沒有安裝過Pytorch,啊,我也不知道怎么辦,你要不就看看安裝Pytorch的教程吧。(總體步驟是,有一個不太老的N卡,先去驅(qū)動里看看cuda版本,安裝合適的CUDA,然后根據(jù)CUDA版本去pytorch.org找到合適的安裝指令,然后在本地pip install)

以下是項目運(yùn)行環(huán)境,我是3060 6G,CUDA版本11.6

這個約等號不用在意,可以都安裝最新版本,反正我這里應(yīng)該沒用什么特殊的API

torch~=1.12.1+cu116
torchvision~=0.13.1+cu116
Pillow~=9.3.0


數(shù)據(jù)集準(zhǔn)備

首先定義classes_txt方法在Utils.py中(不是我寫的,是CSDN那兩篇博客的,下同):

生成每張圖片的路徑,存儲到train.txt或test.txt。方便訓(xùn)練或評估時讀取數(shù)據(jù)

def classes_txt(root, out_path, num_class=None):
? ?dirs = os.listdir(root)
? ?if not num_class:
? ? ? ?num_class = len(dirs)

? ?with open(out_path, 'w') as f:
? ? ? ?end = 0
? ? ? ?if end < num_class - 1:
? ? ? ? ? ?dirs.sort()
? ? ? ? ? ?dirs = dirs[end:num_class]
? ? ? ? ? ?for dir1 in dirs:
? ? ? ? ? ? ? ?files = os.listdir(os.path.join(root, dir1))
? ? ? ? ? ? ? ?for file in files:
? ? ? ? ? ? ? ? ? ?f.write(os.path.join(root, dir1, file) + '\n')

定義Dataset類,用于制作數(shù)據(jù)集,為每個圖片加上對應(yīng)的標(biāo)簽,即圖片所在文件夾的代號

class MyDataset(Dataset):
? ?def __init__(self, txt_path, num_class, transforms=None):
? ? ? ?super(MyDataset, self).__init__()
? ? ? ?images = []
? ? ? ?labels = []
? ? ? ?with open(txt_path, 'r') as f:
? ? ? ? ? ?for line in f:
? ? ? ? ? ? ? ?if int(line.split('\\')[1]) >= num_class: # 超出規(guī)定的類,就不添加,例如VGG19只添加了1000類
? ? ? ? ? ? ? ? ? ?break
? ? ? ? ? ? ? ?line = line.strip('\n')
? ? ? ? ? ? ? ?images.append(line)
? ? ? ? ? ? ? ?labels.append(int(line.split('\\')[1]))
? ? ? ?self.images = images
? ? ? ?self.labels = labels
? ? ? ?self.transforms = transforms

? ?def __getitem__(self, index):
? ? ? ?image = Image.open(self.images[index]).convert('RGB')
? ? ? ?label = self.labels[index]
? ? ? ?if self.transforms is not None:
? ? ? ? ? ?image = self.transforms(image)
? ? ? ?return image, label

? ?def __len__(self):
? ? ? ?return len(self.labels)


入口

我把各種超參都放在了args里方便改,請根據(jù)實(shí)際情況自行調(diào)整。這套defaults就是我訓(xùn)練這個模型時使用的超參,圖片size默認(rèn)32是因為我顯存太小辣??!但是數(shù)據(jù)集給的圖片大小普遍不超過64,如果想訓(xùn)得更精確,可以試試64*64的大小。

如果你訓(xùn)練時爆mem,請調(diào)小batch_size,試試256,128,64,32

parser = argparse.ArgumentParser(description='EfficientNetV2 arguments')
parser.add_argument('--mode', dest='mode', type=str, default='demo', help='Mode of net')
parser.add_argument('--epoch', dest='epoch', type=int, default=50, help='Epoch number of training')
parser.add_argument('--batch_size', dest='batch_size', type=int, default=512, help='Value of batch size')
parser.add_argument('--lr', dest='lr', type=float, default=0.0001, help='Value of lr')
parser.add_argument('--img_size', dest='img_size', type=int, default=32, help='reSize of input image')
parser.add_argument('--data_root', dest='data_root', type=str, default='../../data/', help='Path to data')
parser.add_argument('--log_root', dest='log_root', type=str, default='../../log/', help='Path to model.pth')
parser.add_argument('--num_classes', dest='num_classes', type=int, default=3755, help='Classes of character')
parser.add_argument('--demo_img', dest='demo_img', type=str, default='../asserts/fo2.png', help='Path to demo image')
args = parser.parse_args()


if __name__ == '__main__':
? ?if not os.path.exists(args.data_root + 'train.txt'): # 只生成一次
? ? ? ?classes_txt(args.data_root + 'train', args.data_root + 'train.txt', args.num_classes)
? ?if not os.path.exists(args.data_root + 'test.txt'): # 只生成一次
? ? ? ?classes_txt(args.data_root + 'test', args.data_root + 'test.txt', args.num_classes)

? ?if args.mode == 'train':
? ? ? ?train(args)
? ?elif args.mode == 'evaluate':
? ? ? ?evaluate(args)
? ?elif args.mode == 'demo':
? ? ? ?demo(args)
? ?else:
? ? ? ?print('Unknown mode')


訓(xùn)練

在前面CSDN博客的基礎(chǔ)上,增加了lr_scheduler自行調(diào)整學(xué)習(xí)率(如果連續(xù)2個epoch無改進(jìn),就調(diào)小lr到一半),增加了連續(xù)訓(xùn)練的功能:

先在log文件夾下尋找是否存在參數(shù)文件,如果沒有,就認(rèn)為是初次訓(xùn)練;如果有,就找到后綴數(shù)字最大的log.pth,在這個基礎(chǔ)上繼續(xù)訓(xùn)練,并且每訓(xùn)練完一個epoch,就保存最新的log.pth,代號是上一次的+1。這樣可以多次訓(xùn)練,防止訓(xùn)練過程中出錯,參數(shù)文件損壞前功盡棄。

其中has_log_filefind_max_log在Utils.py中有定義。

def train(args):
? ?print("===Train EffNetV2===")
? ?# 歸一化處理,不一定要這樣做,看自己的需求,只是預(yù)訓(xùn)練模型的訓(xùn)練是這樣設(shè)置的
? ?transform = transforms.Compose(
? ? ? ?[transforms.Resize((args.img_size, args.img_size)), transforms.ToTensor(),
? ? ? ? transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
? ? ? ? transforms.ColorJitter()]) ?

? ?train_set = MyDataset(args.data_root + 'train.txt', num_class=args.num_classes, transforms=transform)
? ?train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True)
? ?device = torch.device('cuda:0')
? ?# 加載模型
? ?model = efficientnetv2_s(num_classes=args.num_classes)
? ?model.to(device)
? ?model.train()
? ?criterion = nn.CrossEntropyLoss()
? ?optimizer = optim.Adam(model.parameters(), lr=args.lr)
? ?# 學(xué)習(xí)率調(diào)整函數(shù),不一定要這樣做,可以自定義
? ?scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=2, factor=0.5)
? ?print("load model...")
? ?
# 加載最近保存了的參數(shù)
? ?if has_log_file(args.log_root):
? ? ? ?max_log = find_max_log(args.log_root)
? ? ? ?print("continue training with " + max_log + "...")
? ? ? ?checkpoint = torch.load(max_log)
? ? ? ?model.load_state_dict(checkpoint['model_state_dict'])
? ? ? ?optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
? ? ? ?loss = checkpoint['loss']
? ? ? ?epoch = checkpoint['epoch'] + 1
? ?else:
? ? ? ?print("train for the first time...")
? ? ? ?loss = 0.0
? ? ? ?epoch = 0

? ?while epoch < args.epoch:
? ? ? ?running_loss = 0.0
? ? ? ?for i, data in enumerate(train_loader):
? ? ? ? ? ?inputs, labels = data[0].to(device), data[1].to(device)
? ? ? ? ? ?optimizer.zero_grad()
? ? ? ? ? ?outs = model(inputs)
? ? ? ? ? ?loss = criterion(outs, labels)
? ? ? ? ? ?loss.backward()
? ? ? ? ? ?optimizer.step()
? ? ? ? ? ?running_loss += loss.item()
? ? ? ? ? ?if i % 200 == 199:
? ? ? ? ? ? ? ?print('epoch %5d: batch: %5d, loss: %8f, lr: %f' % (
? ? ? ? ? ? ? ? ? ?epoch + 1, i + 1, running_loss / 200, optimizer.state_dict()['param_groups'][0]['lr']))
? ? ? ? ? ? ? ?running_loss = 0.0

? ? ? ?scheduler.step(loss)
? ? ? ?# 每個epoch結(jié)束后就保存最新的參數(shù)
? ? ? ?print('Save checkpoint...')
? ? ? ?torch.save({'epoch': epoch,
? ? ? ? ? ? ? ? ? ?'model_state_dict': model.state_dict(),
? ? ? ? ? ? ? ? ? ?'optimizer_state_dict': optimizer.state_dict(),
? ? ? ? ? ? ? ? ? ?'loss': loss},
? ? ? ? ? ? ? ? ? args.log_root + 'log' + str(epoch) + '.pth')
? ? ? ?print('Saved')
? ? ? ?epoch += 1

? ?print('Finish training')


評估

沒什么好說的,就是跑測試集,算總體準(zhǔn)確率。但是有一點(diǎn)不完善,就是看不到每一個類具體的準(zhǔn)確率。我的預(yù)訓(xùn)練模型其實(shí)感覺有幾類是過擬合的,但是我懶得調(diào)整了。

def evaluate(args):
? ?print("===Evaluate EffNetV2===")
? ?# 這個地方要和train一致,不過colorJitter可有可無
? ?transform = transforms.Compose(
? ? ? ?[transforms.Resize((args.img_size, args.img_size)), transforms.ToTensor(),
? ? ? ? transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
? ? ? ? transforms.ColorJitter()])

? ?model = efficientnetv2_s(num_classes=args.num_classes)
? ?model.eval()
? ?if has_log_file(args.log_root):
? ? ? ?file = find_max_log(args.log_root)
? ? ? ?print("Using log file: ", file)
? ? ? ?checkpoint = torch.load(file)
? ? ? ?model.load_state_dict(checkpoint['model_state_dict'])
? ?else:
? ? ? ?print("Warning: No log file")

? ?model.to(torch.device('cuda:0'))
? ?test_loader = DataLoader(MyDataset(args.data_root + 'test.txt', num_class=args.num_classes, transforms=transform),batch_size=args.batch_size, shuffle=False)
? ?total = 0.0
? ?correct = 0.0
? ?print("Evaluating...")
? ?with torch.no_grad():
? ? ? ?for i, data in enumerate(test_loader):
? ? ? ? ? ?inputs, labels = data[0].cuda(), data[1].cuda()
? ? ? ? ? ?outputs = model(inputs)
? ? ? ? ? ?_, predict = torch.max(outputs.data, 1)
? ? ? ? ? ?total += labels.size(0)
? ? ? ? ? ?correct += (predict == labels).sum().item()
? ?acc = correct / total * 100
? ?print('Accuracy'': ', acc, '%')


推理

輸入文字圖片,輸出識別結(jié)果:

其中char_dict就是每個漢字在數(shù)據(jù)集里的代號對應(yīng)的gb2312編碼,這個模型的輸出結(jié)果是它在數(shù)據(jù)集里的代號,所以要查這個char_dict來獲取它對應(yīng)的漢字。

def demo(args):
? ?print('==Demo EfficientNetV2===')
? ?print('Input Image: ', args.demo_img)
? ?# 這個地方要和train一致,不過colorJitter可有可無
? ?transform = transforms.Compose(
? ? ? ?[transforms.Resize((args.img_size, args.img_size)), transforms.ToTensor(),
? ? ? ? transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
? ?img = Image.open(args.demo_img)
? ?img = transform(img)
? ?img = img.unsqueeze(0) # 增維
? ?model = efficientnetv2_s(num_classes=args.num_classes)
? ?model.eval()
? ?if has_log_file(args.log_root):
? ? ? ?file = find_max_log(args.log_root)
? ? ? ?print("Using log file: ", file)
? ? ? ?checkpoint = torch.load(file)
? ? ? ?model.load_state_dict(checkpoint['model_state_dict'])
? ?else:
? ? ? ?print("Warning: No log file")

? ?with torch.no_grad():
? ? ? ?output = model(img)
? ?_, pred = torch.max(output.data, 1)
? ?f = open('../char_dict', 'rb')
? ?dic = pickle.load(f)
? ?for cha in dic:
? ? ? ?if dic[cha] == int(pred):
? ? ? ? ? ?print('predict: ', cha)
? ?f.close()

例如輸入圖片為:

程序運(yùn)行結(jié)果:


其他說明

這個模型我正在嘗試移植到安卓應(yīng)用,因為Pytorch有一套Pytorch for Android,但是現(xiàn)在遇到一個問題,它的bitmap2Tensor函數(shù)內(nèi)部實(shí)現(xiàn)與Pytorch的toTensor()+Normalize()不一樣,導(dǎo)致輸入相同的圖片,轉(zhuǎn)出來的張量是不一樣的,比如我輸入的圖片是白底黑字,白底的部分輸出一樣,但是黑色的部分的數(shù)值出現(xiàn)了偏移,我用的是同一套歸一化參數(shù),不知道這是為什么。然后這個張量的差異就導(dǎo)致安卓端表現(xiàn)很不好,目前正在尋找解決辦法,灰階處理可能是出路?

另外,這個模型對于太細(xì)太黑的字體,準(zhǔn)確度貌似不是很好,可能還是有點(diǎn)過擬合了。建議輸入的圖片與數(shù)據(jù)集的風(fēng)格靠攏,黑色盡量淺一點(diǎn),線不要太細(xì)。

如果還存在疑問可以打在B站專欄的評論區(qū)。差不多就是這些了,傳統(tǒng)功夫宜點(diǎn)到為止,謝謝大家。


【模型+代碼/保姆級教程】使用Pytorch實(shí)現(xiàn)手寫漢字識別的評論 (共 條)

分享到微博請遵守國家法律
茌平县| 临沧市| 彭阳县| 锡林浩特市| 富平县| 博爱县| 德令哈市| 清流县| 龙泉市| 盐山县| 铜陵市| 吉木萨尔县| 镇原县| 佛山市| 天峨县| 镇原县| 海宁市| 读书| 洮南市| 灵台县| 古蔺县| 宣恩县| 曲松县| 静海县| 武陟县| 长宁县| 洛扎县| 广南县| 安溪县| 疏勒县| 张家川| 松原市| 萨迦县| 兴和县| 祁东县| 衡南县| 尉犁县| 吕梁市| 昆明市| 西乌珠穆沁旗| 马关县|