Source code for pywick.models.segmentation.resnet_gcn

# Another implementation of GCN
# Source: https://github.com/saeedizadi/binseg_pytoch/tree/master/models (Apache-2.0)

"""
Implementation of `Large Kernel Matters <https://arxiv.org/pdf/1703.02719>`_ with Resnet backend.
"""

import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
import torchvision.models as models
import numpy as np

__all__ = ['ResnetGCN']

def initialize_weights(method='kaiming', *models):
    for model in models:
        for module in model.modules():

            if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
                if method == 'kaiming':
                    init.kaiming_normal_(module.weight.data, np.sqrt(2.0))
                elif method == 'xavier':
                    init.xavier_normal(module.weight.data, np.sqrt(2.0))
                elif method == 'orthogonal':
                    init.orthogonal(module.weight.data, np.sqrt(2.0))
                elif method == 'normal':
                    init.normal(module.weight.data,mean=0, std=0.02)
                if module.bias is not None:
                    init.constant(module.bias.data,0)


class GlobalConvolutionBlock(nn.Module):
    def __init__(self, in_channels, out_channels, k):

        super(GlobalConvolutionBlock, self).__init__()
        self.left = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=(k[0],1), padding=(k[0]//2,0)),
                                  nn.Conv2d(out_channels, out_channels, kernel_size=(1,k[1]), padding=(0,k[1]//2)))

        self.right = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=(1,k[1]), padding=(0,k[1]//2)),
                                   nn.Conv2d(out_channels, out_channels, kernel_size=(k[0],1), padding=(k[0]//2,0)))


    def forward(self,x):
        left = self.left(x)
        right = self.right(x)
        return left + right



class BoundaryRefine(nn.Module):
    def __init__(self, in_channels):
        super(BoundaryRefine, self).__init__()
        self.layer = nn.Sequential(nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1),
                                   nn.BatchNorm2d(in_channels),
                                   nn.ReLU(inplace=True),
                                   nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1),
                                   nn.BatchNorm2d(in_channels))

    def forward(self,x):
        convs = self.layer(x)
        return x.expand_as(convs)+convs

[docs]class ResnetGCN(nn.Module): def __init__(self, num_classes, pretrained=True, **kwargs): super(ResnetGCN, self).__init__() resent = models.resnet101(pretrained=pretrained) self.layer0 = nn.Sequential(resent.conv1, resent.bn1, resent.relu, resent.maxpool) self.layer1 = resent.layer1 self.layer2 = resent.layer2 self.layer3 = resent.layer3 self.layer4 = resent.layer4 #Assuming input of size 240x320 ks = 7 self.gcn256 = GlobalConvolutionBlock(256, num_classes, (59,79)) self.br256 = BoundaryRefine(num_classes) self.gcn512 = GlobalConvolutionBlock(512, num_classes, (29,39)) self.br512 = BoundaryRefine(num_classes) self.gcn1024 = GlobalConvolutionBlock(1024, num_classes, (13,19)) self.br1024 = BoundaryRefine(num_classes) self.gcn2048 = GlobalConvolutionBlock(2048, num_classes, (7,9)) self.br2048 = BoundaryRefine(num_classes) self.br1 = BoundaryRefine(num_classes) self.br2 = BoundaryRefine(num_classes) self.br3 = BoundaryRefine(num_classes) self.br4 = BoundaryRefine(num_classes) self.br5 = BoundaryRefine(num_classes) self.activation = nn.Sigmoid() self.deconv1 = nn.ConvTranspose2d(1,1,2,stride=2) self.deconv2 = nn.ConvTranspose2d(1, 1, 2, stride=2) initialize_weights(self.gcn256,self.gcn512,self.gcn1024, self.gcn2048, self.br5,self.br4,self.br3, self.br2, self.br1, self.br256, self.br512, self.br1024, self.br2048, self.deconv1, self.deconv2) def forward(self,x): # Assuming input of size 240x320 x = self.layer0(x) ## 120x160x64 layer1 = self.layer1(x) ## 60x80x256 layer2 = self.layer2(layer1) ## 30x40x512 layer3 = self.layer3(layer2) ## 15x 20x1024 layer4 = self.layer4(layer3) ## 7x10x2048 enc1 = self.br256(self.gcn256(layer1)) enc2 = self.br512(self.gcn512(layer2)) enc3 = self.br1024(self.gcn1024(layer3)) enc4 = self.br2048(self.gcn2048(layer4)) ## 8x10x1 dec1 = self.br1(F.interpolate(enc4, size=enc3.size()[2:], mode='bilinear')+ enc3) dec2 = self.br2(F.interpolate(dec1, enc2.size()[2:], mode='bilinear') + enc2) dec3 = self.br3(F.interpolate(dec2, enc1.size()[2:], mode='bilinear') + enc1) dec4 = self.br4(self.deconv1(dec3)) score_map = self.br5(self.deconv2(dec4)) return self.activation(score_map) @staticmethod def _do_upsample(num_classes=1, kernel_size=2, stride=2): return nn.ConvTranspose2d(num_classes, num_classes, kernel_size=kernel_size, stride=stride)