mindcv.models.pvt 源代码

"""
MindSpore implementation of `PVT`.
Refer to PVT: Pyramid Vision Transformer: A Versatile Backbone for Dense Prediction without Convolutions
"""
import math
from typing import Optional

import mindspore
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor
from mindspore.common import initializer as weight_init
from functools import partial

from .layers.drop_path import DropPath
from .layers.mlp import Mlp
from .layers.identity import Identity
from .utils import load_pretrained
from .registry import register_model

__all__ = [
    'PyramidVisionTransformer',
    'pvt_tiny',
    'pvt_small',
    'pvt_medium',
    'pvt_large',
]


def _cfg(url='', **kwargs):
    return {
        'url': url,
        'num_classes': 1000,
        'first_conv': 'patch_embed1.proj', 'classifier': 'head',
        **kwargs
    }


default_cfgs = {
    'pvt_tiny': _cfg(url='https://download.mindspore.cn/toolkits/mindcv/pvt/pvt_tiny_224.ckpt'),
    'pvt_small': _cfg(url='https://download.mindspore.cn/toolkits/mindcv/pvt/pvt_small_224.ckpt'),
    'pvt_medium': _cfg(url='https://download.mindspore.cn/toolkits/mindcv/pvt/pvt_medium_224.ckpt'),
    'pvt_large': _cfg(url='https://download.mindspore.cn/toolkits/mindcv/pvt/pvt_large_224.ckpt'),

}


class Attention(nn.Cell):
    """spatial-reduction attention (SRA)"""
    def __init__(self,
                 dim: int,
                 num_heads: int = 8,
                 qkv_bias: bool = False,
                 qk_scale: Optional[float] = None,
                 attn_drop: float = 0.,
                 proj_drop: float = 0.,
                 sr_ratio: int = 1):
        super(Attention, self).__init__()
        assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."

        self.dim = dim
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        self.q = nn.Dense(dim, dim, has_bias=qkv_bias)
        self.kv = nn.Dense(dim, dim * 2, has_bias=qkv_bias)
        self.attn_drop = nn.Dropout(1 - attn_drop)
        self.proj = nn.Dense(dim, dim)
        self.proj_drop = nn.Dropout(1 - proj_drop)
        self.qk_batmatmul = ops.BatchMatMul(transpose_b=True)
        self.batmatmul = ops.BatchMatMul()
        self.softmax = nn.Softmax(axis=-1)
        self.reshape = ops.reshape
        self.transpose = ops.transpose

        self.sr_ratio = sr_ratio
        if sr_ratio > 1:
            self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio, has_bias=True)
            self.norm = nn.LayerNorm([dim])

    def construct(self, x, H, W):
        B, N, C = x.shape
        q = self.q(x)
        q = self.reshape(q, (B, N, self.num_heads, C // self.num_heads))
        q = self.transpose(q, (0, 2, 1, 3))
        if self.sr_ratio > 1:

            x_ = self.reshape(self.transpose(x, (0, 2, 1)), (B, C, H, W))

            x_ = self.transpose(self.reshape(self.sr(x_), (B, C, -1)), (0, 2, 1))
            x_ = self.norm(x_)
            kv = self.kv(x_)

            kv = self.transpose(self.reshape(kv, (B, -1, 2, self.num_heads, C // self.num_heads)), (2, 0, 3, 1, 4))
        else:
            kv = self.kv(x)
            kv = self.transpose(self.reshape(kv, (B, -1, 2, self.num_heads, C // self.num_heads)), (2, 0, 3, 1, 4))
        k, v = kv[0], kv[1]
        attn = self.qk_batmatmul(q, k) * self.scale
        attn = self.softmax(attn)
        attn = self.attn_drop(attn)
        x = self.batmatmul(attn, v)
        x = self.reshape(self.transpose(x, (0, 2, 1, 3)), (B, N, C))
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class Block(nn.Cell):
    """ Block with spatial-reduction attention (SRA) and feed forward"""
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1):
        super(Block, self).__init__()
        self.norm1 = norm_layer([dim], epsilon=1e-5)
        self.attn = Attention(
            dim,
            num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
            attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
        self.norm2 = norm_layer([dim])
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def construct(self, x, H, W):
        x1 = self.norm1(x)
        x1 = self.attn(x1, H, W)
        x = x + self.drop_path(x1)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x


class PatchEmbed(nn.Cell):
    """ Image to Patch Embedding"""

    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()

        img_size = (img_size, img_size)
        patch_size = (patch_size, patch_size)

        self.img_size = img_size
        self.patch_size = patch_size

        self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
        self.num_patches = self.H * self.W
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, has_bias=True)
        self.norm = nn.LayerNorm([embed_dim], epsilon=1e-5)
        self.reshape = ops.reshape
        self.transpose = ops.transpose

    def construct(self, x):
        B, C, H, W = x.shape

        x = self.proj(x)
        b, c, h, w = x.shape
        x = self.reshape(x, (b, c, h * w))
        x = self.transpose(x, (0, 2, 1))
        x = self.norm(x)
        H, W = H // self.patch_size[0], W // self.patch_size[1]

        return x, (H, W)


[文档]class PyramidVisionTransformer(nn.Cell): r"""Pyramid Vision Transformer model class, based on `"Pyramid Vision Transformer: A Versatile Backbone for Dense Prediction without Convolutions" <https://arxiv.org/abs/2102.12122>`_ Args: img_size(int) : size of a input image. patch_size (int) : size of a single image patch. in_chans (int) : number the channels of the input. Default: 3. num_classes (int) : number of classification classes. Default: 1000. embed_dims (list) : how many hidden dim in each PatchEmbed. num_heads (list) : number of attention head in each stage. mlp_ratios (list): ratios of MLP hidden dims in each stage. qkv_bias(bool) : use bias in attention. qk_scale(float) : Scale multiplied by qk in attention(if not none), otherwise head_dim ** -0.5. drop_rate(float) : The drop rate for each block. Default: 0.0. attn_drop_rate(float) : The drop rate for attention. Default: 0.0. drop_path_rate(float) : The drop rate for drop path. Default: 0.0. norm_layer(nn.Cell) : Norm layer that will be used in blocks. Default: nn.LayerNorm. depths (list) : number of Blocks. sr_ratios(list) : stride and kernel size of each attention. num_stages(int) : number of stage. Default: 4. """ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], num_stages=4): super(PyramidVisionTransformer, self).__init__() self.num_classes = num_classes self.depths = depths self.num_stages = num_stages start = Tensor(0, mindspore.float32) stop = Tensor(drop_path_rate, mindspore.float32) dpr = [float(x) for x in ops.linspace(start, stop, sum(depths))] # stochastic depth decay rule cur = 0 b_list = [] self.pos_embed = [] self.pos_drop = nn.Dropout(1 - drop_rate) for i in range(num_stages): block = nn.CellList( [Block(dim=embed_dims[i], num_heads=num_heads[i], mlp_ratio=mlp_ratios[i], qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + j], norm_layer=norm_layer, sr_ratio=sr_ratios[i]) for j in range(depths[i]) ]) b_list.append(block) cur += depths[0] self.patch_embed1 = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0]) num_patches = self.patch_embed1.num_patches self.pos_embed1 = mindspore.Parameter(ops.zeros((1, num_patches, embed_dims[0]), mindspore.float16)) self.pos_drop1 = nn.Dropout(1 - drop_rate) self.patch_embed2 = PatchEmbed(img_size=img_size // (2 ** (1 + 1)), patch_size=2, in_chans=embed_dims[1 - 1], embed_dim=embed_dims[1]) num_patches = self.patch_embed2.num_patches self.pos_embed2 = mindspore.Parameter(ops.zeros((1, num_patches, embed_dims[1]), mindspore.float16)) self.pos_drop2 = nn.Dropout(1 - drop_rate) self.patch_embed3 = PatchEmbed(img_size=img_size // (2 ** (2 + 1)), patch_size=2, in_chans=embed_dims[2 - 1], embed_dim=embed_dims[2]) num_patches = self.patch_embed3.num_patches self.pos_embed3 = mindspore.Parameter(ops.zeros((1, num_patches, embed_dims[2]), mindspore.float16)) self.pos_drop3 = nn.Dropout(1 - drop_rate) self.patch_embed4 = PatchEmbed(img_size // (2 ** (3 + 1)), patch_size=2, in_chans=embed_dims[3 - 1], embed_dim=embed_dims[3]) num_patches = self.patch_embed4.num_patches + 1 self.pos_embed4 = mindspore.Parameter(ops.zeros((1, num_patches, embed_dims[3]), mindspore.float16)) self.pos_drop4 = nn.Dropout(1 - drop_rate) self.Blocks = nn.CellList(b_list) self.norm = norm_layer([embed_dims[3]]) # cls_token self.cls_token = mindspore.Parameter(ops.zeros((1, 1, embed_dims[3]), mindspore.float32)) # classification head self.head = nn.Dense(embed_dims[3], num_classes) if num_classes > 0 else Identity() self.reshape = ops.reshape self.transpose = ops.transpose self.tile = ops.Tile() self.Concat = ops.Concat(axis=1) self._initialize_weights() def _initialize_weights(self): for _, cell in self.cells_and_names(): if isinstance(cell, nn.Dense): cell.weight.set_data(weight_init.initializer(weight_init.TruncatedNormal(sigma=0.02), cell.weight.shape, cell.weight.dtype)) if isinstance(cell, nn.Dense) and cell.bias is not None: cell.bias.set_data(weight_init.initializer(weight_init.Zero(), cell.bias.shape, cell.bias.dtype)) elif isinstance(cell, nn.LayerNorm): cell.gamma.set_data(weight_init.initializer(weight_init.One(), cell.gamma.shape, cell.gamma.dtype)) cell.beta.set_data(weight_init.initializer(weight_init.Zero(), cell.beta.shape, cell.beta.dtype)) elif isinstance(cell, nn.Conv2d): fan_out = cell.kernel_size[0] * cell.kernel_size[1] * cell.out_channels fan_out //= cell.group cell.weight.set_data(weight_init.initializer(weight_init.Normal(sigma=math.sqrt(2.0 / fan_out)), cell.weight.shape, cell.weight.dtype)) if cell.bias is not None: cell.bias.set_data(weight_init.initializer(weight_init.Zero(), cell.bias.shape, cell.bias.dtype)) def get_classifier(self): return self.head def reset_classifier(self, num_classes, global_pool=''): self.num_classes = num_classes self.head = nn.Dense(self.embed_dim, num_classes) if num_classes > 0 else Identity() def _get_pos_embed(self, pos_embed, ph, pw, H, W): if H * W == self.patch_embed1.num_patches: return pos_embed else: ResizeBilinear = nn.ResizeBilinear() pos_embed = self.transpose(self.reshape(pos_embed, (1, ph, pw, -1)), (0, 3, 1, 2)) pos_embed = ResizeBilinear(pos_embed, (H, W)) pos_embed = self.transpose(self.reshape(pos_embed, (1, -1, H * W)), (0, 2, 1)) return pos_embed def forward_features(self, x): B = x.shape[0] x, (H, W) = self.patch_embed1(x) pos_embed = self.pos_embed1 x = self.pos_drop1(x + pos_embed) for blk in self.Blocks[0]: x = blk(x, H, W) x = self.transpose(self.reshape(x, (B, H, W, -1)), (0, 3, 1, 2)) x, (H, W) = self.patch_embed2(x) ph, pw = self.patch_embed2.H, self.patch_embed2.W pos_embed = self._get_pos_embed(self.pos_embed2, ph, pw, H, W) x = self.pos_drop2(x + pos_embed) for blk in self.Blocks[1]: x = blk(x, H, W) x = self.transpose(self.reshape(x, (B, H, W, -1)), (0, 3, 1, 2)) x, (H, W) = self.patch_embed3(x) ph, pw = self.patch_embed3.H, self.patch_embed3.W pos_embed = self._get_pos_embed(self.pos_embed3, ph, pw, H, W) x = self.pos_drop3(x + pos_embed) for blk in self.Blocks[2]: x = blk(x, H, W) x = self.transpose(self.reshape(x, (B, H, W, -1)), (0, 3, 1, 2)) x, (H, W) = self.patch_embed4(x) cls_tokens = self.tile(self.cls_token, (B, 1, 1)) x = self.Concat((cls_tokens, x)) ph, pw = self.patch_embed4.H, self.patch_embed4.W pos_embed_ = self._get_pos_embed(self.pos_embed4[:, 1:], ph, pw, H, W) pos_embed = self.Concat((self.pos_embed4[:, 0:1], pos_embed_)) x = self.pos_drop4(x + pos_embed) for blk in self.Blocks[3]: x = blk(x, H, W) x = self.norm(x) return x[:, 0] def forward_head(self , x: Tensor)->Tensor: return self.head(x) def construct(self, x): x = self.forward_features(x) x = self.forward_head(x) return x
@register_model def pvt_tiny(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs) -> PyramidVisionTransformer: """Get PVT tiny model Refer to the base class "models.PVT" for more details. """ default_cfg = default_cfgs['pvt_tiny'] model = PyramidVisionTransformer(in_chans=in_channels, num_classes=num_classes, patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, norm_layer=partial(nn.LayerNorm, epsilon=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], **kwargs) if pretrained: load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) return model @register_model def pvt_small(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs) -> PyramidVisionTransformer: """Get PVT small model Refer to the base class "models.PVT" for more details. """ default_cfg = default_cfgs['pvt_small'] model = PyramidVisionTransformer(in_chans=in_channels, num_classes=num_classes, patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, norm_layer=partial(nn.LayerNorm, epsilon=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], **kwargs) if pretrained: load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) return model @register_model def pvt_medium(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs) -> PyramidVisionTransformer: """Get PVT medium model Refer to the base class "models.PVT" for more details. """ default_cfg = default_cfgs['pvt_medium'] model = PyramidVisionTransformer(in_chans=in_channels, num_classes=num_classes, patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, norm_layer=partial(nn.LayerNorm, epsilon=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], **kwargs) if pretrained: load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) return model @register_model def pvt_large(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs) -> PyramidVisionTransformer: """Get PVT large model Refer to the base class "models.PVT" for more details. """ default_cfg = default_cfgs['pvt_large'] model = PyramidVisionTransformer(in_chans=in_channels, num_classes=num_classes, patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, norm_layer=partial(nn.LayerNorm, epsilon=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], **kwargs) if pretrained: load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) return model