Source code for pywick.models.segmentation.carvana_unet

"""
Implementation of `U-net: Convolutional networks for biomedical image segmentation <https://arxiv.org/pdf/1505.04597>`_
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

__all__ = ['UNet128', 'UNet256', 'UNet512', 'UNet1024']

BN_EPS = 1e-4

## ==== EXPERIMENTAL - ought to try ===== ##
# Source: https://github.com/qbit-/unet/blob/master/model/unet_work.py
class Conv3BN(nn.Module):
    """A module which applies the following actions:
        - convolution with 3x3 kernel;
        - batch normalization (if enabled);
        - ELU.
    Attributes:
        in_ch: Number of input channels.
        out_ch: Number of output channels.
        bn: A boolean indicating if Batch Normalization is enabled or not.
    """

    def __init__(self, in_ch: int, out_ch: int, bn=True):
        super(Conv3BN, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.bn = nn.BatchNorm2d(out_ch) if bn else None
        self.activation = nn.ELU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        x = self.activation(x)
        return x
## == END == ##

class ConvBnRelu2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, dilation=1, stride=1, groups=1, is_bn=True, is_relu=True):
        super(ConvBnRelu2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, stride=stride, dilation=dilation, groups=groups, bias=False)
        self.bn   = nn.BatchNorm2d(out_channels, eps=BN_EPS)
        self.relu = nn.ReLU(inplace=True)
        if is_bn   is False: self.bn  =None
        if is_relu is False: self.relu=None


    def forward(self,x):
        x = self.conv(x)
        if self.bn   is not None: x = self.bn(x)
        if self.relu is not None: x = self.relu(x)
        return x


    def merge_bn(self):
        if self.bn is None: return

        if (self.conv.bias is not None):
            raise AssertionError
        conv_weight     = self.conv.weight.data
        bn_weight       = self.bn.weight.data
        bn_bias         = self.bn.bias.data
        bn_running_mean = self.bn.running_mean
        bn_running_var  = self.bn.running_var
        bn_eps          = self.bn.eps

        #https://github.com/sanghoon/pva-faster-rcnn/issues/5
        #https://github.com/sanghoon/pva-faster-rcnn/commit/39570aab8c6513f0e76e5ab5dba8dfbf63e9c68c

        N,C,KH,KW = conv_weight.size()
        std = 1/(torch.sqrt(bn_running_var+bn_eps))
        std_bn_weight =(std*bn_weight).repeat(C*KH*KW,1).t().contiguous().view(N,C,KH,KW )
        conv_weight_hat = std_bn_weight*conv_weight
        conv_bias_hat   = (bn_bias - bn_weight*std*bn_running_mean)

        self.bn   = None
        self.conv = nn.Conv2d(in_channels=self.conv.in_channels, out_channels=self.conv.out_channels, kernel_size=self.conv.kernel_size,
                              padding=self.conv.padding, stride=self.conv.stride, dilation=self.conv.dilation, groups=self.conv.groups,
                              bias=True)
        self.conv.weight.data = conv_weight_hat #fill in
        self.conv.bias.data   = conv_bias_hat



class ConvResidual (nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ConvResidual, self).__init__()

        self.block = nn.Sequential(
            ConvBnRelu2d(in_channels,  out_channels, kernel_size=3, padding=1,  stride=1 ),
            ConvBnRelu2d(out_channels, out_channels, kernel_size=3, padding=1,  stride=1, is_relu=False),
        )
        self.shortcut = None
        if in_channels!=out_channels or stride!=1:
           self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0, stride=stride,  bias=True)

    def forward(self, x):
        r = x if self.shortcut is None else self.shortcut(x)
        x = self.block(x)
        x = F.relu(x+r, inplace=True)
        return x



## -----------------------------------------------------------------------------------------------------------

## origainl 3x3 stack filters used in UNet
class StackEncoder (nn.Module):
    def __init__(self, x_channels, y_channels, kernel_size=3):
        super(StackEncoder, self).__init__()
        padding=(kernel_size-1)//2
        self.encode = nn.Sequential(
            ConvBnRelu2d(x_channels, y_channels, kernel_size=kernel_size, padding=padding, dilation=1, stride=1, groups=1),
            ConvBnRelu2d(y_channels, y_channels, kernel_size=kernel_size, padding=padding, dilation=1, stride=1, groups=1),
        )

    def forward(self,x):
        y = self.encode(x)
        y_small = F.max_pool2d(y, kernel_size=2, stride=2)
        return y, y_small


class StackDecoder (nn.Module):
    def __init__(self, x_big_channels, x_channels, y_channels, kernel_size=3):
        super(StackDecoder, self).__init__()
        padding=(kernel_size-1)//2

        self.decode = nn.Sequential(
            ConvBnRelu2d(x_big_channels+x_channels, y_channels, kernel_size=kernel_size, padding=padding, dilation=1, stride=1, groups=1),
            ConvBnRelu2d(y_channels, y_channels, kernel_size=kernel_size, padding=padding, dilation=1, stride=1, groups=1),
            ConvBnRelu2d(y_channels, y_channels, kernel_size=kernel_size, padding=padding, dilation=1, stride=1, groups=1),
        )

    def forward(self, x_big, x):
        N,C,H,W = x_big.size()
        y = F.interpolate(x, size=(H,W),mode='bilinear')
        y = torch.cat([y,x_big],1)
        y = self.decode(y)
        return  y
##---------------------------------------------------------------


## origainl 3x3 stack filters used in UNet
class ResStackEncoder (nn.Module):
    def __init__(self, x_channels, y_channels):
        super(ResStackEncoder, self).__init__()
        self.encode = ConvResidual(x_channels, y_channels)

    def forward(self,x):
        y = self.encode(x)
        y_small = F.max_pool2d(y, kernel_size=2, stride=2)
        return y, y_small


class ResStackDecoder (nn.Module):
    def __init__(self, x_big_channels, x_channels, y_channels, kernel_size=3):
        super(ResStackDecoder, self).__init__()
        padding=(kernel_size-1)//2

        self.decode = nn.Sequential(
            ConvBnRelu2d(x_big_channels+x_channels, y_channels, kernel_size=kernel_size, padding=padding, dilation=1, stride=1, groups=1),
            ConvResidual(y_channels, y_channels)
        )

    def forward(self, x_big, x):
        N,C,H,W = x_big.size()
        y = F.interpolate(x, size=(H,W), mode='bilinear')
        #y = F.interpolate(x, scale_factor=2,mode='bilinear')
        y = torch.cat([y,x_big],1)
        y = self.decode(y)
        return  y


##---------------------------------------------------------------


# baseline 128x128, 256x256, 512x512, 1024x1024 for experiments -----------------------------------------------

# 1024x1024
[docs]class UNet1024 (nn.Module): def __init__(self, in_shape=(3, 1024, 1024), **kwargs): super(UNet1024, self).__init__() C,H,W = in_shape #assert(C==3) #1024 self.down1 = StackEncoder( C, 24, kernel_size=3) #512 self.down2 = StackEncoder( 24, 64, kernel_size=3) #256 self.down3 = StackEncoder( 64, 128, kernel_size=3) #128 self.down4 = StackEncoder(128, 256, kernel_size=3) # 64 self.down5 = StackEncoder(256, 512, kernel_size=3) # 32 self.down6 = StackEncoder(512, 768, kernel_size=3) # 16 self.center = nn.Sequential( ConvBnRelu2d(768, 768, kernel_size=3, padding=1, stride=1 ), ) # 8 # x_big_channels, x_channels, y_channels self.up6 = StackDecoder(768, 768, 512, kernel_size=3) # 16 self.up5 = StackDecoder( 512, 512, 256, kernel_size=3) # 32 self.up4 = StackDecoder( 256, 256, 128, kernel_size=3) # 64 self.up3 = StackDecoder( 128, 128, 64, kernel_size=3) #128 self.up2 = StackDecoder( 64, 64, 24, kernel_size=3) #256 self.up1 = StackDecoder( 24, 24, 24, kernel_size=3) #512 self.classify = nn.Conv2d(24, 1, kernel_size=1, padding=0, stride=1, bias=True) def forward(self, x): out = x #;print('x ',x.size()) # down1,out = self.down1(out) ##;print('down1',down1.size()) #256 down2,out = self.down2(out) #;print('down2',down2.size()) #128 down3,out = self.down3(out) #;print('down3',down3.size()) #64 down4,out = self.down4(out) #;print('down4',down4.size()) #32 down5,out = self.down5(out) #;print('down5',down5.size()) #16 down6,out = self.down6(out) #;print('down6',down6.size()) #8 pass #;print('out ',out.size()) out = self.center(out) out = self.up6(down6, out) out = self.up5(down5, out) out = self.up4(down4, out) out = self.up3(down3, out) out = self.up2(down2, out) out = self.up1(down1, out) #1024 out = self.classify(out) # out = torch.squeeze(out, dim=1) return out
# 512x512
[docs]class UNet512 (nn.Module): def __init__(self, in_shape=(3, 512, 512), **kwargs): super(UNet512, self).__init__() C,H,W = in_shape #assert(C==3) #1024 self.down2 = StackEncoder( C, 64, kernel_size=3) #256 self.down3 = StackEncoder( 64, 128, kernel_size=3) #128 self.down4 = StackEncoder(128, 256, kernel_size=3) #64 self.down5 = StackEncoder(256, 512, kernel_size=3) #32 self.down6 = StackEncoder(512, 1024, kernel_size=3) #16 self.center = nn.Sequential( ConvBnRelu2d(1024, 1024, kernel_size=3, padding=1, stride=1 ), #ConvBnRelu2d(2048, 1024, kernel_size=3, padding=1, stride=1 ), ) # 16 # x_big_channels, x_channels, y_channels self.up6 = StackDecoder(1024,1024, 512, kernel_size=3) # 16 self.up5 = StackDecoder( 512, 512, 256, kernel_size=3) # 32 self.up4 = StackDecoder( 256, 256, 128, kernel_size=3) # 64 self.up3 = StackDecoder( 128, 128, 64, kernel_size=3) #128 self.up2 = StackDecoder( 64, 64, 32, kernel_size=3) #256 self.classify = nn.Conv2d(32, 1, kernel_size=1, padding=0, stride=1, bias=True) def forward(self, x): out = x #;print('x ',x.size()) down2,out = self.down2(out) #;print('down2',down2.size()) down3,out = self.down3(out) #;print('down3',down3.size()) down4,out = self.down4(out) #;print('down4',down4.size()) down5,out = self.down5(out) #;print('down5',down5.size()) down6,out = self.down6(out) #;print('down6',down6.size()) pass #;print('out ',out.size()) out = self.center(out) out = self.up6(down6, out) out = self.up5(down5, out) out = self.up4(down4, out) out = self.up3(down3, out) out = self.up2(down2, out) out = self.classify(out) # out = torch.squeeze(out, dim=1) return out
# 256x256
[docs]class UNet256 (nn.Module): def __init__(self, in_shape=(3, 256, 256), **kwargs): super(UNet256, self).__init__() C,H,W = in_shape #assert(C==3) #256 self.down2 = StackEncoder( C, 64, kernel_size=3) #128 self.down3 = StackEncoder( 64, 128, kernel_size=3) # 64 self.down4 = StackEncoder(128, 256, kernel_size=3) # 32 self.down5 = StackEncoder(256, 512, kernel_size=3) # 16 self.down6 = StackEncoder(512, 1024, kernel_size=3) # 8 self.center = nn.Sequential( #ConvBnRelu2d( 512, 1024, kernel_size=3, padding=1, stride=1 ), ConvBnRelu2d(1024, 1024, kernel_size=3, padding=1, stride=1 ), ) # 8 # x_big_channels, x_channels, y_channels self.up6 = StackDecoder(1024,1024, 512, kernel_size=3) # 16 self.up5 = StackDecoder( 512, 512, 256, kernel_size=3) # 32 self.up4 = StackDecoder( 256, 256, 128, kernel_size=3) # 64 self.up3 = StackDecoder( 128, 128, 64, kernel_size=3) #128 self.up2 = StackDecoder( 64, 64, 32, kernel_size=3) #256 self.classify = nn.Conv2d(32, 1, kernel_size=1, padding=0, stride=1, bias=True) def forward(self, x): out = x #;print('x ',x.size()) # down2,out = self.down2(out) #;print('down2',down2.size()) #128 down3,out = self.down3(out) #;print('down3',down3.size()) #64 down4,out = self.down4(out) #;print('down4',down4.size()) #32 down5,out = self.down5(out) #;print('down5',down5.size()) #16 down6,out = self.down6(out) #;print('down6',down6.size()) #8 pass #;print('out ',out.size()) out = self.center(out) out = self.up6(down6, out) out = self.up5(down5, out) out = self.up4(down4, out) out = self.up3(down3, out) out = self.up2(down2, out) out = self.classify(out) # out = torch.squeeze(out, dim=1) return out
# 128x128
[docs]class UNet128 (nn.Module): def __init__(self, in_shape=(3, 128, 128), **kwargs): super(UNet128, self).__init__() C,H,W = in_shape #assert(C==3) #128 self.down3 = StackEncoder( C, 128, kernel_size=3) # 64 self.down4 = StackEncoder(128, 256, kernel_size=3) # 32 self.down5 = StackEncoder(256, 512, kernel_size=3) # 16 self.down6 = StackEncoder(512, 1024, kernel_size=3) # 8 self.center = nn.Sequential( ConvBnRelu2d(1024, 1024, kernel_size=3, padding=1, stride=1 ), ) # 8 # x_big_channels, x_channels, y_channels self.up6 = StackDecoder(1024,1024, 512, kernel_size=3) # 16 self.up5 = StackDecoder( 512, 512, 256, kernel_size=3) # 32 self.up4 = StackDecoder( 256, 256, 128, kernel_size=3) # 64 self.up3 = StackDecoder( 128, 128, 64, kernel_size=3) #128 self.classify = nn.Conv2d(64, 1, kernel_size=1, padding=0, stride=1, bias=True) def forward(self, x): out = x #;print('x ',x.size()) down3,out = self.down3(out) #;print('down3',down3.size()) #64 down4,out = self.down4(out) #;print('down4',down4.size()) #32 down5,out = self.down5(out) #;print('down5',down5.size()) #16 down6,out = self.down6(out) #;print('down6',down6.size()) #8 pass #;print('out ',out.size()) out = self.center(out) out = self.up6(down6, out) out = self.up5(down5, out) out = self.up4(down4, out) out = self.up3(down3, out) out = self.classify(out) # out = torch.squeeze(out, dim=1) return out