mindcv.models.edgenext 源代码

"""
MindSpore implementation of `edgenext`.
Refer to EdgeNeXt: Efficiently Amalgamated CNN-Transformer Architecture for Mobile Vision Applications.
"""

import numpy as np
import math
from typing import Tuple

import mindspore as ms
from mindspore import nn, Tensor, Parameter, ops
import mindspore.common.initializer as init

from .registry import register_model
from .layers.drop_path import DropPath
from .layers.identity import Identity
from .utils import load_pretrained

__all__ = [
    'EdgeNeXt',
    'edgenext_small',
]


def _cfg(url='', **kwargs):
    return {
        'url': url,
        'num_classes': 1000,
        'first_conv': 'conv_0.conv',
        'classifier': 'last_linear',
        **kwargs
    }


default_cfgs = {
    'edgenext_small': _cfg(url='https://download.mindspore.cn/toolkits/mindcv/edgenext/edgenext_small.ckpt'),
}


def ssplit(x: Tensor, dim, width):
    B, C, H, W = x.shape
    if C % width == 0:
        return ops.split(x, dim, C // width)
    else:
        begin = 0
        temp = []
        while begin + width < C:
            temp.append(x[:, begin:begin + width, :, :])
            begin += width
        temp.append(x[:, begin:, :, :])
        return temp


class LayerNorm(nn.LayerNorm):
    r""" LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W).
    """
    def __init__(self,
                 normalized_shape: Tuple[int],
                 epsilon: float,
                 norm_axis: int = -1
                 ) -> None:
        super().__init__(normalized_shape=normalized_shape, epsilon=epsilon)
        assert norm_axis in (-1, 1), "ConvNextLayerNorm's norm_axis must be 1 or -1."
        self.norm_axis = norm_axis

    def construct(self, input_x: Tensor) -> Tensor:
        if self.norm_axis == -1:
            y, _, _ = self.layer_norm(input_x, self.gamma, self.beta)
        else:
            input_x = ops.transpose(input_x, (0, 2, 3, 1))
            y, _, _ = self.layer_norm(input_x, self.gamma, self.beta)
            y = ops.transpose(y, (0, 3, 1, 2))
        return y


class PositionalEncodingFourier(nn.Cell):
    def __init__(self, hidden_dim=32, dim=768, temperature=10000):
        super().__init__()
        self.token_projection = nn.Conv2d(hidden_dim * 2, dim, kernel_size=1, has_bias=True)
        self.scale = 2 * math.pi
        self.temperature = temperature
        self.hidden_dim = hidden_dim
        self.dim = dim

    def construct(self, B, H, W):
        mask = Tensor(np.zeros((B, H, W))).astype(ms.bool_)
        not_mask = ~mask

        y_embed = not_mask.cumsum(1, dtype=ms.float32)
        x_embed = not_mask.cumsum(2, dtype=ms.float32)

        eps = 1e-6
        y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
        x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale

        dim_t = ms.numpy.arange(self.hidden_dim, dtype=ms.float32)
        dim_t = self.temperature ** (2 * (dim_t // 2) / self.hidden_dim)

        pos_x = x_embed[:, :, :, None] / dim_t
        pos_y = y_embed[:, :, :, None] / dim_t

        pos_x = ops.stack((ops.sin(pos_x[:, :, :, 0::2]),
                           ops.cos(pos_x[:, :, :, 1::2])), axis=4)
        s1, s2, s3, _, _ = pos_x.shape
        pos_x = ops.reshape(pos_x, (s1, s2, s3, -1))
        pos_y = ops.stack((ops.sin(pos_y[:, :, :, 0::2]),
                           ops.cos(pos_y[:, :, :, 1::2])), axis=4)
        s1, s2, s3, _, _ = pos_y.shape
        pos_y = ops.reshape(pos_y, (s1, s2, s3, -1))
        pos = ops.transpose(ops.concat((pos_y, pos_x), axis=3), (0, 3, 1, 2))
        pos = self.token_projection(pos)
        return pos


class ConvEncoder(nn.Cell):
    def __init__(self,
                 dim,
                 drop_path=0.,
                 layer_scale_init_value=1e-6,
                 expan_ratio=4,
                 kernel_size=7):
        super().__init__()
        self.dwconv = nn.Conv2d(dim, dim, kernel_size=kernel_size, pad_mode="pad", padding=kernel_size // 2, group=dim,
                                has_bias=True)
        self.norm = LayerNorm((dim,), epsilon=1e-6)
        self.pwconv1 = nn.Dense(dim, expan_ratio * dim)
        self.act = nn.GELU(approximate=False)
        self.pwconv2 = nn.Dense(expan_ratio * dim, dim)

        self.gamma1 = Parameter(Tensor(layer_scale_init_value * np.ones(dim), ms.float32), requires_grad=True) if layer_scale_init_value > 0. else None
        self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()

    def construct(self, x: Tensor) -> Tensor:
        input = x
        x = self.dwconv(x)

        x = ops.transpose(x, (0, 2, 3, 1))
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.pwconv2(x)
        if self.gamma1 is not None:
            x = self.gamma1 * x
        x = ops.transpose(x, (0, 3, 1, 2))
        x = input + self.drop_path(x)
        return x


class SDTAEncoder(nn.Cell):
    def __init__(self,
                 dim, drop_path=0.,
                 layer_scale_init_value=1e-6,
                 expan_ratio=4,
                 use_pos_emb=True,
                 num_heads=8,
                 qkv_bias=True,
                 attn_drop=0.,
                 drop=0.,
                 scales=1):
        super().__init__()
        width = max(int(math.ceil(dim / scales)), int(math.floor(dim // scales)))
        self.width = width
        if scales == 1:
            self.nums = 1
        else:
            self.nums = scales - 1
        convs = []
        for i in range(self.nums):
            convs.append(nn.Conv2d(width, width, kernel_size=3, pad_mode="pad", padding=1, group=width, has_bias=True))
        self.convs = nn.CellList(convs)

        self.pos_embd = None
        if use_pos_emb:
            self.pos_embd = PositionalEncodingFourier(dim=dim)
        self.norm_xca = LayerNorm((dim,), epsilon=1e-6)
        self.gamma_xca = Parameter(Tensor(layer_scale_init_value * np.ones(dim), ms.float32),
                                   requires_grad=True) if layer_scale_init_value > 0. else None
        self.xca = XCA(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
        self.norm = LayerNorm((dim,), epsilon=1e-6)
        self.pwconv1 = nn.Dense(dim, expan_ratio * dim)
        self.act = nn.GELU(approximate=False)
        self.pwconv2 = nn.Dense(expan_ratio * dim, dim)
        self.gamma = Parameter(Tensor(layer_scale_init_value * np.ones((dim)), ms.float32),
                               requires_grad=True) if layer_scale_init_value > 0 else None
        self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()

    def construct(self, x: Tensor) -> Tensor:
        input = x

        spx = ssplit(x, 1, self.width)
        sp = None
        out = None
        for i in range(self.nums):
            if i == 0:
                sp = spx[i]
            else:
                sp = sp + spx[i]
            sp = self.convs[i](sp)
            if i == 0:
                out = sp
            else:
                out = ops.concat((out, sp), 1)
        x = ops.concat((out, spx[self.nums]), 1)
        # XCA
        B, C, H, W = x.shape
        x = ops.reshape(x, (B, C, H * W))
        x = ops.transpose(x, (0, 2, 1))
        if self.pos_embd is not None:
            pos_encoding = ops.transpose(ops.reshape(self.pos_embd(B, H, W), (B, -1, x.shape[1])), (0, 2, 1))
            x = x + pos_encoding
        x = x + self.drop_path(self.gamma_xca * self.xca(self.norm_xca(x)))
        x = x.astype(ms.float32)
        x = ops.reshape(x, (B, H, W, C))
        # Inverted Bottleneck
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.pwconv2(x)
        if self.gamma is not None:
            x = self.gamma * x
        x = ops.transpose(x, (0, 3, 1, 2))  # (N, H, W, C) -> (N, C, H, W)

        x = input + self.drop_path(x)
        return x


class XCA(nn.Cell):
    def __init__(self,
                 dim,
                 num_heads=8,
                 qkv_bias=False,
                 attn_drop=0.,
                 proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        self.temperature = Parameter(Tensor(np.ones((num_heads, 1, 1)), ms.float32))

        self.qkv = nn.Dense(dim, dim * 3, has_bias=qkv_bias)
        self.attn_drop = nn.Dropout(1 - attn_drop)
        self.proj = nn.Dense(dim, dim)
        self.proj_drop = nn.Dropout(1 - proj_drop)

    def construct(self, x: Tensor) -> Tensor:
        B, N, C = x.shape
        qkv = ops.reshape(self.qkv(x), (B, N, 3, self.num_heads, C // self.num_heads))
        qkv = ops.transpose(qkv, (2, 0, 3, 1, 4))
        q, k, v = qkv[0], qkv[1], qkv[2]

        q = ops.transpose(q, (0, 1, 3, 2))
        k = ops.transpose(k, (0, 1, 3, 2))
        v = ops.transpose(v, (0, 1, 3, 2))
        l2_normalize = ops.L2Normalize(-1)
        q = l2_normalize(q)
        k = l2_normalize(k)
        attn = (ops.matmul(q, ops.transpose(k, (0, 1, 3, 2)))) * self.temperature
        # -------------------
        attn = ops.Softmax(-1)(attn)
        attn = self.attn_drop(attn)
        x = ops.reshape(ops.transpose((ops.matmul(attn, v)), (0, 3, 1, 2)), (B, N, C))
        # # ------------------
        x = self.proj(x)
        x = self.proj_drop(x)

        return x


[文档]class EdgeNeXt(nn.Cell): r"""EdgeNeXt model class, based on `"Efficiently Amalgamated CNN-Transformer Architecture for Mobile Vision" <https://arxiv.org/abs/2206.10589>`_ Args: in_channels: number of input channels. Default: 3 num_classes: number of classification classes. Default: 1000 depths: the depths of each layer. Default: [0, 0, 0, 3] dims: the middle dim of each layer. Default: [24, 48, 88, 168] global_block: number of global block. Default: [0, 0, 0, 3] global_block_type: type of global block. Default: ['None', 'None', 'None', 'SDTA'] drop_path_rate: Stochastic Depth. Default: 0. layer_scale_init_value: value of layer scale initialization. Default: 1e-6 head_init_scale: scale of head initialization. Default: 1. expan_ratio: ratio of expansion. Default: 4 kernel_sizes: kernel sizes of different stages. Default: [7, 7, 7, 7] heads: number of attention heads. Default: [8, 8, 8, 8] use_pos_embd_xca: use position embedding in xca or not. Default: [False, False, False, False] use_pos_embd_global: use position embedding globally or not. Default: False d2_scales: scales of splitting channels """ def __init__(self, in_chans=3, num_classes=1000, depths=[3, 3, 9, 3], dims=[24, 48, 88, 168], global_block=[0, 0, 0, 3], global_block_type=['None', 'None', 'None', 'SDTA'], drop_path_rate=0., layer_scale_init_value=1e-6, head_init_scale=1., expan_ratio=4, kernel_sizes=[7, 7, 7, 7], heads=[8, 8, 8, 8], use_pos_embd_xca=[False, False, False, False], use_pos_embd_global=False, d2_scales=[2, 3, 4, 5], **kwargs): super().__init__() for g in global_block_type: assert g in ['None', 'SDTA'] if use_pos_embd_global: self.pos_embd = PositionalEncodingFourier(dim=dims[0]) else: self.pos_embd = None self.downsample_layers = nn.CellList() # stem and 3 intermediate downsampling conv layers stem = nn.SequentialCell( nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4, has_bias=True), LayerNorm((dims[0],), epsilon=1e-6, norm_axis=1) ) self.downsample_layers.append(stem) for i in range(3): downsample_layer = nn.SequentialCell( LayerNorm((dims[i],), epsilon=1e-6, norm_axis=1), nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2, has_bias=True), ) self.downsample_layers.append(downsample_layer) self.stages = nn.CellList() # 4 feature resolution stages, each consisting of multiple residual blocks dp_rates = list(np.linspace(0, drop_path_rate, sum(depths))) cur = 0 for i in range(4): stage_blocks = [] for j in range(depths[i]): if j > depths[i] - global_block[i] - 1: if global_block_type[i] == 'SDTA': stage_blocks.append(SDTAEncoder(dim=dims[i], drop_path=dp_rates[cur + j], expan_ratio=expan_ratio, scales=d2_scales[i], use_pos_emb=use_pos_embd_xca[i], num_heads=heads[i])) else: raise NotImplementedError else: stage_blocks.append(ConvEncoder(dim=dims[i], drop_path=dp_rates[cur + j], layer_scale_init_value=layer_scale_init_value, expan_ratio=expan_ratio, kernel_size=kernel_sizes[i])) self.stages.append(nn.SequentialCell(*stage_blocks)) cur += depths[i] self.norm = nn.LayerNorm((dims[-1],), epsilon=1e-6) # Final norm layer self.head = nn.Dense(dims[-1], num_classes) # self.head_dropout = nn.Dropout(kwargs["classifier_dropout"]) self.head_dropout = nn.Dropout(1.0) self.head_init_scale = head_init_scale self._initialize_weights() def _initialize_weights(self) -> None: """Initialize weights for cells.""" for _, cell in self.cells_and_names(): if isinstance(cell, (nn.Dense, nn.Conv2d)): cell.weight.set_data(init.initializer(init.TruncatedNormal(sigma=0.02), cell.weight.shape, cell.weight.dtype)) if isinstance(cell, nn.Dense) and cell.bias is not None: cell.bias.set_data(init.initializer(init.Zero(), cell.bias.shape, cell.bias.dtype)) elif isinstance(cell, (nn.LayerNorm)): cell.gamma.set_data(init.initializer(init.One(), cell.gamma.shape, cell.gamma.dtype)) cell.beta.set_data(init.initializer(init.Zero(), cell.beta.shape, cell.beta.dtype)) self.head.weight.set_data(self.head.weight * self.head_init_scale) self.head.bias.set_data(self.head.bias * self.head_init_scale) def forward_features(self, x): x = self.downsample_layers[0](x) x = self.stages[0](x) if self.pos_embd is not None: B, C, H, W = x.shape x = x + self.pos_embd(B, H, W) for i in range(1, 4): x = self.downsample_layers[i](x) x = self.stages[i](x) return self.norm(x.mean([-2, -1])) # Global average pooling, (N, C, H, W) -> (N, C) def construct(self, x): x = self.forward_features(x) x = self.head(self.head_dropout(x)) return x
@register_model def edgenext_small(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs) -> EdgeNeXt: """Get edgenext_small model. Refer to the base class `models.EdgeNeXt` for more details.""" default_cfg = default_cfgs['edgenext_small'] model = EdgeNeXt(depths=[3, 3, 9, 3], dims=[48, 96, 160, 304], expan_ratio=4, num_classes=num_classes, global_block=[0, 1, 1, 1], global_block_type=['None', 'SDTA', 'SDTA', 'SDTA'], use_pos_embd_xca=[False, True, False, False], kernel_sizes=[3, 5, 7, 9], d2_scales=[2, 2, 3, 4], **kwargs) if pretrained: load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) return model