Source code for pywick.models.segmentation.refinenet.refinenet

# Source: https://github.com/thomasjpfan/pytorch_refinenet (License: MIT)

"""
Implementation of `RefineNet: Multi-Path Refinement Networks for High-Resolution Semantic Segmentation <https://arxiv.org/abs/1611.06612>`_.
"""

import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F

from .blocks import (RefineNetBlock, ResidualConvUnit,
                      RefineNetBlockImprovedPooling)

__all__ = ['RefineNet4Cascade', 'RefineNet4CascadePoolingImproved']

class BaseRefineNet4Cascade(nn.Module):
    def __init__(self,
                 input_shape,
                 refinenet_block,
                 num_classes=1,
                 features=256,
                 resnet_factory=models.resnet101,
                 pretrained=True,
                 freeze_resnet=False,
                 **kwargs):
        """Multi-path 4-Cascaded RefineNet for image segmentation

        Args:
            input_shape ((int, int)): (channel, size) assumes input has
                equal height and width
            refinenet_block (block): RefineNet Block
            num_classes (int, optional): number of classes
            features (int, optional): number of features in refinenet
            resnet_factory (func, optional): A Resnet model from torchvision.
                Default: models.resnet101
            pretrained (bool, optional): Use pretrained version of resnet
                Default: True
            freeze_resnet (bool, optional): Freeze resnet model
                Default: True

        Raises:
            ValueError: size of input_shape not divisible by 32
        """
        super().__init__()

        input_channel, input_size = input_shape

        if input_size % 32 != 0:
            raise ValueError("{} not divisble by 32".format(input_shape))

        resnet = resnet_factory(pretrained=pretrained)

        self.layer1 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu,
                                    resnet.maxpool, resnet.layer1)

        self.layer2 = resnet.layer2
        self.layer3 = resnet.layer3
        self.layer4 = resnet.layer4

        if freeze_resnet:
            layers = [self.layer1, self.layer2, self.layer3, self.layer4]
            for layer in layers:
                for param in layer.parameters():
                    param.requires_grad = False

        self.layer1_rn = nn.Conv2d(
            256, features, kernel_size=3, stride=1, padding=1, bias=False)
        self.layer2_rn = nn.Conv2d(
            512, features, kernel_size=3, stride=1, padding=1, bias=False)
        self.layer3_rn = nn.Conv2d(
            1024, features, kernel_size=3, stride=1, padding=1, bias=False)
        self.layer4_rn = nn.Conv2d(
            2048, 2 * features, kernel_size=3, stride=1, padding=1, bias=False)

        self.refinenet4 = RefineNetBlock(2 * features,
                                         (2 * features, input_size // 32))
        self.refinenet3 = RefineNetBlock(features,
                                         (2 * features, input_size // 32),
                                         (features, input_size // 16))
        self.refinenet2 = RefineNetBlock(features,
                                         (features, input_size // 16),
                                         (features, input_size // 8))
        self.refinenet1 = RefineNetBlock(features, (features, input_size // 8),
                                         (features, input_size // 4))

        self.output_conv = nn.Sequential(
            ResidualConvUnit(features), ResidualConvUnit(features),
            nn.Conv2d(
                features,
                num_classes,
                kernel_size=1,
                stride=1,
                padding=0,
                bias=True))

    def forward(self, x):
        size = x.size()[2:]
        layer_1 = self.layer1(x)
        layer_2 = self.layer2(layer_1)
        layer_3 = self.layer3(layer_2)
        layer_4 = self.layer4(layer_3)

        layer_1_rn = self.layer1_rn(layer_1)
        layer_2_rn = self.layer2_rn(layer_2)
        layer_3_rn = self.layer3_rn(layer_3)
        layer_4_rn = self.layer4_rn(layer_4)

        path_4 = self.refinenet4(layer_4_rn)
        path_3 = self.refinenet3(path_4, layer_3_rn)
        path_2 = self.refinenet2(path_3, layer_2_rn)
        path_1 = self.refinenet1(path_2, layer_1_rn)
        out_conv = self.output_conv(path_1)
        out = F.interpolate(out_conv, size, mode='bilinear', align_corners=True)
        return out


[docs]class RefineNet4CascadePoolingImproved(BaseRefineNet4Cascade): def __init__(self, num_classes=1, pretrained=True, input_shape=(1, 512), features=256, resnet_factory=models.resnet101, freeze_resnet=False, **kwargs): """Multi-path 4-Cascaded RefineNet for image segmentation with improved pooling Args: input_shape ((int, int)): (channel, size) assumes input has equal height and width refinenet_block (block): RefineNet Block num_classes (int, optional): number of classes features (int, optional): number of features in refinenet resnet_factory (func, optional): A Resnet model from torchvision. Default: models.resnet101 pretrained (bool, optional): Use pretrained version of resnet Default: True freeze_resnet (bool, optional): Freeze resnet model Default: True Raises: ValueError: size of input_shape not divisible by 32 """ super().__init__( input_shape, RefineNetBlockImprovedPooling, num_classes=num_classes, features=features, resnet_factory=resnet_factory, pretrained=pretrained, freeze_resnet=freeze_resnet, **kwargs)
[docs]class RefineNet4Cascade(BaseRefineNet4Cascade): def __init__(self, num_classes=1, pretrained=True, input_shape=(1, 512), features=256, resnet_factory=models.resnet101, freeze_resnet=False, **kwargs): """Multi-path 4-Cascaded RefineNet for image segmentation Args: input_shape ((int, int)): (channel, size) assumes input has equal height and width refinenet_block (block): RefineNet Block num_classes (int, optional): number of classes features (int, optional): number of features in refinenet resnet_factory (func, optional): A Resnet model from torchvision. Default: models.resnet101 pretrained (bool, optional): Use pretrained version of resnet Default: True freeze_resnet (bool, optional): Freeze resnet model Default: True Raises: ValueError: size of input_shape not divisible by 32 """ super().__init__( input_shape, RefineNetBlock, num_classes=num_classes, features=features, resnet_factory=resnet_factory, pretrained=pretrained, freeze_resnet=freeze_resnet, **kwargs)