# Source: https://github.com/thomasjpfan/pytorch_refinenet (License: MIT)
"""
Implementation of `RefineNet: Multi-Path Refinement Networks for High-Resolution Semantic Segmentation <https://arxiv.org/abs/1611.06612>`_.
"""
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F
from .blocks import (RefineNetBlock, ResidualConvUnit,
RefineNetBlockImprovedPooling)
__all__ = ['RefineNet4Cascade', 'RefineNet4CascadePoolingImproved']
class BaseRefineNet4Cascade(nn.Module):
def __init__(self,
input_shape,
refinenet_block,
num_classes=1,
features=256,
resnet_factory=models.resnet101,
pretrained=True,
freeze_resnet=False,
**kwargs):
"""Multi-path 4-Cascaded RefineNet for image segmentation
Args:
input_shape ((int, int)): (channel, size) assumes input has
equal height and width
refinenet_block (block): RefineNet Block
num_classes (int, optional): number of classes
features (int, optional): number of features in refinenet
resnet_factory (func, optional): A Resnet model from torchvision.
Default: models.resnet101
pretrained (bool, optional): Use pretrained version of resnet
Default: True
freeze_resnet (bool, optional): Freeze resnet model
Default: True
Raises:
ValueError: size of input_shape not divisible by 32
"""
super().__init__()
input_channel, input_size = input_shape
if input_size % 32 != 0:
raise ValueError("{} not divisble by 32".format(input_shape))
resnet = resnet_factory(pretrained=pretrained)
self.layer1 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu,
resnet.maxpool, resnet.layer1)
self.layer2 = resnet.layer2
self.layer3 = resnet.layer3
self.layer4 = resnet.layer4
if freeze_resnet:
layers = [self.layer1, self.layer2, self.layer3, self.layer4]
for layer in layers:
for param in layer.parameters():
param.requires_grad = False
self.layer1_rn = nn.Conv2d(
256, features, kernel_size=3, stride=1, padding=1, bias=False)
self.layer2_rn = nn.Conv2d(
512, features, kernel_size=3, stride=1, padding=1, bias=False)
self.layer3_rn = nn.Conv2d(
1024, features, kernel_size=3, stride=1, padding=1, bias=False)
self.layer4_rn = nn.Conv2d(
2048, 2 * features, kernel_size=3, stride=1, padding=1, bias=False)
self.refinenet4 = RefineNetBlock(2 * features,
(2 * features, input_size // 32))
self.refinenet3 = RefineNetBlock(features,
(2 * features, input_size // 32),
(features, input_size // 16))
self.refinenet2 = RefineNetBlock(features,
(features, input_size // 16),
(features, input_size // 8))
self.refinenet1 = RefineNetBlock(features, (features, input_size // 8),
(features, input_size // 4))
self.output_conv = nn.Sequential(
ResidualConvUnit(features), ResidualConvUnit(features),
nn.Conv2d(
features,
num_classes,
kernel_size=1,
stride=1,
padding=0,
bias=True))
def forward(self, x):
size = x.size()[2:]
layer_1 = self.layer1(x)
layer_2 = self.layer2(layer_1)
layer_3 = self.layer3(layer_2)
layer_4 = self.layer4(layer_3)
layer_1_rn = self.layer1_rn(layer_1)
layer_2_rn = self.layer2_rn(layer_2)
layer_3_rn = self.layer3_rn(layer_3)
layer_4_rn = self.layer4_rn(layer_4)
path_4 = self.refinenet4(layer_4_rn)
path_3 = self.refinenet3(path_4, layer_3_rn)
path_2 = self.refinenet2(path_3, layer_2_rn)
path_1 = self.refinenet1(path_2, layer_1_rn)
out_conv = self.output_conv(path_1)
out = F.interpolate(out_conv, size, mode='bilinear', align_corners=True)
return out
[docs]class RefineNet4CascadePoolingImproved(BaseRefineNet4Cascade):
def __init__(self,
num_classes=1,
pretrained=True,
input_shape=(1, 512),
features=256,
resnet_factory=models.resnet101,
freeze_resnet=False,
**kwargs):
"""Multi-path 4-Cascaded RefineNet for image segmentation with improved pooling
Args:
input_shape ((int, int)): (channel, size) assumes input has
equal height and width
refinenet_block (block): RefineNet Block
num_classes (int, optional): number of classes
features (int, optional): number of features in refinenet
resnet_factory (func, optional): A Resnet model from torchvision.
Default: models.resnet101
pretrained (bool, optional): Use pretrained version of resnet
Default: True
freeze_resnet (bool, optional): Freeze resnet model
Default: True
Raises:
ValueError: size of input_shape not divisible by 32
"""
super().__init__(
input_shape,
RefineNetBlockImprovedPooling,
num_classes=num_classes,
features=features,
resnet_factory=resnet_factory,
pretrained=pretrained,
freeze_resnet=freeze_resnet,
**kwargs)
[docs]class RefineNet4Cascade(BaseRefineNet4Cascade):
def __init__(self,
num_classes=1,
pretrained=True,
input_shape=(1, 512),
features=256,
resnet_factory=models.resnet101,
freeze_resnet=False,
**kwargs):
"""Multi-path 4-Cascaded RefineNet for image segmentation
Args:
input_shape ((int, int)): (channel, size) assumes input has
equal height and width
refinenet_block (block): RefineNet Block
num_classes (int, optional): number of classes
features (int, optional): number of features in refinenet
resnet_factory (func, optional): A Resnet model from torchvision.
Default: models.resnet101
pretrained (bool, optional): Use pretrained version of resnet
Default: True
freeze_resnet (bool, optional): Freeze resnet model
Default: True
Raises:
ValueError: size of input_shape not divisible by 32
"""
super().__init__(
input_shape,
RefineNetBlock,
num_classes=num_classes,
features=features,
resnet_factory=resnet_factory,
pretrained=pretrained,
freeze_resnet=freeze_resnet,
**kwargs)