Source code for pywick.models.segmentation.duc_hdc

# Source: (MIT)

Implementation of: `Understanding Convolution for Semantic Segmentation <>`_
import torch
from torch import nn
from torchvision import models

from .config import res152_path

__all__ = ['ResNetDUC', 'ResNetDUCHDC']

class _DenseUpsamplingConvModule(nn.Module):
    def __init__(self, down_factor, in_dim, num_classes):
        super(_DenseUpsamplingConvModule, self).__init__()
        upsample_dim = (down_factor ** 2) * num_classes
        self.conv = nn.Conv2d(in_dim, upsample_dim, kernel_size=3, padding=1) = nn.BatchNorm2d(upsample_dim)
        self.relu = nn.ReLU(inplace=True)
        self.pixel_shuffle = nn.PixelShuffle(down_factor)

    def forward(self, x):
        x = self.conv(x)
        x =
        x = self.relu(x)
        x = self.pixel_shuffle(x)
        return x

[docs]class ResNetDUC(nn.Module): # the size of image should be multiple of 8 def __init__(self, num_classes, pretrained=True, **kwargs): super(ResNetDUC, self).__init__() resnet = models.resnet152() if pretrained: resnet.load_state_dict(torch.load(res152_path)) self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) self.layer1 = resnet.layer1 self.layer2 = resnet.layer2 self.layer3 = resnet.layer3 self.layer4 = resnet.layer4 for n, m in self.layer3.named_modules(): if 'conv2' in n: m.dilation = (2, 2) m.padding = (2, 2) m.stride = (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) for n, m in self.layer4.named_modules(): if 'conv2' in n: m.dilation = (4, 4) m.padding = (4, 4) m.stride = (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) self.duc = _DenseUpsamplingConvModule(8, 2048, num_classes) def forward(self, x): x = self.layer0(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.duc(x) return x
[docs]class ResNetDUCHDC(nn.Module): # the size of image should be multiple of 8 def __init__(self, num_classes, pretrained=True, **kwargs): super(ResNetDUCHDC, self).__init__() resnet = models.resnet152() if pretrained: resnet.load_state_dict(torch.load(res152_path)) self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) self.layer1 = resnet.layer1 self.layer2 = resnet.layer2 self.layer3 = resnet.layer3 self.layer4 = resnet.layer4 for n, m in self.layer3.named_modules(): if 'conv2' in n or 'downsample.0' in n: m.stride = (1, 1) for n, m in self.layer4.named_modules(): if 'conv2' in n or 'downsample.0' in n: m.stride = (1, 1) layer3_group_config = [1, 2, 5, 9] for idx in range(len(self.layer3)): self.layer3[idx].conv2.dilation = (layer3_group_config[idx % 4], layer3_group_config[idx % 4]) self.layer3[idx].conv2.padding = (layer3_group_config[idx % 4], layer3_group_config[idx % 4]) layer4_group_config = [5, 9, 17] for idx in range(len(self.layer4)): self.layer4[idx].conv2.dilation = (layer4_group_config[idx], layer4_group_config[idx]) self.layer4[idx].conv2.padding = (layer4_group_config[idx], layer4_group_config[idx]) self.duc = _DenseUpsamplingConvModule(8, 2048, num_classes) def forward(self, x): x = self.layer0(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.duc(x) return x