Swin-Unet與Unet的菜雞分析與復現(xiàn)



Unet
網(wǎng)絡(luò)結(jié)構(gòu)

左側(cè)相當于編碼器,右側(cè)相當于解碼器。左右各四個Stage。編碼器進行四輪卷積(RELU)-池化操作,解碼器進行四輪卷積-上采樣操作。其中左側(cè)每進行一次池化得到的特征圖,與右側(cè)對應的特征圖進行拼接。
深層網(wǎng)絡(luò)通過拼接的方式,有助于找回前面丟失的邊緣特征。
代碼復現(xiàn)
為了與Swin-Unet在相同數(shù)據(jù)集和條件下訓練,適配輸入224*224單通道圖片。修改部分參數(shù),Pytorch代碼如下:
class Unet(nn.Module):
? ?def __init__(self, num_classes=9):
? ? ? ?super(Unet, self).__init__()
? ? ? ?self.encoder = Encoder(in_channels=1)
? ? ? ?self.decoder = Decoder(num_classes)
? ?def forward(self, inputs):
? ? ? ?[feat1, feat2, feat3, feat4, feat5] = self.encoder(inputs)
? ? ? ?output = self.decoder(feat1, feat2, feat3, feat4, feat5)
? ? ? ?return output
? ? ? ?
# A VGG-like network, pure CNN
class Encoder(nn.Module):
? ?def __init__(self, in_channels):
? ? ? ?super(Encoder, self).__init__()
? ? ? ?self.features = nn.Sequential(
? ? ? ? ? ?nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
? ? ? ? ? ?nn.BatchNorm2d(64),
? ? ? ? ? ?nn.ReLU(inplace=True),
? ? ? ? ? ?nn.Conv2d(64, 64, kernel_size=3, padding=1),
? ? ? ? ? ?nn.BatchNorm2d(64),
? ? ? ? ? ?nn.ReLU(inplace=True),
? ? ? ? ? ?nn.MaxPool2d(kernel_size=2, stride=2),
? ? ? ? ? ?nn.Conv2d(64, 128, kernel_size=3, padding=1),
? ? ? ? ? ?nn.BatchNorm2d(128),
? ? ? ? ? ?nn.ReLU(inplace=True),
? ? ? ? ? ?nn.Conv2d(128, 128, kernel_size=3, padding=1),
? ? ? ? ? ?nn.BatchNorm2d(128),
? ? ? ? ? ?nn.ReLU(inplace=True),
? ? ? ? ? ?nn.MaxPool2d(kernel_size=2, stride=2),
? ? ? ? ? ?nn.Conv2d(128, 256, kernel_size=3, padding=1),
? ? ? ? ? ?nn.BatchNorm2d(256),
? ? ? ? ? ?nn.ReLU(inplace=True),
? ? ? ? ? ?nn.Conv2d(256, 256, kernel_size=3, padding=1),
? ? ? ? ? ?nn.BatchNorm2d(256),
? ? ? ? ? ?nn.ReLU(inplace=True),
? ? ? ? ? ?nn.Conv2d(256, 256, kernel_size=3, padding=1),
? ? ? ? ? ?nn.BatchNorm2d(256),
? ? ? ? ? ?nn.ReLU(inplace=True),
? ? ? ? ? ?nn.MaxPool2d(kernel_size=2, stride=2),
? ? ? ? ? ?nn.Conv2d(256, 512, kernel_size=3, padding=1),
? ? ? ? ? ?nn.BatchNorm2d(512),
? ? ? ? ? ?nn.ReLU(inplace=True),
? ? ? ? ? ?nn.Conv2d(512, 512, kernel_size=3, padding=1),
? ? ? ? ? ?nn.BatchNorm2d(512),
? ? ? ? ? ?nn.ReLU(inplace=True),
? ? ? ? ? ?nn.Conv2d(512, 512, kernel_size=3, padding=1),
? ? ? ? ? ?nn.BatchNorm2d(512),
? ? ? ? ? ?nn.ReLU(inplace=True),
? ? ? ? ? ?nn.MaxPool2d(kernel_size=2, stride=2),
? ? ? ? ? ?nn.Conv2d(512, 512, kernel_size=3, padding=1),
? ? ? ? ? ?nn.BatchNorm2d(512),
? ? ? ? ? ?nn.ReLU(inplace=True),
? ? ? ? ? ?nn.Conv2d(512, 512, kernel_size=3, padding=1),
? ? ? ? ? ?nn.BatchNorm2d(512),
? ? ? ? ? ?nn.ReLU(inplace=True),
? ? ? ? ? ?nn.Conv2d(512, 512, kernel_size=3, padding=1),
? ? ? ? ? ?nn.BatchNorm2d(512),
? ? ? ? ? ?nn.ReLU(inplace=True),
? ? ? ? ? ?nn.MaxPool2d(kernel_size=2, stride=2)
? ? ? ?)
? ? ? ?self._initialize_weights()
? ?def forward(self, x):
? ? ? ?# feature 1-4 is for copy-crop later, feature5 is the output
? ? ? ?feat1 = self.features[:6](x)
? ? ? ?feat2 = self.features[6:13](feat1)
? ? ? ?feat3 = self.features[13:23](feat2)
? ? ? ?feat4 = self.features[23:33](feat3)
? ? ? ?feat5 = self.features[33:-1](feat4)
? ? ? ?return [feat1, feat2, feat3, feat4, feat5]
? ?
class DecoderLayer(nn.Module):
? ?def __init__(self, in_size, out_size):
? ? ? ?super(DecoderLayer, self).__init__()
? ? ? ?self.conv1 = nn.Conv2d(in_size, out_size, kernel_size=3, padding=1)
? ? ? ?self.conv2 = nn.Conv2d(out_size, out_size, kernel_size=3, padding=1)
? ? ? ?self.up = nn.UpsamplingBilinear2d(scale_factor=2)
? ? ? ?self.relu = nn.ReLU(inplace=True)
? ?def forward(self, inputs1, inputs2):
? ? ? ?outputs = torch.cat([inputs1, self.up(inputs2)], 1)
? ? ? ?outputs = self.conv1(outputs)
? ? ? ?outputs = self.relu(outputs)
? ? ? ?outputs = self.conv2(outputs)
? ? ? ?outputs = self.relu(outputs)
? ? ? ?return outputs
class Decoder(nn.Module):
? ?def __init__(self, num_classes):
? ? ? ?super(Decoder, self).__init__()
? ? ? ?# upSampling
? ? ? ?# 64,64,512
? ? ? ?self.up_concat4 = DecoderLayer(1024, 512)
? ? ? ?# 128,128,256
? ? ? ?self.up_concat3 = DecoderLayer(768, 256)
? ? ? ?# 256,256,128
? ? ? ?self.up_concat2 = DecoderLayer(384, 128)
? ? ? ?# 512,512,64
? ? ? ?self.up_concat1 = DecoderLayer(192, 64)
? ? ? ?self.final = nn.Conv2d(64, num_classes, 1)
? ?def forward(self, feat1, feat2, feat3, feat4, feat5):
? ? ? ?up4 = self.up_concat4(feat4, feat5)
? ? ? ?up3 = self.up_concat3(feat3, up4)
? ? ? ?up2 = self.up_concat2(feat2, up3)
? ? ? ?up1 = self.up_concat1(feat1, up2)
? ? ? ?final = self.final(up1)
? ? ? ?return final
Swin-Unet
網(wǎng)絡(luò)結(jié)構(gòu)

編碼器
先對input圖像進行Patch Partition
,對應于代碼里PatchEmbed
類的實現(xiàn),將圖片切成patch_size*patch_size的圖塊,嵌入到Embedding
接三個Stage,每次先經(jīng)過Patch Merging
下采樣,在行方向和列方向上間隔2選取元素(對應Unet里的卷積用來降低分辨率),然后經(jīng)過兩個Swin Transformer
塊。但是代碼里的實現(xiàn)并不是按照圖示這樣分割的,它是將Patch Partition
, Linear Embedding
放在了前面開頭,然后以“兩個Swin Transformer
塊+Patch Merging
”作為一個BasicLayer,將Bottleneck
看做沒有后接Patch Merging
的BasicLayer。要是按著圖示去復現(xiàn),那Encoder跟官方代碼肯定是長得非常不一樣了。
最后再進行一次Patch Merging
。連接編解碼器的是Bottleneck
,其實就是兩個Swin Transformer
塊,此層圖片大小不變。
解碼器
上采樣操作由Patch Expanding
完成,每個Stage由上采樣和兩個Swin Tranformer
塊組成。但是代碼里的實現(xiàn)并不是按照圖示這樣分割的,從下往上走,它將第一個Patch Expanding單獨分出來,然后以“兩個Swin Transformer
塊+Patch Expanding
”作為一個BasicLayer_up,最后的Patch Expanding
稱為FinalPatchExpand_X4
,是因為只有它是以4倍上采樣。
左右每個Stage之間有跳躍連接,圖示位置不明確;具體講,是把編碼器每個Stage進入BasicLayer前的輸入,連接到解碼器進入每個Stage前的輸入。不過關(guān)于在哪里進行跳躍連接,作者的意思是“可以調(diào)整,不影響整體架構(gòu)”。
Swin Transformer

Swin-Unet的編碼器部分就是Swin Transformer的結(jié)構(gòu),每兩個Swin Tranformer塊的結(jié)構(gòu)如右b圖。
LN:LayerNorm
MLP:帶有GELU非線性的2層MLP
W-MSA:Window Attention
,Transformer是基于全局來計算注意力的。而Swin Transformer將注意力的計算限制在每個窗口內(nèi),進而減少了計算量
SW-MSA:Shifted Window Attention
,對特征圖移位,并給Attention設(shè)置mask來實現(xiàn)和Window Attention
相同的計算結(jié)果
我認為Swin-Unet是把Unet結(jié)構(gòu)和Swin Transformer放在一起時自然而然產(chǎn)生的想法。因為Unet的主干網(wǎng)絡(luò)本身就不一定是原本論文里的樣子,可以是Resnet、VGG,可以是TransUnet這種CNN和Transformer結(jié)合的形態(tài)。關(guān)鍵還是在于有跳躍連接,而Swin Transformer又更加高效輕量,長程注意力也有很大優(yōu)勢。
代碼復現(xiàn)
,結(jié)合自己的理解,寫一下核心部分的代碼:
import torch
from torch import nn
def no_weight_decay():
? ?return {'absolute_pos_embed'}
def no_weight_decay_keywords():
? ?return {'relative_position_bias_table'}
def _init_weights(m):
? ?if isinstance(m, nn.Linear):
? ? ? ?trunc_normal_(m.weight, std=.02)
? ? ? ?if isinstance(m, nn.Linear) and m.bias is not None:
? ? ? ? ? ?nn.init.constant_(m.bias, 0)
? ?elif isinstance(m, nn.LayerNorm):
? ? ? ?nn.init.constant_(m.bias, 0)
? ? ? ?nn.init.constant_(m.weight, 1.0)
class Swin_Unet(nn.Module):
? ?def __init__(self, img_size, patch_size, in_channels, num_classes,
? ? ? ? ? ? ? ? embed_dim, depths, num_heads,
? ? ? ? ? ? ? ? window_size, mlp_ratio, qkv_bias, qk_scale,
? ? ? ? ? ? ? ? drop_rate, attn_drop_rate, drop_path_rate):
? ? ? ?super().__init__()
? ? ? ?self.num_classes = num_classes
? ? ? ?self.num_layers = len(depths)
? ? ? ?self.embed_dim = embed_dim
? ? ? ?self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
? ? ? ?self.num_features_up = int(embed_dim * 2)
? ? ? ?self.mlp_ratio = mlp_ratio
? ? ? ?# patch partition 和 linear embedding
? ? ? ?# split image into non-overlapping patches
? ? ? ?self.patch_embed = PatchEmbed(
? ? ? ? ? ?img_size=img_size, patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim,
? ? ? ? ? ?norm_layer=nn.LayerNorm)
? ? ? ?patches_resolution = self.patch_embed.patches_resolution
? ? ? ?self.patches_resolution = patches_resolution
? ? ? ?# absolute position embedding,
? ? ? ?self.pos_drop = nn.Dropout(p=drop_rate)
? ? ? ?# stochastic depth
? ? ? ?dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
? ? ? ?# build encoderStages and bottleneck layers,每個BasicLayer包含兩個Swin Transformer Block和一個下采樣
? ? ? ?self.layers = nn.ModuleList()
? ? ? ?for i_layer in range(self.num_layers):
? ? ? ? ? ?layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? input_resolution=(patches_resolution[0] // (2 ** i_layer),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? patches_resolution[1] // (2 ** i_layer)),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? depth=depths[i_layer],
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? num_heads=num_heads[i_layer],
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? window_size=window_size,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? mlp_ratio=self.mlp_ratio,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? qkv_bias=qkv_bias, qk_scale=qk_scale,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? drop=drop_rate, attn_drop=attn_drop_rate,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? norm_layer=nn.LayerNorm,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? downsample=PatchMerging if (i_layer < self.num_layers - 1) else None) # bottleneck沒有下采樣
? ? ? ? ? ?self.layers.append(layer)
? ? ? ?# build decoder layers,解碼器每個Stage
? ? ? ?self.layers_up = nn.ModuleList()
? ? ? ?self.concat_back_dim = nn.ModuleList()
? ? ? ?for i_layer in range(self.num_layers):
? ? ? ? ? ?concat_linear = nn.Linear(2 * int(embed_dim * 2 ** (self.num_layers - 1 - i_layer)),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?int(embed_dim * 2 ** (
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?self.num_layers - 1 - i_layer))) if i_layer > 0 else nn.Identity()
? ? ? ? ? ?if i_layer == 0:
? ? ? ? ? ? ? ?layer_up = PatchExpand(
? ? ? ? ? ? ? ? ? ?input_resolution=(patches_resolution[0] // (2 ** (self.num_layers - 1 - i_layer)),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?patches_resolution[1] // (2 ** (self.num_layers - 1 - i_layer))),
? ? ? ? ? ? ? ? ? ?dim=int(embed_dim * 2 ** (self.num_layers - 1 - i_layer)), dim_scale=2, norm_layer=norm_layer)
? ? ? ? ? ?else:
? ? ? ? ? ? ? ?layer_up = BasicLayer_up(dim=int(embed_dim * 2 ** (self.num_layers - 1 - i_layer)),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? input_resolution=(
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? patches_resolution[0] // (2 ** (self.num_layers - 1 - i_layer)),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? patches_resolution[1] // (2 ** (self.num_layers - 1 - i_layer))),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? depth=depths[(self.num_layers - 1 - i_layer)],
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? num_heads=num_heads[(self.num_layers - 1 - i_layer)],
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? window_size=window_size,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? mlp_ratio=self.mlp_ratio,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? qkv_bias=qkv_bias, qk_scale=qk_scale,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? drop=drop_rate, attn_drop=attn_drop_rate,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? drop_path=dpr[sum(depths[:(self.num_layers - 1 - i_layer)]):sum(
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? depths[:(self.num_layers - 1 - i_layer) + 1])],
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? norm_layer=nn.LayerNorm,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? upsample=PatchExpand if (i_layer < self.num_layers - 1) else None)
? ? ? ? ? ?self.layers_up.append(layer_up)
? ? ? ? ? ?self.concat_back_dim.append(concat_linear)
? ? ? ?self.norm = nn.LayerNorm(self.num_features)
? ? ? ?self.norm_up = nn.LayerNorm(self.embed_dim)
? ? ? ?self.up = FinalPatchExpand_X4(input_resolution=(img_size // patch_size, img_size // patch_size),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?dim_scale=4, dim=embed_dim)
? ? ? ?self.output = nn.Conv2d(in_channels=embed_dim, out_channels=self.num_classes, kernel_size=1, bias=False)
? ? ? ?self.apply(_init_weights)
? ?# Encoder and Bottleneck
? ?def forward_features(self, x):
? ? ? ?x = self.patch_embed(x)
? ? ? ?if self.ape:
? ? ? ? ? ?x = x + self.absolute_pos_embed
? ? ? ?x = self.pos_drop(x)
? ? ? ?x_down_sample = []
? ? ? ?for layer in self.layers:
? ? ? ? ? ?x_down_sample.append(x)
? ? ? ? ? ?x = layer(x)
? ? ? ?x = self.norm(x) ?# B L C
? ? ? ?return x, x_down_sample
? ?# Decoder and Skip connection
? ?def forward_up_features(self, x, x_down_sample):
? ? ? ?for inx, layer_up in enumerate(self.layers_up):
? ? ? ? ? ?if inx == 0:
? ? ? ? ? ? ? ?x = layer_up(x)
? ? ? ? ? ?else:
? ? ? ? ? ? ? ?x = torch.cat([x, x_down_sample[3 - inx]], -1)
? ? ? ? ? ? ? ?x = self.concat_back_dim[inx](x)
? ? ? ? ? ? ? ?x = layer_up(x)
? ? ? ?x = self.norm_up(x) ?# B L C
? ? ? ?return x
? ?def up_x4(self, x):
? ? ? ?H, W = self.patches_resolution
? ? ? ?B, L, C = x.shape
? ? ? ?assert L == H * W, "input features has wrong size"
? ? ? ?x = self.up(x)
? ? ? ?x = x.view(B, 4 * H, 4 * W, -1)
? ? ? ?x = x.permute(0, 3, 1, 2) ?# B,C,H,W
? ? ? ?x = self.output(x)
? ? ? ?return x
? ?def forward(self, x):
? ? ? ?x, x_down_sample = self.forward_features(x)
? ? ? ?x = self.forward_up_features(x, x_down_sample)
? ? ? ?x = self.up_x4(x)
? ? ? ?return x
? ?def flops(self):
? ? ? ?flops = 0
? ? ? ?flops += self.patch_embed.flops()
? ? ? ?for i, layer in enumerate(self.layers):
? ? ? ? ? ?flops += layer.flops()
? ? ? ?flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
? ? ? ?flops += self.num_features * self.num_classes
? ? ? ?return flops
訓練與推理
PyTorch 1.11.0
Python 3.8(ubuntu20.04)
Cuda 11.3
GPU:RTX 4090(24GB) * 1
CPU:22 vCPU AMD EPYC 7T83 64-Core Processor
內(nèi)存:90GB
數(shù)據(jù)集:Synapse多器官分割
num_classes=9(0背景,1-8器官)
Unet:epoch=150,batch_size=16
Swin-Unet:epoch=150,batch_size=24(原論文)
很玄學,我這邊換成2冪的batch_size,模型表現(xiàn)反而下降。
Run the train script on synapse dataset. The batch size we used is 24. If you do not have enough GPU memory, the bacth size can be reduced to 12 or 6 to save memory.
測試集表現(xiàn):使用Swin-Unet論文相同的指標Dice-Similarity coefficient (DSC↑) 和 average Hausdorff Distance(HD↓)

復現(xiàn)結(jié)果來看,Swin-Unet的效果稍弱于論文,確實不太明白作者是怎么做到79.13的,如果顯卡和隨機種子不是關(guān)鍵因素的話,那最可能是因為論文實際上用了不一樣的學習率策略。(我嘗試了更多epoch到500,但并沒有明顯的提升,也沒有明顯的過擬合,所以學習率可能是影響因素)但是4090到期了,沒錢繼續(xù)煉了。深度學習,富人的游戲。
In our code, we carefully set the random seed, so the results should be consistent when trained multiple times on the same type of GPU. If the training does not give the same segmentation results as in the paper, it is recommended to adjust the learning rate.
而Unet的效果稍強于論文。
不過相對而言,Swin-Unet的優(yōu)勢是顯而易見的。與論文相似,HD的下降了14%,在后面幾個器官分類上,都有不小改進。