"""
Implementation of `U-net: Convolutional networks for biomedical image segmentation <https://arxiv.org/pdf/1505.04597>`_
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
__all__ = ['UNet128', 'UNet256', 'UNet512', 'UNet1024']
BN_EPS = 1e-4
## ==== EXPERIMENTAL - ought to try ===== ##
# Source: https://github.com/qbit-/unet/blob/master/model/unet_work.py
class Conv3BN(nn.Module):
"""A module which applies the following actions:
- convolution with 3x3 kernel;
- batch normalization (if enabled);
- ELU.
Attributes:
in_ch: Number of input channels.
out_ch: Number of output channels.
bn: A boolean indicating if Batch Normalization is enabled or not.
"""
def __init__(self, in_ch: int, out_ch: int, bn=True):
super(Conv3BN, self).__init__()
self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1)
self.bn = nn.BatchNorm2d(out_ch) if bn else None
self.activation = nn.ELU(inplace=True)
def forward(self, x):
x = self.conv(x)
if self.bn is not None:
x = self.bn(x)
x = self.activation(x)
return x
## == END == ##
class ConvBnRelu2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, dilation=1, stride=1, groups=1, is_bn=True, is_relu=True):
super(ConvBnRelu2d, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, stride=stride, dilation=dilation, groups=groups, bias=False)
self.bn = nn.BatchNorm2d(out_channels, eps=BN_EPS)
self.relu = nn.ReLU(inplace=True)
if is_bn is False: self.bn =None
if is_relu is False: self.relu=None
def forward(self,x):
x = self.conv(x)
if self.bn is not None: x = self.bn(x)
if self.relu is not None: x = self.relu(x)
return x
def merge_bn(self):
if self.bn is None: return
if (self.conv.bias is not None):
raise AssertionError
conv_weight = self.conv.weight.data
bn_weight = self.bn.weight.data
bn_bias = self.bn.bias.data
bn_running_mean = self.bn.running_mean
bn_running_var = self.bn.running_var
bn_eps = self.bn.eps
#https://github.com/sanghoon/pva-faster-rcnn/issues/5
#https://github.com/sanghoon/pva-faster-rcnn/commit/39570aab8c6513f0e76e5ab5dba8dfbf63e9c68c
N,C,KH,KW = conv_weight.size()
std = 1/(torch.sqrt(bn_running_var+bn_eps))
std_bn_weight =(std*bn_weight).repeat(C*KH*KW,1).t().contiguous().view(N,C,KH,KW )
conv_weight_hat = std_bn_weight*conv_weight
conv_bias_hat = (bn_bias - bn_weight*std*bn_running_mean)
self.bn = None
self.conv = nn.Conv2d(in_channels=self.conv.in_channels, out_channels=self.conv.out_channels, kernel_size=self.conv.kernel_size,
padding=self.conv.padding, stride=self.conv.stride, dilation=self.conv.dilation, groups=self.conv.groups,
bias=True)
self.conv.weight.data = conv_weight_hat #fill in
self.conv.bias.data = conv_bias_hat
class ConvResidual (nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super(ConvResidual, self).__init__()
self.block = nn.Sequential(
ConvBnRelu2d(in_channels, out_channels, kernel_size=3, padding=1, stride=1 ),
ConvBnRelu2d(out_channels, out_channels, kernel_size=3, padding=1, stride=1, is_relu=False),
)
self.shortcut = None
if in_channels!=out_channels or stride!=1:
self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0, stride=stride, bias=True)
def forward(self, x):
r = x if self.shortcut is None else self.shortcut(x)
x = self.block(x)
x = F.relu(x+r, inplace=True)
return x
## -----------------------------------------------------------------------------------------------------------
## origainl 3x3 stack filters used in UNet
class StackEncoder (nn.Module):
def __init__(self, x_channels, y_channels, kernel_size=3):
super(StackEncoder, self).__init__()
padding=(kernel_size-1)//2
self.encode = nn.Sequential(
ConvBnRelu2d(x_channels, y_channels, kernel_size=kernel_size, padding=padding, dilation=1, stride=1, groups=1),
ConvBnRelu2d(y_channels, y_channels, kernel_size=kernel_size, padding=padding, dilation=1, stride=1, groups=1),
)
def forward(self,x):
y = self.encode(x)
y_small = F.max_pool2d(y, kernel_size=2, stride=2)
return y, y_small
class StackDecoder (nn.Module):
def __init__(self, x_big_channels, x_channels, y_channels, kernel_size=3):
super(StackDecoder, self).__init__()
padding=(kernel_size-1)//2
self.decode = nn.Sequential(
ConvBnRelu2d(x_big_channels+x_channels, y_channels, kernel_size=kernel_size, padding=padding, dilation=1, stride=1, groups=1),
ConvBnRelu2d(y_channels, y_channels, kernel_size=kernel_size, padding=padding, dilation=1, stride=1, groups=1),
ConvBnRelu2d(y_channels, y_channels, kernel_size=kernel_size, padding=padding, dilation=1, stride=1, groups=1),
)
def forward(self, x_big, x):
N,C,H,W = x_big.size()
y = F.interpolate(x, size=(H,W),mode='bilinear')
y = torch.cat([y,x_big],1)
y = self.decode(y)
return y
##---------------------------------------------------------------
## origainl 3x3 stack filters used in UNet
class ResStackEncoder (nn.Module):
def __init__(self, x_channels, y_channels):
super(ResStackEncoder, self).__init__()
self.encode = ConvResidual(x_channels, y_channels)
def forward(self,x):
y = self.encode(x)
y_small = F.max_pool2d(y, kernel_size=2, stride=2)
return y, y_small
class ResStackDecoder (nn.Module):
def __init__(self, x_big_channels, x_channels, y_channels, kernel_size=3):
super(ResStackDecoder, self).__init__()
padding=(kernel_size-1)//2
self.decode = nn.Sequential(
ConvBnRelu2d(x_big_channels+x_channels, y_channels, kernel_size=kernel_size, padding=padding, dilation=1, stride=1, groups=1),
ConvResidual(y_channels, y_channels)
)
def forward(self, x_big, x):
N,C,H,W = x_big.size()
y = F.interpolate(x, size=(H,W), mode='bilinear')
#y = F.interpolate(x, scale_factor=2,mode='bilinear')
y = torch.cat([y,x_big],1)
y = self.decode(y)
return y
##---------------------------------------------------------------
# baseline 128x128, 256x256, 512x512, 1024x1024 for experiments -----------------------------------------------
# 1024x1024
[docs]class UNet1024 (nn.Module):
def __init__(self, in_shape=(3, 1024, 1024), **kwargs):
super(UNet1024, self).__init__()
C,H,W = in_shape
#assert(C==3)
#1024
self.down1 = StackEncoder( C, 24, kernel_size=3) #512
self.down2 = StackEncoder( 24, 64, kernel_size=3) #256
self.down3 = StackEncoder( 64, 128, kernel_size=3) #128
self.down4 = StackEncoder(128, 256, kernel_size=3) # 64
self.down5 = StackEncoder(256, 512, kernel_size=3) # 32
self.down6 = StackEncoder(512, 768, kernel_size=3) # 16
self.center = nn.Sequential(
ConvBnRelu2d(768, 768, kernel_size=3, padding=1, stride=1 ),
)
# 8
# x_big_channels, x_channels, y_channels
self.up6 = StackDecoder(768, 768, 512, kernel_size=3) # 16
self.up5 = StackDecoder( 512, 512, 256, kernel_size=3) # 32
self.up4 = StackDecoder( 256, 256, 128, kernel_size=3) # 64
self.up3 = StackDecoder( 128, 128, 64, kernel_size=3) #128
self.up2 = StackDecoder( 64, 64, 24, kernel_size=3) #256
self.up1 = StackDecoder( 24, 24, 24, kernel_size=3) #512
self.classify = nn.Conv2d(24, 1, kernel_size=1, padding=0, stride=1, bias=True)
def forward(self, x):
out = x #;print('x ',x.size())
#
down1,out = self.down1(out) ##;print('down1',down1.size()) #256
down2,out = self.down2(out) #;print('down2',down2.size()) #128
down3,out = self.down3(out) #;print('down3',down3.size()) #64
down4,out = self.down4(out) #;print('down4',down4.size()) #32
down5,out = self.down5(out) #;print('down5',down5.size()) #16
down6,out = self.down6(out) #;print('down6',down6.size()) #8
pass #;print('out ',out.size())
out = self.center(out)
out = self.up6(down6, out)
out = self.up5(down5, out)
out = self.up4(down4, out)
out = self.up3(down3, out)
out = self.up2(down2, out)
out = self.up1(down1, out)
#1024
out = self.classify(out)
# out = torch.squeeze(out, dim=1)
return out
# 512x512
[docs]class UNet512 (nn.Module):
def __init__(self, in_shape=(3, 512, 512), **kwargs):
super(UNet512, self).__init__()
C,H,W = in_shape
#assert(C==3)
#1024
self.down2 = StackEncoder( C, 64, kernel_size=3) #256
self.down3 = StackEncoder( 64, 128, kernel_size=3) #128
self.down4 = StackEncoder(128, 256, kernel_size=3) #64
self.down5 = StackEncoder(256, 512, kernel_size=3) #32
self.down6 = StackEncoder(512, 1024, kernel_size=3) #16
self.center = nn.Sequential(
ConvBnRelu2d(1024, 1024, kernel_size=3, padding=1, stride=1 ),
#ConvBnRelu2d(2048, 1024, kernel_size=3, padding=1, stride=1 ),
)
# 16
# x_big_channels, x_channels, y_channels
self.up6 = StackDecoder(1024,1024, 512, kernel_size=3) # 16
self.up5 = StackDecoder( 512, 512, 256, kernel_size=3) # 32
self.up4 = StackDecoder( 256, 256, 128, kernel_size=3) # 64
self.up3 = StackDecoder( 128, 128, 64, kernel_size=3) #128
self.up2 = StackDecoder( 64, 64, 32, kernel_size=3) #256
self.classify = nn.Conv2d(32, 1, kernel_size=1, padding=0, stride=1, bias=True)
def forward(self, x):
out = x #;print('x ',x.size())
down2,out = self.down2(out) #;print('down2',down2.size())
down3,out = self.down3(out) #;print('down3',down3.size())
down4,out = self.down4(out) #;print('down4',down4.size())
down5,out = self.down5(out) #;print('down5',down5.size())
down6,out = self.down6(out) #;print('down6',down6.size())
pass #;print('out ',out.size())
out = self.center(out)
out = self.up6(down6, out)
out = self.up5(down5, out)
out = self.up4(down4, out)
out = self.up3(down3, out)
out = self.up2(down2, out)
out = self.classify(out)
# out = torch.squeeze(out, dim=1)
return out
# 256x256
[docs]class UNet256 (nn.Module):
def __init__(self, in_shape=(3, 256, 256), **kwargs):
super(UNet256, self).__init__()
C,H,W = in_shape
#assert(C==3)
#256
self.down2 = StackEncoder( C, 64, kernel_size=3) #128
self.down3 = StackEncoder( 64, 128, kernel_size=3) # 64
self.down4 = StackEncoder(128, 256, kernel_size=3) # 32
self.down5 = StackEncoder(256, 512, kernel_size=3) # 16
self.down6 = StackEncoder(512, 1024, kernel_size=3) # 8
self.center = nn.Sequential(
#ConvBnRelu2d( 512, 1024, kernel_size=3, padding=1, stride=1 ),
ConvBnRelu2d(1024, 1024, kernel_size=3, padding=1, stride=1 ),
)
# 8
# x_big_channels, x_channels, y_channels
self.up6 = StackDecoder(1024,1024, 512, kernel_size=3) # 16
self.up5 = StackDecoder( 512, 512, 256, kernel_size=3) # 32
self.up4 = StackDecoder( 256, 256, 128, kernel_size=3) # 64
self.up3 = StackDecoder( 128, 128, 64, kernel_size=3) #128
self.up2 = StackDecoder( 64, 64, 32, kernel_size=3) #256
self.classify = nn.Conv2d(32, 1, kernel_size=1, padding=0, stride=1, bias=True)
def forward(self, x):
out = x #;print('x ',x.size())
#
down2,out = self.down2(out) #;print('down2',down2.size()) #128
down3,out = self.down3(out) #;print('down3',down3.size()) #64
down4,out = self.down4(out) #;print('down4',down4.size()) #32
down5,out = self.down5(out) #;print('down5',down5.size()) #16
down6,out = self.down6(out) #;print('down6',down6.size()) #8
pass #;print('out ',out.size())
out = self.center(out)
out = self.up6(down6, out)
out = self.up5(down5, out)
out = self.up4(down4, out)
out = self.up3(down3, out)
out = self.up2(down2, out)
out = self.classify(out)
# out = torch.squeeze(out, dim=1)
return out
# 128x128
[docs]class UNet128 (nn.Module):
def __init__(self, in_shape=(3, 128, 128), **kwargs):
super(UNet128, self).__init__()
C,H,W = in_shape
#assert(C==3)
#128
self.down3 = StackEncoder( C, 128, kernel_size=3) # 64
self.down4 = StackEncoder(128, 256, kernel_size=3) # 32
self.down5 = StackEncoder(256, 512, kernel_size=3) # 16
self.down6 = StackEncoder(512, 1024, kernel_size=3) # 8
self.center = nn.Sequential(
ConvBnRelu2d(1024, 1024, kernel_size=3, padding=1, stride=1 ),
)
# 8
# x_big_channels, x_channels, y_channels
self.up6 = StackDecoder(1024,1024, 512, kernel_size=3) # 16
self.up5 = StackDecoder( 512, 512, 256, kernel_size=3) # 32
self.up4 = StackDecoder( 256, 256, 128, kernel_size=3) # 64
self.up3 = StackDecoder( 128, 128, 64, kernel_size=3) #128
self.classify = nn.Conv2d(64, 1, kernel_size=1, padding=0, stride=1, bias=True)
def forward(self, x):
out = x #;print('x ',x.size())
down3,out = self.down3(out) #;print('down3',down3.size()) #64
down4,out = self.down4(out) #;print('down4',down4.size()) #32
down5,out = self.down5(out) #;print('down5',down5.size()) #16
down6,out = self.down6(out) #;print('down6',down6.size()) #8
pass #;print('out ',out.size())
out = self.center(out)
out = self.up6(down6, out)
out = self.up5(down5, out)
out = self.up4(down4, out)
out = self.up3(down3, out)
out = self.classify(out)
# out = torch.squeeze(out, dim=1)
return out