Source code for pywick.models.segmentation.gcnnets.gcn_nasnet

# Source: https://github.com/flixpar/VisDa/tree/master/models

"""
Implementation of `Large Kernel Matters <https://arxiv.org/pdf/1703.02719>`_ with NASNet backend
"""

from math import floor
import numpy as np

import torch
import torch.nn as nn

__all__ = ['GCN_NASNet']

################### GCN ######################

class _GlobalConvModule(nn.Module):
    def __init__(self, in_dim, out_dim, kernel_size):
        super(_GlobalConvModule, self).__init__()

        pad0 = floor((kernel_size[0] - 1) / 2)
        pad1 = floor((kernel_size[1] - 1) / 2)

        self.conv_l1 = nn.Conv2d(in_dim, out_dim, kernel_size=(kernel_size[0], 1), padding=(pad0, 0))
        self.conv_l2 = nn.Conv2d(out_dim, out_dim, kernel_size=(1, kernel_size[1]), padding=(0, pad1))
        self.conv_r1 = nn.Conv2d(in_dim, out_dim, kernel_size=(1, kernel_size[1]), padding=(0, pad1))
        self.conv_r2 = nn.Conv2d(out_dim, out_dim, kernel_size=(kernel_size[0], 1), padding=(pad0, 0))

    def forward(self, x):
        x_l = self.conv_l1(x)
        x_l = self.conv_l2(x_l)
        x_r = self.conv_r1(x)
        x_r = self.conv_r2(x_r)
        x = x_l + x_r
        return x


class _BoundaryRefineModule(nn.Module):
    def __init__(self, dim):
        super(_BoundaryRefineModule, self).__init__()
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, padding=1)

    def forward(self, x):
        residual = self.conv1(x)
        residual = self.relu(residual)
        residual = self.conv2(residual)
        out = x + residual
        return out


class _LearnedBilinearDeconvModule(nn.Module):
    def __init__(self, channels):
        super(_LearnedBilinearDeconvModule, self).__init__()
        self.deconv = nn.ConvTranspose2d(channels, channels, kernel_size=4, stride=2, padding=1)
        self.deconv.weight.data = self.make_bilinear_weights(4, channels)
        self.deconv.bias.data.zero_()

    def forward(self, x):
        out = self.deconv(x)
        return out

    @staticmethod
    def make_bilinear_weights(size, num_channels):
        factor = (size + 1) // 2
        if size % 2 == 1:
            center = factor - 1
        else:
            center = factor - 0.5
        og = np.ogrid[:size, :size]
        filt = (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor)
        filt = torch.from_numpy(filt)
        w = torch.zeros(num_channels, num_channels, size, size)
        for i in range(num_channels):
            w[i, i] = filt
        return w


[docs]class GCN_NASNet(nn.Module): def __init__(self, num_classes, pretrained=True, k=7): super(GCN_NASNet, self).__init__() self.K = k model = NASNetALarge(num_classes=1001) if pretrained: model.load_state_dict(torch.utils.model_zoo.load_url('http://data.lip6.fr/cadene/pretrainedmodels/nasnetalarge-a1897284.pth')) self.nasnet = model self.gcm1 = _GlobalConvModule(4032, num_classes, (self.K, self.K)) self.gcm2 = _GlobalConvModule(2016, num_classes, (self.K, self.K)) self.gcm3 = _GlobalConvModule(1008, num_classes, (self.K, self.K)) self.gcm4 = _GlobalConvModule(num_classes, num_classes, (self.K, self.K)) self.brm1 = _BoundaryRefineModule(num_classes) self.brm2 = _BoundaryRefineModule(num_classes) self.brm3 = _BoundaryRefineModule(num_classes) self.brm4 = _BoundaryRefineModule(num_classes) self.brm5 = _BoundaryRefineModule(num_classes) self.brm6 = _BoundaryRefineModule(num_classes) self.brm7 = _BoundaryRefineModule(num_classes) self.brm8 = _BoundaryRefineModule(num_classes) self.deconv = _LearnedBilinearDeconvModule(num_classes) initialize_weights(self.gcm1, self.gcm2, self.gcm3, self.gcm4, self.brm1, self.brm2, self.brm3, self.brm4, self.brm5, self.brm6, self.brm7, self.brm8) def forward(self, x): x_conv0 = self.nasnet.conv0(x) x_stem_0 = self.nasnet.cell_stem_0(x_conv0) x_stem_1 = self.nasnet.cell_stem_1(x_conv0, x_stem_0) x_cell_0 = self.nasnet.cell_0(x_stem_1, x_stem_0) x_cell_1 = self.nasnet.cell_1(x_cell_0, x_stem_1) x_cell_2 = self.nasnet.cell_2(x_cell_1, x_cell_0) x_cell_3 = self.nasnet.cell_3(x_cell_2, x_cell_1) x_cell_4 = self.nasnet.cell_4(x_cell_3, x_cell_2) x_cell_5 = self.nasnet.cell_5(x_cell_4, x_cell_3) x_reduction_cell_0 = self.nasnet.reduction_cell_0(x_cell_5, x_cell_4) x_cell_6 = self.nasnet.cell_6(x_reduction_cell_0, x_cell_4) x_cell_7 = self.nasnet.cell_7(x_cell_6, x_reduction_cell_0) x_cell_8 = self.nasnet.cell_8(x_cell_7, x_cell_6) x_cell_9 = self.nasnet.cell_9(x_cell_8, x_cell_7) x_cell_10 = self.nasnet.cell_10(x_cell_9, x_cell_8) x_cell_11 = self.nasnet.cell_11(x_cell_10, x_cell_9) x_reduction_cell_1 = self.nasnet.reduction_cell_1(x_cell_11, x_cell_10) x_cell_12 = self.nasnet.cell_12(x_reduction_cell_1, x_cell_10) x_cell_13 = self.nasnet.cell_13(x_cell_12, x_reduction_cell_1) x_cell_14 = self.nasnet.cell_14(x_cell_13, x_cell_12) x_cell_15 = self.nasnet.cell_15(x_cell_14, x_cell_13) x_cell_16 = self.nasnet.cell_16(x_cell_15, x_cell_14) x_cell_17 = self.nasnet.cell_17(x_cell_16, x_cell_15) gcfm1 = self.brm1(self.gcm1(x_cell_17)) gcfm2 = self.brm2(self.gcm2(x_cell_11)) gcfm3 = self.brm3(self.gcm3(x_cell_5)) fs1 = self.brm4(self.deconv(gcfm1) + gcfm2) fs2 = self.brm5(self.deconv(fs1) + gcfm3) fs3 = self.brm6(self.deconv(fs2)) fs4 = self.brm7(self.deconv(fs3)) out = self.brm8(self.deconv(self.gcm4(fs4))) return out
def initialize_weights(*models): for model in models: for module in model.modules(): if isinstance(module, (nn.Conv2d, nn.Linear)): nn.init.kaiming_normal_(module.weight) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.BatchNorm2d): module.weight.data.fill_(1) module.bias.data.zero_() ################## NASNet ############################ class MaxPoolPad(nn.Module): def __init__(self): super(MaxPoolPad, self).__init__() self.pad = nn.ZeroPad2d((1, 0, 1, 0)) self.pool = nn.MaxPool2d(3, stride=2, padding=1) def forward(self, x): x = self.pad(x) x = self.pool(x) x = x[:, :, 1:, 1:] return x class AvgPoolPad(nn.Module): def __init__(self, stride=2, padding=1): super(AvgPoolPad, self).__init__() self.pad = nn.ZeroPad2d((1, 0, 1, 0)) self.pool = nn.AvgPool2d(3, stride=stride, padding=padding, count_include_pad=False) def forward(self, x): x = self.pad(x) x = self.pool(x) x = x[:, :, 1:, 1:] return x class SeparableConv2d(nn.Module): def __init__(self, in_channels, out_channels, dw_kernel, dw_stride, dw_padding, bias=False): super(SeparableConv2d, self).__init__() self.depthwise_conv2d = nn.Conv2d(in_channels, in_channels, dw_kernel, stride=dw_stride, padding=dw_padding, bias=bias, groups=in_channels) self.pointwise_conv2d = nn.Conv2d(in_channels, out_channels, 1, stride=1, bias=bias) def forward(self, x): x = self.depthwise_conv2d(x) x = self.pointwise_conv2d(x) return x class BranchSeparables(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=False): super(BranchSeparables, self).__init__() self.relu = nn.ReLU() self.separable_1 = SeparableConv2d(in_channels, in_channels, kernel_size, stride, padding, bias=bias) self.bn_sep_1 = nn.BatchNorm2d(in_channels, eps=0.001, momentum=0.1, affine=True) self.relu1 = nn.ReLU() self.separable_2 = SeparableConv2d(in_channels, out_channels, kernel_size, 1, padding, bias=bias) self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1, affine=True) def forward(self, x): x = self.relu(x) x = self.separable_1(x) x = self.bn_sep_1(x) x = self.relu1(x) x = self.separable_2(x) x = self.bn_sep_2(x) return x class BranchSeparablesStem(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=False): super(BranchSeparablesStem, self).__init__() self.relu = nn.ReLU() self.separable_1 = SeparableConv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias) self.bn_sep_1 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1, affine=True) self.relu1 = nn.ReLU() self.separable_2 = SeparableConv2d(out_channels, out_channels, kernel_size, 1, padding, bias=bias) self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1, affine=True) def forward(self, x): x = self.relu(x) x = self.separable_1(x) x = self.bn_sep_1(x) x = self.relu1(x) x = self.separable_2(x) x = self.bn_sep_2(x) return x class BranchSeparablesReduction(BranchSeparables): def __init__(self, in_channels, out_channels, kernel_size, stride, padding, z_padding=1, bias=False): BranchSeparables.__init__(self, in_channels, out_channels, kernel_size, stride, padding, bias) self.padding = nn.ZeroPad2d((z_padding, 0, z_padding, 0)) def forward(self, x): x = self.relu(x) x = self.padding(x) x = self.separable_1(x) x = x[:, :, 1:, 1:].contiguous() x = self.bn_sep_1(x) x = self.relu1(x) x = self.separable_2(x) x = self.bn_sep_2(x) return x class CellStem0(nn.Module): def __init__(self, stem_filters, num_filters=42): super(CellStem0, self).__init__() self.num_filters = num_filters self.stem_filters = stem_filters self.conv_1x1 = nn.Sequential() self.conv_1x1.add_module('relu', nn.ReLU()) self.conv_1x1.add_module('conv', nn.Conv2d(self.stem_filters, self.num_filters, 1, stride=1, bias=False)) self.conv_1x1.add_module('bn', nn.BatchNorm2d(self.num_filters, eps=0.001, momentum=0.1, affine=True)) self.comb_iter_0_left = BranchSeparables(self.num_filters, self.num_filters, 5, 2, 2) self.comb_iter_0_right = BranchSeparablesStem(self.stem_filters, self.num_filters, 7, 2, 3, bias=False) self.comb_iter_1_left = nn.MaxPool2d(3, stride=2, padding=1) self.comb_iter_1_right = BranchSeparablesStem(self.stem_filters, self.num_filters, 7, 2, 3, bias=False) self.comb_iter_2_left = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False) self.comb_iter_2_right = BranchSeparablesStem(self.stem_filters, self.num_filters, 5, 2, 2, bias=False) self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False) self.comb_iter_4_left = BranchSeparables(self.num_filters, self.num_filters, 3, 1, 1, bias=False) self.comb_iter_4_right = nn.MaxPool2d(3, stride=2, padding=1) def forward(self, x): x1 = self.conv_1x1(x) x_comb_iter_0_left = self.comb_iter_0_left(x1) x_comb_iter_0_right = self.comb_iter_0_right(x) x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right x_comb_iter_1_left = self.comb_iter_1_left(x1) x_comb_iter_1_right = self.comb_iter_1_right(x) x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right x_comb_iter_2_left = self.comb_iter_2_left(x1) x_comb_iter_2_right = self.comb_iter_2_right(x) x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0) x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1 x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0) x_comb_iter_4_right = self.comb_iter_4_right(x1) x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) return x_out class CellStem1(nn.Module): def __init__(self, stem_filters, num_filters): super(CellStem1, self).__init__() self.num_filters = num_filters self.stem_filters = stem_filters self.conv_1x1 = nn.Sequential() self.conv_1x1.add_module('relu', nn.ReLU()) self.conv_1x1.add_module('conv', nn.Conv2d(2*self.num_filters, self.num_filters, 1, stride=1, bias=False)) self.conv_1x1.add_module('bn', nn.BatchNorm2d(self.num_filters, eps=0.001, momentum=0.1, affine=True)) self.relu = nn.ReLU() self.path_1 = nn.Sequential() self.path_1.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)) self.path_1.add_module('conv', nn.Conv2d(self.stem_filters, self.num_filters//2, 1, stride=1, bias=False)) self.path_2 = nn.ModuleList() self.path_2.add_module('pad', nn.ZeroPad2d((0, 1, 0, 1))) self.path_2.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)) self.path_2.add_module('conv', nn.Conv2d(self.stem_filters, self.num_filters//2, 1, stride=1, bias=False)) self.final_path_bn = nn.BatchNorm2d(self.num_filters, eps=0.001, momentum=0.1, affine=True) self.comb_iter_0_left = BranchSeparables(self.num_filters, self.num_filters, 5, 2, 2, bias=False) self.comb_iter_0_right = BranchSeparables(self.num_filters, self.num_filters, 7, 2, 3, bias=False) self.comb_iter_1_left = nn.MaxPool2d(3, stride=2, padding=1) self.comb_iter_1_right = BranchSeparables(self.num_filters, self.num_filters, 7, 2, 3, bias=False) self.comb_iter_2_left = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False) self.comb_iter_2_right = BranchSeparables(self.num_filters, self.num_filters, 5, 2, 2, bias=False) self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False) self.comb_iter_4_left = BranchSeparables(self.num_filters, self.num_filters, 3, 1, 1, bias=False) self.comb_iter_4_right = nn.MaxPool2d(3, stride=2, padding=1) def forward(self, x_conv0, x_stem_0): x_left = self.conv_1x1(x_stem_0) x_relu = self.relu(x_conv0) # path 1 x_path1 = self.path_1(x_relu) # path 2 x_path2 = self.path_2.pad(x_relu) x_path2 = x_path2[:, :, 1:, 1:] x_path2 = self.path_2.avgpool(x_path2) x_path2 = self.path_2.conv(x_path2) # final path x_right = self.final_path_bn(torch.cat([x_path1, x_path2], 1)) x_comb_iter_0_left = self.comb_iter_0_left(x_left) x_comb_iter_0_right = self.comb_iter_0_right(x_right) x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right x_comb_iter_1_left = self.comb_iter_1_left(x_left) x_comb_iter_1_right = self.comb_iter_1_right(x_right) x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right x_comb_iter_2_left = self.comb_iter_2_left(x_left) x_comb_iter_2_right = self.comb_iter_2_right(x_right) x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0) x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1 x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0) x_comb_iter_4_right = self.comb_iter_4_right(x_left) x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) return x_out class FirstCell(nn.Module): def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right): super(FirstCell, self).__init__() self.conv_1x1 = nn.Sequential() self.conv_1x1.add_module('relu', nn.ReLU()) self.conv_1x1.add_module('conv', nn.Conv2d(in_channels_right, out_channels_right, 1, stride=1, bias=False)) self.conv_1x1.add_module('bn', nn.BatchNorm2d(out_channels_right, eps=0.001, momentum=0.1, affine=True)) self.relu = nn.ReLU() self.path_1 = nn.Sequential() self.path_1.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)) self.path_1.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False)) self.path_2 = nn.ModuleList() self.path_2.add_module('pad', nn.ZeroPad2d((0, 1, 0, 1))) self.path_2.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)) self.path_2.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False)) self.final_path_bn = nn.BatchNorm2d(out_channels_left * 2, eps=0.001, momentum=0.1, affine=True) self.comb_iter_0_left = BranchSeparables(out_channels_right, out_channels_right, 5, 1, 2, bias=False) self.comb_iter_0_right = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False) self.comb_iter_1_left = BranchSeparables(out_channels_right, out_channels_right, 5, 1, 2, bias=False) self.comb_iter_1_right = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False) self.comb_iter_2_left = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False) self.comb_iter_3_left = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False) self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False) self.comb_iter_4_left = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False) def forward(self, x, x_prev): x_relu = self.relu(x_prev) # path 1 x_path1 = self.path_1(x_relu) # path 2 x_path2 = self.path_2.pad(x_relu) x_path2 = x_path2[:, :, 1:, 1:] x_path2 = self.path_2.avgpool(x_path2) x_path2 = self.path_2.conv(x_path2) # final path x_left = self.final_path_bn(torch.cat([x_path1, x_path2], 1)) x_right = self.conv_1x1(x) x_comb_iter_0_left = self.comb_iter_0_left(x_right) x_comb_iter_0_right = self.comb_iter_0_right(x_left) x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right x_comb_iter_1_left = self.comb_iter_1_left(x_left) x_comb_iter_1_right = self.comb_iter_1_right(x_left) x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right x_comb_iter_2_left = self.comb_iter_2_left(x_right) x_comb_iter_2 = x_comb_iter_2_left + x_left x_comb_iter_3_left = self.comb_iter_3_left(x_left) x_comb_iter_3_right = self.comb_iter_3_right(x_left) x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right x_comb_iter_4_left = self.comb_iter_4_left(x_right) x_comb_iter_4 = x_comb_iter_4_left + x_right x_out = torch.cat([x_left, x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) return x_out class NormalCell(nn.Module): def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right): super(NormalCell, self).__init__() self.conv_prev_1x1 = nn.Sequential() self.conv_prev_1x1.add_module('relu', nn.ReLU()) self.conv_prev_1x1.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False)) self.conv_prev_1x1.add_module('bn', nn.BatchNorm2d(out_channels_left, eps=0.001, momentum=0.1, affine=True)) self.conv_1x1 = nn.Sequential() self.conv_1x1.add_module('relu', nn.ReLU()) self.conv_1x1.add_module('conv', nn.Conv2d(in_channels_right, out_channels_right, 1, stride=1, bias=False)) self.conv_1x1.add_module('bn', nn.BatchNorm2d(out_channels_right, eps=0.001, momentum=0.1, affine=True)) self.comb_iter_0_left = BranchSeparables(out_channels_right, out_channels_right, 5, 1, 2, bias=False) self.comb_iter_0_right = BranchSeparables(out_channels_left, out_channels_left, 3, 1, 1, bias=False) self.comb_iter_1_left = BranchSeparables(out_channels_left, out_channels_left, 5, 1, 2, bias=False) self.comb_iter_1_right = BranchSeparables(out_channels_left, out_channels_left, 3, 1, 1, bias=False) self.comb_iter_2_left = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False) self.comb_iter_3_left = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False) self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False) self.comb_iter_4_left = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False) def forward(self, x, x_prev): x_left = self.conv_prev_1x1(x_prev) x_right = self.conv_1x1(x) x_comb_iter_0_left = self.comb_iter_0_left(x_right) x_comb_iter_0_right = self.comb_iter_0_right(x_left) x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right x_comb_iter_1_left = self.comb_iter_1_left(x_left) x_comb_iter_1_right = self.comb_iter_1_right(x_left) x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right x_comb_iter_2_left = self.comb_iter_2_left(x_right) x_comb_iter_2 = x_comb_iter_2_left + x_left x_comb_iter_3_left = self.comb_iter_3_left(x_left) x_comb_iter_3_right = self.comb_iter_3_right(x_left) x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right x_comb_iter_4_left = self.comb_iter_4_left(x_right) x_comb_iter_4 = x_comb_iter_4_left + x_right x_out = torch.cat([x_left, x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) return x_out class ReductionCell0(nn.Module): def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right): super(ReductionCell0, self).__init__() self.conv_prev_1x1 = nn.Sequential() self.conv_prev_1x1.add_module('relu', nn.ReLU()) self.conv_prev_1x1.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False)) self.conv_prev_1x1.add_module('bn', nn.BatchNorm2d(out_channels_left, eps=0.001, momentum=0.1, affine=True)) self.conv_1x1 = nn.Sequential() self.conv_1x1.add_module('relu', nn.ReLU()) self.conv_1x1.add_module('conv', nn.Conv2d(in_channels_right, out_channels_right, 1, stride=1, bias=False)) self.conv_1x1.add_module('bn', nn.BatchNorm2d(out_channels_right, eps=0.001, momentum=0.1, affine=True)) self.comb_iter_0_left = BranchSeparablesReduction(out_channels_right, out_channels_right, 5, 2, 2, bias=False) self.comb_iter_0_right = BranchSeparablesReduction(out_channels_right, out_channels_right, 7, 2, 3, bias=False) self.comb_iter_1_left = MaxPoolPad() self.comb_iter_1_right = BranchSeparablesReduction(out_channels_right, out_channels_right, 7, 2, 3, bias=False) self.comb_iter_2_left = AvgPoolPad() self.comb_iter_2_right = BranchSeparablesReduction(out_channels_right, out_channels_right, 5, 2, 2, bias=False) self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False) self.comb_iter_4_left = BranchSeparablesReduction(out_channels_right, out_channels_right, 3, 1, 1, bias=False) self.comb_iter_4_right = MaxPoolPad() def forward(self, x, x_prev): x_left = self.conv_prev_1x1(x_prev) x_right = self.conv_1x1(x) x_comb_iter_0_left = self.comb_iter_0_left(x_right) x_comb_iter_0_right = self.comb_iter_0_right(x_left) x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right x_comb_iter_1_left = self.comb_iter_1_left(x_right) x_comb_iter_1_right = self.comb_iter_1_right(x_left) x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right x_comb_iter_2_left = self.comb_iter_2_left(x_right) x_comb_iter_2_right = self.comb_iter_2_right(x_left) x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0) x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1 x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0) x_comb_iter_4_right = self.comb_iter_4_right(x_right) x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) return x_out class ReductionCell1(nn.Module): def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right): super(ReductionCell1, self).__init__() self.conv_prev_1x1 = nn.Sequential() self.conv_prev_1x1.add_module('relu', nn.ReLU()) self.conv_prev_1x1.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False)) self.conv_prev_1x1.add_module('bn', nn.BatchNorm2d(out_channels_left, eps=0.001, momentum=0.1, affine=True)) self.conv_1x1 = nn.Sequential() self.conv_1x1.add_module('relu', nn.ReLU()) self.conv_1x1.add_module('conv', nn.Conv2d(in_channels_right, out_channels_right, 1, stride=1, bias=False)) self.conv_1x1.add_module('bn', nn.BatchNorm2d(out_channels_right, eps=0.001, momentum=0.1, affine=True)) self.comb_iter_0_left = BranchSeparables(out_channels_right, out_channels_right, 5, 2, 2, bias=False) self.comb_iter_0_right = BranchSeparables(out_channels_right, out_channels_right, 7, 2, 3, bias=False) self.comb_iter_1_left = nn.MaxPool2d(3, stride=2, padding=1) self.comb_iter_1_right = BranchSeparables(out_channels_right, out_channels_right, 7, 2, 3, bias=False) self.comb_iter_2_left = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False) self.comb_iter_2_right = BranchSeparables(out_channels_right, out_channels_right, 5, 2, 2, bias=False) self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False) self.comb_iter_4_left = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False) self.comb_iter_4_right = nn.MaxPool2d(3, stride=2, padding=1) def forward(self, x, x_prev): x_left = self.conv_prev_1x1(x_prev) x_right = self.conv_1x1(x) x_comb_iter_0_left = self.comb_iter_0_left(x_right) x_comb_iter_0_right = self.comb_iter_0_right(x_left) x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right x_comb_iter_1_left = self.comb_iter_1_left(x_right) x_comb_iter_1_right = self.comb_iter_1_right(x_left) x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right x_comb_iter_2_left = self.comb_iter_2_left(x_right) x_comb_iter_2_right = self.comb_iter_2_right(x_left) x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0) x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1 x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0) x_comb_iter_4_right = self.comb_iter_4_right(x_right) x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) return x_out class NASNetALarge(nn.Module): """NASNetALarge (6 @ 4032) """ def __init__(self, num_classes=1001, stem_filters=96, penultimate_filters=4032, filters_multiplier=2): super(NASNetALarge, self).__init__() self.num_classes = num_classes self.stem_filters = stem_filters self.penultimate_filters = penultimate_filters self.filters_multiplier = filters_multiplier filters = self.penultimate_filters // 24 # 24 is default value for the architecture self.conv0 = nn.Sequential() self.conv0.add_module('conv', nn.Conv2d(in_channels=3, out_channels=self.stem_filters, kernel_size=3, padding=0, stride=2, bias=False)) self.conv0.add_module('bn', nn.BatchNorm2d(self.stem_filters, eps=0.001, momentum=0.1, affine=True)) self.cell_stem_0 = CellStem0(self.stem_filters, num_filters=filters // (filters_multiplier ** 2)) self.cell_stem_1 = CellStem1(self.stem_filters, num_filters=filters // filters_multiplier) self.cell_0 = FirstCell(in_channels_left=filters, out_channels_left=filters//2, in_channels_right=2*filters, out_channels_right=filters) self.cell_1 = NormalCell(in_channels_left=2*filters, out_channels_left=filters, in_channels_right=6*filters, out_channels_right=filters) self.cell_2 = NormalCell(in_channels_left=6*filters, out_channels_left=filters, in_channels_right=6*filters, out_channels_right=filters) self.cell_3 = NormalCell(in_channels_left=6*filters, out_channels_left=filters, in_channels_right=6*filters, out_channels_right=filters) self.cell_4 = NormalCell(in_channels_left=6*filters, out_channels_left=filters, in_channels_right=6*filters, out_channels_right=filters) self.cell_5 = NormalCell(in_channels_left=6*filters, out_channels_left=filters, in_channels_right=6*filters, out_channels_right=filters) self.reduction_cell_0 = ReductionCell0(in_channels_left=6*filters, out_channels_left=2*filters, in_channels_right=6*filters, out_channels_right=2*filters) self.cell_6 = FirstCell(in_channels_left=6*filters, out_channels_left=filters, in_channels_right=8*filters, out_channels_right=2*filters) self.cell_7 = NormalCell(in_channels_left=8*filters, out_channels_left=2*filters, in_channels_right=12*filters, out_channels_right=2*filters) self.cell_8 = NormalCell(in_channels_left=12*filters, out_channels_left=2*filters, in_channels_right=12*filters, out_channels_right=2*filters) self.cell_9 = NormalCell(in_channels_left=12*filters, out_channels_left=2*filters, in_channels_right=12*filters, out_channels_right=2*filters) self.cell_10 = NormalCell(in_channels_left=12*filters, out_channels_left=2*filters, in_channels_right=12*filters, out_channels_right=2*filters) self.cell_11 = NormalCell(in_channels_left=12*filters, out_channels_left=2*filters, in_channels_right=12*filters, out_channels_right=2*filters) self.reduction_cell_1 = ReductionCell1(in_channels_left=12*filters, out_channels_left=4*filters, in_channels_right=12*filters, out_channels_right=4*filters) self.cell_12 = FirstCell(in_channels_left=12*filters, out_channels_left=2*filters, in_channels_right=16*filters, out_channels_right=4*filters) self.cell_13 = NormalCell(in_channels_left=16*filters, out_channels_left=4*filters, in_channels_right=24*filters, out_channels_right=4*filters) self.cell_14 = NormalCell(in_channels_left=24*filters, out_channels_left=4*filters, in_channels_right=24*filters, out_channels_right=4*filters) self.cell_15 = NormalCell(in_channels_left=24*filters, out_channels_left=4*filters, in_channels_right=24*filters, out_channels_right=4*filters) self.cell_16 = NormalCell(in_channels_left=24*filters, out_channels_left=4*filters, in_channels_right=24*filters, out_channels_right=4*filters) self.cell_17 = NormalCell(in_channels_left=24*filters, out_channels_left=4*filters, in_channels_right=24*filters, out_channels_right=4*filters) self.relu = nn.ReLU() self.avg_pool = nn.AvgPool2d(11, stride=1, padding=0) self.dropout = nn.Dropout() self.last_linear = nn.Linear(24*filters, self.num_classes) def features(self, input_): x_conv0 = self.conv0(input_) x_stem_0 = self.cell_stem_0(x_conv0) x_stem_1 = self.cell_stem_1(x_conv0, x_stem_0) x_cell_0 = self.cell_0(x_stem_1, x_stem_0) x_cell_1 = self.cell_1(x_cell_0, x_stem_1) x_cell_2 = self.cell_2(x_cell_1, x_cell_0) x_cell_3 = self.cell_3(x_cell_2, x_cell_1) x_cell_4 = self.cell_4(x_cell_3, x_cell_2) x_cell_5 = self.cell_5(x_cell_4, x_cell_3) x_reduction_cell_0 = self.reduction_cell_0(x_cell_5, x_cell_4) x_cell_6 = self.cell_6(x_reduction_cell_0, x_cell_4) x_cell_7 = self.cell_7(x_cell_6, x_reduction_cell_0) x_cell_8 = self.cell_8(x_cell_7, x_cell_6) x_cell_9 = self.cell_9(x_cell_8, x_cell_7) x_cell_10 = self.cell_10(x_cell_9, x_cell_8) x_cell_11 = self.cell_11(x_cell_10, x_cell_9) x_reduction_cell_1 = self.reduction_cell_1(x_cell_11, x_cell_10) x_cell_12 = self.cell_12(x_reduction_cell_1, x_cell_10) x_cell_13 = self.cell_13(x_cell_12, x_reduction_cell_1) x_cell_14 = self.cell_14(x_cell_13, x_cell_12) x_cell_15 = self.cell_15(x_cell_14, x_cell_13) x_cell_16 = self.cell_16(x_cell_15, x_cell_14) x_cell_17 = self.cell_17(x_cell_16, x_cell_15) return x_cell_17 def logits(self, features): x = self.relu(features) x = self.avg_pool(x) x = x.view(x.size(0), -1) x = self.dropout(x) x = self.last_linear(x) return x def forward(self, input_): x = self.features(input_) x = self.logits(x) return x