Source code for pywick.models.segmentation.unet_dilated

# Source: https://github.com/Hsuxu/carvana-pytorch-uNet/blob/master/model.py

"""
Implementation of `U-net: Convolutional networks for biomedical image segmentation <https://arxiv.org/pdf/1505.04597>`_ with dilation convolution operation
"""

import torch
import torch.nn as nn

__all__ = ['UNetDilated']


class Conv_transition(nn.Module):
    '''
    resnet block contains inception
    '''

    def __init__(self, kernel_size, in_channels, out_channels):
        super(Conv_transition, self).__init__()
        if not kernel_size:
            kernel_size = [1, 3, 5]
        paddings = [int(a / 2) for a in kernel_size]
        # self.Conv0=nn.Conv2d(in_channels,out_channels,3,stride=1,padding=1)
        self.Conv1 = nn.Conv2d(in_channels, out_channels, kernel_size[0], stride=1, padding=paddings[0])
        self.Conv2 = nn.Conv2d(in_channels, out_channels, kernel_size[1], stride=1, padding=paddings[1])
        self.Conv3 = nn.Conv2d(in_channels, out_channels, kernel_size[2], stride=1, padding=paddings[2])
        self.Conv_f = nn.Conv2d(3 * out_channels, out_channels, 3, stride=1, padding=1)
        self.bn = nn.BatchNorm2d(out_channels)
        self.act = nn.PReLU()

    def forward(self, x):
        # x = self.Conv0(x)
        x1 = self.act(self.Conv1(x))
        x2 = self.act(self.Conv2(x))
        x3 = self.act(self.Conv3(x))
        x = torch.cat([x1, x2, x3], dim=1)
        return self.act(self.bn(self.Conv_f(x)))


class Dense_layer(nn.Module):
    """
    an two-layer
    """

    def __init__(self, in_channels, growth_rate):
        super(Dense_layer, self).__init__()
        # self.bn0=nn.BatchNorm2d(in_channels)
        self.Conv0 = nn.Conv2d(in_channels, in_channels + growth_rate, 3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(in_channels + growth_rate)
        self.Conv1 = nn.Conv2d(in_channels + growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)

        self.bn2 = nn.BatchNorm2d(in_channels + growth_rate)
        self.Conv2 = nn.Conv2d(in_channels + growth_rate, in_channels, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(in_channels)

        # self.Conv1=nn.Conv2d(in_channels+growth_rate,growth_rate,kernel_size=3,stride=1,padding=1,bias=False)

        self.act = nn.PReLU()

    def forward(self, x):
        x1 = self.act(self.bn1(self.Conv0(x)))
        x1 = self.act(self.bn2(torch.cat([self.Conv1(x1), x], dim=1)))

        return self.act(self.bn3(self.Conv2(x1)))


class Fire_Down(nn.Module):
    def __init__(self, kernel_size, in_channels, inner_channels, out_channels):
        super(Fire_Down, self).__init__()
        dilations = [1, 3, 5]
        self.Conv1 = nn.Conv2d(in_channels, inner_channels, kernel_size=kernel_size, stride=1, padding=dilations[0],
                               dilation=dilations[0])
        self.Conv4 = nn.Conv2d(in_channels, inner_channels, kernel_size=kernel_size, stride=1, padding=dilations[1],
                               dilation=dilations[1])
        self.Conv8 = nn.Conv2d(in_channels, inner_channels, kernel_size=kernel_size, stride=1, padding=dilations[2],
                               dilation=dilations[2])
        self.Conv_f3 = nn.Conv2d(3 * inner_channels, out_channels, kernel_size=kernel_size, stride=2, padding=1)
        self.Conv_f1 = nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.act = nn.PReLU()

    def forward(self, x):
        x1 = self.act(self.Conv1(x))
        x2 = self.act(self.Conv4(x))
        x3 = self.act(self.Conv8(x))
        x = torch.cat([x1, x2, x3], dim=1)
        x = self.act(self.Conv_f3(x))
        return self.act(self.bn1(self.Conv_f1(x)))


class Fire_Up(nn.Module):
    def __init__(self, kernel_size, in_channels, inner_channels, out_channels, out_padding=(1, 1)):
        super(Fire_Up, self).__init__()
        padds = int(kernel_size / 2)
        self.Conv1 = nn.Conv2d(in_channels, inner_channels, kernel_size=3, stride=1, padding=1)
        if not out_padding:
            out_padding = (1, 1)
        # self.ConvT1=nn.ConvTranspose2d(inner_channels,out_channels,kernel_size=1,stride=2,padding=0,output_padding=out_padding)
        self.ConvT4 = nn.ConvTranspose2d(inner_channels, out_channels, kernel_size=kernel_size, stride=2, padding=padds,
                                         output_padding=out_padding)
        # self.ConvT8=nn.ConvTranspose2d(inner_channels,out_channels,kernel_size=5,stride=2,padding=2,output_padding=out_padding)
        self.Conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1, stride=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.act = nn.PReLU()

    def forward(self, x):
        x = self.act(self.Conv1(x))
        # x1=self.act(self.ConvT1(x))
        x = self.act(self.ConvT4(x))
        # x8=self.act(self.ConvT8(x))
        # x=torch.cat([x1,x4],dim=1)
        x = self.act(self.bn1(self.Conv2(x)))
        return x


[docs]class UNetDilated(nn.Module): """ Unet utilizing dilation """ def __init__(self, num_classes, **_): super(UNetDilated, self).__init__() self.Conv0 = self._transition(3, 8) # 1918 self.down1 = self._down_block(8, 16, 16) # 959 self.down2 = self._down_block(16, 16, 32) # 480 self.down3 = self._down_block(32, 32, 64) # 240 self.down4 = self._down_block(64, 64, 96) # 120 self.down5 = self._down_block(96, 96, 128) # 60 self.tran0 = self._transition(128, 256) self.db0 = self._dense_block(256, 32) self.up1 = self._up_block(256, 96, 96) # 120 self.db1 = self._dense_block(96, 32) self.conv1 = nn.Conv2d(96 * 2, 96, 3, stride=1, padding=1) self.bn1 = nn.BatchNorm2d(96) self.up2 = self._up_block(96, 64, 64) # 240 self.db2 = self._dense_block(64, 24) self.conv2 = nn.Conv2d(64 * 2, 64, 3, stride=1, padding=1) self.bn2 = nn.BatchNorm2d(64) self.up3 = self._up_block(64, 32, 32) # 480 self.db3 = self._dense_block(32, 10) self.conv3 = nn.Conv2d(32 * 2, 32, 3, stride=1, padding=1) self.bn3 = nn.BatchNorm2d(32) self.up4 = self._up_block(32, 16, 16) # ,output_padding=(1, 0)) # 959 self.db4 = self._dense_block(16, 8) self.conv4 = nn.Conv2d(16 * 2, 16, 3, stride=1, padding=1) self.bn4 = nn.BatchNorm2d(16) self.up5 = self._up_block(16, 16, 16) # 1918 self.db5 = self._dense_block(16, 4) self.conv5 = nn.Conv2d(16, num_classes, 3, stride=1, padding=1) self.clss = nn.LogSoftmax() self.act = nn.PReLU() def forward(self, x): x1 = self.Conv0(x) down1 = self.down1(x1) down2 = self.down2(down1) down3 = self.down3(down2) down4 = self.down4(down3) down5 = self.down5(down4) down5 = self.tran0(down5) down5 = self.db0(down5) ## TODO Problem here: # self.up1(down5).data.shape => torch.Size([2, 96, 64, 44]) # -- MISMATCH WITH -- # down4.data.shape => torch.Size([2, 96, 64, 43]) up1 = self.act(self.bn1(self.conv1(torch.cat([self.db1(self.up1(down5)), down4], dim=1)))) up2 = self.act(self.bn2(self.conv2(torch.cat([self.db2(self.up2(up1)), down3], dim=1)))) up3 = self.act(self.bn3(self.conv3(torch.cat([self.db3(self.up3(up2)), down2], dim=1)))) up4 = self.act(self.bn4(self.conv4(torch.cat([self.db4(self.up4(up3)), down1], dim=1)))) up5 = self.up5(up4) # up5=self.conv5(up5) # return self.clss(self.conv5(up5)) return self.conv5(up5) @staticmethod def _transition(in_channels, out_channels): layers = [] layers.append(Conv_transition([1, 3, 5], in_channels, out_channels)) return nn.Sequential(*layers) @staticmethod def _down_block(in_channels, inner_channels, out_channels): layers = [] layers.append(Fire_Down(3, in_channels, inner_channels, out_channels)) return nn.Sequential(*layers) @staticmethod def _up_block(in_channels, inner_channels, out_channels, output_padding=(1, 1)): layers = [] layers.append(Fire_Up(3, in_channels, inner_channels, out_channels, output_padding)) return nn.Sequential(*layers) @staticmethod def _dense_block(in_channels, growth_rate): layers = [] layers.append(Dense_layer(in_channels, growth_rate)) return nn.Sequential(*layers)