mindcv.models.bit 源代码

"""
MindSpore implementation of `BiT`.
Refer to Big Transfer (BiT): General Visual Representation Learning.
"""

from typing import Optional, Type, List, Union

import mindspore
from mindspore import nn, Tensor, ops

from .layers.pooling import GlobalAvgPooling
from .utils import load_pretrained
from .registry import register_model

__all__ = [
    'BiT',
    'BiTresnet50',
]


def _cfg(url='', **kwargs):
    return {
        'url': url,
        'num_classes': 1000,
        'first_conv': 'conv1', 'classifier': 'classifier',
        **kwargs
    }


default_cfgs = {
    'BiTresnet50': _cfg(url='https://download.mindspore.cn/toolkits/mindcv/bit/BiTresnet50.ckpt')
}


class StdConv2d(nn.Conv2d):
    r"""Conv2d with Weight Standardization

          Args:
              in_channels(int): The channel number of the input tensor of the Conv2d layer.
              out_channels(int): The channel number of the output tensor of the Conv2d layer.
              kernel_size(int): Specifies the height and width of the 2D convolution kernel.
              stride(int): The movement stride of the 2D convolution kernel. Default: 1.
              pad_mode(str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same".
              padding(int): The number of padding on the height and width directions of the input. Default: 0.
              group(int): Splits filter into groups. Default: 1.
          """

    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 pad_mode='same',
                 padding=0,
                 group=1
                 ) -> None:
        super(StdConv2d, self).__init__(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            pad_mode,
            padding,
            group)
        self.mean_op = ops.ReduceMean(keep_dims=True)

    def construct(self, x):
        w = self.weight
        m = self.mean_op(w, [1, 2, 3])
        v = w.var((1, 2, 3), keepdims=True)
        w = (w - m) / mindspore.ops.sqrt(v + 1e-10)
        output = self.conv2d(x, w)
        return output


class Bottleneck(nn.Cell):
    """define the basic block of BiT

        Args:
              in_channels(int): The channel number of the input tensor of the Conv2d layer.
              channels(int): The channel number of the output tensor of the middle Conv2d layer.
              stride(int): The movement stride of the 2D convolution kernel. Default: 1.
              groups(int): Number of groups for group conv in blocks. Default: 1.
              base_width(int): Base width of pre group hidden channel in blocks. Default: 64.
              norm(nn.Cell): Normalization layer in blocks. Default: None.
              down_sample(nn.Cell): Down sample in blocks. Default: None.
          """

    expansion: int = 4

    def __init__(self,
                 in_channels: int,
                 channels: int,
                 stride: int = 1,
                 groups: int = 1,
                 base_width: int = 64,
                 norm: Optional[nn.Cell] = None,
                 down_sample: Optional[nn.Cell] = None
                 ) -> None:
        super().__init__()
        if norm is None:
            norm = nn.GroupNorm

        width = int(channels * (base_width / 64.0)) * groups
        self.gn1 = norm(32, in_channels)
        self.conv1 = StdConv2d(in_channels, width, kernel_size=1, stride=1)
        self.gn2 = norm(32, width)
        self.conv2 = StdConv2d(width, width, kernel_size=3, stride=stride,
                               padding=1, pad_mode='pad', group=groups)
        self.gn3 = norm(32, width)
        self.conv3 = StdConv2d(width, channels * self.expansion,
                               kernel_size=1, stride=1)

        self.relu = nn.ReLU()
        self.down_sample = down_sample

    def construct(self, x: Tensor) -> Tensor:
        identity = x
        out = self.gn1(x)
        out = self.relu(out)

        residual = out

        out = self.conv1(out)

        out = self.gn2(out)
        out = self.relu(out)
        out = self.conv2(out)

        out = self.gn3(out)
        out = self.relu(out)
        out = self.conv3(out)

        if self.down_sample is not None:
            identity = self.down_sample(residual)

        out += identity
        # out = self.relu(out)

        return out


class BiT(nn.Cell):
    r"""BiT model class, based on
    `"Big Transfer (BiT): General Visual Representation Learning" <https://arxiv.org/abs/1912.11370>`_

    Args:
        block(Union[Bottleneck]): block of BiT.
        layers(tuple(int)): number of layers of each stage.
        wf(int): width of each layer. Default: 1.
        num_classes(int): number of classification classes. Default: 1000.
        in_channels(int): number the channels of the input. Default: 3.
        groups(int): number of groups for group conv in blocks. Default: 1.
        base_width(int): base width of pre group hidden channel in blocks. Default: 64.
        norm(nn.Cell): normalization layer in blocks. Default: None.
    """

    def __init__(self,
                 block: Type[Union[Bottleneck]],
                 layers: List[int],
                 wf: int = 1,
                 num_classes: int = 1000,
                 in_channels: int = 3,
                 groups: int = 1,
                 base_width: int = 64,
                 norm: Optional[nn.Cell] = None
                 ) -> None:
        super().__init__()

        if norm is None:
            norm = nn.GroupNorm

        self.norm: nn.Cell = norm  # add type hints to make pylint happy
        self.input_channels = 64 * wf
        self.groups = groups
        self.base_with = base_width

        self.conv1 = StdConv2d(in_channels, self.input_channels, kernel_size=7,
                               stride=2, pad_mode='pad', padding=3)
        self.pad = nn.ConstantPad2d(1, 0)
        self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='valid')

        self.layer1 = self._make_layer(block, 64 * wf, layers[0])
        self.layer2 = self._make_layer(block, 128 * wf, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256 * wf, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512 * wf, layers[3], stride=2)

        self.gn = norm(32, 2048 * wf)
        self.relu = nn.ReLU()
        self.pool = GlobalAvgPooling(keep_dims=True)
        self.classifier = nn.Conv2d(512 * block.expansion * wf, num_classes, kernel_size=1, has_bias=True)

    def _make_layer(self,
                    block: Type[Union[Bottleneck]],
                    channels: int,
                    block_nums: int,
                    stride: int = 1
                    ) -> nn.SequentialCell:
        """build model depending on cfgs"""
        down_sample = None

        if stride != 1 or self.input_channels != channels * block.expansion:
            down_sample = nn.SequentialCell([
                StdConv2d(self.input_channels, channels * block.expansion, kernel_size=1, stride=stride),
            ])

        layers = []
        layers.append(
            block(
                self.input_channels,
                channels,
                stride=stride,
                down_sample=down_sample,
                groups=self.groups,
                base_width=self.base_with,
                norm=self.norm
            )
        )
        self.input_channels = channels * block.expansion

        for _ in range(1, block_nums):
            layers.append(
                block(
                    self.input_channels,
                    channels,
                    groups=self.groups,
                    base_width=self.base_with,
                    norm=self.norm
                )
            )

        return nn.SequentialCell(layers)

    def root(self, x: Tensor) -> Tensor:
        x = self.conv1(x)
        x = self.pad(x)
        x = self.max_pool(x)
        return x

    def forward_features(self, x: Tensor) -> Tensor:
        """Network forward feature extraction."""
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        return x

    def forward_head(self, x: Tensor) -> Tensor:
        x = self.gn(x)
        x = self.relu(x)
        x = self.pool(x)
        x = self.classifier(x)
        return x

    def construct(self, x: Tensor) -> Tensor:
        x = self.root(x)
        x = self.forward_features(x)
        x = self.forward_head(x)
        assert x.shape[-2:] == (1, 1)  # We should have no spatial shape left.
        return x[..., 0, 0]


[文档]@register_model def BiTresnet50(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs): """Get 50 layers ResNet model. Refer to the base class `models.BiT` for more details. """ default_cfg = default_cfgs['BiTresnet50'] model = BiT(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_channels=in_channels, **kwargs) if pretrained: load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) return model