'''
MindSpore implementation of pnasnet.
Refer to Progressive Neural Architecture Search.
'''
from collections import OrderedDict
import math
from mindspore import nn, ops, Tensor
import mindspore.common.initializer as init
from .layers import GlobalAvgPooling
from .registry import register_model
from .utils import load_pretrained
__all__ = [
'Pnasnet',
'pnasnet'
]
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000,
'first_conv': 'conv_0.conv', 'classifier': 'last_linear',
**kwargs
}
default_cfgs = {
'pnasnet': _cfg(url='https://download.mindspore.cn/toolkits/mindcv/pnasnet/pnasnet_224.ckpt')
}
class MaxPool(nn.Cell):
"""
MaxPool: MaxPool2d with zero padding.
"""
def __init__(self,
kernel_size: int,
stride: int = 1,
zero_pad: bool = False) -> None:
super().__init__()
self.pad = zero_pad
if self.pad:
self.zero_pad = nn.Pad(paddings=((0, 0), (0, 0), (1, 0), (1, 0)))
self.pool = nn.MaxPool2d(kernel_size=kernel_size, stride=stride, pad_mode='same')
def construct(self, x: Tensor) -> Tensor:
if self.pad:
x = self.zero_pad(x)
x = self.pool(x)
if self.pad:
x = x[:, :, 1:, 1:]
return x
class SeparableConv2d(nn.Cell):
"""
SeparableConv2d: Separable convolutions consist of first performing
a depthwise spatial convolution followed by a pointwise convolution.
"""
def __init__(self,
in_channels: int,
out_channels: int,
dw_kernel_size: int,
dw_stride: int,
dw_padding: int) -> None:
super().__init__()
self.depthwise_conv2d = nn.Conv2d(in_channels=in_channels, out_channels=in_channels,
kernel_size=dw_kernel_size, stride=dw_stride,
pad_mode='pad', padding=dw_padding,
group=in_channels, has_bias=False)
self.pointwise_conv2d = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
kernel_size=1, pad_mode='pad', has_bias=False)
def construct(self, x: Tensor) -> Tensor:
x = self.depthwise_conv2d(x)
x = self.pointwise_conv2d(x)
return x
class BranchSeparables(nn.Cell):
"""
BranchSeparables: ReLU + Zero_Pad (when zero_pad is True) + SeparableConv2d + BatchNorm2d +
ReLU + SeparableConv2d + BatchNorm2d
"""
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
stem_cell: bool = False,
zero_pad: bool = False) -> None:
super().__init__()
padding = kernel_size // 2
middle_channels = out_channels if stem_cell else in_channels
self.pad = zero_pad
if self.pad:
self.zero_pad = nn.Pad(paddings=((0, 0), (0, 0), (1, 0), (1, 0)))
self.relu_1 = nn.ReLU()
self.separable_1 = SeparableConv2d(in_channels, middle_channels,
kernel_size, dw_stride=stride,
dw_padding=padding)
self.bn_sep_1 = nn.BatchNorm2d(num_features=middle_channels, eps=0.001, momentum=0.9)
self.relu_2 = nn.ReLU()
self.separable_2 = SeparableConv2d(middle_channels, out_channels,
kernel_size, dw_stride=1,
dw_padding=padding)
self.bn_sep_2 = nn.BatchNorm2d(num_features=out_channels, eps=0.001, momentum=0.9)
def construct(self, x: Tensor) -> Tensor:
x = self.relu_1(x)
if self.pad:
x = self.zero_pad(x)
x = self.separable_1(x)
if self.pad:
x = x[:, :, 1:, 1:]
x = self.bn_sep_1(x)
x = self.relu_2(x)
x = self.separable_2(x)
x = self.bn_sep_2(x)
return x
class ReluConvBn(nn.Cell):
"""
ReluConvBn: ReLU + Conv2d + BatchNorm2d
"""
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1) -> None:
super().__init__()
self.relu = nn.ReLU()
self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
stride=stride, pad_mode='pad', has_bias=False)
self.bn = nn.BatchNorm2d(num_features=out_channels, eps=0.001, momentum=0.9)
def construct(self, x: Tensor) -> Tensor:
x = self.relu(x)
x = self.conv(x)
x = self.bn(x)
return x
class FactorizedReduction(nn.Cell):
"""
FactorizedReduction is used to reduce the spatial size
of the left input of a cell approximately by a factor of 2.
"""
def __init__(self,
in_channels: int,
out_channels: int) -> None:
super().__init__()
self.relu = nn.ReLU()
path_1 = OrderedDict([
('avgpool', nn.AvgPool2d(kernel_size=1, stride=2, pad_mode='valid')),
('conv', nn.Conv2d(in_channels=in_channels, out_channels=out_channels // 2, kernel_size=1,
pad_mode='pad', has_bias=False)),
])
self.path_1 = nn.SequentialCell(path_1)
self.path_2 = nn.CellList([])
self.path_2.append(nn.Pad(paddings=((0, 0), (0, 0), (0, 1), (0, 1)), mode="CONSTANT"))
self.path_2.append(
nn.AvgPool2d(kernel_size=1, stride=2, pad_mode='valid')
)
self.path_2.append(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels // 2 + int(out_channels % 2),
kernel_size=1, stride=1, pad_mode='pad', has_bias=False)
)
self.final_path_bn = nn.BatchNorm2d(num_features=out_channels, eps=0.001, momentum=0.9)
def construct(self, x: Tensor) -> Tensor:
x = self.relu(x)
x_path1 = self.path_1(x)
x_path2 = self.path_2[0](x)
x_path2 = x_path2[:, :, 1:, 1:]
x_path2 = self.path_2[1](x_path2)
x_path2 = self.path_2[2](x_path2)
out = self.final_path_bn(ops.concat((x_path1, x_path2), axis=1))
return out
class CellBase(nn.Cell):
"""
CellBase: PNASNet base unit.
"""
def cell_forward(self, x_left: Tensor, x_right: Tensor) -> Tensor:
"""
cell_forward: to calculate the output according the x_left and x_right.
"""
x_comb_iter_0_left = self.comb_iter_0_left(x_left)
x_comb_iter_0_right = self.comb_iter_0_right(x_left)
x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
x_comb_iter_1_left = self.comb_iter_1_left(x_right)
x_comb_iter_1_right = self.comb_iter_1_right(x_right)
x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
x_comb_iter_2_left = self.comb_iter_2_left(x_right)
x_comb_iter_2_right = self.comb_iter_2_right(x_right)
x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right
x_comb_iter_3_left = self.comb_iter_3_left(x_comb_iter_2)
x_comb_iter_3_right = self.comb_iter_3_right(x_right)
x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right
x_comb_iter_4_left = self.comb_iter_4_left(x_left)
if self.comb_iter_4_right is not None:
x_comb_iter_4_right = self.comb_iter_4_right(x_right)
else:
x_comb_iter_4_right = x_right
x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right
x_out = ops.concat((x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4), axis=1)
return x_out
class CellStem0(CellBase):
"""
CellStemp0:PNASNet Stem0 unit
"""
def __init__(self,
in_channels_left: int,
out_channels_left: int,
in_channels_right: int,
out_channels_right: int) -> None:
super().__init__()
self.conv_1x1 = ReluConvBn(in_channels_right, out_channels_right,
kernel_size=1)
self.comb_iter_0_left = BranchSeparables(in_channels_left,
out_channels_left,
kernel_size=5, stride=2,
stem_cell=True)
comb_iter_0_right = OrderedDict([
('max_pool', MaxPool(3, stride=2)),
('conv', nn.Conv2d(in_channels_left, out_channels_left,
kernel_size=1, has_bias=False)),
('bn', nn.BatchNorm2d(out_channels_left, eps=0.001, momentum=0.9))
])
self.comb_iter_0_right = nn.SequentialCell(comb_iter_0_right)
self.comb_iter_1_left = BranchSeparables(out_channels_right,
out_channels_right,
kernel_size=7, stride=2)
self.comb_iter_1_right = MaxPool(3, stride=2)
self.comb_iter_2_left = BranchSeparables(out_channels_right,
out_channels_right,
kernel_size=5, stride=2)
self.comb_iter_2_right = BranchSeparables(out_channels_right,
out_channels_right,
kernel_size=3, stride=2)
self.comb_iter_3_left = BranchSeparables(out_channels_right,
out_channels_right,
kernel_size=3)
self.comb_iter_3_right = MaxPool(3, stride=2)
self.comb_iter_4_left = BranchSeparables(in_channels_right,
out_channels_right,
kernel_size=3, stride=2,
stem_cell=True)
self.comb_iter_4_right = ReluConvBn(out_channels_right,
out_channels_right,
kernel_size=1, stride=2)
def construct(self, x_left: Tensor) -> Tensor:
x_right = self.conv_1x1(x_left)
x_out = self.cell_forward(x_left, x_right)
return x_out
class Cell(CellBase):
"""
Cell class that is used as a 'layer' in image architectures
"""
def __init__(self,
in_channels_left: int,
out_channels_left: int,
in_channels_right: int,
out_channels_right: int,
is_reduction: bool = False,
zero_pad: bool = False,
match_prev_layer_dimensions: bool = False) -> None:
super().__init__()
stride = 2 if is_reduction else 1
self.match_prev_layer_dimensions = match_prev_layer_dimensions
if match_prev_layer_dimensions:
self.conv_prev_1x1 = FactorizedReduction(in_channels_left, out_channels_left)
else:
self.conv_prev_1x1 = ReluConvBn(in_channels_left, out_channels_left, kernel_size=1)
self.conv_1x1 = ReluConvBn(in_channels_right, out_channels_right, kernel_size=1)
self.comb_iter_0_left = BranchSeparables(out_channels_left,
out_channels_left,
kernel_size=5, stride=stride,
zero_pad=zero_pad)
self.comb_iter_0_right = MaxPool(3, stride=stride, zero_pad=zero_pad)
self.comb_iter_1_left = BranchSeparables(out_channels_right,
out_channels_right,
kernel_size=7, stride=stride,
zero_pad=zero_pad)
self.comb_iter_1_right = MaxPool(3, stride=stride, zero_pad=zero_pad)
self.comb_iter_2_left = BranchSeparables(out_channels_right,
out_channels_right,
kernel_size=5, stride=stride,
zero_pad=zero_pad)
self.comb_iter_2_right = BranchSeparables(out_channels_right,
out_channels_right,
kernel_size=3, stride=stride,
zero_pad=zero_pad)
self.comb_iter_3_left = BranchSeparables(out_channels_right,
out_channels_right,
kernel_size=3)
self.comb_iter_3_right = MaxPool(3, stride=stride, zero_pad=zero_pad)
self.comb_iter_4_left = BranchSeparables(out_channels_left,
out_channels_left,
kernel_size=3, stride=stride,
zero_pad=zero_pad)
if is_reduction:
self.comb_iter_4_right = ReluConvBn(out_channels_right,
out_channels_right,
kernel_size=1, stride=stride)
else:
self.comb_iter_4_right = None
def construct(self, x_left: Tensor, x_right: Tensor) -> Tensor:
x_left = self.conv_prev_1x1(x_left)
x_right = self.conv_1x1(x_right)
x_out = self.cell_forward(x_left, x_right)
return x_out
[文档]class Pnasnet(nn.Cell):
r"""PNasNet model class, based on
`"Progressive Neural Architecture Search" <https://arxiv.org/pdf/1712.00559.pdf>`_
Args:
number of input channels. Default: 3.
num_classes: number of classification classes. Default: 1000.
"""
def __init__(self,
in_channels: int = 3,
num_classes: int = 1000) -> None:
super().__init__()
self.num_classes = num_classes
self.conv_0 = nn.SequentialCell(OrderedDict([
('conv', nn.Conv2d(in_channels=in_channels, out_channels=32, kernel_size=3, stride=2,
pad_mode='pad', has_bias=False)),
('bn', nn.BatchNorm2d(num_features=32, eps=0.001, momentum=0.9))
]))
self.cell_stem_0 = CellStem0(in_channels_left=32, out_channels_left=13,
in_channels_right=32, out_channels_right=13)
self.cell_stem_1 = Cell(in_channels_left=32, out_channels_left=27,
in_channels_right=65, out_channels_right=27,
match_prev_layer_dimensions=True,
is_reduction=True)
self.cell_0 = Cell(in_channels_left=65, out_channels_left=54,
in_channels_right=135, out_channels_right=54,
match_prev_layer_dimensions=True)
self.cell_1 = Cell(in_channels_left=135, out_channels_left=54,
in_channels_right=270, out_channels_right=54)
self.cell_2 = Cell(in_channels_left=270, out_channels_left=54,
in_channels_right=270, out_channels_right=54)
self.cell_3 = Cell(in_channels_left=270, out_channels_left=108,
in_channels_right=270, out_channels_right=108,
is_reduction=True, zero_pad=True)
self.cell_4 = Cell(in_channels_left=270, out_channels_left=108,
in_channels_right=540, out_channels_right=108,
match_prev_layer_dimensions=True)
self.cell_5 = Cell(in_channels_left=540, out_channels_left=108,
in_channels_right=540, out_channels_right=108)
self.cell_6 = Cell(in_channels_left=540, out_channels_left=216,
in_channels_right=540, out_channels_right=216,
is_reduction=True)
self.cell_7 = Cell(in_channels_left=540, out_channels_left=216,
in_channels_right=1080, out_channels_right=216,
match_prev_layer_dimensions=True)
self.cell_8 = Cell(in_channels_left=1080, out_channels_left=216,
in_channels_right=1080, out_channels_right=216)
self.relu = nn.ReLU()
self.pool = GlobalAvgPooling()
self.dropout = nn.Dropout(keep_prob=0.5)
self.last_linear = nn.Dense(in_channels=1080, out_channels=num_classes)
self._initialize_weights()
def _initialize_weights(self):
"""Initialize weights for cells."""
self.init_parameters_data()
for _, cell in self.cells_and_names():
if isinstance(cell, nn.Conv2d):
n = cell.kernel_size[0] * cell.kernel_size[1] * cell.out_channels
cell.weight.set_data(init.initializer(init.Normal(math.sqrt(2. / n), 0),
cell.weight.shape, cell.weight.dtype))
if cell.bias is not None:
cell.bias.set_data(init.initializer(init.Zero(),
cell.bias.shape, cell.bias.dtype))
elif isinstance(cell, nn.BatchNorm2d):
cell.gamma.set_data(init.initializer(init.One(),
cell.gamma.shape, cell.gamma.dtype))
cell.beta.set_data(init.initializer(init.Zero(),
cell.beta.shape, cell.beta.dtype))
elif isinstance(cell, nn.Dense):
cell.weight.set_data(init.initializer(init.Normal(0.01, 0),
cell.weight.shape, cell.weight.dtype))
if cell.bias is not None:
cell.bias.set_data(init.initializer(init.Zero(),
cell.bias.shape, cell.bias.dtype))
def forward_features(self, x: Tensor) -> Tensor:
x_conv_0 = self.conv_0(x)
x_stem_0 = self.cell_stem_0(x_conv_0)
x_stem_1 = self.cell_stem_1(x_conv_0, x_stem_0)
x_cell_0 = self.cell_0(x_stem_0, x_stem_1)
x_cell_1 = self.cell_1(x_stem_1, x_cell_0)
x_cell_2 = self.cell_2(x_cell_0, x_cell_1)
x_cell_3 = self.cell_3(x_cell_1, x_cell_2)
x_cell_4 = self.cell_4(x_cell_2, x_cell_3)
x_cell_5 = self.cell_5(x_cell_3, x_cell_4)
x_cell_6 = self.cell_6(x_cell_4, x_cell_5)
x_cell_7 = self.cell_7(x_cell_5, x_cell_6)
x_cell_8 = self.cell_8(x_cell_6, x_cell_7)
return x_cell_8
def forward_head(self, x: Tensor) -> Tensor:
x = self.relu(x)
x = self.pool(x)
x = self.dropout(x)
x = self.last_linear(x)
return x
def construct(self, x: Tensor) -> Tensor:
x = self.forward_features(x)
x = self.forward_head(x)
return x
@register_model
def pnasnet(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs) -> Pnasnet:
"""Get Pnasnet model.
Refer to the base class `models.Pnasnet` for more details."""
default_cfg = default_cfgs['pnasnet']
model = Pnasnet(in_channels=in_channels, num_classes=num_classes, **kwargs)
if pretrained:
load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)
return model