Source code for pywick.models.segmentation.unet_res

"""
Implementation of `U-net: Convolutional networks for biomedical image segmentation <https://arxiv.org/pdf/1505.04597>`_
"""

# Source: https://github.com/saeedizadi/binseg_pytoch (Apache-2.0)

import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
import numpy as np

__all__ = ['UNetRes']

def initialize_weights(method='kaiming', *models):
    for model in models:
        for module in model.modules():

            if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
                if method == 'kaiming':
                    init.kaiming_normal_(module.weight.data, np.sqrt(2.0))
                elif method == 'xavier':
                    init.xavier_normal(module.weight.data, np.sqrt(2.0))
                elif method == 'orthogonal':
                    init.orthogonal(module.weight.data, np.sqrt(2.0))
                elif method == 'normal':
                    init.normal(module.weight.data,mean=0, std=0.02)
                if module.bias is not None:
                    init.constant(module.bias.data,0)


class UnetEncoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UnetEncoder, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels

        self.layer = nn.Sequential(nn.Conv2d(self.in_channels, self.out_channels, kernel_size=3, padding=1),
                                   nn.BatchNorm2d(self.out_channels),
                                   nn.ReLU(),
                                   nn.Conv2d(self.out_channels, self.out_channels, kernel_size=3, padding=1),
                                   nn.BatchNorm2d(self.out_channels),
                                   nn.ReLU())

    def forward(self, x):
        return self.layer(x)


class UnetDecoder(nn.Module):
    def __init__(self, in_channels, featrures, out_channels):
        super(UnetDecoder, self).__init__()
        self.in_channels = in_channels
        self.features = featrures
        self.out_channels = out_channels

        self.layer = nn.Sequential(nn.Conv2d(self.in_channels, self.features, kernel_size=3, padding=1),
                                   nn.BatchNorm2d(self.features),
                                   nn.ReLU(),
                                   nn.Conv2d(self.features, self.features, kernel_size=3, padding=1),
                                   nn.BatchNorm2d(self.features),
                                   nn.ReLU(),
                                   nn.ConvTranspose2d(self.features, self.out_channels, kernel_size=2, stride=2),
                                   nn.BatchNorm2d(self.out_channels),
                                   nn.ReLU())

    def forward(self, x):
        return self.layer(x)


class UNet(nn.Module):
    def __init__(self, num_classes):
        super(UNet, self).__init__()
        self.num_classes = num_classes
        self.down1 = UnetEncoder(3, 64)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
        self.down2 = UnetEncoder(64, 128)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
        self.down3 = UnetEncoder(128, 256)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
        self.down4 = UnetEncoder(256, 512)
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
        self.center = nn.Sequential(nn.Conv2d(512, 1024, kernel_size=3, padding=1),
                                    nn.BatchNorm2d(1024),
                                    nn.ReLU(),
                                    nn.Conv2d(1024, 1024, kernel_size=3, padding=1),
                                    nn.BatchNorm2d(1024),
                                    nn.ReLU(),
                                    nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2),
                                    nn.BatchNorm2d(512),
                                    nn.ReLU())

        self.up1 = UnetDecoder(1024, 512, 256)
        self.up2 = UnetDecoder(512, 256, 128)
        self.up3 = UnetDecoder(256, 128, 64)

        self.up4 = nn.Sequential(nn.Conv2d(128, 64, 3, padding=1),
                                 # nn.BatchNorm2d(64),
                                 # nn.ReLU(),
                                 nn.Conv2d(64, 64, 3, padding=1))
        # nn.BatchNorm2d(64),
        # nn.ReLU())

        self.output = nn.Conv2d(64, self.num_classes, kernel_size=1, stride=1)
        self.final = nn.Sigmoid()

        # Initialize weights
        initialize_weights(self)

    def forward(self, x):
        en1 = self.down1(x)
        po1 = self.pool1(en1)
        en2 = self.down2(po1)
        po2 = self.pool2(en2)
        en3 = self.down3(po2)
        po3 = self.pool3(en3)
        en4 = self.down4(po3)
        po4 = self.pool4(en4)

        c1 = self.center(po4)

        dec1 = self.up1(torch.cat([c1, F.interpolate(en4, c1.size()[2:], mode="bilinear")], 1))
        dec2 = self.up2(torch.cat([dec1, F.interpolate(en3, dec1.size()[2:], mode="bilinear")], 1))
        dec3 = self.up3(torch.cat([dec2, F.interpolate(en2, dec2.size()[2:], mode="bilinear")], 1))
        dec4 = self.up4(torch.cat([dec3, F.interpolate(en1, dec3.size()[2:], mode="bilinear")], 1))

        out = self.output(dec4)
        return self.final(out)


# The improved version of UNet model which replaces all poolings with convolution, skip conenction goes through convolutions, and residual convlutions
class Conv2dX2_Res(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1):
        super(Conv2dX2_Res, self).__init__()

        self.layer = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding),
                                   nn.BatchNorm2d(out_channels),
                                   nn.ReLU(inplace=True),
                                   nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, padding=padding))

    def forward(self, x):
        conv = self.layer(x)
        return F.relu(x.expand_as(conv) + conv)


class PassConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0):
        super(PassConv, self).__init__()

        self.layer = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.ReLU(inplace=True))

    def forward(self, x):
        return self.layer(x)


class DeconvX2_Res(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=2, stride=2):
        super(DeconvX2_Res, self).__init__()

        self.convx2_res = Conv2dX2_Res(in_channels, in_channels, kernel_size=3, padding=1)
        self.upsample = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride),
            nn.ReLU(inplace=True))

    def forward(self, x):
        convx2_res = self.convx2_res(x)
        return self.upsample(convx2_res)


[docs]class UNetRes(nn.Module): def __init__(self, num_classes, **kwargs): super(UNetRes, self).__init__() # Assuming Input as 240x320x3 self.enc1 = nn.Sequential(nn.Conv2d(3, 64, 3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(64, 64, 3, padding=1), nn.ReLU(inplace=True)) self.pool1 = nn.Conv2d(64, 128, kernel_size=2, stride=2) # Conv as Pool self.enc2 = Conv2dX2_Res(128, 128, 3, padding=1) self.pool2 = nn.Conv2d(128, 256, kernel_size=2, stride=2) # Conv as Pool self.enc3 = Conv2dX2_Res(256, 256, 3, padding=1) self.pool3 = nn.Conv2d(256, 512, kernel_size=2, stride=2) # Conv as Pool self.enc4 = Conv2dX2_Res(512, 512, 3, padding=1) self.pool4 = nn.Conv2d(512, 1024, kernel_size=2, stride=2) # Conv as Pool self.middle = nn.Sequential(Conv2dX2_Res(1024, 1024, 3, padding=1), nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2), nn.ReLU(inplace=True)) self.pass_enc4 = PassConv(512, 512) self.pass_enc3 = PassConv(256, 256) self.pass_enc2 = PassConv(128, 128) self.pass_enc1 = PassConv(64, 64) self.dec1 = DeconvX2_Res(512, 256, 2, stride=2) self.dec2 = DeconvX2_Res(256, 128, 2, stride=2) self.dec3 = DeconvX2_Res(128, 64, 2, stride=2) self.dec4 = nn.Sequential(nn.Conv2d(64, 64, 3, padding=1), nn.Conv2d(64, num_classes, kernel_size=1, stride=1)) self.activation = nn.Sigmoid() initialize_weights(self) def forward(self, x): en1 = self.enc1(x) ##240x320x64 en2 = self.enc2(self.pool1(en1)) ## 120x160x128 en3 = self.enc3(self.pool2(en2)) ## 60x80x256 en4 = self.enc4(self.pool3(en3)) ## 30x40x512 middle = self.middle(self.pool4(en4)) ## 30x40x512 # pass_en4 = self.pass_enc4(en4) ## 30x40x512 # dec1 = self.dec1(pass_en4+middle) ## 60x80x256 dec1 = self.dec1(en4 + middle) ## 60x80x256 # pass_enc3 = self.pass_enc3(en3) ## 60x80x256 # dec2 = self.dec2(pass_enc3+dec1) ## 120x160x128 dec2 = self.dec2(en3 + dec1) ## 120x160x128 # pass_enc2 = self.pass_enc2(en2) ## 120x160x128 # dec3 = self.dec3(pass_enc2+dec2) ## 240x320x64 dec3 = self.dec3(en2 + dec2) ## 240x320x64 # pass_enc1 = self.pass_enc1(enc1) ## 240x320x64 # dec4 = self.dec4(pass_enc1+dec3) ## 240x320x1 dec4 = self.dec4(en1 + dec3) ## 240x320x1 return self.activation(dec4)