"""
MindSpore implementation of `PVTv2`.
Refer to PVTv2: PVTv2: Improved Baselines with Pyramid Vision Transformer
"""
import math
from functools import partial
import mindspore
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor
from mindspore.common import initializer as weight_init
from .layers import DropPath
from .layers import Identity
from .utils import load_pretrained
from .registry import register_model
__all__ = [
'PyramidVisionTransformerV2',
'pvt_v2_b0',
'pvt_v2_b1',
'pvt_v2_b2',
'pvt_v2_b3',
'pvt_v2_b4',
'pvt_v2_b5'
]
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000,
'first_conv': 'patch_embed_list.0.proj', 'classifier': 'head',
**kwargs
}
default_cfgs = {
'pvt_v2_b0': _cfg(url=''),
'pvt_v2_b1': _cfg(url=''),
'pvt_v2_b2': _cfg(url=''),
'pvt_v2_b3': _cfg(url=''),
'pvt_v2_b4': _cfg(url=''),
'pvt_v2_b5': _cfg(url='')
}
class DWConv(nn.Cell):
"""Depthwise separable convolution"""
def __init__(self, dim=768):
super(DWConv, self).__init__()
self.dwconv = nn.Conv2d(dim, dim, 3, 1, has_bias=True, group=dim)
def construct(self, x, H, W):
B, N, C = x.shape
x = ops.transpose(x, (0, 2, 1)).view((B, C, H, W))
x = self.dwconv(x)
x = ops.transpose(x.view((B, C, H * W)), (0, 2, 1))
return x
class Mlp(nn.Cell):
"""MLP with depthwise separable convolution"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., linear=False):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Dense(in_features, hidden_features)
self.dwconv = DWConv(hidden_features)
self.act = act_layer()
self.fc2 = nn.Dense(hidden_features, out_features)
self.drop = nn.Dropout(1 - drop)
self.linear = linear
if self.linear:
self.relu = nn.ReLU()
def construct(self, x, H, W):
x = self.fc1(x)
if self.linear:
x = self.relu(x)
x = self.dwconv(x, H, W)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Cell):
"""Linear Spatial Reduction Attention"""
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1,
linear=False):
super().__init__()
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.q = nn.Dense(dim, dim, has_bias=qkv_bias)
self.kv = nn.Dense(dim, dim * 2, has_bias=qkv_bias)
self.attn_drop = nn.Dropout(1 - attn_drop)
self.proj = nn.Dense(dim, dim)
self.proj_drop = nn.Dropout(1 - proj_drop)
self.qk_batmatmul = ops.BatchMatMul(transpose_b=True)
self.batmatmul = ops.BatchMatMul()
self.softmax = nn.Softmax(axis=-1)
self.linear = linear
self.sr_ratio = sr_ratio
if not linear:
if sr_ratio > 1:
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio, has_bias=True)
self.norm = nn.LayerNorm([dim])
else:
self.pool = nn.AdaptiveAvgPool2d(7)
self.sr = nn.Conv2d(dim, dim, kernel_size=1, stride=1, has_bias=True)
self.norm = nn.LayerNorm([dim])
self.act = nn.GELU()
def construct(self, x, H, W):
B, N, C = x.shape
q = self.q(x)
q = ops.reshape(q, (B, N, self.num_heads, C // self.num_heads))
q = ops.transpose(q, (0, 2, 1, 3))
if not self.linear:
if self.sr_ratio > 1:
x_ = ops.reshape(ops.transpose(x, (0, 2, 1)), (B, C, H, W))
x_ = self.sr(x_)
x_ = ops.transpose(ops.reshape(x_, (B, C, -1)), (0, 2, 1))
x_ = self.norm(x_)
kv = self.kv(x_)
kv = ops.transpose(ops.reshape(kv, (B, -1, 2, self.num_heads, C // self.num_heads)), (2, 0, 3, 1, 4))
else:
kv = self.kv(x)
kv = ops.transpose(ops.reshape(kv, (B, -1, 2, self.num_heads, C // self.num_heads)), (2, 0, 3, 1, 4))
else:
x_ = ops.reshape(ops.transpose(x, (0, 2, 1)), (B, C, H, W))
x_ = self.sr(self.pool(x_))
x_ = ops.reshape(ops.transpose(x_, (0, 2, 1)), (B, C, -1))
x_ = self.norm(x_)
x_ = self.act(x_)
kv = ops.transpose(ops.reshape(self.kv(x_), (B, -1, 2, self.num_heads, C // self.num_heads)),
(2, 0, 3, 1, 4))
k, v = kv[0], kv[1]
attn = self.qk_batmatmul(q, k) * self.scale
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = self.batmatmul(attn, v)
x = ops.reshape(ops.transpose(x, (0, 2, 1, 3)), (B, N, C))
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Cell):
""" Block with Linear Spatial Reduction Attention and Convolutional Feed-Forward"""
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, linear=False, block_id=0):
super().__init__()
self.norm1 = norm_layer([dim])
self.attn = Attention(
dim,
num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio, linear=linear)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
self.norm2 = norm_layer([dim])
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, linear=linear)
def construct(self, x, H, W):
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
return x
class OverlapPatchEmbed(nn.Cell):
"""Overlapping Patch Embedding"""
def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
super().__init__()
img_size = (img_size, img_size)
patch_size = (patch_size, patch_size)
assert max(patch_size) > stride, "Set larger patch_size than stride"
self.img_size = img_size
self.patch_size = patch_size
self.H, self.W = img_size[0] // stride, img_size[1] // stride
self.num_patches = self.H * self.W
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, has_bias=True)
self.norm = nn.LayerNorm([embed_dim])
def construct(self, x):
x = self.proj(x)
B, C, H, W = x.shape
x = ops.transpose(ops.reshape(x, (B, C, H * W)), (0, 2, 1))
x = self.norm(x)
return x, H, W
@register_model
def pvt_v2_b0(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3,
**kwargs) -> PyramidVisionTransformerV2:
"""Get PVTV2-b0 model
Refer to the base class "models.PVTv2" for more details.
"""
default_cfg = default_cfgs['pvt_v2_b0']
model = PyramidVisionTransformerV2(in_chans=in_channels, num_classes=num_classes,
patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,
norm_layer=partial(nn.LayerNorm, epsilon=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
**kwargs)
if pretrained:
load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)
return model
@register_model
def pvt_v2_b1(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3,
**kwargs) -> PyramidVisionTransformerV2:
"""Get PVTV2-b1 model
Refer to the base class "models.PVTv2" for more details.
"""
default_cfg = default_cfgs['pvt_v2_b1']
model = PyramidVisionTransformerV2(in_chans=in_channels, num_classes=num_classes,
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,
norm_layer=partial(nn.LayerNorm, epsilon=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
**kwargs)
if pretrained:
load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)
return model
@register_model
def pvt_v2_b2(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3,
**kwargs) -> PyramidVisionTransformerV2:
"""Get PVTV2-b2 model
Refer to the base class "models.PVTv2" for more details.
"""
default_cfg = default_cfgs['pvt_v2_b2']
model = PyramidVisionTransformerV2(in_chans=in_channels, num_classes=num_classes,
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,
norm_layer=partial(nn.LayerNorm, epsilon=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], **kwargs)
if pretrained:
load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)
return model
@register_model
def pvt_v2_b3(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3,
**kwargs) -> PyramidVisionTransformerV2:
"""Get PVTV2-b3 model
Refer to the base class "models.PVTv2" for more details.
"""
default_cfg = default_cfgs['pvt_v2_b3']
model = PyramidVisionTransformerV2(in_chans=in_channels, num_classes=num_classes,
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,
norm_layer=partial(nn.LayerNorm, epsilon=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], **kwargs)
if pretrained:
load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)
return model
@register_model
def pvt_v2_b4(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3,
**kwargs) -> PyramidVisionTransformerV2:
"""Get PVTV2-b4 model
Refer to the base class "models.PVTv2" for more details.
"""
default_cfg = default_cfgs['pvt_v2_b4']
model = PyramidVisionTransformerV2(in_chans=in_channels, num_classes=num_classes,
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,
norm_layer=partial(nn.LayerNorm, epsilon=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1],
**kwargs)
if pretrained:
load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)
return model
@register_model
def pvt_v2_b5(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3,
**kwargs) -> PyramidVisionTransformerV2:
"""Get PVTV2-b5 model
Refer to the base class "models.PVTv2" for more details.
"""
default_cfg = default_cfgs['pvt_v2_b5']
model = PyramidVisionTransformerV2(in_chans=in_channels, num_classes=num_classes,
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=True,
norm_layer=partial(nn.LayerNorm, epsilon=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1],
**kwargs)
if pretrained:
load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)
return model