python實(shí)現(xiàn)CylcleGAN人臉轉(zhuǎn)卡通圖【文末源碼】
CycleGan人臉轉(zhuǎn)為卡通圖
引言:近幾天一個(gè)GitHub項(xiàng)目火遍了朋友圈,那就是卡通頭像AI生成小程序。如下圖所見(jiàn):
而這個(gè)項(xiàng)目的基本原理是用python搭建的GAN算法模型,進(jìn)行訓(xùn)練得出。
而所謂的GAN就是指生成對(duì)抗網(wǎng)絡(luò)深度學(xué)習(xí)模型。網(wǎng)絡(luò)中有生成器G(generator)和鑒別器(Discriminator)。
有兩個(gè)數(shù)據(jù)域分別為X,Y。G
負(fù)責(zé)把X域中的數(shù)據(jù)拿過(guò)來(lái)拼命地模仿成真實(shí)數(shù)據(jù)并把它們藏在真實(shí)數(shù)據(jù)中,而
D
就拼命地要把偽造數(shù)據(jù)和真實(shí)數(shù)據(jù)分開(kāi)。經(jīng)過(guò)二者的博弈以后,G
的偽造技術(shù)越來(lái)越厲害,D
的鑒別技術(shù)也越來(lái)越厲害。直到
D
再也分不出數(shù)據(jù)是真實(shí)的還是
G
生成的數(shù)據(jù)的時(shí)候,這個(gè)對(duì)抗的過(guò)程達(dá)到一個(gè)動(dòng)態(tài)的平衡。
而CycleGAN本質(zhì)上是兩個(gè)鏡像對(duì)稱的GAN,構(gòu)成了一個(gè)環(huán)形網(wǎng)絡(luò)。
兩個(gè)GAN共享兩個(gè)生成器,并各自帶一個(gè)判別器,即共有兩個(gè)判別器和兩個(gè)生成器。一個(gè)單向GAN兩個(gè)loss,兩個(gè)即共四個(gè)loss。
可以實(shí)現(xiàn)無(wú)配對(duì)的兩個(gè)圖片集的訓(xùn)練是CycleGAN與Pixel2Pixel相比的一個(gè)典型優(yōu)點(diǎn)。但是我們?nèi)匀恍枰ㄟ^(guò)訓(xùn)練創(chuàng)建這個(gè)映射來(lái)確保輸入圖像和生成圖像間存在有意義的關(guān)聯(lián),即輸入輸出共享一些特征。
簡(jiǎn)而言之,該模型通過(guò)從域DA獲取輸入圖像,該輸入圖像被傳遞到第一個(gè)生成器GeneratorA→B,其任務(wù)是將來(lái)自域DA的給定圖像轉(zhuǎn)換到目標(biāo)域DB中的圖像。然后這個(gè)新生成的圖像被傳遞到另一個(gè)生成器GeneratorB→A,其任務(wù)是在原始域DA轉(zhuǎn)換回圖像,這里可與自動(dòng)編碼器作對(duì)比。這個(gè)輸出圖像必須與原始輸入圖像相似,用來(lái)定義非配對(duì)數(shù)據(jù)集中原來(lái)不存在的有意義映射。
在本次的項(xiàng)目中就是利用了CycleGAN進(jìn)行搭建模型。模型訓(xùn)練數(shù)據(jù)集如下:
一、實(shí)驗(yàn)前的準(zhǔn)備:
首先我們使用的python版本是3.6.5所用到的庫(kù)有pytorch和TensorFlow,用來(lái)訓(xùn)練和加載神經(jīng)網(wǎng)絡(luò)常見(jiàn)的框架;face-alignment用來(lái)是用來(lái)提取人臉特征的常用庫(kù);
dlib是一個(gè)機(jī)器學(xué)習(xí)的開(kāi)源庫(kù),包含了機(jī)器學(xué)習(xí)的很多算法,使用起來(lái)很方便,直接包含頭文件即可,并且不依賴于其他庫(kù)(自帶圖像編解碼庫(kù)源碼)。Dlib可以幫助您創(chuàng)建很多復(fù)雜的機(jī)器學(xué)習(xí)方面的軟件來(lái)幫助解決實(shí)際問(wèn)題。目前Dlib已經(jīng)被廣泛的用在行業(yè)和學(xué)術(shù)領(lǐng)域,包括機(jī)器人,嵌入式設(shè)備,移動(dòng)電話和大型高性能計(jì)算環(huán)境。
二、模型的訓(xùn)練
1、數(shù)據(jù)集處理和準(zhǔn)備:
訓(xùn)練數(shù)據(jù)包括真實(shí)照片和卡通畫(huà)像,為降低訓(xùn)練復(fù)雜度,我們對(duì)兩類數(shù)據(jù)進(jìn)行了如下預(yù)處理:
檢測(cè)人臉及關(guān)鍵點(diǎn)。
根據(jù)關(guān)鍵點(diǎn)旋轉(zhuǎn)校正人臉。
將關(guān)鍵點(diǎn)邊界框按固定的比例擴(kuò)張并裁剪出人臉區(qū)域。
使用人像分割模型將背景置白。
為了形成匹配效果,需要準(zhǔn)備一些卡通人物圖片和真實(shí)的人臉圖片進(jìn)行訓(xùn)練
2、模型的訓(xùn)練:
模型的訓(xùn)練使用python train.py --dataset photo2cartoon進(jìn)行訓(xùn)練即可。
3、神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)搭建:
整個(gè)算法的搭建正如上面可見(jiàn),需要有生成器和判別器。使用論文提出的一種Soft-AdaLIN(Soft Adaptive Layer-Instance Normalization)歸一化方法,在反規(guī)范化時(shí)將編碼器的均值方差(照片特征)與解碼器的均值方差(卡通特征)相融合。
模型結(jié)構(gòu)方面,在U-GAT-IT的基礎(chǔ)上,在編碼器之前和解碼器之后各增加了2個(gè)hourglass模塊,漸進(jìn)地提升模型特征抽象和重建能力。
部分代碼如下:
class ResnetGenerator(nn.Module):
def __init__(self, ngf=64, img_size=256, light=False):
super(ResnetGenerator, self).__init__()
self.light = light
self.ConvBlock1 = nn.Sequential(nn.ReflectionPad2d(3),
? ? ? ? ? ? ? ? ?nn.Conv2d(3, ngf, kernel_size=7, stride=1, padding=0, bias=False),
? ? ? ? ? ? ? ? ?nn.InstanceNorm2d(ngf),
? ? ? ? ? ? ? ? ?nn.ReLU(True))
self.HourGlass1 = HourGlass(ngf, ngf)
self.HourGlass2 = HourGlass(ngf, ngf)
# Down-Sampling
self.DownBlock1 = nn.Sequential(nn.ReflectionPad2d(1),
? ? ? ? ? ? ? ? ? nn.Conv2d(ngf, ngf*2, kernel_size=3, stride=2, padding=0, bias=False),
? ? ? ? ? ? ? ? ? nn.InstanceNorm2d(ngf * 2),
? ? ? ? ? ? ? ? ? nn.ReLU(True))
self.DownBlock2 = nn.Sequential(nn.ReflectionPad2d(1),
? ? ? ? ? ? ? ? ? nn.Conv2d(ngf*2, ngf*4, kernel_size=3, stride=2, padding=0, bias=False),
? ? ? ? ? ? ? ? ? nn.InstanceNorm2d(ngf*4),
? ? ? ? ? ? ? ? ? nn.ReLU(True))
# Encoder Bottleneck
self.EncodeBlock1 = ResnetBlock(ngf*4)
self.EncodeBlock2 = ResnetBlock(ngf*4)
self.EncodeBlock3 = ResnetBlock(ngf*4)
self.EncodeBlock4 = ResnetBlock(ngf*4)
# Class Activation Map
self.gap_fc = nn.Linear(ngf*4, 1)
self.gmp_fc = nn.Linear(ngf*4, 1)
self.conv1x1 = nn.Conv2d(ngf*8, ngf*4, kernel_size=1, stride=1)
self.relu = nn.ReLU(True)
# Gamma, Beta block
if self.light:
self.FC = nn.Sequential(nn.Linear(ngf*4, ngf*4),
? ? ? ? ? ? ? nn.ReLU(True),
? ? ? ? ? ? ? nn.Linear(ngf*4, ngf*4),
? ? ? ? ? ? ? nn.ReLU(True))
else:
self.FC = nn.Sequential(nn.Linear(img_size//4*img_size//4*ngf*4, ngf*4),
? ? ? ? ? ? ? nn.ReLU(True),
? ? ? ? ? ? ? nn.Linear(ngf*4, ngf*4),
? ? ? ? ? ? ? nn.ReLU(True))
# Decoder Bottleneck
self.DecodeBlock1 = ResnetSoftAdaLINBlock(ngf*4)
self.DecodeBlock2 = ResnetSoftAdaLINBlock(ngf*4)
self.DecodeBlock3 = ResnetSoftAdaLINBlock(ngf*4)
self.DecodeBlock4 = ResnetSoftAdaLINBlock(ngf*4)
# Up-Sampling
self.UpBlock1 = nn.Sequential(nn.Upsample(scale_factor=2),
? ? ? ? ? ? ? ? nn.ReflectionPad2d(1),
? ? ? ? ? ? ? ? nn.Conv2d(ngf*4, ngf*2, kernel_size=3, stride=1, padding=0, bias=False),
? ? ? ? ? ? ? ? LIN(ngf*2),
? ? ? ? ? ? ? ? nn.ReLU(True))
self.UpBlock2 = nn.Sequential(nn.Upsample(scale_factor=2),
? ? ? ? ? ? ? ? nn.ReflectionPad2d(1),
? ? ? ? ? ? ? ? nn.Conv2d(ngf*2, ngf, kernel_size=3, stride=1, padding=0, bias=False),
? ? ? ? ? ? ? ? LIN(ngf),
? ? ? ? ? ? ? ? nn.ReLU(True))
self.HourGlass3 = HourGlass(ngf, ngf)
self.HourGlass4 = HourGlass(ngf, ngf, False)
self.ConvBlock2 = nn.Sequential(nn.ReflectionPad2d(3),
? ? ? ? ? ? ? ? ? nn.Conv2d(3, 3, kernel_size=7, stride=1, padding=0, bias=False),
? ? ? ? ? ? ? ? ? nn.Tanh())
def forward(self, x):
x = self.ConvBlock1(x)
x = self.HourGlass1(x)
x = self.HourGlass2(x)
x = self.DownBlock1(x)
x = self.DownBlock2(x)
x = self.EncodeBlock1(x)
content_features1 = F.adaptive_avg_pool2d(x, 1).view(x.shape[0], -1)
x = self.EncodeBlock2(x)
content_features2 = F.adaptive_avg_pool2d(x, 1).view(x.shape[0], -1)
x = self.EncodeBlock3(x)
content_features3 = F.adaptive_avg_pool2d(x, 1).view(x.shape[0], -1)
x = self.EncodeBlock4(x)
content_features4 = F.adaptive_avg_pool2d(x, 1).view(x.shape[0], -1)
gap = F.adaptive_avg_pool2d(x, 1)
gap_logit = self.gap_fc(gap.view(x.shape[0], -1))
gap_weight = list(self.gap_fc.parameters())[0]
gap = x * gap_weight.unsqueeze(2).unsqueeze(3)
gmp = F.adaptive_max_pool2d(x, 1)
gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1))
gmp_weight = list(self.gmp_fc.parameters())[0]
gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3)
cam_logit = torch.cat([gap_logit, gmp_logit], 1)
x = torch.cat([gap, gmp], 1)
x = self.relu(self.conv1x1(x))
heatmap = torch.sum(x, dim=1, keepdim=True)
if self.light:
x_ = F.adaptive_avg_pool2d(x, 1)
style_features = self.FC(x_.view(x_.shape[0], -1))
else:
style_features = self.FC(x.view(x.shape[0], -1))
x = self.DecodeBlock1(x, content_features4, style_features)
x = self.DecodeBlock2(x, content_features3, style_features)
x = self.DecodeBlock3(x, content_features2, style_features)
x = self.DecodeBlock4(x, content_features1, style_features)
x = self.UpBlock1(x)
x = self.UpBlock2(x)
x = self.HourGlass3(x)
x = self.HourGlass4(x)
out = self.ConvBlock2(x)
return out, cam_logit, heatmap
4、提取人臉特征:
為了提取人臉特征以達(dá)到加載到網(wǎng)絡(luò)中的目的,我們需要正確框出人臉同時(shí)計(jì)算特征距離,以方便后面訓(xùn)練模型師損失函數(shù)的調(diào)用。
代碼如下:
class FaceFeatures(object):
def __init__(self, weights_path, device):
self.device = device
self.model = MobileFaceNet(512).to(device)
self.model.load_state_dict(torch.load(weights_path))
self.model.eval()
def infer(self, batch_tensor):
# crop face
h, w = batch_tensor.shape[2:]
top = int(h / 2.1 * (0.8 - 0.33))
bottom = int(h - (h / 2.1 * 0.3))
size = bottom - top
left = int(w / 2 - size / 2)
right = left + size
batch_tensor = batch_tensor[:, :, top: bottom, left: right]
batch_tensor = F.interpolate(batch_tensor, size=[112, 112], mode='bilinear', align_corners=True)
features = self.model(batch_tensor)
return features
def cosine_distance(self, batch_tensor1, batch_tensor2):
feature1 = self.infer(batch_tensor1)
feature2 = self.infer(batch_tensor2)
return 1 - torch.cosine_similarity(feature1, feature2)
三、模型測(cè)試
在訓(xùn)練好模型后,我們使用python test.py --photo_path ./images/1.jpg --save_path ./images/2.png測(cè)試生成圖片。其中1.jpg是原始圖片,最終會(huì)生成2.jpg圖片。
使用python data_process.py --data_path YourPhotoFolderPath --save_path YourSaveFolderPath批量生成
1、調(diào)用模型:
調(diào)用模型首先要使用torch進(jìn)行加載模型,讀取神經(jīng)網(wǎng)絡(luò)參數(shù)。在對(duì)原始圖片提取人臉特征的基礎(chǔ)上,加載進(jìn)網(wǎng)絡(luò)進(jìn)行生成即可。因?yàn)檫@里我們還需要對(duì)生成的數(shù)據(jù)進(jìn)行轉(zhuǎn)換成圖片,我們這里還需要使用numpy和opencv進(jìn)行圖片的轉(zhuǎn)化。因?yàn)榧虞d如模型和模型生成的必然是數(shù)據(jù),而我們需要將生成器產(chǎn)生的數(shù)據(jù)再轉(zhuǎn)換為圖片,就用到了這兩個(gè)庫(kù)。
代碼如下:
class Photo2Cartoon:
def __init__(self):
self.pre = Preprocess()
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.net = ResnetGenerator(ngf=32, img_size=256, light=True).to(self.device)
params = torch.load('./models/photo2cartoon_weights.pt', map_location=self.device)
self.net.load_state_dict(params['genA2B'])
def inference(self, img):
# face alignment and segmentation
face_rgba = self.pre.process(img)
if face_rgba is None:
print('can not detect face!!!')
return None
face_rgba = CV2.resize(face_rgba, (256, 256), interpolation=CV2.INTER_AREA)
face = face_rgba[:, :, :3].copy()
mask = face_rgba[:, :, 3][:, :, np.newaxis].copy() / 255.
face = (face*mask + (1-mask)*255) / 127.5 - 1
face = np.transpose(face[np.newaxis, :, :, :], (0, 3, 1, 2)).astype(np.float32)
face = torch.from_numpy(face).to(self.device)
# inference
with torch.no_grad():
cartoon = self.net(face)[0][0]
# post-process
cartoon = np.transpose(cartoon.cpu().numpy(), (1, 2, 0))
cartoon = (cartoon + 1) * 127.5
cartoon = (cartoon * mask + 255 * (1 - mask)).astype(np.uint8)
cartoon = CV2.cvtColor(cartoon, CV2.COLOR_RGB2BGR)
return cartoon
if __name__ == '__main__':
img = CV2.cvtColor(CV2.imread(args.photo_path), CV2.COLOR_BGR2RGB)
c2p = Photo2Cartoon()
cartoon = c2p.inference(img)
if cartoon is not None:
CV2.imwrite(args.save_path, cartoon)
到這里,我們整體的程序就搭建完成,下面為我們程序的運(yùn)行結(jié)果:
源碼地址:https://gitcode.net/qq_42279468/python-cylclegan.git