最美情侣中文字幕电影,在线麻豆精品传媒,在线网站高清黄,久久黄色视频

歡迎光臨散文網(wǎng) 會員登陸 & 注冊

Swin-Unet與Unet的菜雞分析與復現(xiàn)

2023-07-14 19:14 作者:蟈總  | 我要投稿

Unet和Swin-Unet都是語義分割模型,網(wǎng)絡(luò)結(jié)構(gòu)都是一個類似于U型的編碼器-解碼器結(jié)構(gòu)。前者是2015年提出的經(jīng)典模型,全使用了卷積/反卷積操作;后者將這些操作全部改為Transformer。


Unet

網(wǎng)絡(luò)結(jié)構(gòu)



左側(cè)相當于編碼器,右側(cè)相當于解碼器。左右各四個Stage。編碼器進行四輪卷積(RELU)-池化操作,解碼器進行四輪卷積-上采樣操作。其中左側(cè)每進行一次池化得到的特征圖,與右側(cè)對應的特征圖進行拼接。

深層網(wǎng)絡(luò)通過拼接的方式,有助于找回前面丟失的邊緣特征。

代碼復現(xiàn)

為了與Swin-Unet在相同數(shù)據(jù)集和條件下訓練,適配輸入224*224單通道圖片。修改部分參數(shù),Pytorch代碼如下:

class Unet(nn.Module):
? ?def __init__(self, num_classes=9):
? ? ? ?super(Unet, self).__init__()
? ? ? ?self.encoder = Encoder(in_channels=1)
? ? ? ?self.decoder = Decoder(num_classes)

? ?def forward(self, inputs):
? ? ? ?[feat1, feat2, feat3, feat4, feat5] = self.encoder(inputs)
? ? ? ?output = self.decoder(feat1, feat2, feat3, feat4, feat5)

? ? ? ?return output
? ? ? ?
# A VGG-like network, pure CNN
class Encoder(nn.Module):
? ?def __init__(self, in_channels):
? ? ? ?super(Encoder, self).__init__()
? ? ? ?self.features = nn.Sequential(
? ? ? ? ? ?nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
? ? ? ? ? ?nn.BatchNorm2d(64),
? ? ? ? ? ?nn.ReLU(inplace=True),

? ? ? ? ? ?nn.Conv2d(64, 64, kernel_size=3, padding=1),
? ? ? ? ? ?nn.BatchNorm2d(64),
? ? ? ? ? ?nn.ReLU(inplace=True),
? ? ? ? ? ?nn.MaxPool2d(kernel_size=2, stride=2),

? ? ? ? ? ?nn.Conv2d(64, 128, kernel_size=3, padding=1),
? ? ? ? ? ?nn.BatchNorm2d(128),
? ? ? ? ? ?nn.ReLU(inplace=True),

? ? ? ? ? ?nn.Conv2d(128, 128, kernel_size=3, padding=1),
? ? ? ? ? ?nn.BatchNorm2d(128),
? ? ? ? ? ?nn.ReLU(inplace=True),
? ? ? ? ? ?nn.MaxPool2d(kernel_size=2, stride=2),

? ? ? ? ? ?nn.Conv2d(128, 256, kernel_size=3, padding=1),
? ? ? ? ? ?nn.BatchNorm2d(256),
? ? ? ? ? ?nn.ReLU(inplace=True),

? ? ? ? ? ?nn.Conv2d(256, 256, kernel_size=3, padding=1),
? ? ? ? ? ?nn.BatchNorm2d(256),
? ? ? ? ? ?nn.ReLU(inplace=True),

? ? ? ? ? ?nn.Conv2d(256, 256, kernel_size=3, padding=1),
? ? ? ? ? ?nn.BatchNorm2d(256),
? ? ? ? ? ?nn.ReLU(inplace=True),
? ? ? ? ? ?nn.MaxPool2d(kernel_size=2, stride=2),

? ? ? ? ? ?nn.Conv2d(256, 512, kernel_size=3, padding=1),
? ? ? ? ? ?nn.BatchNorm2d(512),
? ? ? ? ? ?nn.ReLU(inplace=True),

? ? ? ? ? ?nn.Conv2d(512, 512, kernel_size=3, padding=1),
? ? ? ? ? ?nn.BatchNorm2d(512),
? ? ? ? ? ?nn.ReLU(inplace=True),

? ? ? ? ? ?nn.Conv2d(512, 512, kernel_size=3, padding=1),
? ? ? ? ? ?nn.BatchNorm2d(512),
? ? ? ? ? ?nn.ReLU(inplace=True),
? ? ? ? ? ?nn.MaxPool2d(kernel_size=2, stride=2),

? ? ? ? ? ?nn.Conv2d(512, 512, kernel_size=3, padding=1),
? ? ? ? ? ?nn.BatchNorm2d(512),
? ? ? ? ? ?nn.ReLU(inplace=True),

? ? ? ? ? ?nn.Conv2d(512, 512, kernel_size=3, padding=1),
? ? ? ? ? ?nn.BatchNorm2d(512),
? ? ? ? ? ?nn.ReLU(inplace=True),

? ? ? ? ? ?nn.Conv2d(512, 512, kernel_size=3, padding=1),
? ? ? ? ? ?nn.BatchNorm2d(512),
? ? ? ? ? ?nn.ReLU(inplace=True),
? ? ? ? ? ?nn.MaxPool2d(kernel_size=2, stride=2)
? ? ? ?)
? ? ? ?self._initialize_weights()

? ?def forward(self, x):
? ? ? ?# feature 1-4 is for copy-crop later, feature5 is the output
? ? ? ?feat1 = self.features[:6](x)
? ? ? ?feat2 = self.features[6:13](feat1)
? ? ? ?feat3 = self.features[13:23](feat2)
? ? ? ?feat4 = self.features[23:33](feat3)
? ? ? ?feat5 = self.features[33:-1](feat4)
? ? ? ?return [feat1, feat2, feat3, feat4, feat5]
? ?
class DecoderLayer(nn.Module):
? ?def __init__(self, in_size, out_size):
? ? ? ?super(DecoderLayer, self).__init__()
? ? ? ?self.conv1 = nn.Conv2d(in_size, out_size, kernel_size=3, padding=1)
? ? ? ?self.conv2 = nn.Conv2d(out_size, out_size, kernel_size=3, padding=1)
? ? ? ?self.up = nn.UpsamplingBilinear2d(scale_factor=2)
? ? ? ?self.relu = nn.ReLU(inplace=True)

? ?def forward(self, inputs1, inputs2):
? ? ? ?outputs = torch.cat([inputs1, self.up(inputs2)], 1)
? ? ? ?outputs = self.conv1(outputs)
? ? ? ?outputs = self.relu(outputs)
? ? ? ?outputs = self.conv2(outputs)
? ? ? ?outputs = self.relu(outputs)
? ? ? ?return outputs


class Decoder(nn.Module):
? ?def __init__(self, num_classes):
? ? ? ?super(Decoder, self).__init__()

? ? ? ?# upSampling
? ? ? ?# 64,64,512
? ? ? ?self.up_concat4 = DecoderLayer(1024, 512)
? ? ? ?# 128,128,256
? ? ? ?self.up_concat3 = DecoderLayer(768, 256)
? ? ? ?# 256,256,128
? ? ? ?self.up_concat2 = DecoderLayer(384, 128)
? ? ? ?# 512,512,64
? ? ? ?self.up_concat1 = DecoderLayer(192, 64)

? ? ? ?self.final = nn.Conv2d(64, num_classes, 1)

? ?def forward(self, feat1, feat2, feat3, feat4, feat5):
? ? ? ?up4 = self.up_concat4(feat4, feat5)
? ? ? ?up3 = self.up_concat3(feat3, up4)
? ? ? ?up2 = self.up_concat2(feat2, up3)
? ? ? ?up1 = self.up_concat1(feat1, up2)
? ? ? ?final = self.final(up1)

? ? ? ?return final

Swin-Unet

網(wǎng)絡(luò)結(jié)構(gòu)


編碼器

先對input圖像進行Patch Partition,對應于代碼里PatchEmbed類的實現(xiàn),將圖片切成patch_size*patch_size的圖塊,嵌入到Embedding

接三個Stage,每次先經(jīng)過Patch Merging下采樣,在行方向和列方向上間隔2選取元素(對應Unet里的卷積用來降低分辨率),然后經(jīng)過兩個Swin Transformer塊。但是代碼里的實現(xiàn)并不是按照圖示這樣分割的,它是將Patch Partition, Linear Embedding放在了前面開頭,然后以“兩個Swin Transformer塊+Patch Merging”作為一個BasicLayer,將Bottleneck看做沒有后接Patch Merging的BasicLayer。要是按著圖示去復現(xiàn),那Encoder跟官方代碼肯定是長得非常不一樣了。

最后再進行一次Patch Merging。連接編解碼器的是Bottleneck,其實就是兩個Swin Transformer塊,此層圖片大小不變。

解碼器

上采樣操作由Patch Expanding完成,每個Stage由上采樣和兩個Swin Tranformer塊組成。但是代碼里的實現(xiàn)并不是按照圖示這樣分割的,從下往上走,它將第一個Patch Expanding單獨分出來,然后以“兩個Swin Transformer塊+Patch Expanding”作為一個BasicLayer_up,最后的Patch Expanding稱為FinalPatchExpand_X4,是因為只有它是以4倍上采樣。

左右每個Stage之間有跳躍連接,圖示位置不明確;具體講,是把編碼器每個Stage進入BasicLayer前的輸入,連接到解碼器進入每個Stage前的輸入。不過關(guān)于在哪里進行跳躍連接,作者的意思是“可以調(diào)整,不影響整體架構(gòu)”。

Swin Transformer


Swin-Unet的編碼器部分就是Swin Transformer的結(jié)構(gòu),每兩個Swin Tranformer塊的結(jié)構(gòu)如右b圖。

LNLayerNorm

MLP:帶有GELU非線性的2層MLP

W-MSAWindow Attention,Transformer是基于全局來計算注意力的。而Swin Transformer將注意力的計算限制在每個窗口內(nèi),進而減少了計算量

SW-MSAShifted Window Attention,對特征圖移位,并給Attention設(shè)置mask來實現(xiàn)和Window Attention相同的計算結(jié)果

我認為Swin-Unet是把Unet結(jié)構(gòu)和Swin Transformer放在一起時自然而然產(chǎn)生的想法。因為Unet的主干網(wǎng)絡(luò)本身就不一定是原本論文里的樣子,可以是Resnet、VGG,可以是TransUnet這種CNN和Transformer結(jié)合的形態(tài)。關(guān)鍵還是在于有跳躍連接,而Swin Transformer又更加高效輕量,長程注意力也有很大優(yōu)勢。

代碼復現(xiàn)

引用自論文源倉庫,結(jié)合自己的理解,寫一下核心部分的代碼:

import torch
from torch import nn


def no_weight_decay():
? ?return {'absolute_pos_embed'}


def no_weight_decay_keywords():
? ?return {'relative_position_bias_table'}


def _init_weights(m):
? ?if isinstance(m, nn.Linear):
? ? ? ?trunc_normal_(m.weight, std=.02)
? ? ? ?if isinstance(m, nn.Linear) and m.bias is not None:
? ? ? ? ? ?nn.init.constant_(m.bias, 0)
? ?elif isinstance(m, nn.LayerNorm):
? ? ? ?nn.init.constant_(m.bias, 0)
? ? ? ?nn.init.constant_(m.weight, 1.0)


class Swin_Unet(nn.Module):
? ?def __init__(self, img_size, patch_size, in_channels, num_classes,
? ? ? ? ? ? ? ? embed_dim, depths, num_heads,
? ? ? ? ? ? ? ? window_size, mlp_ratio, qkv_bias, qk_scale,
? ? ? ? ? ? ? ? drop_rate, attn_drop_rate, drop_path_rate):
? ? ? ?super().__init__()

? ? ? ?self.num_classes = num_classes
? ? ? ?self.num_layers = len(depths)
? ? ? ?self.embed_dim = embed_dim
? ? ? ?self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
? ? ? ?self.num_features_up = int(embed_dim * 2)
? ? ? ?self.mlp_ratio = mlp_ratio

? ? ? ?# patch partition 和 linear embedding

? ? ? ?# split image into non-overlapping patches
? ? ? ?self.patch_embed = PatchEmbed(
? ? ? ? ? ?img_size=img_size, patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim,
? ? ? ? ? ?norm_layer=nn.LayerNorm)
? ? ? ?patches_resolution = self.patch_embed.patches_resolution
? ? ? ?self.patches_resolution = patches_resolution
? ? ? ?# absolute position embedding,
? ? ? ?self.pos_drop = nn.Dropout(p=drop_rate)
? ? ? ?# stochastic depth
? ? ? ?dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]

? ? ? ?# build encoderStages and bottleneck layers,每個BasicLayer包含兩個Swin Transformer Block和一個下采樣
? ? ? ?self.layers = nn.ModuleList()
? ? ? ?for i_layer in range(self.num_layers):
? ? ? ? ? ?layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? input_resolution=(patches_resolution[0] // (2 ** i_layer),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? patches_resolution[1] // (2 ** i_layer)),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? depth=depths[i_layer],
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? num_heads=num_heads[i_layer],
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? window_size=window_size,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? mlp_ratio=self.mlp_ratio,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? qkv_bias=qkv_bias, qk_scale=qk_scale,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? drop=drop_rate, attn_drop=attn_drop_rate,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? norm_layer=nn.LayerNorm,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? downsample=PatchMerging if (i_layer < self.num_layers - 1) else None) # bottleneck沒有下采樣
? ? ? ? ? ?self.layers.append(layer)

? ? ? ?# build decoder layers,解碼器每個Stage
? ? ? ?self.layers_up = nn.ModuleList()
? ? ? ?self.concat_back_dim = nn.ModuleList()
? ? ? ?for i_layer in range(self.num_layers):
? ? ? ? ? ?concat_linear = nn.Linear(2 * int(embed_dim * 2 ** (self.num_layers - 1 - i_layer)),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?int(embed_dim * 2 ** (
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?self.num_layers - 1 - i_layer))) if i_layer > 0 else nn.Identity()
? ? ? ? ? ?if i_layer == 0:
? ? ? ? ? ? ? ?layer_up = PatchExpand(
? ? ? ? ? ? ? ? ? ?input_resolution=(patches_resolution[0] // (2 ** (self.num_layers - 1 - i_layer)),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?patches_resolution[1] // (2 ** (self.num_layers - 1 - i_layer))),
? ? ? ? ? ? ? ? ? ?dim=int(embed_dim * 2 ** (self.num_layers - 1 - i_layer)), dim_scale=2, norm_layer=norm_layer)
? ? ? ? ? ?else:
? ? ? ? ? ? ? ?layer_up = BasicLayer_up(dim=int(embed_dim * 2 ** (self.num_layers - 1 - i_layer)),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? input_resolution=(
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? patches_resolution[0] // (2 ** (self.num_layers - 1 - i_layer)),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? patches_resolution[1] // (2 ** (self.num_layers - 1 - i_layer))),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? depth=depths[(self.num_layers - 1 - i_layer)],
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? num_heads=num_heads[(self.num_layers - 1 - i_layer)],
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? window_size=window_size,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? mlp_ratio=self.mlp_ratio,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? qkv_bias=qkv_bias, qk_scale=qk_scale,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? drop=drop_rate, attn_drop=attn_drop_rate,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? drop_path=dpr[sum(depths[:(self.num_layers - 1 - i_layer)]):sum(
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? depths[:(self.num_layers - 1 - i_layer) + 1])],
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? norm_layer=nn.LayerNorm,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? upsample=PatchExpand if (i_layer < self.num_layers - 1) else None)
? ? ? ? ? ?self.layers_up.append(layer_up)
? ? ? ? ? ?self.concat_back_dim.append(concat_linear)

? ? ? ?self.norm = nn.LayerNorm(self.num_features)
? ? ? ?self.norm_up = nn.LayerNorm(self.embed_dim)

? ? ? ?self.up = FinalPatchExpand_X4(input_resolution=(img_size // patch_size, img_size // patch_size),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?dim_scale=4, dim=embed_dim)
? ? ? ?self.output = nn.Conv2d(in_channels=embed_dim, out_channels=self.num_classes, kernel_size=1, bias=False)

? ? ? ?self.apply(_init_weights)

? ?# Encoder and Bottleneck
? ?def forward_features(self, x):
? ? ? ?x = self.patch_embed(x)
? ? ? ?if self.ape:
? ? ? ? ? ?x = x + self.absolute_pos_embed
? ? ? ?x = self.pos_drop(x)
? ? ? ?x_down_sample = []

? ? ? ?for layer in self.layers:
? ? ? ? ? ?x_down_sample.append(x)
? ? ? ? ? ?x = layer(x)

? ? ? ?x = self.norm(x) ?# B L C

? ? ? ?return x, x_down_sample

? ?# Decoder and Skip connection
? ?def forward_up_features(self, x, x_down_sample):
? ? ? ?for inx, layer_up in enumerate(self.layers_up):
? ? ? ? ? ?if inx == 0:
? ? ? ? ? ? ? ?x = layer_up(x)
? ? ? ? ? ?else:
? ? ? ? ? ? ? ?x = torch.cat([x, x_down_sample[3 - inx]], -1)
? ? ? ? ? ? ? ?x = self.concat_back_dim[inx](x)
? ? ? ? ? ? ? ?x = layer_up(x)

? ? ? ?x = self.norm_up(x) ?# B L C

? ? ? ?return x

? ?def up_x4(self, x):
? ? ? ?H, W = self.patches_resolution
? ? ? ?B, L, C = x.shape
? ? ? ?assert L == H * W, "input features has wrong size"

? ? ? ?x = self.up(x)
? ? ? ?x = x.view(B, 4 * H, 4 * W, -1)
? ? ? ?x = x.permute(0, 3, 1, 2) ?# B,C,H,W
? ? ? ?x = self.output(x)

? ? ? ?return x

? ?def forward(self, x):
? ? ? ?x, x_down_sample = self.forward_features(x)
? ? ? ?x = self.forward_up_features(x, x_down_sample)
? ? ? ?x = self.up_x4(x)

? ? ? ?return x

? ?def flops(self):
? ? ? ?flops = 0
? ? ? ?flops += self.patch_embed.flops()
? ? ? ?for i, layer in enumerate(self.layers):
? ? ? ? ? ?flops += layer.flops()
? ? ? ?flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
? ? ? ?flops += self.num_features * self.num_classes
? ? ? ?return flops

訓練與推理

PyTorch 1.11.0

Python 3.8(ubuntu20.04)

Cuda 11.3

GPU:RTX 4090(24GB) * 1

CPU:22 vCPU AMD EPYC 7T83 64-Core Processor

內(nèi)存:90GB

數(shù)據(jù)集:Synapse多器官分割

num_classes=9(0背景,1-8器官)

Unet:epoch=150,batch_size=16

Swin-Unet:epoch=150,batch_size=24(原論文)

很玄學,我這邊換成2冪的batch_size,模型表現(xiàn)反而下降。

Run the train script on synapse dataset. The batch size we used is 24. If you do not have enough GPU memory, the bacth size can be reduced to 12 or 6 to save memory.

測試集表現(xiàn):使用Swin-Unet論文相同的指標Dice-Similarity coefficient (DSC↑) 和 average Hausdorff Distance(HD↓)

復現(xiàn)結(jié)果來看,Swin-Unet的效果稍弱于論文,確實不太明白作者是怎么做到79.13的,如果顯卡和隨機種子不是關(guān)鍵因素的話,那最可能是因為論文實際上用了不一樣的學習率策略。(我嘗試了更多epoch到500,但并沒有明顯的提升,也沒有明顯的過擬合,所以學習率可能是影響因素)但是4090到期了,沒錢繼續(xù)煉了。深度學習,富人的游戲。

In our code, we carefully set the random seed, so the results should be consistent when trained multiple times on the same type of GPU. If the training does not give the same segmentation results as in the paper, it is recommended to adjust the learning rate.

Unet的效果稍強于論文。

不過相對而言,Swin-Unet的優(yōu)勢是顯而易見的。與論文相似,HD的下降了14%,在后面幾個器官分類上,都有不小改進。

參考文獻

  1. Cao, Hu et al. “Swin-Unet: Unet-like Pure Transformer for Medical Image Segmentation.” ECCV Workshops (2021).

  2. Ronneberger, Olaf et al. “U-Net: Convolutional Networks for Biomedical Image Segmentation.” ArXiv abs/1505.04597 (2015): n. pag.

  3. Liu, Ze et al. “Swin Transformer: Hierarchical Vision Transformer using Shifted Windows.” 2021 IEEE/CVF International Conference on Computer Vision (ICCV) (2021): 9992-10002.


Swin-Unet與Unet的菜雞分析與復現(xiàn)的評論 (共 條)

分享到微博請遵守國家法律
开原市| 宜宾市| 云南省| 长寿区| 青州市| 苏尼特左旗| 普宁市| 洪洞县| 开江县| 太仆寺旗| 赤壁市| 凉城县| 道孚县| 泸溪县| 白河县| 靖边县| 龙里县| 岳普湖县| 阜南县| 泰和县| 诸城市| 凉城县| 元氏县| 兴和县| 扎囊县| 资源县| 遂溪县| 民权县| 永丰县| 灵石县| 南投市| 泸定县| 高密市| 高雄县| 布拖县| 永康市| 佛学| 江津市| 东乌| 梁平县| 遵义市|