"""
MindSpore implementation of `edgenext`.
Refer to EdgeNeXt: Efficiently Amalgamated CNN-Transformer Architecture for Mobile Vision Applications.
"""
import numpy as np
import math
from typing import Tuple
import mindspore as ms
from mindspore import nn, Tensor, Parameter, ops
import mindspore.common.initializer as init
from .registry import register_model
from .layers.drop_path import DropPath
from .layers.identity import Identity
from .utils import load_pretrained
__all__ = [
'EdgeNeXt',
'edgenext_small',
]
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000,
'first_conv': 'conv_0.conv',
'classifier': 'last_linear',
**kwargs
}
default_cfgs = {
'edgenext_small': _cfg(url='https://download.mindspore.cn/toolkits/mindcv/edgenext/edgenext_small.ckpt'),
}
def ssplit(x: Tensor, dim, width):
B, C, H, W = x.shape
if C % width == 0:
return ops.split(x, dim, C // width)
else:
begin = 0
temp = []
while begin + width < C:
temp.append(x[:, begin:begin + width, :, :])
begin += width
temp.append(x[:, begin:, :, :])
return temp
class LayerNorm(nn.LayerNorm):
r""" LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W).
"""
def __init__(self,
normalized_shape: Tuple[int],
epsilon: float,
norm_axis: int = -1
) -> None:
super().__init__(normalized_shape=normalized_shape, epsilon=epsilon)
assert norm_axis in (-1, 1), "ConvNextLayerNorm's norm_axis must be 1 or -1."
self.norm_axis = norm_axis
def construct(self, input_x: Tensor) -> Tensor:
if self.norm_axis == -1:
y, _, _ = self.layer_norm(input_x, self.gamma, self.beta)
else:
input_x = ops.transpose(input_x, (0, 2, 3, 1))
y, _, _ = self.layer_norm(input_x, self.gamma, self.beta)
y = ops.transpose(y, (0, 3, 1, 2))
return y
class PositionalEncodingFourier(nn.Cell):
def __init__(self, hidden_dim=32, dim=768, temperature=10000):
super().__init__()
self.token_projection = nn.Conv2d(hidden_dim * 2, dim, kernel_size=1, has_bias=True)
self.scale = 2 * math.pi
self.temperature = temperature
self.hidden_dim = hidden_dim
self.dim = dim
def construct(self, B, H, W):
mask = Tensor(np.zeros((B, H, W))).astype(ms.bool_)
not_mask = ~mask
y_embed = not_mask.cumsum(1, dtype=ms.float32)
x_embed = not_mask.cumsum(2, dtype=ms.float32)
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = ms.numpy.arange(self.hidden_dim, dtype=ms.float32)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.hidden_dim)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = ops.stack((ops.sin(pos_x[:, :, :, 0::2]),
ops.cos(pos_x[:, :, :, 1::2])), axis=4)
s1, s2, s3, _, _ = pos_x.shape
pos_x = ops.reshape(pos_x, (s1, s2, s3, -1))
pos_y = ops.stack((ops.sin(pos_y[:, :, :, 0::2]),
ops.cos(pos_y[:, :, :, 1::2])), axis=4)
s1, s2, s3, _, _ = pos_y.shape
pos_y = ops.reshape(pos_y, (s1, s2, s3, -1))
pos = ops.transpose(ops.concat((pos_y, pos_x), axis=3), (0, 3, 1, 2))
pos = self.token_projection(pos)
return pos
class ConvEncoder(nn.Cell):
def __init__(self,
dim,
drop_path=0.,
layer_scale_init_value=1e-6,
expan_ratio=4,
kernel_size=7):
super().__init__()
self.dwconv = nn.Conv2d(dim, dim, kernel_size=kernel_size, pad_mode="pad", padding=kernel_size // 2, group=dim,
has_bias=True)
self.norm = LayerNorm((dim,), epsilon=1e-6)
self.pwconv1 = nn.Dense(dim, expan_ratio * dim)
self.act = nn.GELU(approximate=False)
self.pwconv2 = nn.Dense(expan_ratio * dim, dim)
self.gamma1 = Parameter(Tensor(layer_scale_init_value * np.ones(dim), ms.float32), requires_grad=True) if layer_scale_init_value > 0. else None
self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
def construct(self, x: Tensor) -> Tensor:
input = x
x = self.dwconv(x)
x = ops.transpose(x, (0, 2, 3, 1))
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma1 is not None:
x = self.gamma1 * x
x = ops.transpose(x, (0, 3, 1, 2))
x = input + self.drop_path(x)
return x
class SDTAEncoder(nn.Cell):
def __init__(self,
dim, drop_path=0.,
layer_scale_init_value=1e-6,
expan_ratio=4,
use_pos_emb=True,
num_heads=8,
qkv_bias=True,
attn_drop=0.,
drop=0.,
scales=1):
super().__init__()
width = max(int(math.ceil(dim / scales)), int(math.floor(dim // scales)))
self.width = width
if scales == 1:
self.nums = 1
else:
self.nums = scales - 1
convs = []
for i in range(self.nums):
convs.append(nn.Conv2d(width, width, kernel_size=3, pad_mode="pad", padding=1, group=width, has_bias=True))
self.convs = nn.CellList(convs)
self.pos_embd = None
if use_pos_emb:
self.pos_embd = PositionalEncodingFourier(dim=dim)
self.norm_xca = LayerNorm((dim,), epsilon=1e-6)
self.gamma_xca = Parameter(Tensor(layer_scale_init_value * np.ones(dim), ms.float32),
requires_grad=True) if layer_scale_init_value > 0. else None
self.xca = XCA(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
self.norm = LayerNorm((dim,), epsilon=1e-6)
self.pwconv1 = nn.Dense(dim, expan_ratio * dim)
self.act = nn.GELU(approximate=False)
self.pwconv2 = nn.Dense(expan_ratio * dim, dim)
self.gamma = Parameter(Tensor(layer_scale_init_value * np.ones((dim)), ms.float32),
requires_grad=True) if layer_scale_init_value > 0 else None
self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
def construct(self, x: Tensor) -> Tensor:
input = x
spx = ssplit(x, 1, self.width)
sp = None
out = None
for i in range(self.nums):
if i == 0:
sp = spx[i]
else:
sp = sp + spx[i]
sp = self.convs[i](sp)
if i == 0:
out = sp
else:
out = ops.concat((out, sp), 1)
x = ops.concat((out, spx[self.nums]), 1)
# XCA
B, C, H, W = x.shape
x = ops.reshape(x, (B, C, H * W))
x = ops.transpose(x, (0, 2, 1))
if self.pos_embd is not None:
pos_encoding = ops.transpose(ops.reshape(self.pos_embd(B, H, W), (B, -1, x.shape[1])), (0, 2, 1))
x = x + pos_encoding
x = x + self.drop_path(self.gamma_xca * self.xca(self.norm_xca(x)))
x = x.astype(ms.float32)
x = ops.reshape(x, (B, H, W, C))
# Inverted Bottleneck
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = self.gamma * x
x = ops.transpose(x, (0, 3, 1, 2)) # (N, H, W, C) -> (N, C, H, W)
x = input + self.drop_path(x)
return x
class XCA(nn.Cell):
def __init__(self,
dim,
num_heads=8,
qkv_bias=False,
attn_drop=0.,
proj_drop=0.):
super().__init__()
self.num_heads = num_heads
self.temperature = Parameter(Tensor(np.ones((num_heads, 1, 1)), ms.float32))
self.qkv = nn.Dense(dim, dim * 3, has_bias=qkv_bias)
self.attn_drop = nn.Dropout(1 - attn_drop)
self.proj = nn.Dense(dim, dim)
self.proj_drop = nn.Dropout(1 - proj_drop)
def construct(self, x: Tensor) -> Tensor:
B, N, C = x.shape
qkv = ops.reshape(self.qkv(x), (B, N, 3, self.num_heads, C // self.num_heads))
qkv = ops.transpose(qkv, (2, 0, 3, 1, 4))
q, k, v = qkv[0], qkv[1], qkv[2]
q = ops.transpose(q, (0, 1, 3, 2))
k = ops.transpose(k, (0, 1, 3, 2))
v = ops.transpose(v, (0, 1, 3, 2))
l2_normalize = ops.L2Normalize(-1)
q = l2_normalize(q)
k = l2_normalize(k)
attn = (ops.matmul(q, ops.transpose(k, (0, 1, 3, 2)))) * self.temperature
# -------------------
attn = ops.Softmax(-1)(attn)
attn = self.attn_drop(attn)
x = ops.reshape(ops.transpose((ops.matmul(attn, v)), (0, 3, 1, 2)), (B, N, C))
# # ------------------
x = self.proj(x)
x = self.proj_drop(x)
return x
[文档]class EdgeNeXt(nn.Cell):
r"""EdgeNeXt model class, based on
`"Efficiently Amalgamated CNN-Transformer Architecture for Mobile Vision" <https://arxiv.org/abs/2206.10589>`_
Args:
in_channels: number of input channels. Default: 3
num_classes: number of classification classes. Default: 1000
depths: the depths of each layer. Default: [0, 0, 0, 3]
dims: the middle dim of each layer. Default: [24, 48, 88, 168]
global_block: number of global block. Default: [0, 0, 0, 3]
global_block_type: type of global block. Default: ['None', 'None', 'None', 'SDTA']
drop_path_rate: Stochastic Depth. Default: 0.
layer_scale_init_value: value of layer scale initialization. Default: 1e-6
head_init_scale: scale of head initialization. Default: 1.
expan_ratio: ratio of expansion. Default: 4
kernel_sizes: kernel sizes of different stages. Default: [7, 7, 7, 7]
heads: number of attention heads. Default: [8, 8, 8, 8]
use_pos_embd_xca: use position embedding in xca or not. Default: [False, False, False, False]
use_pos_embd_global: use position embedding globally or not. Default: False
d2_scales: scales of splitting channels
"""
def __init__(self, in_chans=3, num_classes=1000,
depths=[3, 3, 9, 3], dims=[24, 48, 88, 168],
global_block=[0, 0, 0, 3], global_block_type=['None', 'None', 'None', 'SDTA'],
drop_path_rate=0., layer_scale_init_value=1e-6, head_init_scale=1., expan_ratio=4,
kernel_sizes=[7, 7, 7, 7], heads=[8, 8, 8, 8], use_pos_embd_xca=[False, False, False, False],
use_pos_embd_global=False, d2_scales=[2, 3, 4, 5], **kwargs):
super().__init__()
for g in global_block_type:
assert g in ['None', 'SDTA']
if use_pos_embd_global:
self.pos_embd = PositionalEncodingFourier(dim=dims[0])
else:
self.pos_embd = None
self.downsample_layers = nn.CellList() # stem and 3 intermediate downsampling conv layers
stem = nn.SequentialCell(
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4, has_bias=True),
LayerNorm((dims[0],), epsilon=1e-6, norm_axis=1)
)
self.downsample_layers.append(stem)
for i in range(3):
downsample_layer = nn.SequentialCell(
LayerNorm((dims[i],), epsilon=1e-6, norm_axis=1),
nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2, has_bias=True),
)
self.downsample_layers.append(downsample_layer)
self.stages = nn.CellList() # 4 feature resolution stages, each consisting of multiple residual blocks
dp_rates = list(np.linspace(0, drop_path_rate, sum(depths)))
cur = 0
for i in range(4):
stage_blocks = []
for j in range(depths[i]):
if j > depths[i] - global_block[i] - 1:
if global_block_type[i] == 'SDTA':
stage_blocks.append(SDTAEncoder(dim=dims[i], drop_path=dp_rates[cur + j],
expan_ratio=expan_ratio, scales=d2_scales[i],
use_pos_emb=use_pos_embd_xca[i], num_heads=heads[i]))
else:
raise NotImplementedError
else:
stage_blocks.append(ConvEncoder(dim=dims[i], drop_path=dp_rates[cur + j],
layer_scale_init_value=layer_scale_init_value,
expan_ratio=expan_ratio, kernel_size=kernel_sizes[i]))
self.stages.append(nn.SequentialCell(*stage_blocks))
cur += depths[i]
self.norm = nn.LayerNorm((dims[-1],), epsilon=1e-6) # Final norm layer
self.head = nn.Dense(dims[-1], num_classes)
# self.head_dropout = nn.Dropout(kwargs["classifier_dropout"])
self.head_dropout = nn.Dropout(1.0)
self.head_init_scale = head_init_scale
self._initialize_weights()
def _initialize_weights(self) -> None:
"""Initialize weights for cells."""
for _, cell in self.cells_and_names():
if isinstance(cell, (nn.Dense, nn.Conv2d)):
cell.weight.set_data(init.initializer(init.TruncatedNormal(sigma=0.02),
cell.weight.shape,
cell.weight.dtype))
if isinstance(cell, nn.Dense) and cell.bias is not None:
cell.bias.set_data(init.initializer(init.Zero(),
cell.bias.shape,
cell.bias.dtype))
elif isinstance(cell, (nn.LayerNorm)):
cell.gamma.set_data(init.initializer(init.One(),
cell.gamma.shape,
cell.gamma.dtype))
cell.beta.set_data(init.initializer(init.Zero(),
cell.beta.shape,
cell.beta.dtype))
self.head.weight.set_data(self.head.weight * self.head_init_scale)
self.head.bias.set_data(self.head.bias * self.head_init_scale)
def forward_features(self, x):
x = self.downsample_layers[0](x)
x = self.stages[0](x)
if self.pos_embd is not None:
B, C, H, W = x.shape
x = x + self.pos_embd(B, H, W)
for i in range(1, 4):
x = self.downsample_layers[i](x)
x = self.stages[i](x)
return self.norm(x.mean([-2, -1])) # Global average pooling, (N, C, H, W) -> (N, C)
def construct(self, x):
x = self.forward_features(x)
x = self.head(self.head_dropout(x))
return x
@register_model
def edgenext_small(pretrained: bool = False,
num_classes: int = 1000,
in_channels: int = 3,
**kwargs) -> EdgeNeXt:
"""Get edgenext_small model.
Refer to the base class `models.EdgeNeXt` for more details."""
default_cfg = default_cfgs['edgenext_small']
model = EdgeNeXt(depths=[3, 3, 9, 3], dims=[48, 96, 160, 304], expan_ratio=4,
num_classes=num_classes,
global_block=[0, 1, 1, 1],
global_block_type=['None', 'SDTA', 'SDTA', 'SDTA'],
use_pos_embd_xca=[False, True, False, False],
kernel_sizes=[3, 5, 7, 9],
d2_scales=[2, 2, 3, 4],
**kwargs)
if pretrained:
load_pretrained(model,
default_cfg,
num_classes=num_classes,
in_channels=in_channels)
return model