最美情侣中文字幕电影,在线麻豆精品传媒,在线网站高清黄,久久黄色视频

歡迎光臨散文網(wǎng) 會員登陸 & 注冊

MMNET 微表情識別(CASME2數(shù)據(jù)集)

2023-06-24 18:22 作者:感覺__站不如油管  | 我要投稿

原github地址:https://github.com/muse1998/MMNet

代碼和數(shù)據(jù)集都存在一些問題,經(jīng)過修改后方能夠運行

main.py

CA_block.py

PC_module.py


main.py

# -*- coding: utf-8 -*-
import torch
import math
import numpy as np
import torchvision.models
import torch.utils.data as data
from torchvision import transforms
import CV2
import pandas as pd
import os, torch
import torch.nn as nn
#import image_utils
import argparse, random
from functools import partial

from MMNET.CA_block import resnet18_pos_attention

from PC_module import VisionTransformer_POS

from torchvision.transforms import Resize
torch.set_printoptions(precision=3, edgeitems=14, linewidth=350)



def parse_args():
? ?parser = argparse.ArgumentParser()
? ?parser.add_argument('--raf_path', type=str, default='D:/CASME2/', help='Raf-DB dataset path.')#default='D:/CASME2/'
? ?parser.add_argument('--checkpoint', type=str, default='D:/CASME2/',
? ? ? ? ? ? ? ? ? ? ? ?help='Pytorch checkpoint file path')
? ?parser.add_argument('--pretrained', type=str, default=None,
? ? ? ? ? ? ? ? ? ? ? ?help='Pretrained weights')
? ?parser.add_argument('--beta', type=float, default=0.7, help='Ratio of high importance group in one mini-batch.')
? ?parser.add_argument('--relabel_epoch', type=int, default=1000,
? ? ? ? ? ? ? ? ? ? ? ?help='Relabeling samples on each mini-batch after 10(Default) epochs.')
? ?parser.add_argument('--batch_size', type=int, default=34, help='Batch size.')
? ?parser.add_argument('--optimizer', type=str, default="adam", help='Optimizer, adam or sgd.')
? ?parser.add_argument('--lr', type=float, default=0.0001, help='Initial learning rate for sgd.')
? ?parser.add_argument('--momentum', default=0.9, type=float, help='Momentum for sgd')
? ?parser.add_argument('--workers', default=0, type=int, help='Number of data loading workers (default: 4)')
? ?parser.add_argument('--epochs', type=int, default=1000, help='Total training epochs.')
? ?parser.add_argument('--drop_rate', type=float, default=0, help='Drop out rate.')
? ?return parser.parse_args()






class RafDataSet(data.Dataset):
? ?def __init__(self, raf_path, phase,num_loso, transform = None, basic_aug = False, transform_norm=None):
? ? ? ?self.phase = phase
? ? ? ?self.transform = transform
? ? ? ?self.raf_path = raf_path
? ? ? ?self.transform_norm = transform_norm
? ? ? ?SUBJECT_COLUMN =0
? ? ? ?NAME_COLUMN = 1
? ? ? ?ONSET_COLUMN = 2
? ? ? ?APEX_COLUMN = 3
? ? ? ?OFF_COLUMN = 4
? ? ? ?LABEL_AU_COLUMN = 5
? ? ? ?LABEL_ALL_COLUMN = 6


? ? ? ?df = pd.read_excel(os.path.join(self.raf_path, 'CASME2-coding-20140508.xlsx'),usecols=[0,1,3,4,5,7,8])
? ? ? ?df['Subject'] = df['Subject'].apply(str)

? ? ? ?if phase == 'train':
? ? ? ? ? ?dataset = df.loc[df['Subject']!=num_loso]
? ? ? ?else:
? ? ? ? ? ?dataset = df.loc[df['Subject'] == num_loso]

? ? ? ?Subject = dataset.iloc[:, SUBJECT_COLUMN].values
? ? ? ?File_names = dataset.iloc[:, NAME_COLUMN].values
? ? ? ?Label_all = dataset.iloc[:, LABEL_ALL_COLUMN].values ?# 0:Surprise, 1:Fear, 2:Disgust, 3:Happiness, 4:Sadness, 5:Anger, 6:Neutral
? ? ? ?Onset_num = dataset.iloc[:, ONSET_COLUMN].values
? ? ? ?Apex_num = dataset.iloc[:, APEX_COLUMN].values
? ? ? ?Offset_num = dataset.iloc[:, OFF_COLUMN].values
? ? ? ?Label_au = dataset.iloc[:, LABEL_AU_COLUMN].values
? ? ? ?self.file_paths_on = []
? ? ? ?self.file_paths_off = []
? ? ? ?self.file_paths_apex = []
? ? ? ?self.label_all = []
? ? ? ?self.label_au = []
? ? ? ?self.sub= []
? ? ? ?self.file_names =[]
? ? ? ?a=0
? ? ? ?b=0
? ? ? ?c=0
? ? ? ?d=0
? ? ? ?e=0
? ? ? ?# use aligned images for training/testing
? ? ? ?for (f,sub,onset,apex,offset,label_all,label_au) in zip(File_names,Subject,Onset_num,Apex_num,Offset_num,Label_all,Label_au):


? ? ? ? ? ?if label_all == 'happiness' or label_all == 'repression' or label_all == 'disgust' or label_all == 'surprise' or label_all == 'fear' or label_all == 'sadness':

? ? ? ? ? ? ? ?self.file_paths_on.append(onset)
? ? ? ? ? ? ? ?self.file_paths_off.append(offset)
? ? ? ? ? ? ? ?self.file_paths_apex.append(apex)
? ? ? ? ? ? ? ?self.sub.append(sub)
? ? ? ? ? ? ? ?self.file_names.append(f)
? ? ? ? ? ? ? ?if label_all == 'happiness':
? ? ? ? ? ? ? ? ? ?self.label_all.append(0)
? ? ? ? ? ? ? ? ? ?a=a+1
? ? ? ? ? ? ? ?elif label_all == 'surprise':
? ? ? ? ? ? ? ? ? ?self.label_all.append(1)
? ? ? ? ? ? ? ? ? ?b=b+1
? ? ? ? ? ? ? ?else:
? ? ? ? ? ? ? ? ? ?self.label_all.append(2)
? ? ? ? ? ? ? ? ? ?c=c+1

? ? ? ? ? ?# label_au =label_au.split("+")
? ? ? ? ? ? ? ?if isinstance(label_au, int):
? ? ? ? ? ? ? ? ? ?self.label_au.append([label_au])
? ? ? ? ? ? ? ?else:
? ? ? ? ? ? ? ? ? ?label_au = label_au.split("+")
? ? ? ? ? ? ? ? ? ?self.label_au.append(label_au)






? ? ? ? ? ?##label

? ? ? ?self.basic_aug = basic_aug
? ? ? ?#self.aug_func = [image_utils.flip_image,image_utils.add_gaussian_noise]

? ?def __len__(self):
? ? ? ?return len(self.file_paths_on)

? ?def __getitem__(self, idx):
? ? ? ?##sampling strategy for training set
? ? ? ?if self.phase == 'train':
? ? ? ? ? ?onset = self.file_paths_on[idx]
? ? ? ? ? ?#onset = onset.astype('int64')
? ? ? ? ? ?apex = self.file_paths_apex[idx]
? ? ? ? ? ?#apex = apex.astype('int64')
? ? ? ? ? ?offset =self.file_paths_off[idx]
? ? ? ? ? ?#offset = offset.astype('int64')

? ? ? ? ? ?on0 = str(random.randint(int(onset), int(onset + int(0.2* (int(apex) - int(onset)) / 4))))
? ? ? ? ? ?# on0 = str(int(onset))
? ? ? ? ? ?on1 = str(
? ? ? ? ? ? ? ?random.randint(int(onset + int(0.9 * (apex - onset) / 4)), int(onset + int(1.1 * (apex - onset) / 4))))
? ? ? ? ? ?on2 = str(
? ? ? ? ? ? ? ?random.randint(int(onset + int(1.8 * (apex - onset) / 4)), int(onset + int(2.2 * (apex - onset) / 4))))
? ? ? ? ? ?on3 = str(random.randint(int(onset + int(2.7 * (apex - onset) / 4)), onset + int(3.3 * (apex - onset) / 4)))
? ? ? ? ? ?# apex0 = str(apex)
? ? ? ? ? ?apex0 = str(
? ? ? ? ? ? ? ?random.randint(int(apex - int(0.15* (apex - onset) / 4)), apex + int(0.15 * (offset - apex) / 4)))
? ? ? ? ? ?off0 = str(
? ? ? ? ? ? ? ?random.randint(int(apex + int(0.9 * (offset - apex) / 4)), int(apex + int(1.1 * (offset - apex) / 4))))
? ? ? ? ? ?off1 = str(
? ? ? ? ? ? ? ?random.randint(int(apex + int(1.8 * (offset - apex) / 4)), int(apex + int(2.2 * (offset - apex) / 4))))
? ? ? ? ? ?off2 = str(
? ? ? ? ? ? ? ?random.randint(int(apex + int(2.9 * (offset - apex) / 4)), int(apex + int(3.1 * (offset - apex) / 4))))
? ? ? ? ? ?off3 = str(random.randint(int(apex + int(3.8 * (offset - apex) / 4)), offset))



? ? ? ? ? ?sub =str(self.sub[idx])
? ? ? ? ? ?f = str(self.file_names[idx])
? ? ? ?else:##sampling strategy for testing set
? ? ? ? ? ?onset = self.file_paths_on[idx]
? ? ? ? ? ?apex = self.file_paths_apex[idx]
? ? ? ? ? ?offset = self.file_paths_off[idx]

? ? ? ? ? ?on0 = str(onset)
? ? ? ? ? ?on1 = str(int(onset + int((apex - onset) / 4)))
? ? ? ? ? ?on2 = str(int(onset + int(2 * (apex - onset) / 4)))
? ? ? ? ? ?on3 = str(int(onset + int(3 * (apex - onset) / 4)))
? ? ? ? ? ?apex0 = str(apex)
? ? ? ? ? ?off0 = str(int(apex + int((offset - apex) / 4)))
? ? ? ? ? ?off1 = str(int(apex + int(2 * (offset - apex) / 4)))
? ? ? ? ? ?off2 = str(int(apex + int(3 * (offset - apex) / 4)))
? ? ? ? ? ?off3 = str(offset)

? ? ? ? ? ?sub = str(self.sub[idx])
? ? ? ? ? ?f = str(self.file_names[idx])


? ? ? ?on0 ='reg_img' + on0 + '.jpg'
? ? ? ?on1 = 'reg_img' + on1 + '.jpg'
? ? ? ?on2 = 'reg_img' + on2 + '.jpg'
? ? ? ?on3 = 'reg_img' + on3 + '.jpg'
? ? ? ?apex0 ='reg_img' + apex0 + '.jpg'
? ? ? ?off0 ='reg_img' + off0 + '.jpg'
? ? ? ?off1='reg_img' + off1 + '.jpg'
? ? ? ?off2 ='reg_img' + off2 + '.jpg'
? ? ? ?off3 = 'reg_img' + off3 + '.jpg'
? ? ? ?path_on0 = os.path.join(self.raf_path, 'Cropped-updated/Cropped/', 'sub'+sub, f, on0).replace('\\', '/')
? ? ? ?path_on1 = os.path.join(self.raf_path, 'Cropped-updated/Cropped/', 'sub'+sub, f, on1).replace('\\', '/')
? ? ? ?path_on2 = os.path.join(self.raf_path, 'Cropped-updated/Cropped/', 'sub'+sub, f, on2).replace('\\', '/')
? ? ? ?path_on3 = os.path.join(self.raf_path, 'Cropped-updated/Cropped/', 'sub'+sub, f, on3).replace('\\', '/')
? ? ? ?path_apex0 = os.path.join(self.raf_path, 'Cropped-updated/Cropped/', 'sub'+sub, f, apex0).replace('\\', '/')
? ? ? ?path_off0 = os.path.join(self.raf_path, 'Cropped-updated/Cropped/', 'sub'+sub, f, off0).replace('\\', '/')
? ? ? ?path_off1 = os.path.join(self.raf_path, 'Cropped-updated/Cropped/', 'sub'+sub, f, off1).replace('\\', '/')
? ? ? ?path_off2 = os.path.join(self.raf_path, 'Cropped-updated/Cropped/', 'sub'+sub, f, off2).replace('\\', '/')
? ? ? ?path_off3 = os.path.join(self.raf_path, 'Cropped-updated/Cropped/', 'sub'+sub, f, off3).replace('\\', '/')
? ? ? ?"""
? ? ? ?print(path_on0)
? ? ? ?print(path_on1)
? ? ? ?print(path_on2)
? ? ? ?print(path_on3)
? ? ? ?print(path_apex0)
? ? ? ?print(path_off0)
? ? ? ?print(path_off1)
? ? ? ?print(path_off2)
? ? ? ?print(path_off3)
? ? ? ?"""

? ? ? ?image_on0 = CV2.imread(path_on0)
? ? ? ?image_on1= CV2.imread(path_on1)
? ? ? ?image_on2 = CV2.imread(path_on2)
? ? ? ?image_on3 = CV2.imread(path_on3)
? ? ? ?image_apex0 = CV2.imread(path_apex0)
? ? ? ?image_off0 = CV2.imread(path_off0)
? ? ? ?image_off1 = CV2.imread(path_off1)
? ? ? ?image_off2 = CV2.imread(path_off2)
? ? ? ?image_off3 = CV2.imread(path_off3)

? ? ? ?image_on0 = image_on0[:, :, ::-1] # BGR to RGB
? ? ? ?image_on1 = image_on1[:, :, ::-1]
? ? ? ?image_on2 = image_on2[:, :, ::-1]
? ? ? ?image_on3 = image_on3[:, :, ::-1]
? ? ? ?image_off0 = image_off0[:, :, ::-1]
? ? ? ?image_off1 = image_off1[:, :, ::-1]
? ? ? ?image_off2 = image_off2[:, :, ::-1]
? ? ? ?image_off3 = image_off3[:, :, ::-1]
? ? ? ?image_apex0 = image_apex0[:, :, ::-1]

? ? ? ?label_all = self.label_all[idx]
? ? ? ?label_au = self.label_au[idx]

? ? ? ?# normalization for testing and training
? ? ? ?if self.transform is not None:
? ? ? ? ? ?image_on0 = self.transform(image_on0)
? ? ? ? ? ?image_on1 = self.transform(image_on1)
? ? ? ? ? ?image_on2 = self.transform(image_on2)
? ? ? ? ? ?image_on3 = self.transform(image_on3)
? ? ? ? ? ?image_off0 = self.transform(image_off0)
? ? ? ? ? ?image_off1 = self.transform(image_off1)
? ? ? ? ? ?image_off2 = self.transform(image_off2)
? ? ? ? ? ?image_off3 = self.transform(image_off3)
? ? ? ? ? ?image_apex0 = self.transform(image_apex0)
? ? ? ? ? ?ALL = torch.cat(
? ? ? ? ? ? ? ?(image_on0, image_on1, image_on2, image_on3, image_apex0, image_off0, image_off1, image_off2,
? ? ? ? ? ? ? ? image_off3), dim=0)
? ? ? ? ? ?## data augmentation for training only
? ? ? ? ? ?if self.transform_norm is not None and self.phase == 'train':
? ? ? ? ? ? ? ?ALL = self.transform_norm(ALL)
? ? ? ? ? ?image_on0 = ALL[0:3, :, :]
? ? ? ? ? ?image_on1 = ALL[3:6, :, :]
? ? ? ? ? ?image_on2 = ALL[6:9, :, :]
? ? ? ? ? ?image_on3 = ALL[9:12, :, :]
? ? ? ? ? ?image_apex0 = ALL[12:15, :, :]
? ? ? ? ? ?image_off0 = ALL[15:18, :, :]
? ? ? ? ? ?image_off1 = ALL[18:21, :, :]
? ? ? ? ? ?image_off2 = ALL[21:24, :, :]
? ? ? ? ? ?image_off3 = ALL[24:27, :, :]


? ? ? ? ? ?temp = torch.zeros(38)
? ? ? ? ? ?for i in label_au:
? ? ? ? ? ? ? ?#print(i)
? ? ? ? ? ? ? ?temp[int(i) - 1] = 1

? ? ? ? ? ?return image_on0, image_on1, image_on2, image_on3, image_apex0, image_off0, image_off1, image_off2, image_off3, label_all, temp


def initialize_weight_goog(m, n=''):
? ?if isinstance(m, nn.Conv2d):
? ? ? ?fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
? ? ? ?m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
? ? ? ?if m.bias is not None:
? ? ? ? ? ?m.bias.data.zero_()
? ?elif isinstance(m, nn.BatchNorm2d):
? ? ? ?m.weight.data.fill_(1.0)
? ? ? ?m.bias.data.zero_()
? ?elif isinstance(m, nn.Linear):
? ? ? ?fan_out = m.weight.size(0) ?# fan-out
? ? ? ?fan_in = 0
? ? ? ?if 'routing_fn' in n:
? ? ? ? ? ?fan_in = m.weight.size(1)
? ? ? ?init_range = 1.0 / math.sqrt(fan_in + fan_out)
? ? ? ?m.weight.data.uniform_(-init_range, init_range)
? ? ? ?m.bias.data.zero_()


def criterion2(y_pred, y_true):
? ?y_pred = (1 - 2 * y_true) * y_pred
? ?y_pred_neg = y_pred - y_true * 1e12
? ?y_pred_pos = y_pred - (1 - y_true) * 1e12
? ?zeros = torch.zeros_like(y_pred[..., :1])
? ?y_pred_neg = torch.cat((y_pred_neg, zeros), dim=-1)
? ?y_pred_pos = torch.cat((y_pred_pos, zeros), dim=-1)
? ?neg_loss = torch.logsumexp(y_pred_neg, dim=-1)
? ?pos_loss = torch.logsumexp(y_pred_pos, dim=-1)
? ?return torch.mean(neg_loss + pos_loss)


class MMNet(nn.Module):
? ?def __init__(self):
? ? ? ?super(MMNet, self).__init__()


? ? ? ?self.conv_act = nn.Sequential(
? ? ? ? ? ?nn.Conv2d(in_channels=3, out_channels=90*2, kernel_size=3, stride=2,padding=1, bias=False,groups=1),#group=2
? ? ? ? ? ?nn.BatchNorm2d(180),
? ? ? ? ? ?nn.ReLU(inplace=True),
? ? ? ? ? ?)
? ? ? ?self.pos =nn.Sequential(
? ? ? ? ? ?nn.Conv2d(in_channels=3, out_channels=512, kernel_size=1, stride=1, bias=False),
? ? ? ? ? ?nn.BatchNorm2d(512),
? ? ? ? ? ?nn.ReLU(inplace=True),

? ? ? ? ? ?)
? ? ? ?##Position Calibration Module(subbranch)
? ? ? ?self.vit_pos=VisionTransformer_POS(img_size=14,
? ? ? ?patch_size=1, embed_dim=512, depth=3, num_heads=4, mlp_ratio=2, qkv_bias=True,norm_layer=partial(nn.LayerNorm, eps=1e-6),drop_path_rate=0.3)
? ? ? ?self.resize=Resize([14,14])
? ? ? ?##main branch consisting of CA blocks
? ? ? ?self.main_branch =resnet18_pos_attention()
? ? ? ?self.head1 = nn.Sequential(
? ? ? ? ? ?nn.Dropout(p=0.5),
? ? ? ? ? ?nn.Linear(1 * 112 *112, 38,bias=False),

? ? ? ?)

? ? ? ?self.timeembed = nn.Parameter(torch.zeros(1, 4, 111, 111))

? ? ? ?self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
? ?def forward(self, x1, x2, x3, x4, x5, x6, x7, x8, x9, if_shuffle):
? ? ? ?##onset:x1 apex:x5
? ? ? ?B = x1.shape[0]

? ? ? ?#Position Calibration Module (subbranch)
? ? ? ?POS =self.vit_pos(self.resize(x1)).transpose(1,2).view(B,512,14,14)
? ? ? ?act = x5 -x1
? ? ? ?act=self.conv_act(act)
? ? ? ?#main branch and fusion
? ? ? ?out,_=self.main_branch(act,POS)

? ? ? ?return out





def run_training():

? ?args = parse_args()
? ?imagenet_pretrained = True #是否加載預(yù)訓(xùn)練模型

? ?if not imagenet_pretrained:
? ? ? ?for m in res18.modules():
? ? ? ? ? ?initialize_weight_goog(m)

? ?if args.pretrained:
? ? ? ?print("Loading pretrained weights...", args.pretrained)
? ? ? ?pretrained = torch.load(args.pretrained)
? ? ? ?pretrained_state_dict = pretrained['state_dict']
? ? ? ?model_state_dict = res18.state_dict()
? ? ? ?loaded_keys = 0
? ? ? ?total_keys = 0
? ? ? ?for key in pretrained_state_dict:
? ? ? ? ? ?if ((key == 'module.fc.weight') | (key == 'module.fc.bias')):
? ? ? ? ? ? ? ?pass
? ? ? ? ? ?else:
? ? ? ? ? ? ? ?model_state_dict[key] = pretrained_state_dict[key]
? ? ? ? ? ? ? ?total_keys += 1
? ? ? ? ? ? ? ?if key in model_state_dict:
? ? ? ? ? ? ? ? ? ?loaded_keys += 1
? ? ? ?print("Loaded params num:", loaded_keys)
? ? ? ?print("Total params num:", total_keys)
? ? ? ?res18.load_state_dict(model_state_dict, strict=False)
? ?### data normalization for both training set
? ?data_transforms = transforms.Compose([
? ? ? ?transforms.ToPILImage(),
? ? ? ?transforms.Resize((224, 224)),

? ? ? ?transforms.ToTensor(),
? ? ? ?transforms.Normalize(mean=[0.485, 0.456, 0.406],
? ? ? ? ? ? ? ? ? ? ? ? ? ? std=[0.229, 0.224, 0.225]),

? ?])
? ?### data augmentation for training set only
? ?data_transforms_norm = transforms.Compose([

? ? ? ?transforms.RandomHorizontalFlip(p=0.5),
? ? ? ?transforms.RandomRotation(4),
? ? ? ?transforms.RandomCrop(224, padding=4),


? ?])


? ?### data normalization for both teating set
? ?data_transforms_val = transforms.Compose([
? ? ? ?transforms.ToPILImage(),
? ? ? ?transforms.Resize((224, 224)),
? ? ? ?transforms.ToTensor(),
? ? ? ?transforms.Normalize(mean=[0.485, 0.456, 0.406],
? ? ? ? ? ? ? ? ? ? ? ? ? ? std=[0.229, 0.224, 0.225])])



? ?criterion = torch.nn.CrossEntropyLoss()
? ?#leave one subject out protocal
? ?LOSO = ['17', '26', '16', '9', '5', '24', '2', '13', '4', '23', '11', '12', '8', '14', '3', '19', '1', '10',
? ? ? ? ? ?'20', '21', '22', '15', '6', '25', '7']

? ?val_now = 0
? ?num_sum = 0
? ?pos_pred_ALL = torch.zeros(3)
? ?pos_label_ALL = torch.zeros(3)
? ?TP_ALL = torch.zeros(3)

? ?for subj in LOSO:
? ? ? ?train_dataset = RafDataSet(args.raf_path, phase='train', num_loso=subj, transform=data_transforms,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? basic_aug=True, transform_norm=data_transforms_norm)
? ? ? ?val_dataset = RafDataSet(args.raf_path, phase='test', num_loso=subj, transform=data_transforms_val)
? ? ? ?train_loader = torch.utils.data.DataLoader(train_dataset,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? batch_size=24,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? num_workers=args.workers,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? shuffle=True,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? pin_memory=True)
? ? ? ?val_loader = torch.utils.data.DataLoader(val_dataset,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? batch_size=24,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? num_workers=args.workers,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? shuffle=False,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? pin_memory=True)
? ? ? ?print('num_sub', subj)
? ? ? ?print('Train set size:', train_dataset.__len__())
? ? ? ?print('Validation set size:', val_dataset.__len__())

? ? ? ?max_corr = 0
? ? ? ?max_f1 = 0
? ? ? ?max_pos_pred = torch.zeros(3)
? ? ? ?max_pos_label = torch.zeros(3)
? ? ? ?max_TP = torch.zeros(3)
? ? ? ?##model initialization
? ? ? ?net_all = MMNet()

? ? ? ?params_all = net_all.parameters()

? ? ? ?if args.optimizer == 'adam':
? ? ? ? ? ?optimizer_all = torch.optim.AdamW(params_all, lr=0.0008, weight_decay=0.7)
? ? ? ? ? ?##optimizer for MMNet

? ? ? ?elif args.optimizer == 'sgd':
? ? ? ? ? ?optimizer = torch.optim.SGD(params, args.lr,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?momentum=args.momentum,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?weight_decay=1e-4)
? ? ? ?else:
? ? ? ? ? ?raise ValueError("Optimizer not supported.")
? ? ? ?##lr_decay
? ? ? ?scheduler_all = torch.optim.lr_scheduler.ExponentialLR(optimizer_all, gamma=0.987)

? ? ? ?net_all = net_all.cuda()

? ? ? ?for i in range(1, 100):
? ? ? ? ? ?running_loss = 0.0
? ? ? ? ? ?correct_sum = 0
? ? ? ? ? ?running_loss_MASK = 0.0
? ? ? ? ? ?correct_sum_MASK = 0
? ? ? ? ? ?iter_cnt = 0

? ? ? ? ? ?net_all.train()


? ? ? ? ? ?for batch_i, (
? ? ? ? ? ?image_on0, image_on1, image_on2, image_on3, image_apex0, image_off0, image_off1, image_off2, image_off3,
? ? ? ? ? ?label_all,
? ? ? ? ? ?label_au) in enumerate(train_loader):
? ? ? ? ? ? ? ?batch_sz = image_on0.size(0)
? ? ? ? ? ? ? ?b, c, h, w = image_on0.shape
? ? ? ? ? ? ? ?iter_cnt += 1

? ? ? ? ? ? ? ?image_on0 = image_on0.cuda()
? ? ? ? ? ? ? ?image_on1 = image_on1.cuda()
? ? ? ? ? ? ? ?image_on2 = image_on2.cuda()
? ? ? ? ? ? ? ?image_on3 = image_on3.cuda()
? ? ? ? ? ? ? ?image_apex0 = image_apex0.cuda()
? ? ? ? ? ? ? ?image_off0 = image_off0.cuda()
? ? ? ? ? ? ? ?image_off1 = image_off1.cuda()
? ? ? ? ? ? ? ?image_off2 = image_off2.cuda()
? ? ? ? ? ? ? ?image_off3 = image_off3.cuda()
? ? ? ? ? ? ? ?label_all = label_all.cuda()
? ? ? ? ? ? ? ?label_au = label_au.cuda()


? ? ? ? ? ? ? ?##train MMNet
? ? ? ? ? ? ? ?ALL = net_all(image_on0, image_on1, image_on2, image_on3, image_apex0, image_off0, image_off1,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? image_off2, image_off3, False)

? ? ? ? ? ? ? ?loss_all = criterion(ALL, label_all)

? ? ? ? ? ? ? ?optimizer_all.zero_grad()

? ? ? ? ? ? ? ?loss_all.backward()

? ? ? ? ? ? ? ?optimizer_all.step()
? ? ? ? ? ? ? ?running_loss += loss_all
? ? ? ? ? ? ? ?_, predicts = torch.max(ALL, 1)
? ? ? ? ? ? ? ?correct_num = torch.eq(predicts, label_all).sum()
? ? ? ? ? ? ? ?correct_sum += correct_num






? ? ? ? ? ?## lr decay
? ? ? ? ? ?if i <= 50:

? ? ? ? ? ? ? ?scheduler_all.step()
? ? ? ? ? ?if i>=0:
? ? ? ? ? ? ? ?acc = correct_sum.float() / float(train_dataset.__len__())

? ? ? ? ? ? ? ?running_loss = running_loss / iter_cnt

? ? ? ? ? ? ? ?print('[Epoch %d] Training accuracy: %.4f. Loss: %.3f' % (i, acc, running_loss))


? ? ? ? ? ?pos_label = torch.zeros(3)
? ? ? ? ? ?pos_pred = torch.zeros(3)
? ? ? ? ? ?TP = torch.zeros(3)
? ? ? ? ? ?##test
? ? ? ? ? ?with torch.no_grad():
? ? ? ? ? ? ? ?running_loss = 0.0
? ? ? ? ? ? ? ?iter_cnt = 0
? ? ? ? ? ? ? ?bingo_cnt = 0
? ? ? ? ? ? ? ?sample_cnt = 0
? ? ? ? ? ? ? ?pre_lab_all = []
? ? ? ? ? ? ? ?Y_test_all = []
? ? ? ? ? ? ? ?net_all.eval()
? ? ? ? ? ? ? ?# net_au.eval()
? ? ? ? ? ? ? ?for batch_i, (
? ? ? ? ? ? ? ?image_on0, image_on1, image_on2, image_on3, image_apex0, image_off0, image_off1, image_off2,
? ? ? ? ? ? ? ?image_off3, label_all,
? ? ? ? ? ? ? ?label_au) in enumerate(val_loader):
? ? ? ? ? ? ? ? ? ?batch_sz = image_on0.size(0)
? ? ? ? ? ? ? ? ? ?b, c, h, w = image_on0.shape

? ? ? ? ? ? ? ? ? ?image_on0 = image_on0.cuda()
? ? ? ? ? ? ? ? ? ?image_on1 = image_on1.cuda()
? ? ? ? ? ? ? ? ? ?image_on2 = image_on2.cuda()
? ? ? ? ? ? ? ? ? ?image_on3 = image_on3.cuda()
? ? ? ? ? ? ? ? ? ?image_apex0 = image_apex0.cuda()
? ? ? ? ? ? ? ? ? ?image_off0 = image_off0.cuda()
? ? ? ? ? ? ? ? ? ?image_off1 = image_off1.cuda()
? ? ? ? ? ? ? ? ? ?image_off2 = image_off2.cuda()
? ? ? ? ? ? ? ? ? ?image_off3 = image_off3.cuda()
? ? ? ? ? ? ? ? ? ?label_all = label_all.cuda()
? ? ? ? ? ? ? ? ? ?label_au = label_au.cuda()

? ? ? ? ? ? ? ? ? ?##test
? ? ? ? ? ? ? ? ? ?ALL = net_all(image_on0, image_on1, image_on2, image_on3, image_apex0, image_off0, image_off1, image_off2, image_off3, False)


? ? ? ? ? ? ? ? ? ?loss = criterion(ALL, label_all)
? ? ? ? ? ? ? ? ? ?running_loss += loss
? ? ? ? ? ? ? ? ? ?iter_cnt += 1
? ? ? ? ? ? ? ? ? ?_, predicts = torch.max(ALL, 1)
? ? ? ? ? ? ? ? ? ?correct_num = torch.eq(predicts, label_all)
? ? ? ? ? ? ? ? ? ?bingo_cnt += correct_num.sum().cpu()
? ? ? ? ? ? ? ? ? ?sample_cnt += ALL.size(0)

? ? ? ? ? ? ? ? ? ?for cls in range(3):

? ? ? ? ? ? ? ? ? ? ? ?for element in predicts:
? ? ? ? ? ? ? ? ? ? ? ? ? ?if element == cls:
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?pos_label[cls] = pos_label[cls] + 1
? ? ? ? ? ? ? ? ? ? ? ?for element in label_all:
? ? ? ? ? ? ? ? ? ? ? ? ? ?if element == cls:
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?pos_pred[cls] = pos_pred[cls] + 1
? ? ? ? ? ? ? ? ? ? ? ?for elementp, elementl in zip(predicts, label_all):
? ? ? ? ? ? ? ? ? ? ? ? ? ?if elementp == elementl and elementp == cls:
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?TP[cls] = TP[cls] + 1

? ? ? ? ? ? ? ? ? ?count = 0
? ? ? ? ? ? ? ? ? ?SUM_F1 = 0
? ? ? ? ? ? ? ? ? ?for index in range(3):
? ? ? ? ? ? ? ? ? ? ? ?if pos_label[index] != 0 or pos_pred[index] != 0:
? ? ? ? ? ? ? ? ? ? ? ? ? ?count = count + 1
? ? ? ? ? ? ? ? ? ? ? ? ? ?SUM_F1 = SUM_F1 + 2 * TP[index] / (pos_pred[index] + pos_label[index])

? ? ? ? ? ? ? ? ? ?AVG_F1 = SUM_F1 / count


? ? ? ? ? ? ? ?running_loss = running_loss / iter_cnt
? ? ? ? ? ? ? ?acc = bingo_cnt.float() / float(sample_cnt)
? ? ? ? ? ? ? ?acc = np.around(acc.numpy(), 4)
? ? ? ? ? ? ? ?if bingo_cnt > max_corr:
? ? ? ? ? ? ? ? ? ?max_corr = bingo_cnt
? ? ? ? ? ? ? ?if AVG_F1 >= max_f1:
? ? ? ? ? ? ? ? ? ?max_f1 = AVG_F1
? ? ? ? ? ? ? ? ? ?max_pos_label = pos_label
? ? ? ? ? ? ? ? ? ?max_pos_pred = pos_pred
? ? ? ? ? ? ? ? ? ?max_TP = TP
? ? ? ? ? ? ? ?print("[Epoch %d] Validation accuracy:%.4f. Loss:%.3f, F1-score:%.3f" % (i, acc, running_loss, AVG_F1))
? ? ? ?num_sum = num_sum + max_corr
? ? ? ?pos_label_ALL = pos_label_ALL + max_pos_label
? ? ? ?pos_pred_ALL = pos_pred_ALL + max_pos_pred
? ? ? ?TP_ALL = TP_ALL + max_TP
? ? ? ?count = 0
? ? ? ?SUM_F1 = 0
? ? ? ?for index in range(3):
? ? ? ? ? ?if pos_label_ALL[index] != 0 or pos_pred_ALL[index] != 0:
? ? ? ? ? ? ? ?count = count + 1
? ? ? ? ? ? ? ?SUM_F1 = SUM_F1 + 2 * TP_ALL[index] / (pos_pred_ALL[index] + pos_label_ALL[index])

? ? ? ?F1_ALL = SUM_F1 / count
? ? ? ?val_now = val_now + val_dataset.__len__()
? ? ? ?print("[..........%s] correctnum:%d . zongshu:%d ? " % (subj, max_corr, val_dataset.__len__()))
? ? ? ?print("[ALL_corr]: %d [ALL_val]: %d" % (num_sum, val_now))
? ? ? ?print("[F1_now]: %.4f [F1_ALL]: %.4f" % (max_f1, F1_ALL))


if __name__ == "__main__":
? ?run_training()

CA_block.py


# -*- coding: utf-8 -*-

#import torch
#import torch.nn as nn
import torch
import torch.nn as nn

torch.nn

__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
? ? ? ? ? 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
? ? ? ? ? 'wide_resnet50_2', 'wide_resnet101_2']


model_urls = {
? ?'resnet18': 'https://download.torch.org/models/resnet18-5c106cde.pth',
? ?'resnet34': 'https://download.torch.org/models/resnet34-333f7ec4.pth',
? ?'resnet50': 'https://download.torch.org/models/resnet50-19c8e357.pth',
? ?'resnet101': 'https://download.torch.org/models/resnet101-5d3b4d8f.pth',
? ?'resnet152': 'https://download.torch.org/models/resnet152-b121ed2d.pth',
? ?'resnext50_32x4d': 'https://download.torch.org/models/resnext50_32x4d-7cdf4587.pth',
? ?'resnext101_32x8d': 'https://download.torch.org/models/resnext101_32x8d-8ba56ff5.pth',
? ?'wide_resnet50_2': 'https://download.torch.org/models/wide_resnet50_2-95faca4d.pth',
? ?'wide_resnet101_2': 'https://download.torch.org/models/wide_resnet101_2-32ee1156.pth',
}


def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
? ?"""3x3 convolution with padding"""
? ?return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
? ? ? ? ? ? ? ? ? ? padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1, groups=1):
? ?"""1x1 convolution"""
? ?return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False,groups=groups)

##CA BLOCK
class CABlock(nn.Module):
? ?expansion = 1

? ?def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
? ? ? ? ? ? ? ? base_width=64, dilation=1, norm_layer=None):
? ? ? ?super(CABlock, self).__init__()
? ? ? ?if norm_layer is None:
? ? ? ? ? ?norm_layer = nn.BatchNorm2d
? ? ? ?# if groups != 1 or base_width != 64:
? ? ? ?# ? ? raise ValueError('BasicBlock only supports groups=1 and base_width=64')
? ? ? ?if dilation > 1:
? ? ? ? ? ?raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
? ? ? ?# Both self.conv1 and self.downsample layers downsample the input when stride != 1
? ? ? ?self.conv1 = conv3x3(inplanes, planes, stride,groups=groups)
? ? ? ?self.bn1 = norm_layer(planes)
? ? ? ?self.relu = nn.ReLU(inplace=True)
? ? ? ?self.conv2 = conv1x1(planes, planes,groups=groups)
? ? ? ?self.bn2 = norm_layer(planes)
? ? ? ?self.attn = nn.Sequential(
? ? ? ? ? ?nn.Conv2d(2, 1, kernel_size=1, stride=1,bias=False), ?# 32*33*33
? ? ? ? ? ?nn.BatchNorm2d(1),
? ? ? ? ? ?nn.Sigmoid(),
? ? ? ?)
? ? ? ?self.downsample = downsample
? ? ? ?self.stride = stride
? ? ? ?self.planes=planes

? ?def forward(self, x):
? ? ? ?x, attn_last,if_attn =x##attn_last: downsampled attention maps from last layer as a prior knowledge
? ? ? ?identity = x

? ? ? ?out = self.conv1(x)
? ? ? ?out = self.bn1(out)

? ? ? ?out = self.relu(out)

? ? ? ?out = self.conv2(out)
? ? ? ?out = self.bn2(out)
? ? ? ?if self.downsample is not None:
? ? ? ? ? ?identity = self.downsample(identity)

? ? ? ?out = self.relu(out+identity)
? ? ? ?avg_out = torch.mean(out, dim=1, keepdim=True)
? ? ? ?max_out, _ = torch.max(out, dim=1, keepdim=True)
? ? ? ?attn = torch.cat((avg_out, max_out), dim=1)
? ? ? ?attn = self.attn(attn)
? ? ? ?if attn_last is not None:
? ? ? ? ? ?attn = attn_last * attn

? ? ? ?attn = attn.repeat(1, self.planes, 1, 1)
? ? ? ?if if_attn:
? ? ? ? ? ?out = out *attn


? ? ? ?return out,attn[:, 0, :, :].unsqueeze(1),True





class ResNet(nn.Module):

? ?def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
? ? ? ? ? ? ? ? groups=4, width_per_group=64, replace_stride_with_dilation=None,
? ? ? ? ? ? ? ? norm_layer=None):
? ? ? ?super(ResNet, self).__init__()
? ? ? ?if norm_layer is None:
? ? ? ? ? ?norm_layer = nn.BatchNorm2d
? ? ? ?self._norm_layer = norm_layer

? ? ? ?self.inplanes = 128
? ? ? ?self.dilation = 1
? ? ? ?if replace_stride_with_dilation is None:
? ? ? ? ? ?# each element in the tuple indicates if we should replace
? ? ? ? ? ?# the 2x2 stride with a dilated convolution instead
? ? ? ? ? ?replace_stride_with_dilation = [False, False, False]
? ? ? ?if len(replace_stride_with_dilation) != 3:
? ? ? ? ? ?raise ValueError("replace_stride_with_dilation should be None "
? ? ? ? ? ? ? ? ? ? ? ? ? ? "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
? ? ? ?self.groups = groups
? ? ? ?self.base_width = width_per_group
? ? ? ?self.conv1 = nn.Conv2d(90*2, self.inplanes, kernel_size=3, stride=1,padding=1,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? bias=False,groups=1)
? ? ? ?self.bn1 = norm_layer(self.inplanes)
? ? ? ?self.relu = nn.ReLU(inplace=True)
? ? ? ?self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2,padding=1)
? ? ? ?self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
? ? ? ?self.layer1 = self._make_layer(block, 128, layers[0],groups=1)
? ? ? ?self.inplanes = int(self.inplanes*1)
? ? ? ?self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? dilate=replace_stride_with_dilation[0],groups=1)
? ? ? ?self.inplanes = int(self.inplanes * 1)

? ? ? ?self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? dilate=replace_stride_with_dilation[1],groups=1)
? ? ? ?self.inplanes = int(self.inplanes * 1)

? ? ? ?self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? dilate=replace_stride_with_dilation[2],groups=1)
? ? ? ?self.inplanes = int(self.inplanes * 1)





? ? ? ?self.fc = nn.Linear(512* block.expansion*196, 5)
? ? ? ?self.drop = nn.Dropout(p=0.1)
? ? ? ?for m in self.modules():
? ? ? ? ? ?if isinstance(m, nn.Conv2d):
? ? ? ? ? ? ? ?nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
? ? ? ? ? ?elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
? ? ? ? ? ? ? ?nn.init.constant_(m.weight, 1)
? ? ? ? ? ? ? ?nn.init.constant_(m.bias, 0)

? ? ? ?# Zero-initialize the last BN in each residual branch,
? ? ? ?# so that the residual branch starts with zeros, and each residual block behaves like an identity.
? ? ? ?# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
? ? ? ?if zero_init_residual:
? ? ? ? ? ?for m in self.modules():
? ? ? ? ? ? ? ?if isinstance(m, Bottleneck):
? ? ? ? ? ? ? ? ? ?nn.init.constant_(m.bn3.weight, 0)
? ? ? ? ? ? ? ?elif isinstance(m, BasicBlock):
? ? ? ? ? ? ? ? ? ?nn.init.constant_(m.bn2.weight, 0)

? ?def _make_layer(self, block, planes, blocks, stride=1, dilate=False,groups=1):
? ? ? ?norm_layer = self._norm_layer
? ? ? ?downsample = None
? ? ? ?previous_dilation = self.dilation
? ? ? ?if dilate:
? ? ? ? ? ?self.dilation *= stride
? ? ? ? ? ?stride = 1
? ? ? ?if stride != 1 or self.inplanes != planes * block.expansion:
? ? ? ? ? ?downsample = nn.Sequential(
? ? ? ? ? ? ? ?conv1x1(self.inplanes, planes * block.expansion, stride),
? ? ? ? ? ? ? ?norm_layer(planes * block.expansion),
? ? ? ? ? ?)

? ? ? ?layers = []
? ? ? ?layers.append(block(self.inplanes, planes, stride, downsample, groups,
? ? ? ? ? ? ? ? ? ? ? ? ? ?self.base_width, previous_dilation, norm_layer))
? ? ? ?self.inplanes = planes * block.expansion
? ? ? ?for _ in range(1, blocks):
? ? ? ? ? ?layers.append(block(self.inplanes, planes, groups=self.groups,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?base_width=self.base_width, dilation=self.dilation,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?norm_layer=norm_layer))

? ? ? ?return nn.Sequential(*layers)

? ?def _forward_impl(self, x,POS):##x->input of main branch; POS->position embeddings generated by sub branch

? ? ? ?x = self.conv1(x)
? ? ? ?x = self.bn1(x)
? ? ? ?x = self.relu(x)
? ? ? ?##main branch
? ? ? ?x,attn1,_ = self.layer1((x,None,True))
? ? ? ?temp = attn1
? ? ? ?attn1 = self.maxpool(attn1)

? ? ? ?x ,attn2,_= self.layer2((x,attn1,True))


? ? ? ?attn2=self.maxpool(attn2)

? ? ? ?x ,attn3,_= self.layer3((x,attn2,True))
? ? ? ?#
? ? ? ?attn3 = self.maxpool(attn3)
? ? ? ?x,attn4,_ = self.layer4((x,attn3,True))

? ? ? ?x=x+POS#fusion of motion pattern feature and position embeddings

? ? ? ?x = torch.flatten(x, 1)

? ? ? ?x = self.fc(x)

? ? ? ?return x,temp.view(x.size(0),-1)

? ?def forward(self, x,POS):
? ? ? ?return self._forward_impl(x,POS)


def _resnet(arch, block, layers, pretrained, progress, **kwargs):
? ?model = ResNet(block, layers, **kwargs)
? ?if pretrained:
? ? ? ?state_dict = load_state_dict_from_url(model_urls[arch],
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?progress=progress)
? ? ? ?model.load_state_dict(state_dict)
? ?return model

##main branch consisting of CA blocks
def resnet18_pos_attention(pretrained=False, progress=True, **kwargs):
? ?r"""ResNet-18 model from
? ?`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

? ?Args:
? ? ? ?pretrained (bool): If True, returns a model pre-trained on ImageNet
? ? ? ?progress (bool): If True, displays a progress bar of the download to stderr
? ?"""
? ?return _resnet('resnet18', CABlock, [1, 1, 1, 1], pretrained, progress,
? ? ? ? ? ? ? ? ? **kwargs)


PC_module.py

# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
# -*- coding: utf-8 -*-
#import torch
#import torch.nn as nn
import torch
import torch.nn as nn
from functools import partial

from timm.models.vision_transformer import _cfg
from timm.models.registry import register_model
from timm.models.layers import trunc_normal_
import math
import logging
from functools import partial
from collections import OrderedDict

#import torch
#import torch.nn as nn
#import torch.nn.functional as F
import torch.nn.functional as F
from itertools import repeat
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
import collections.abc
def drop_path(x, drop_prob: float = 0., training: bool = False):
? ?"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

? ?This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
? ?the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
? ?See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
? ?changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
? ?'survival rate' as the argument.

? ?"""
? ?if drop_prob == 0. or not training:
? ? ? ?return x
? ?keep_prob = 1 - drop_prob
? ?shape = (x.shape[0],) + (1,) * (x.ndim - 1) ?# work with diff dim tensors, not just 2D ConvNets
? ?random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
? ?random_tensor.floor_() ?# binarize
? ?output = x.div(keep_prob) * random_tensor
? ?return output


class DropPath(nn.Module):
? ?"""Drop paths (Stochastic Depth) per sample ?(when applied in main path of residual blocks).
? ?"""
? ?def __init__(self, drop_prob=None):
? ? ? ?super(DropPath, self).__init__()
? ? ? ?self.drop_prob = drop_prob

? ?def forward(self, x):
? ? ? ?return drop_path(x, self.drop_prob, self.training)
def _ntuple(n):
? ?def parse(x):
? ? ? ?if isinstance(x, collections.abc.Iterable):
? ? ? ? ? ?return x
? ? ? ?return tuple(repeat(x, n))
? ?return parse


to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple
__all__ = [
? ?'deit_tiny_patch16_224', 'deit_small_patch16_224', 'deit_base_patch16_224',
? ?'deit_tiny_distilled_patch16_224', 'deit_small_distilled_patch16_224',
? ?'deit_base_distilled_patch16_224', 'deit_base_patch16_384',
? ?'deit_base_distilled_patch16_384',
]

class Mlp(nn.Module):
? ?def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
? ? ? ?super().__init__()
? ? ? ?out_features = out_features or in_features
? ? ? ?hidden_features = hidden_features or in_features
? ? ? ?self.fc1 = nn.Linear(in_features, hidden_features)
? ? ? ?self.act = act_layer()
? ? ? ?self.fc2 = nn.Linear(hidden_features, out_features)
? ? ? ?self.drop = nn.Dropout(drop)

? ?def forward(self, x):
? ? ? ?x = self.fc1(x)
? ? ? ?x = self.act(x)
? ? ? ?x = self.drop(x)
? ? ? ?x = self.fc2(x)
? ? ? ?x = self.drop(x)
? ? ? ?return x


class Attention(nn.Module):
? ?def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
? ? ? ?super().__init__()
? ? ? ?self.num_heads = num_heads
? ? ? ?head_dim = dim // num_heads
? ? ? ?# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
? ? ? ?self.scale = qk_scale or head_dim ** -0.5

? ? ? ?self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
? ? ? ?self.attn_drop = nn.Dropout(attn_drop)
? ? ? ?self.proj = nn.Linear(dim, dim)
? ? ? ?self.proj_drop = nn.Dropout(proj_drop)

? ?def forward(self, x):
? ? ? ?B, N, C = x.shape
? ? ? ?qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
? ? ? ?q, k, v = qkv[0], qkv[1], qkv[2] ? # make torchscript happy (cannot use tensor as tuple)
? ? ? ?varq = torch.var(q, dim=2).sum(dim=2).sum()/B/N
? ? ? ?vark = torch.var(k, dim=2).sum(dim=2).sum()/B/N
? ? ? ?varv = torch.var(v, dim=2).sum(dim=2).sum()/B/N
? ? ? ?attn = (q @ k.transpose(-2, -1)) * self.scale
? ? ? ?attn = attn.softmax(dim=-1)
? ? ? ?attn = self.attn_drop(attn)

? ? ? ?x = (attn @ v).transpose(1, 2).reshape(B, N, C)
? ? ? ?x = self.proj(x)
? ? ? ?x = self.proj_drop(x)
? ? ? ?return x


class Block(nn.Module):

? ?def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
? ? ? ? ? ? ? ? drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
? ? ? ?super().__init__()
? ? ? ?self.norm1 = norm_layer(dim)
? ? ? ?self.attn = Attention(
? ? ? ? ? ?dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
? ? ? ?# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
? ? ? ?self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
? ? ? ?self.norm2 = norm_layer(dim)
? ? ? ?mlp_hidden_dim = int(dim * mlp_ratio)
? ? ? ?self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

? ?def forward(self, x):
? ? ? ?x = x + self.drop_path(self.attn(self.norm1(x)))
? ? ? ?x = x + self.drop_path(self.mlp(self.norm2(x)))
? ? ? ?return x


class PatchEmbed(nn.Module):
? ?""" Image to Patch Embedding
? ?"""
? ?def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
? ? ? ?super().__init__()
? ? ? ?img_size = to_2tuple(img_size)
? ? ? ?patch_size = to_2tuple(patch_size)
? ? ? ?num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
? ? ? ?self.img_size = img_size
? ? ? ?self.patch_size = patch_size
? ? ? ?self.num_patches = num_patches

? ? ? ?self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

? ?def forward(self, x):
? ? ? ?B, C, H, W = x.shape
? ? ? ?# FIXME look at relaxing size constraints
? ? ? ?assert H == self.img_size[0] and W == self.img_size[1], \
? ? ? ? ? ?f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
? ? ? ?x = self.proj(x).flatten(2).transpose(1, 2)
? ? ? ?return x


class HybridEmbed(nn.Module):
? ?""" CNN Feature Map Embedding
? ?Extract feature map from CNN, flatten, project to embedding dim.
? ?"""
? ?def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
? ? ? ?super().__init__()
? ? ? ?assert isinstance(backbone, nn.Module)
? ? ? ?img_size = to_2tuple(img_size)
? ? ? ?self.img_size = img_size
? ? ? ?self.backbone = backbone
? ? ? ?if feature_size is None:
? ? ? ? ? ?with torch.no_grad():
? ? ? ? ? ? ? ?# FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
? ? ? ? ? ? ? ?# map for all networks, the feature metadata has reliable channel and stride info, but using
? ? ? ? ? ? ? ?# stride to calc feature dim requires info about padding of each stage that isn't captured.
? ? ? ? ? ? ? ?training = backbone.training
? ? ? ? ? ? ? ?if training:
? ? ? ? ? ? ? ? ? ?backbone.eval()
? ? ? ? ? ? ? ?o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
? ? ? ? ? ? ? ?if isinstance(o, (list, tuple)):
? ? ? ? ? ? ? ? ? ?o = o[-1] ?# last feature if backbone outputs list/tuple of features
? ? ? ? ? ? ? ?feature_size = o.shape[-2:]
? ? ? ? ? ? ? ?feature_dim = o.shape[1]
? ? ? ? ? ? ? ?backbone.train(training)
? ? ? ?else:
? ? ? ? ? ?feature_size = to_2tuple(feature_size)
? ? ? ? ? ?if hasattr(self.backbone, 'feature_info'):
? ? ? ? ? ? ? ?feature_dim = self.backbone.feature_info.channels()[-1]
? ? ? ? ? ?else:
? ? ? ? ? ? ? ?feature_dim = self.backbone.num_features
? ? ? ?self.num_patches = feature_size[0] * feature_size[1]
? ? ? ?self.proj = nn.Conv2d(feature_dim, embed_dim, 1)

? ?def forward(self, x):
? ? ? ?x = self.backbone(x)
? ? ? ?if isinstance(x, (list, tuple)):
? ? ? ? ? ?x = x[-1] ?# last feature if backbone outputs list/tuple of features
? ? ? ?x = self.proj(x).flatten(2).transpose(1, 2)
? ? ? ?return x


###Position Calibration Module
class VisionTransformer_POS(nn.Module):
? ?""" Vision Transformer

? ?A torch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` ?-
? ? ? ?https://arxiv.org/abs/2010.11929
? ?"""
? ?def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
? ? ? ? ? ? ? ? num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
? ? ? ? ? ? ? ? drop_rate=0., attn_drop_rate=0., drop_path_rate=0.15, hybrid_backbone=None, norm_layer=None):
? ? ? ?"""
? ? ? ?Args:
? ? ? ? ? ?img_size (int, tuple): input image size
? ? ? ? ? ?patch_size (int, tuple): patch size
? ? ? ? ? ?in_chans (int): number of input channels
? ? ? ? ? ?num_classes (int): number of classes for classification head
? ? ? ? ? ?embed_dim (int): embedding dimension
? ? ? ? ? ?depth (int): depth of transformer
? ? ? ? ? ?num_heads (int): number of attention heads
? ? ? ? ? ?mlp_ratio (int): ratio of mlp hidden dim to embedding dim
? ? ? ? ? ?qkv_bias (bool): enable bias for qkv if True
? ? ? ? ? ?qk_scale (float): override default qk scale of head_dim ** -0.5 if set
? ? ? ? ? ?representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
? ? ? ? ? ?drop_rate (float): dropout rate
? ? ? ? ? ?attn_drop_rate (float): attention dropout rate
? ? ? ? ? ?drop_path_rate (float): stochastic depth rate
? ? ? ? ? ?hybrid_backbone (nn.Module): CNN backbone to use in-place of PatchEmbed module
? ? ? ? ? ?norm_layer: (nn.Module): normalization layer
? ? ? ?"""
? ? ? ?super().__init__()
? ? ? ?norm_layer=partial(nn.LayerNorm, eps=1e-6)
? ? ? ?self.num_classes = num_classes
? ? ? ?self.num_features = self.embed_dim = embed_dim ?# num_features for consistency with other models
? ? ? ?norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)

? ? ? ?if hybrid_backbone is not None:
? ? ? ? ? ?self.patch_embed = HybridEmbed(
? ? ? ? ? ? ? ?hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
? ? ? ?else:
? ? ? ? ? ?self.patch_embed = PatchEmbed(
? ? ? ? ? ? ? ?img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
? ? ? ?num_patches = self.patch_embed.num_patches

? ? ? ?self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
? ? ? ?self.pos_embed = nn.Parameter(torch.zeros(1, 196, embed_dim))
? ? ? ?self.pos_drop = nn.Dropout(p=drop_rate)

? ? ? ?dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] ?# stochastic depth decay rule
? ? ? ?self.blocks = nn.ModuleList([
? ? ? ? ? ?Block(
? ? ? ? ? ? ? ?dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
? ? ? ? ? ? ? ?drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
? ? ? ? ? ?for i in range(depth)])
? ? ? ?self.norm = norm_layer(embed_dim)

? ? ? ?# Representation layer
? ? ? ?if representation_size:
? ? ? ? ? ?self.num_features = representation_size
? ? ? ? ? ?self.pre_logits = nn.Sequential(OrderedDict([
? ? ? ? ? ? ? ?('fc', nn.Linear(embed_dim, representation_size)),
? ? ? ? ? ? ? ?('act', nn.Tanh())
? ? ? ? ? ?]))
? ? ? ?else:
? ? ? ? ? ?self.pre_logits = nn.Identity()

? ? ? ?# Classifier head
? ? ? ?self.head = nn.Linear(self.num_features, 5) if num_classes > 0 else nn.Identity()
? ? ? ?# self.to_Mask = nn.Sequential(nn.Conv2d(in_channels=self.num_features,out_channels=1,kernel_size=3,padding=1),
? ? ? ?# ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?nn.Hardsigmoid(),
? ? ? ?# ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?)
? ? ? ?# self.to_Mask = nn.Linear(self.num_features,1)
? ? ? ?self.to_Mask = nn.Sequential(nn.Linear(self.num_features,1),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? nn.Sigmoid(),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? )
? ? ? ?trunc_normal_(self.pos_embed, std=.02)
? ? ? ?trunc_normal_(self.cls_token, std=.02)
? ? ? ?self.apply(self._init_weights)

? ?def _init_weights(self, m):
? ? ? ?if isinstance(m, nn.Linear):
? ? ? ? ? ?trunc_normal_(m.weight, std=.02)
? ? ? ? ? ?if isinstance(m, nn.Linear) and m.bias is not None:
? ? ? ? ? ? ? ?nn.init.constant_(m.bias, 0)
? ? ? ?elif isinstance(m, nn.LayerNorm):
? ? ? ? ? ?nn.init.constant_(m.bias, 0)
? ? ? ? ? ?nn.init.constant_(m.weight, 1.0)

? ?@torch.jit.ignore
? ?def no_weight_decay(self):
? ? ? ?return {'pos_embed', 'cls_token'}

? ?def get_classifier(self):
? ? ? ?return self.head

? ?def reset_classifier(self, num_classes, global_pool=''):
? ? ? ?self.num_classes = num_classes
? ? ? ?self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()

? ?def forward_features(self, x):
? ? ? ?B = x.shape[0]
? ? ? ?x = self.patch_embed(x)


? ? ? ?x = x + self.pos_embed
? ? ? ?x = self.pos_drop(x)


? ? ? ?for blk in self.blocks:
? ? ? ? ? ?x = blk(x)


? ? ? ?x = self.norm(x)
? ? ? ?x = self.pre_logits(x)
? ? ? ?return x

? ?def forward(self, x):
? ? ? ?x = self.forward_features(x)

? ? ? ?return x

CASME2數(shù)據(jù)集中的問題請自行修改



MMNET 微表情識別(CASME2數(shù)據(jù)集)的評論 (共 條)

分享到微博請遵守國家法律
新河县| 台东县| 澄江县| 内丘县| 漠河县| 壤塘县| 贡山| 历史| 富宁县| 海淀区| 荥经县| 齐河县| 云龙县| 尼勒克县| 绥棱县| 海伦市| 巴彦淖尔市| 桂东县| 屏边| 汶上县| 玛纳斯县| 曲周县| 安化县| 萨迦县| 石河子市| 岳阳县| 莱州市| 赤峰市| 广州市| 宜都市| 衡阳市| 监利县| 融水| 武城县| 揭西县| 固阳县| 苗栗县| 凉城县| 邯郸市| 墨脱县| 辉县市|