Source code for pywick.models.segmentation.unet_stack

# Source: https://github.com/doodledood/carvana-image-masking-challenge/models (MIT)

"""
Implementation of stacked `U-net: Convolutional networks for biomedical image segmentation <https://arxiv.org/pdf/1505.04597>`_
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

__all__ = ['UNet960', 'UNet_stack']

class ConvBNReluStack(nn.Module):
    def __init__(self, in_dim, out_dim, kernel_size=3, stride=1, padding=1, **kwargs):
        super(ConvBNReluStack, self).__init__()

        in_dim = int(in_dim)
        out_dim = int(out_dim)

        self.conv = nn.Conv2d(in_dim, out_dim, kernel_size, stride=stride, padding=padding)
        # nn.init.xavier_normal(self.conv.weight.data)

        self.bn = nn.BatchNorm2d(out_dim)
        self.activation = nn.PReLU()  # nn.LeakyReLU(0.2)

    def forward(self, inputs_):
        x = self.conv(inputs_)
        x = self.bn(x)
        x = self.activation(x)

        return x


class UNetDownStack(nn.Module):
    def __init__(self, input_dim, filters, pool=True):
        super(UNetDownStack, self).__init__()

        self.stack1 = ConvBNReluStack(input_dim, filters, 1, stride=1, padding=0)
        self.stack3 = ConvBNReluStack(input_dim, filters, 3, stride=1, padding=1)
        self.stack5 = ConvBNReluStack(input_dim, filters, 5, stride=1, padding=2)
        self.stack_pool = nn.AvgPool2d(3, stride=1, padding=1)
        self.reducer = ConvBNReluStack(filters * 3 + input_dim, filters, kernel_size=1, stride=1, padding=0)

        # self.pool = ConvBNReluStack(filters, filters, kernel_size, stride=2, padding=1) if pool else None
        self.pool = nn.MaxPool2d(2, stride=2) if pool else None
        # ConvBNReluStack(filters, filters, kernel_size, stride=2, padding=1) if pool else None
        # nn.MaxPool2d(2, stride=2) if pool else None

    def forward(self, inputs_):
        x1 = self.stack1(inputs_)
        x3 = self.stack3(inputs_)
        x5 = self.stack5(inputs_)
        x_pool = self.stack_pool(inputs_)

        x = torch.cat([x1, x3, x5, x_pool], dim=1)
        x = self.reducer(x)

        if self.pool:
            return x, self.pool(x)

        return x


class UNetUpStack(nn.Module):
    def __init__(self, input_dim, filters, kernel_size=3):
        super(UNetUpStack, self).__init__()

        self.scale_factor = 2
        self.stack1 = ConvBNReluStack(input_dim, filters, 1, stride=1, padding=0)
        self.stack3 = ConvBNReluStack(input_dim, filters, 3, stride=1, padding=1)
        self.stack5 = ConvBNReluStack(input_dim, filters, 5, stride=1, padding=2)
        self.stack_pool = nn.AvgPool2d(3, stride=1, padding=1)
        self.reducer = ConvBNReluStack(filters * 3 + input_dim, filters, kernel_size=1, stride=1, padding=0)

    def forward(self, inputs_, down):
        x = F.interpolate(inputs_, scale_factor=self.scale_factor)
        x = torch.cat([x, down], dim=1)

        x1 = self.stack1(x)
        x3 = self.stack3(x)
        x5 = self.stack5(x)
        x_pool = self.stack_pool(x)

        x = torch.cat([x1, x3, x5, x_pool], dim=1)
        x = self.reducer(x)

        return x


[docs]class UNet_stack(nn.Module):
[docs] @staticmethod def get_n_stacks(input_size, **_): n_stacks = 0 width, height = input_size, input_size while width % 2 == 0 and height % 2 == 0: n_stacks += 1 width = width // 2 height = height // 2 return n_stacks
def __init__(self, input_size=512, filters=12, kernel_size=3, max_stacks=6, **_): super(UNet_stack, self).__init__() self.n_stacks = min(self.get_n_stacks((input_size, input_size)), max_stacks) # dynamically create stacks self.down1 = UNetDownStack(3, filters) prev_filters = filters for i in range(2, self.n_stacks + 1): n = i layer = UNetDownStack(prev_filters, prev_filters * 2) layer_name = 'down' + str(n) setattr(self, layer_name, layer) prev_filters *= 2 self.center = UNetDownStack(prev_filters, prev_filters * 2, pool=False) prev_filters = prev_filters * 3 for i in range(self.n_stacks): n = self.n_stacks - i layer = UNetUpStack(prev_filters, prev_filters // 3, kernel_size) layer_name = 'up' + str(n) setattr(self, layer_name, layer) prev_filters = prev_filters // 2 self.classify = nn.Conv2d(prev_filters * 2 // 3, 1, kernel_size, stride=1, padding=1) # nn.init.xavier_normal(self.classify.weight.data) def forward(self, inputs_): down1, down1_pool = self.down1(inputs_) downs = [down1] # execute down nodes prev_down_pool = down1_pool for i in range(2, self.n_stacks + 1): layer_name = 'down' + str(i) layer = getattr(self, layer_name) down, prev_down_pool = layer(prev_down_pool) downs.append(down) center = self.center(prev_down_pool) # excute up nodes prev = center for i in range(self.n_stacks): n = self.n_stacks - i matching_down = downs.pop() layer_name = 'up' + str(n) layer = getattr(self, layer_name) prev = layer(prev, matching_down) x = self.classify(prev) return x
[docs]class UNet960(nn.Module): def __init__(self, filters=12, kernel_size=3, **_): super(UNet960, self).__init__() # 960 self.down1 = UNetDownStack(3, filters) # 480 self.down2 = UNetDownStack(filters, filters * 2) # 240 self.down3 = UNetDownStack(filters * 2, filters * 4) # 120 self.down4 = UNetDownStack(filters * 4, filters * 8) # 60 self.down5 = UNetDownStack(filters * 8, filters * 16) # 30 self.down6 = UNetDownStack(filters * 16, filters * 32) # 15 self.center = UNetDownStack(filters * 32, filters * 64, pool=False) # 15 self.up6 = UNetUpStack(filters * 96, filters * 32, kernel_size) # 30 self.up5 = UNetUpStack(filters * 48, filters * 16, kernel_size) # 60 self.up4 = UNetUpStack(filters * 24, filters * 8, kernel_size) # 120 self.up3 = UNetUpStack(filters * 12, filters * 4, kernel_size) # 240 self.up2 = UNetUpStack(filters * 6, filters * 2, kernel_size) # 480 self.up1 = UNetUpStack(filters * 3, filters, kernel_size) # 960 self.classify = nn.Conv2d(filters, 1, kernel_size, stride=1, padding=1) def forward(self, inputs_): down1, down1_pool = self.down1(inputs_) down2, down2_pool = self.down2(down1_pool) down3, down3_pool = self.down3(down2_pool) down4, down4_pool = self.down4(down3_pool) down5, down5_pool = self.down5(down4_pool) down6, down6_pool = self.down6(down5_pool) center = self.center(down6_pool) up6 = self.up6(center, down6) up5 = self.up5(up6, down5) up4 = self.up4(up5, down4) up3 = self.up3(up4, down3) up2 = self.up2(up3, down2) up1 = self.up1(up2, down1) x = self.classify(up1) return x