Source code for pywick.models.segmentation.frrn1

# Source: https://github.com/meetshah1995/pytorch-semseg (MIT)

"""
Implementation of `Full Resolution Residual Networks for Semantic Segmentation in Street Scenes <https://arxiv.org/abs/1611.08323>`_
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

__all__ = ['frrn']

frrn_specs_dic = {
    "A": {
        "encoder": [[3, 96, 2], [4, 192, 4], [2, 384, 8], [2, 384, 16]],
        "decoder": [[2, 192, 8], [2, 192, 4], [2, 96, 2]],
    },
    "B": {
        "encoder": [[3, 96, 2], [4, 192, 4], [2, 384, 8], [2, 384, 16], [2, 384, 32]],
        "decoder": [[2, 192, 16], [2, 192, 8], [2, 192, 4], [2, 96, 2]],
    },
}

class conv2DBatchNorm(nn.Module):
    def __init__(
        self,
        in_channels,
        n_filters,
        k_size,
        stride,
        padding,
        bias=True,
        dilation=1,
        with_bn=True,
    ):
        super(conv2DBatchNorm, self).__init__()

        if dilation > 1:
            conv_mod = nn.Conv2d(
                int(in_channels),
                int(n_filters),
                kernel_size=k_size,
                padding=padding,
                stride=stride,
                bias=bias,
                dilation=dilation,
            )

        else:
            conv_mod = nn.Conv2d(
                int(in_channels),
                int(n_filters),
                kernel_size=k_size,
                padding=padding,
                stride=stride,
                bias=bias,
                dilation=1,
            )

        if with_bn:
            self.cb_unit = nn.Sequential(conv_mod, nn.BatchNorm2d(int(n_filters)))
        else:
            self.cb_unit = nn.Sequential(conv_mod)

    def forward(self, inputs):
        outputs = self.cb_unit(inputs)
        return outputs


class deconv2DBatchNorm(nn.Module):
    def __init__(self, in_channels, n_filters, k_size, stride, padding, bias=True):
        super(deconv2DBatchNorm, self).__init__()

        self.dcb_unit = nn.Sequential(
            nn.ConvTranspose2d(
                int(in_channels),
                int(n_filters),
                kernel_size=k_size,
                padding=padding,
                stride=stride,
                bias=bias,
            ),
            nn.BatchNorm2d(int(n_filters)),
        )

    def forward(self, inputs):
        outputs = self.dcb_unit(inputs)
        return outputs


class conv2DBatchNormRelu(nn.Module):
    def __init__(
        self,
        in_channels,
        n_filters,
        k_size,
        stride,
        padding,
        bias=True,
        dilation=1,
        with_bn=True,
    ):
        super(conv2DBatchNormRelu, self).__init__()

        if dilation > 1:
            conv_mod = nn.Conv2d(
                int(in_channels),
                int(n_filters),
                kernel_size=k_size,
                padding=padding,
                stride=stride,
                bias=bias,
                dilation=dilation,
            )

        else:
            conv_mod = nn.Conv2d(
                int(in_channels),
                int(n_filters),
                kernel_size=k_size,
                padding=padding,
                stride=stride,
                bias=bias,
                dilation=1,
            )

        if with_bn:
            self.cbr_unit = nn.Sequential(
                conv_mod, nn.BatchNorm2d(int(n_filters)), nn.ReLU(inplace=True)
            )
        else:
            self.cbr_unit = nn.Sequential(conv_mod, nn.ReLU(inplace=True))

    def forward(self, inputs):
        outputs = self.cbr_unit(inputs)
        return outputs


class deconv2DBatchNormRelu(nn.Module):
    def __init__(self, in_channels, n_filters, k_size, stride, padding, bias=True):
        super(deconv2DBatchNormRelu, self).__init__()

        self.dcbr_unit = nn.Sequential(
            nn.ConvTranspose2d(
                int(in_channels),
                int(n_filters),
                kernel_size=k_size,
                padding=padding,
                stride=stride,
                bias=bias,
            ),
            nn.BatchNorm2d(int(n_filters)),
            nn.ReLU(inplace=True),
        )

    def forward(self, inputs):
        outputs = self.dcbr_unit(inputs)
        return outputs


class RU(nn.Module):
    """
    Residual Unit for FRRN
    """

    def __init__(self, channels, kernel_size=3, strides=1):
        super(RU, self).__init__()

        self.conv1 = conv2DBatchNormRelu(
            channels, channels, k_size=kernel_size, stride=strides, padding=1
        )
        self.conv2 = conv2DBatchNorm(
            channels, channels, k_size=kernel_size, stride=strides, padding=1
        )

    def forward(self, x):
        incoming = x
        x = self.conv1(x)
        x = self.conv2(x)
        return x + incoming


class FRRU(nn.Module):
    """
    Full Resolution Residual Unit for FRRN
    """

    def __init__(self, prev_channels, out_channels, scale):
        super(FRRU, self).__init__()
        self.scale = scale
        self.prev_channels = prev_channels
        self.out_channels = out_channels

        self.conv1 = conv2DBatchNormRelu(
            prev_channels + 32, out_channels, k_size=3, stride=1, padding=1
        )
        self.conv2 = conv2DBatchNormRelu(
            out_channels, out_channels, k_size=3, stride=1, padding=1
        )
        self.conv_res = nn.Conv2d(out_channels, 32, kernel_size=1, stride=1, padding=0)

    def forward(self, y, z):
        x = torch.cat([y, nn.MaxPool2d(self.scale, self.scale)(z)], dim=1)
        y_prime = self.conv1(x)
        y_prime = self.conv2(y_prime)

        x = self.conv_res(y_prime)
        upsample_size = torch.Size([_s * self.scale for _s in y_prime.shape[-2:]])
        x = F.upsample(x, size=upsample_size, mode="nearest")
        z_prime = z + x

        return y_prime, z_prime


[docs]class frrn(nn.Module): """ Full Resolution Residual Networks for Semantic Segmentation URL: https://arxiv.org/abs/1611.08323 References: 1) Original Author's code: https://github.com/TobyPDE/FRRN 2) TF implementation by @kiwonjoon: https://github.com/hiwonjoon/tf-frrn """ def __init__(self, num_classes=21, model_type=None, **kwargs): super(frrn, self).__init__() self.n_classes = num_classes self.model_type = model_type self.K = 64 * 512 self.conv1 = conv2DBatchNormRelu(3, 48, 5, 1, 2) self.up_residual_units = [] self.down_residual_units = [] for i in range(3): self.up_residual_units.append(RU(channels=48, kernel_size=3, strides=1)) self.down_residual_units.append(RU(channels=48, kernel_size=3, strides=1)) self.up_residual_units = nn.ModuleList(self.up_residual_units) self.down_residual_units = nn.ModuleList(self.down_residual_units) self.split_conv = nn.Conv2d( 48, 32, kernel_size=1, padding=0, stride=1, bias=True ) # each spec is as (n_blocks, channels, scale) self.encoder_frru_specs = frrn_specs_dic[self.model_type]["encoder"] self.decoder_frru_specs = frrn_specs_dic[self.model_type]["decoder"] # encoding prev_channels = 48 self.encoding_frrus = {} for n_blocks, channels, scale in self.encoder_frru_specs: for block in range(n_blocks): key = "_".join( map(str, ["encoding_frru", n_blocks, channels, scale, block]) ) setattr( self, key, FRRU( prev_channels=prev_channels, out_channels=channels, scale=scale ), ) prev_channels = channels # decoding self.decoding_frrus = {} for n_blocks, channels, scale in self.decoder_frru_specs: # pass through decoding FRRUs for block in range(n_blocks): key = "_".join( map(str, ["decoding_frru", n_blocks, channels, scale, block]) ) setattr( self, key, FRRU( prev_channels=prev_channels, out_channels=channels, scale=scale ), ) prev_channels = channels self.merge_conv = nn.Conv2d( prev_channels + 32, 48, kernel_size=1, padding=0, stride=1, bias=True ) self.classif_conv = nn.Conv2d( 48, self.n_classes, kernel_size=1, padding=0, stride=1, bias=True ) def forward(self, x): # pass to initial conv x = self.conv1(x) # pass through residual units for i in range(3): x = self.up_residual_units[i](x) # divide stream y = x z = self.split_conv(x) prev_channels = 48 # encoding for n_blocks, channels, scale in self.encoder_frru_specs: # maxpool bigger feature map y_pooled = F.max_pool2d(y, stride=2, kernel_size=2, padding=0) # pass through encoding FRRUs for block in range(n_blocks): key = "_".join( map(str, ["encoding_frru", n_blocks, channels, scale, block]) ) y, z = getattr(self, key)(y_pooled, z) prev_channels = channels # decoding for n_blocks, channels, scale in self.decoder_frru_specs: # bilinear upsample smaller feature map upsample_size = torch.Size([_s * 2 for _s in y.size()[-2:]]) y_upsampled = F.upsample(y, size=upsample_size, mode="bilinear") # pass through decoding FRRUs for block in range(n_blocks): key = "_".join( map(str, ["decoding_frru", n_blocks, channels, scale, block]) ) # print("Incoming FRRU Size: ", key, y_upsampled.shape, z.shape) y, z = getattr(self, key)(y_upsampled, z) # print("Outgoing FRRU Size: ", key, y.shape, z.shape) prev_channels = channels # merge streams x = torch.cat([F.upsample(y, scale_factor=2, mode="bilinear"), z], dim=1) x = self.merge_conv(x) # pass through residual units for i in range(3): x = self.down_residual_units[i](x) # final 1x1 conv to get classification x = self.classif_conv(x) return x