mindcv.models.repmlp 源代码

"""
MindSpore implementation of `RepMLP`.
Refer to RepMLPNet: Hierarchical Vision MLP with Re-parameterized Locality.
"""

import numpy as np
from collections import OrderedDict

from mindspore import nn, ops, Tensor
import mindspore.common.initializer as init

from .registry import register_model
from .utils import load_pretrained

__all__ = [
    "RepMLPNet",
    "RepMLPNet_T224",
    "RepMLPNet_T256",
    "RepMLPNet_B224",
    "RepMLPNet_B256",
    "RepMLPNet_D256",
    "RepMLPNet_L256"
]


def _cfg(url='', **kwargs):
    return {
        'url': url,
        'num_classes': 1000,
        'first_conv': 'features.0', 'classifier': 'classifier',
        **kwargs
    }


default_cfgs = {
    'RepMLPNet_T224': _cfg(url='https://download.mindspore.cn/toolkits/mindcv/repmlp/RepMLPNet_T224-8dbedd00.ckpt'),
    'RepMLPNet_T256': _cfg(url='', input_size=(3, 256, 256)),
    'RepMLPNet_B224': _cfg(url=''),
    'RepMLPNet_B256': _cfg(url='', input_size=(3, 256, 256)),
    'RepMLPNet_D256': _cfg(url='', input_size=(3, 256, 256)),
    'RepMLPNet_L256': _cfg(url='', input_size=(3, 256, 256)),
}


def conv_bn(in_channels, out_channels, kernel_size, stride, padding, group=1, has_bias=False):
    d = OrderedDict()
    conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
                      pad_mode="pad", padding=padding, group=group, has_bias=has_bias)
    bn1 = nn.BatchNorm2d(num_features=out_channels)
    d['conv'] = conv1
    d['bn'] = bn1
    result = nn.SequentialCell(d)
    return result


def conv_bn_relu(in_channels, out_channels, kernel_size, stride, padding, group=1, has_bias=False):
    d = OrderedDict()
    conv2 = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
                    padding=padding, group=group, has_bias=False)
    relu = nn.ReLU()
    d['conv'] = conv2
    d['relu'] = relu
    result = nn.SequentialCell(d)
    return result


def fuse_bn(conv_or_fc, bn):
    std = (bn.running_var + bn.eps).sqrt()
    t = bn.weight / std
    t = t.reshape(-1, 1, 1, 1)

    if len(t) == conv_or_fc.weight.size(0):
        return conv_or_fc.weight * t, bn.bias - bn.running_mean * bn.weight / std
    else:
        repeat_times = conv_or_fc.weight.size(0) // len(t)
        repeated = t.repeat_interleave(repeat_times, 0)
        return conv_or_fc.weight * repeated, (bn.bias - bn.running_mean * bn.weight / std).repeat_interleave(
            repeat_times, 0)


class GlobalPerceptron(nn.Cell):
    """GlobalPerceptron Layers provides global information(One of the three components of RepMLPBlock)"""

    def __init__(self, input_channels, internal_neurons):
        super(GlobalPerceptron, self).__init__()
        self.fc1 = nn.Conv2d(in_channels=input_channels, out_channels=internal_neurons, kernel_size=(1, 1), stride=1,
                             has_bias=True)
        self.fc2 = nn.Conv2d(in_channels=internal_neurons, out_channels=input_channels, kernel_size=(1, 1), stride=1,
                             has_bias=True)

        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.input_channels = input_channels
        self.shape = ops.Shape()

    def construct(self, x):
        shape = self.shape(x)
        pool = nn.AvgPool2d(kernel_size=(shape[2], shape[3]), stride=1)
        x = pool(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.sigmoid(x)
        x = x.view(-1, self.input_channels, 1, 1)
        return x


class RepMLPBlock(nn.Cell):
    """Basic RepMLPBlock Layer(compose of Global Perceptron, Channel Perceptron and Local Perceptron)"""

    def __init__(self, in_channels, out_channels,
                 h, w,
                 reparam_conv_k=None,
                 globalperceptron_reduce=4,
                 num_sharesets=1,
                 deploy=False):
        super().__init__()

        self.C = in_channels
        self.O = out_channels
        self.S = num_sharesets

        self.h, self.w = h, w

        self.deploy = deploy
        self.transpose = ops.Transpose()
        self.shape = ops.Shape()
        self.reshape = ops.Reshape()

        assert in_channels == out_channels
        self.gp = GlobalPerceptron(input_channels=in_channels, internal_neurons=in_channels // globalperceptron_reduce)

        self.fc3 = nn.Conv2d(in_channels=self.h * self.w * num_sharesets, out_channels=self.h * self.w * num_sharesets,
                             kernel_size=(1, 1), stride=1, padding=0, has_bias=deploy, group=num_sharesets)
        if deploy:
            self.fc3_bn = ops.Identity()
        else:
            self.fc3_bn = nn.BatchNorm2d(num_sharesets).set_train()

        self.reparam_conv_k = reparam_conv_k
        self.conv_branch_k = []
        if not deploy and reparam_conv_k is not None:
            for k in reparam_conv_k:
                conv_branch = conv_bn(num_sharesets, num_sharesets, kernel_size=k, stride=1, padding=k // 2,
                                      group=num_sharesets, has_bias=False)
                self.__setattr__('repconv{}'.format(k), conv_branch)
                self.conv_branch_k.append(conv_branch)
                # print(conv_branch)

    def partition(self, x, h_parts, w_parts):
        x = x.reshape(-1, self.C, h_parts, self.h, w_parts, self.w)
        input_perm = (0, 2, 4, 1, 3, 5)
        x = self.transpose(x, input_perm)
        return x

    def partition_affine(self, x, h_parts, w_parts):
        fc_inputs = x.reshape(-1, self.S * self.h * self.w, 1, 1)
        out = self.fc3(fc_inputs)
        out = out.reshape(-1, self.S, self.h, self.w)
        out = self.fc3_bn(out)
        out = out.reshape(-1, h_parts, w_parts, self.S, self.h, self.w)
        return out

    def construct(self, inputs):
        # Global Perceptron
        global_vec = self.gp(inputs)

        origin_shape = self.shape(inputs)

        h_parts = origin_shape[2] // self.h
        w_parts = origin_shape[3] // self.w

        partitions = self.partition(inputs, h_parts, w_parts)

        #   Channel Perceptron
        fc3_out = self.partition_affine(partitions, h_parts, w_parts)

        #   Local Perceptron
        if self.reparam_conv_k is not None and not self.deploy:
            conv_inputs = self.reshape(partitions, (-1, self.S, self.h, self.w))
            conv_out = 0
            for k in self.conv_branch_k:
                conv_out += k(conv_inputs)
            conv_out = self.reshape(conv_out, (-1, h_parts, w_parts, self.S, self.h, self.w))
            fc3_out += conv_out

        input_perm = (0, 3, 1, 4, 2, 5)
        fc3_out = self.transpose(fc3_out, input_perm)  # N, O, h_parts, out_h, w_parts, out_w
        out = fc3_out.reshape(*origin_shape)
        out = out * global_vec
        return out

    def get_equivalent_fc3(self):
        fc_weight, fc_bias = fuse_bn(self.fc3, self.fc3_bn)
        if self.reparam_conv_k is not None:
            largest_k = max(self.reparam_conv_k)
            largest_branch = self.__getattr__('repconv{}'.format(largest_k))
            total_kernel, total_bias = fuse_bn(largest_branch.conv, largest_branch.bn)
            for k in self.reparam_conv_k:
                if k != largest_k:
                    k_branch = self.__getattr__('repconv{}'.format(k))
                    kernel, bias = fuse_bn(k_branch.conv, k_branch.bn)
                    total_kernel += nn.Pad(kernel, [(largest_k - k) // 2] * 4)
                    total_bias += bias
            rep_weight, rep_bias = self._convert_conv_to_fc(total_kernel, total_bias)
            final_fc3_weight = rep_weight.reshape_as(fc_weight) + fc_weight
            final_fc3_bias = rep_bias + fc_bias
        else:
            final_fc3_weight = fc_weight
            final_fc3_bias = fc_bias
        return final_fc3_weight, final_fc3_bias

    def local_inject(self):
        self.deploy = True
        #   Locality Injection
        fc3_weight, fc3_bias = self.get_equivalent_fc3()
        #   Remove Local Perceptron
        if self.reparam_conv_k is not None:
            for k in self.reparam_conv_k:
                self.__delattr__('repconv{}'.format(k))
        self.__delattr__('fc3')
        self.__delattr__('fc3_bn')
        self.fc3 = nn.Conv2d(self.S * self.h * self.w, self.S * self.h * self.w, 1, 1, 0, has_bias=True, group=self.S)
        self.fc3_bn = ops.Identity()
        self.fc3.weight.data = fc3_weight
        self.fc3.bias.data = fc3_bias

    def _convert_conv_to_fc(self, conv_kernel, conv_bias):
        I = ops.eye(self.h * self.w).repeat(1, self.S).reshape(self.h * self.w, self.S, self.h, self.w).to(
            conv_kernel.device)
        fc_k = ops.Conv2D(I, conv_kernel, pad=(conv_kernel.size(2) // 2, conv_kernel.size(3) // 2), group=self.S)
        fc_k = fc_k.reshape(self.h * self.w, self.S * self.h * self.w).t()
        fc_bias = conv_bias.repeat_interleave(self.h * self.w)
        return fc_k, fc_bias


class FFNBlock(nn.Cell):
    """Common FFN layer"""

    def __init__(self, in_channels, hidden_channels=None, out_channels=None, act_layer=nn.GELU):
        super().__init__()
        out_features = out_channels or in_channels
        hidden_features = hidden_channels or in_channels
        self.ffn_fc1 = conv_bn(in_channels, hidden_features, 1, 1, 0, has_bias=False)
        self.ffn_fc2 = conv_bn(hidden_features, out_features, 1, 1, 0, has_bias=False)
        self.act = act_layer()

    def construct(self, inputs):
        x = self.ffn_fc1(inputs)
        x = self.act(x)
        x = self.ffn_fc2(x)
        return x


class RepMLPNetUnit(nn.Cell):
    """Basic unit of RepMLPNet"""

    def __init__(self, channels, h, w, reparam_conv_k, globalperceptron_reduce, ffn_expand=4,
                 num_sharesets=1, deploy=False):
        super().__init__()
        self.repmlp_block = RepMLPBlock(in_channels=channels, out_channels=channels, h=h, w=w,
                                        reparam_conv_k=reparam_conv_k, globalperceptron_reduce=globalperceptron_reduce,
                                        num_sharesets=num_sharesets, deploy=deploy)
        self.ffn_block = FFNBlock(channels, channels * ffn_expand)
        self.prebn1 = nn.BatchNorm2d(channels).set_train()
        self.prebn2 = nn.BatchNorm2d(channels).set_train()

    def construct(self, x):
        y = x + self.repmlp_block(self.prebn1(x))
        # print(y)
        z = y + self.ffn_block(self.prebn2(y))
        return z


[文档]class RepMLPNet(nn.Cell): r"""RepMLPNet model class, based on `"RepMLPNet: Hierarchical Vision MLP with Re-parameterized Locality" <https://arxiv.org/pdf/2112.11081v2.pdf>`_ Args: in_channels: number of input channels. Default: 3. num_classes: number of classification classes. Default: 1000. patch_size: size of a single image patch. Default: (4, 4) num_blocks: number of blocks per stage. Default: (2,2,6,2) channels: number of in_channels(channels[stage_idx]) and out_channels(channels[stage_idx + 1]) per stage. Default: (192,384,768,1536) hs: height of picture per stage. Default: (64,32,16,8) ws: width of picture per stage. Default: (64,32,16,8) sharesets_nums: number of share sets per stage. Default: (4,8,16,32) reparam_conv_k: convolution kernel size in local Perceptron. Default: (3,) globalperceptron_reduce: Intermediate convolution output size(in_channal = inchannal, out_channel = in_channel/globalperceptron_reduce) in globalperceptron. Default: 4 use_checkpoint: whether to use checkpoint deploy: whether to use bias """ def __init__(self, in_channels=3, num_class=1000, patch_size=(4, 4), num_blocks=(2, 2, 6, 2), channels=(192, 384, 768, 1536), hs=(64, 32, 16, 8), ws=(64, 32, 16, 8), sharesets_nums=(4, 8, 16, 32), reparam_conv_k=(3,), globalperceptron_reduce=4, use_checkpoint=False, deploy=False): super().__init__() num_stages = len(num_blocks) assert num_stages == len(channels) assert num_stages == len(hs) assert num_stages == len(ws) assert num_stages == len(sharesets_nums) self.conv_embedding = conv_bn_relu(in_channels, channels[0], kernel_size=patch_size, stride=patch_size, padding=0, has_bias=False) self.conv2d = nn.Conv2d(in_channels, channels[0], kernel_size=patch_size, stride=patch_size, padding=0) stages = [] embeds = [] for stage_idx in range(num_stages): stage_blocks = [RepMLPNetUnit(channels=channels[stage_idx], h=hs[stage_idx], w=ws[stage_idx], reparam_conv_k=reparam_conv_k, globalperceptron_reduce=globalperceptron_reduce, ffn_expand=4, num_sharesets=sharesets_nums[stage_idx], deploy=deploy) for _ in range(num_blocks[stage_idx])] stages.append(nn.CellList(stage_blocks)) if stage_idx < num_stages - 1: embeds.append( conv_bn_relu(in_channels=channels[stage_idx], out_channels=channels[stage_idx + 1], kernel_size=2, stride=2, padding=0)) self.stages = nn.CellList(stages) self.embeds = nn.CellList(embeds) self.head_norm = nn.BatchNorm2d(channels[-1]).set_train() self.head = nn.Dense(channels[-1], num_class) self.use_checkpoint = use_checkpoint self.shape = ops.Shape() self.reshape = ops.Reshape() self._initialize_weights() def _initialize_weights(self): """Initialize weights for cells.""" for name, cell in self.cells_and_names(): if isinstance(cell, nn.Conv2d): k = cell.group / (cell.in_channels * cell.kernel_size[0] * cell.kernel_size[1]) k = k ** 0.5 cell.weight.set_data( init.initializer(init.Uniform(k), cell.weight.shape, cell.weight.dtype)) if cell.bias is not None: cell.bias.set_data( init.initializer(init.Uniform(k), cell.bias.shape, cell.bias.dtype)) elif isinstance(cell, nn.Dense): k = 1 / cell.in_channels k = k ** 0.5 cell.weight.set_data( init.initializer(init.Uniform(k), cell.weight.shape, cell.weight.dtype)) if cell.bias is not None: cell.bias.set_data( init.initializer(init.Uniform(k), cell.bias.shape, cell.bias.dtype)) def forward_features(self, x: Tensor) -> Tensor: x = self.conv_embedding(x) for i, stage in enumerate(self.stages): for block in stage: x = block(x) if i < len(self.stages) - 1: embed = self.embeds[i] x = embed(x) x = self.head_norm(x) shape = self.shape(x) pool = nn.AvgPool2d(kernel_size=(shape[2], shape[3])) x = pool(x) return x.view(shape[0], -1) def forward_head(self, x: Tensor) -> Tensor: return self.head(x) def construct(self, x: Tensor) -> Tensor: x = self.forward_features(x) return self.forward_head(x)
def locality_injection(self): for m in self.modules(): if hasattr(m, 'local_inject'): m.local_inject() @register_model def RepMLPNet_T224(pretrained: bool = False, image_size: int = 224, num_classes: int = 1000, in_channels=3, deploy=False, **kwargs): """Get RepMLPNet_T224 model. Refer to the base class `models.RepMLPNet` for more details.""" default_cfg = default_cfgs['RepMLPNet_T224'] model = RepMLPNet(in_channels=in_channels, num_class=num_classes, channels=(64, 128, 256, 512), hs=(56, 28, 14, 7), ws=(56, 28, 14, 7), num_blocks=(2, 2, 6, 2), reparam_conv_k=(1, 3), sharesets_nums=(1, 4, 16, 128), deploy=deploy) if pretrained: load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) model.image_size = image_size return model @register_model def RepMLPNet_T256(pretrained: bool = False, image_size: int = 256, num_classes: int = 1000, in_channels=3, deploy=False, **kwargs): """Get RepMLPNet_T256 model. Refer to the base class `models.RepMLPNet` for more details.""" default_cfg = default_cfgs['RepMLPNet_T256'] model = RepMLPNet(in_channels=in_channels, num_class=num_classes, channels=(64, 128, 256, 512), hs=(64, 32, 16, 8), ws=(64, 32, 16, 8), num_blocks=(2, 2, 6, 2), reparam_conv_k=(1, 3), sharesets_nums=(1, 4, 16, 128), deploy=deploy) if pretrained: load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) model.image_size = image_size return model @register_model def RepMLPNet_B224(pretrained: bool = False, image_size: int = 224, num_classes: int = 1000, in_channels=3, deploy=False, **kwargs): """Get RepMLPNet_B224 model. Refer to the base class `models.RepMLPNet` for more details.""" default_cfg = default_cfgs['RepMLPNet_B224'] model = RepMLPNet(in_channels=in_channels, num_class=num_classes, channels=(96, 192, 384, 768), hs=(56, 28, 14, 7), ws=(56, 28, 14, 7), num_blocks=(2, 2, 12, 2), reparam_conv_k=(1, 3), sharesets_nums=(1, 4, 32, 128), deploy=deploy) if pretrained: load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) model.image_size = image_size return model @register_model def RepMLPNet_B256(pretrained: bool = False, image_size: int = 256, num_classes: int = 1000, in_channels=3, deploy=False, **kwargs): """Get RepMLPNet_B256 model. Refer to the base class `models.RepMLPNet` for more details.""" default_cfg = default_cfgs['RepMLPNet_B256'] model = RepMLPNet(in_channels=in_channels, num_class=num_classes, channels=(96, 192, 384, 768), hs=(64, 32, 16, 8), ws=(64, 32, 16, 8), num_blocks=(2, 2, 12, 2), reparam_conv_k=(1, 3), sharesets_nums=(1, 4, 32, 128), deploy=deploy) if pretrained: load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) model.image_size = image_size return model @register_model def RepMLPNet_D256(pretrained: bool = False, image_size: int = 256, num_classes: int = 1000, in_channels=3, deploy=False, **kwargs): """Get RepMLPNet_D256 model. Refer to the base class `models.RepMLPNet` for more details.""" default_cfg = default_cfgs['RepMLPNet_D256'] model = RepMLPNet(in_channels=in_channels, num_class=num_classes, channels=(80, 160, 320, 640), hs=(64, 32, 16, 8), ws=(64, 32, 16, 8), num_blocks=(2, 2, 18, 2), reparam_conv_k=(1, 3), sharesets_nums=(1, 4, 16, 128), deploy=deploy) if pretrained: load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) model.image_size = image_size return model @register_model def RepMLPNet_L256(pretrained: bool = False, image_size: int = 256, num_classes: int = 1000, in_channels=3, deploy=False, **kwargs): """Get RepMLPNet_L256 model. Refer to the base class `models.RepMLPNet` for more details.""" default_cfg = default_cfgs['RepMLPNet_L256'] model = RepMLPNet(in_channels=in_channels, num_class=num_classes, channels=(96, 192, 384, 768), hs=(64, 32, 16, 8), ws=(64, 32, 16, 8), num_blocks=(2, 2, 18, 2), reparam_conv_k=(1, 3), sharesets_nums=(1, 4, 32, 256), deploy=deploy) if pretrained: load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) model.image_size = image_size return model # Verify the equivalency if __name__ == '__main__': # x = Tensor(np.ones([1, 3, 3, 3]).astype(np.float32)) dummy_input = Tensor(np.ones([1, 3, 256, 256]).astype(np.float32)) model = RepMLPNet_B256() # model = GlobalPerceptron(input_channels=96, internal_neurons=54) # model = RepMLPBlock(in_channels=96, out_channels=96, h=56, w=56, reparam_conv_k=(1,3), num_sharesets=4) # model = RepMLPNetUnit(channels=96, h=56, w=56, reparam_conv_k=(1, 3), globalperceptron_reduce=4, ffn_expand=4, num_sharesets=1, deploy=False) origin_y = model(dummy_input) # model.locality_injection() print(model) # new_y = model(x) # print((new_y - origin_y).abs().sum()) print(origin_y) # print(origin_y.shape)