mindcv.models.BiT_ResNet 源代码

"""
MindSpore implementation of `BiT_ResNet`.
Refer to Big Transfer (BiT): General Visual Representation Learning.
"""

from typing import Optional, Type, List, Union

import mindspore
from mindspore import nn, Tensor, ops



from .layers.pooling import GlobalAvgPooling
from .utils import load_pretrained
from .registry import register_model

__all__ = [
    'BiT_ResNet',
    'BiTresnet50',
]


def _cfg(url='', **kwargs):
    return {
        'url': url,
        'num_classes': 1000,
        'first_conv': 'conv1', 'classifier': 'classifier',
        **kwargs
    }


class StdConv2d(nn.Conv2d):
    r"""Conv2d with Weight Standardization

          Args:
              in_channels: The channel number of the input tensor of the Conv2d layer.
              out_channels: The channel number of the output tensor of the Conv2d layer.
              kernel_size: Specifies the height and width of the 2D convolution kernel.
              stride: The movement stride of the 2D convolution kernel. Default: 1.
              pad_mode: Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same".
              padding: The number of padding on the height and width directions of the input. Default: 0.
              group: Splits filter into groups. Default: 1.
          """
    def construct(self, x):
        w = self.weight
        mean_op = ops.ReduceMean(keep_dims=True)
        m = mean_op(w, [1, 2, 3])
        v = w.var((1, 2, 3), keepdims=True)
        w = (w - m) / mindspore.ops.sqrt(v + 1e-10)
        output = self.conv2d(x, w)
        return output

class Bottleneck(nn.Cell):
    """define the basic block of BiT"""
    expansion: int = 4

    def __init__(self,
                 in_channels: int,
                 channels: int,
                 stride: int = 1,
                 groups: int = 1,
                 base_width: int = 64,
                 norm: Optional[nn.Cell] = None,
                 down_sample: Optional[nn.Cell] = None
                 ) -> None:
        super().__init__()
        if norm is None:
            norm = nn.GroupNorm

        width = int(channels * (base_width / 64.0)) * groups
        self.gn1 = norm(32,in_channels)
        self.conv1 = StdConv2d(in_channels, width, kernel_size=1, stride=1)
        self.gn2 = norm(32,width)
        self.conv2 = StdConv2d(width, width, kernel_size=3, stride=stride,
                               padding=1, pad_mode='pad', group=groups)
        self.gn3 = norm(32,width)
        self.conv3 = StdConv2d(width, channels * self.expansion,
                               kernel_size=1, stride=1)

        self.relu = nn.ReLU()
        self.down_sample = down_sample

    def construct(self, x: Tensor) -> Tensor:
        identity = x
        out = self.gn1(x)
        out = self.relu(out)

        residual = out

        out = self.conv1(out)

        out = self.gn2(out)
        out = self.relu(out)
        out = self.conv2(out)

        out = self.gn3(out)
        out = self.relu(out)
        out = self.conv3(out)


        if self.down_sample is not None:
            identity = self.down_sample(residual)

        out += identity
        # out = self.relu(out)

        return out


[文档]class BiT_ResNet(nn.Cell): r"""BiT_ResNet model class, based on `"Big Transfer (BiT): General Visual Representation Learning" <https://arxiv.org/abs/1912.11370>`_ Args: block: block of BiT_ResNetv2. layers: number of layers of each stage. width: width of each layer. num_classes: number of classification classes. Default: 1000. in_channels: number the channels of the input. Default: 3. groups: number of groups for group conv in blocks. Default: 1. base_width: base width of pre group hidden channel in blocks. Default: 64. norm: normalization layer in blocks. Default: None. """ def __init__(self, block: Type[Union[Bottleneck]], layers: List[int], wf: int = 1, num_classes: int = 1000, in_channels: int = 3, groups: int = 1, base_width: int = 64, norm: Optional[nn.Cell] = None ) -> None: super().__init__() if norm is None: norm = nn.GroupNorm self.norm: nn.Cell = norm # add type hints to make pylint happy self.input_channels = 64*wf self.groups = groups self.base_with = base_width self.conv1 = StdConv2d(in_channels, self.input_channels, kernel_size=7, stride=2, pad_mode='pad', padding=3) self.pad = nn.ConstantPad2d(1,0) self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='valid') self.layer1 = self._make_layer(block, 64*wf, layers[0]) self.layer2 = self._make_layer(block, 128*wf, layers[1], stride=2) self.layer3 = self._make_layer(block, 256*wf, layers[2], stride=2) self.layer4 = self._make_layer(block, 512*wf, layers[3], stride=2) self.gn = norm(32,2048*wf) self.relu = nn.ReLU() self.pool = GlobalAvgPooling(keep_dims=True) self.classifier = nn.Conv2d(512 * block.expansion * wf,num_classes,kernel_size=1,has_bias=True) def _make_layer(self, block: Type[Union[Bottleneck]], channels: int, block_nums: int, stride: int = 1 ) -> nn.SequentialCell: """build model depending on cfgs""" down_sample = None if stride != 1 or self.input_channels != channels * block.expansion: down_sample = nn.SequentialCell([ StdConv2d(self.input_channels, channels * block.expansion, kernel_size=1, stride=stride), ]) layers = [] layers.append( block( self.input_channels, channels, stride=stride, down_sample=down_sample, groups=self.groups, base_width=self.base_with, norm=self.norm ) ) self.input_channels = channels * block.expansion for _ in range(1, block_nums): layers.append( block( self.input_channels, channels, groups=self.groups, base_width=self.base_with, norm=self.norm ) ) return nn.SequentialCell(layers)
[文档] def root(self, x: Tensor) -> Tensor: x = self.conv1(x) x = self.pad(x) x = self.max_pool(x) return x
[文档] def forward_features(self, x: Tensor) -> Tensor: """Network forward feature extraction.""" x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) return x
[文档] def forward_head(self, x: Tensor) -> Tensor: x = self.gn(x) x = self.relu(x) x = self.pool(x) x = self.classifier(x) return x
[文档] def construct(self, x: Tensor) -> Tensor: x = self.root(x) x = self.forward_features(x) x = self.forward_head(x) assert x.shape[-2:] == (1, 1) # We should have no spatial shape left. return x[...,0,0]
[文档]@register_model def BiTresnet50(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs): """Get 50 layers ResNet model. Refer to the base class `models.BiT_Resnet` for more details. """ model = BiT_ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_channels=in_channels, **kwargs) return model