"""
MindSpore implementation of `BiT`.
Refer to Big Transfer (BiT): General Visual Representation Learning.
"""
from typing import Optional, Type, List, Union
import mindspore
from mindspore import nn, Tensor, ops
from .layers.pooling import GlobalAvgPooling
from .utils import load_pretrained
from .registry import register_model
__all__ = [
'BiT',
'BiTresnet50',
]
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000,
'first_conv': 'conv1', 'classifier': 'classifier',
**kwargs
}
default_cfgs = {
'BiTresnet50': _cfg(url='https://download.mindspore.cn/toolkits/mindcv/bit/BiTresnet50.ckpt')
}
class StdConv2d(nn.Conv2d):
r"""Conv2d with Weight Standardization
Args:
in_channels(int): The channel number of the input tensor of the Conv2d layer.
out_channels(int): The channel number of the output tensor of the Conv2d layer.
kernel_size(int): Specifies the height and width of the 2D convolution kernel.
stride(int): The movement stride of the 2D convolution kernel. Default: 1.
pad_mode(str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same".
padding(int): The number of padding on the height and width directions of the input. Default: 0.
group(int): Splits filter into groups. Default: 1.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
pad_mode='same',
padding=0,
group=1
) -> None:
super(StdConv2d, self).__init__(
in_channels,
out_channels,
kernel_size,
stride,
pad_mode,
padding,
group)
self.mean_op = ops.ReduceMean(keep_dims=True)
def construct(self, x):
w = self.weight
m = self.mean_op(w, [1, 2, 3])
v = w.var((1, 2, 3), keepdims=True)
w = (w - m) / mindspore.ops.sqrt(v + 1e-10)
output = self.conv2d(x, w)
return output
class Bottleneck(nn.Cell):
"""define the basic block of BiT
Args:
in_channels(int): The channel number of the input tensor of the Conv2d layer.
channels(int): The channel number of the output tensor of the middle Conv2d layer.
stride(int): The movement stride of the 2D convolution kernel. Default: 1.
groups(int): Number of groups for group conv in blocks. Default: 1.
base_width(int): Base width of pre group hidden channel in blocks. Default: 64.
norm(nn.Cell): Normalization layer in blocks. Default: None.
down_sample(nn.Cell): Down sample in blocks. Default: None.
"""
expansion: int = 4
def __init__(self,
in_channels: int,
channels: int,
stride: int = 1,
groups: int = 1,
base_width: int = 64,
norm: Optional[nn.Cell] = None,
down_sample: Optional[nn.Cell] = None
) -> None:
super().__init__()
if norm is None:
norm = nn.GroupNorm
width = int(channels * (base_width / 64.0)) * groups
self.gn1 = norm(32, in_channels)
self.conv1 = StdConv2d(in_channels, width, kernel_size=1, stride=1)
self.gn2 = norm(32, width)
self.conv2 = StdConv2d(width, width, kernel_size=3, stride=stride,
padding=1, pad_mode='pad', group=groups)
self.gn3 = norm(32, width)
self.conv3 = StdConv2d(width, channels * self.expansion,
kernel_size=1, stride=1)
self.relu = nn.ReLU()
self.down_sample = down_sample
def construct(self, x: Tensor) -> Tensor:
identity = x
out = self.gn1(x)
out = self.relu(out)
residual = out
out = self.conv1(out)
out = self.gn2(out)
out = self.relu(out)
out = self.conv2(out)
out = self.gn3(out)
out = self.relu(out)
out = self.conv3(out)
if self.down_sample is not None:
identity = self.down_sample(residual)
out += identity
# out = self.relu(out)
return out
class BiT(nn.Cell):
r"""BiT model class, based on
`"Big Transfer (BiT): General Visual Representation Learning" <https://arxiv.org/abs/1912.11370>`_
Args:
block(Union[Bottleneck]): block of BiT.
layers(tuple(int)): number of layers of each stage.
wf(int): width of each layer. Default: 1.
num_classes(int): number of classification classes. Default: 1000.
in_channels(int): number the channels of the input. Default: 3.
groups(int): number of groups for group conv in blocks. Default: 1.
base_width(int): base width of pre group hidden channel in blocks. Default: 64.
norm(nn.Cell): normalization layer in blocks. Default: None.
"""
def __init__(self,
block: Type[Union[Bottleneck]],
layers: List[int],
wf: int = 1,
num_classes: int = 1000,
in_channels: int = 3,
groups: int = 1,
base_width: int = 64,
norm: Optional[nn.Cell] = None
) -> None:
super().__init__()
if norm is None:
norm = nn.GroupNorm
self.norm: nn.Cell = norm # add type hints to make pylint happy
self.input_channels = 64 * wf
self.groups = groups
self.base_with = base_width
self.conv1 = StdConv2d(in_channels, self.input_channels, kernel_size=7,
stride=2, pad_mode='pad', padding=3)
self.pad = nn.ConstantPad2d(1, 0)
self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='valid')
self.layer1 = self._make_layer(block, 64 * wf, layers[0])
self.layer2 = self._make_layer(block, 128 * wf, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256 * wf, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512 * wf, layers[3], stride=2)
self.gn = norm(32, 2048 * wf)
self.relu = nn.ReLU()
self.pool = GlobalAvgPooling(keep_dims=True)
self.classifier = nn.Conv2d(512 * block.expansion * wf, num_classes, kernel_size=1, has_bias=True)
def _make_layer(self,
block: Type[Union[Bottleneck]],
channels: int,
block_nums: int,
stride: int = 1
) -> nn.SequentialCell:
"""build model depending on cfgs"""
down_sample = None
if stride != 1 or self.input_channels != channels * block.expansion:
down_sample = nn.SequentialCell([
StdConv2d(self.input_channels, channels * block.expansion, kernel_size=1, stride=stride),
])
layers = []
layers.append(
block(
self.input_channels,
channels,
stride=stride,
down_sample=down_sample,
groups=self.groups,
base_width=self.base_with,
norm=self.norm
)
)
self.input_channels = channels * block.expansion
for _ in range(1, block_nums):
layers.append(
block(
self.input_channels,
channels,
groups=self.groups,
base_width=self.base_with,
norm=self.norm
)
)
return nn.SequentialCell(layers)
def root(self, x: Tensor) -> Tensor:
x = self.conv1(x)
x = self.pad(x)
x = self.max_pool(x)
return x
def forward_features(self, x: Tensor) -> Tensor:
"""Network forward feature extraction."""
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
return x
def forward_head(self, x: Tensor) -> Tensor:
x = self.gn(x)
x = self.relu(x)
x = self.pool(x)
x = self.classifier(x)
return x
def construct(self, x: Tensor) -> Tensor:
x = self.root(x)
x = self.forward_features(x)
x = self.forward_head(x)
assert x.shape[-2:] == (1, 1) # We should have no spatial shape left.
return x[..., 0, 0]
[文档]@register_model
def BiTresnet50(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs):
"""Get 50 layers ResNet model.
Refer to the base class `models.BiT` for more details.
"""
default_cfg = default_cfgs['BiTresnet50']
model = BiT(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_channels=in_channels, **kwargs)
if pretrained:
load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)
return model