"""
MindSpore implementation of `ConViT`.
Refer to ConViT: Improving Vision Transformers with Soft Convolutional Inductive Biases
"""
import numpy as np
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.ops import constexpr
from mindspore import Parameter, Tensor
import mindspore.common.initializer as init
from .layers.identity import Identity
from .layers.patch_embed import PatchEmbed
from .layers.drop_path import DropPath
from .layers.mlp import Mlp
from .utils import load_pretrained
from .registry import register_model
__all__ = [
"ConViT",
"convit_tiny",
"convit_tiny_plus",
"convit_small",
"convit_small_plus",
"convit_base",
"convit_base_plus"
]
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000,
'input_size': (3, 224, 224),
'first_conv': 'patch_embed.proj', 'classifier': 'classifier',
**kwargs
}
default_cfgs = {
'convit_tiny': _cfg(url=''),
'convit_tiny_plus': _cfg(url=''),
'convit_small': _cfg(url=''),
'convit_small_plus': _cfg(url=''),
'convit_base': _cfg(url=''),
'convit_base_plus': _cfg(url='')
}
@constexpr
def get_rel_indices(num_patches: int = 196) -> Tensor:
img_size = int(num_patches**.5)
rel_indices = ops.Zeros()((1, num_patches, num_patches, 3), ms.float32)
ind = ms.numpy.arange(img_size).view(1, -1) - ms.numpy.arange(img_size).view(-1, 1)
indx = ms.numpy.tile(ind, (img_size, img_size))
indy_ = ops.repeat_elements(ind, rep=img_size, axis=0)
indy = ops.repeat_elements(indy_, rep=img_size, axis=1)
indd = indx**2 + indy**2
rel_indices[:, :, :, 2] = ops.expand_dims(indd, 0)
rel_indices[:, :, :, 1] = ops.expand_dims(indy, 0)
rel_indices[:, :, :, 0] = ops.expand_dims(indx, 0)
return rel_indices
class GPSA(nn.Cell):
def __init__(self,
dim: int,
num_heads: int,
qkv_bias: bool = False,
attn_drop: float = 0.,
proj_drop: float = 0.) -> None:
super().__init__()
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.q = nn.Dense(in_channels=dim, out_channels=dim, has_bias=qkv_bias)
self.k = nn.Dense(in_channels=dim, out_channels=dim, has_bias=qkv_bias)
self.v = nn.Dense(in_channels=dim, out_channels=dim, has_bias=qkv_bias)
self.attn_drop = nn.Dropout(keep_prob=1.0 - attn_drop)
self.proj = nn.Dense(in_channels=dim, out_channels=dim)
self.pos_proj = nn.Dense(in_channels=3, out_channels=num_heads)
self.proj_drop = nn.Dropout(keep_prob=1.0 - proj_drop)
self.gating_param = Parameter(ops.ones((num_heads), ms.float32))
self.softmax = nn.Softmax(axis=-1)
self.batch_matmul = ops.BatchMatMul()
self.rel_indices = get_rel_indices()
def construct(self, x: Tensor) -> Tensor:
B, N, C = x.shape
attn = self.get_attention(x)
v = ops.reshape(self.v(x), (B, N, self.num_heads, C // self.num_heads))
v = ops.transpose(v, (0, 2, 1, 3))
x = ops.transpose(self.batch_matmul(attn, v), (0, 2, 1, 3))
x = ops.reshape(x, (B, N, C))
x = self.proj(x)
x = self.proj_drop(x)
return x
def get_attention(self, x: Tensor) -> Tensor:
B, N, C = x.shape
q = ops.reshape(self.q(x), (B, N, self.num_heads, C // self.num_heads))
q = ops.transpose(q, (0, 2, 1, 3))
k = ops.reshape(self.k(x), (B, N, self.num_heads, C // self.num_heads))
k = ops.transpose(k, (0, 2, 3, 1))
pos_score = self.pos_proj(self.rel_indices)
pos_score = ops.transpose(pos_score, (0, 3, 1, 2))
pos_score = self.softmax(pos_score)
patch_score = self.batch_matmul(q, k)
patch_score = ops.mul(patch_score, self.scale)
patch_score = self.softmax(patch_score)
gating = ops.reshape(self.gating_param, (1, -1, 1, 1))
gating = ops.Sigmoid()(gating)
attn = (1.-gating) * patch_score + gating * pos_score
attn = self.attn_drop(attn)
return attn
class MHSA(nn.Cell):
def __init__(self,
dim: int,
num_heads: int,
qkv_bias: bool = False,
attn_drop: float = 0.,
proj_drop: float = 0.) -> None:
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.q = nn.Dense(in_channels=dim, out_channels=dim, has_bias=qkv_bias)
self.k = nn.Dense(in_channels=dim, out_channels=dim, has_bias=qkv_bias)
self.v = nn.Dense(in_channels=dim, out_channels=dim, has_bias=qkv_bias)
self.attn_drop = nn.Dropout(keep_prob=1.0 - attn_drop)
self.proj = nn.Dense(in_channels=dim, out_channels=dim)
self.proj_drop = nn.Dropout(keep_prob=1.0 - proj_drop)
self.softmax = nn.Softmax(axis=-1)
self.batch_matmul = ops.BatchMatMul()
def construct(self, x: Tensor) -> Tensor:
B, N, C = x.shape
q = ops.reshape(self.q(x), (B, N, self.num_heads, C // self.num_heads))
q = ops.transpose(q, (0, 2, 1, 3))
k = ops.reshape(self.k(x), (B, N, self.num_heads, C // self.num_heads))
k = ops.transpose(k, (0, 2, 3, 1))
v = ops.reshape(self.v(x), (B, N, self.num_heads, C // self.num_heads))
v = ops.transpose(v, (0, 2, 1, 3))
attn = self.batch_matmul(q, k)
attn = ops.mul(attn, self.scale)
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = ops.transpose(self.batch_matmul(attn, v), (0, 2, 1, 3))
x = ops.reshape(x, (B, N, C))
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Cell):
"""Basic module of ConViT"""
def __init__(self,
dim: int,
num_heads: int,
mlp_ratio: float,
qkv_bias: bool = False,
drop: float = 0.,
attn_drop: float = 0.,
drop_path: float = 0.,
use_gpsa: bool = True,
**kwargs) -> None:
super().__init__()
self.norm1 = nn.LayerNorm((dim,))
if use_gpsa:
self.attn = GPSA(dim, num_heads=num_heads, qkv_bias=qkv_bias,
attn_drop=attn_drop, proj_drop=drop, **kwargs)
else:
self.attn = MHSA(dim, num_heads=num_heads, qkv_bias=qkv_bias,
attn_drop=attn_drop, proj_drop=drop, **kwargs)
self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
self.norm2 = nn.LayerNorm((dim,))
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer = nn.GELU, drop=drop)
def construct(self, x: Tensor) -> Tensor:
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
[文档]class ConViT(nn.Cell):
r"""ConViT model class, based on
'"Improving Vision Transformers with Soft Convolutional Inductive Biases"
<https://arxiv.org/pdf/2103.10697.pdf>'
Args:
in_channels (int): number the channels of the input. Default: 3.
num_classes (int) : number of classification classes. Default: 1000.
image_size (int) : images input size. Default: 224.
patch_size (int) : image patch size. Default: 16.
embed_dim (int) : embedding dimension in all head. Default: 48.
num_heads (int) : number of heads. Default: 12.
drop_rate (float) : dropout rate. Default: 0.
drop_path_rate (float) : drop path rate. Default: 0.1.
depth (int) : model block depth. Default: 12.
mlp_ratio (float) : ratio of hidden features in Mlp. Default: 4.
qkv_bias (bool) : have bias in qkv layers or not. Default: False.
attn_drop_rate (float) : attention layers dropout rate. Default: 0.
locality_strength (float) : determines how focused each head is around its attention center. Default: 1.
local_up_to_layer (int) : number of GPSA layers. Default: 10.
"""
def __init__(self,
in_channels: int = 3,
num_classes: int = 1000,
image_size: int = 224,
patch_size: int = 16,
embed_dim: int = 48,
num_heads: int = 12,
drop_rate: float = 0.,
drop_path_rate: float = 0.1,
depth: int = 12,
mlp_ratio: float = 4.,
qkv_bias: bool = False,
attn_drop_rate: float = 0.,
local_up_to_layer: int = 10,
use_pos_embed: bool = True,
locality_strength: float = 1.) -> None:
super().__init__()
self.local_up_to_layer = local_up_to_layer
self.use_pos_embed = use_pos_embed
self.num_heads = num_heads
self.locality_strength = locality_strength
self.embed_dim = embed_dim
self.patch_embed = PatchEmbed(
image_size=image_size, patch_size=patch_size, in_chans=in_channels, embed_dim=embed_dim)
self.num_patches = self.patch_embed.num_patches
self.cls_token = Parameter(ops.Zeros()((1, 1, embed_dim), ms.float32))
self.pos_drop = nn.Dropout(keep_prob=1.0 - drop_rate)
if self.use_pos_embed:
self.pos_embed = Parameter(ops.Zeros()((1, self.num_patches, embed_dim), ms.float32))
self.pos_embed.set_data(init.initializer(init.TruncatedNormal(sigma=.02), self.pos_embed.data.shape))
dpr = [x.item() for x in np.linspace(0, drop_path_rate, depth)]
self.blocks = nn.CellList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i],
use_gpsa=True)
if i<local_up_to_layer else
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i],
use_gpsa=False)
for i in range(depth)])
self.norm = nn.LayerNorm((embed_dim,))
self.classifier = nn.Dense(in_channels=embed_dim, out_channels=num_classes) if num_classes > 0 else Identity()
self.cls_token.set_data(init.initializer(init.TruncatedNormal(sigma=.02), self.cls_token.data.shape))
self._initialize_weights()
def _initialize_weights(self) -> None:
for _, cell in self.cells_and_names():
if isinstance(cell, nn.Dense):
cell.weight.set_data(init.initializer(init.TruncatedNormal(sigma=.02), cell.weight.data.shape))
if cell.bias is not None:
cell.bias.set_data(init.initializer(init.Constant(0), cell.bias.shape))
elif isinstance(cell, nn.LayerNorm):
cell.gamma.set_data(init.initializer(init.Constant(1), cell.gamma.shape))
cell.beta.set_data(init.initializer(init.Constant(0), cell.beta.shape))
# local init
for i in range(self.local_up_to_layer):
self.blocks[i].attn.v.weight.set_data(ops.eye(self.embed_dim, self.embed_dim, ms.float32), slice_shape=True)
locality_distance = 1
kernel_size = int(self.num_heads**.5)
center = (kernel_size - 1) / 2 if kernel_size % 2 == 0 else kernel_size // 2
pos_weight_data = self.blocks[i].attn.pos_proj.weight.data
for h1 in range(kernel_size):
for h2 in range(kernel_size):
position = h1+kernel_size*h2
pos_weight_data[position,2] = -1
pos_weight_data[position,1] = 2*(h1-center)*locality_distance
pos_weight_data[position,0] = 2*(h2-center)*locality_distance
pos_weight_data = pos_weight_data * self.locality_strength
self.blocks[i].attn.pos_proj.weight.set_data(pos_weight_data)
[文档] def forward_features(self, x: Tensor) -> Tensor:
x = self.patch_embed(x)
if self.use_pos_embed:
x = x + self.pos_embed
x = self.pos_drop(x)
cls_tokens = ops.tile(self.cls_token, (x.shape[0], 1, 1))
for u,blk in enumerate(self.blocks):
if u == self.local_up_to_layer:
x = ops.Cast()(x, cls_tokens.dtype)
x = ops.concat((cls_tokens, x), 1)
x = blk(x)
x = self.norm(x)
return x[:, 0]
[文档] def forward_head(self, x: Tensor) -> Tensor:
x = self.classifier(x)
return x
[文档] def construct(self, x: Tensor) -> Tensor:
x = self.forward_features(x)
x = self.forward_head(x)
return x
[文档]@register_model
def convit_tiny(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs) -> ConViT:
"""Get ConViT tiny model
Refer to the base class "models.ConViT" for more details.
"""
default_cfg = default_cfgs['convit_tiny']
model = ConViT(in_channels=in_channels, num_classes=num_classes,
num_heads=4, embed_dim=192, **kwargs)
if pretrained:
load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)
return model
[文档]@register_model
def convit_tiny_plus(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs) -> ConViT:
"""Get ConViT tiny+ model
Refer to the base class "models.ConViT" for more details.
"""
default_cfg = default_cfgs['convit_tiny_plus']
model = ConViT(in_channels=in_channels, num_classes=num_classes,
num_heads=4, embed_dim=256, **kwargs)
if pretrained:
load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)
return model
[文档]@register_model
def convit_small(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs) -> ConViT:
"""Get ConViT small model
Refer to the base class "models.ConViT" for more details.
"""
default_cfg = default_cfgs['convit_small']
model = ConViT(in_channels=in_channels, num_classes=num_classes,
num_heads=9, embed_dim=432, **kwargs)
if pretrained:
load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)
return model
[文档]@register_model
def convit_small_plus(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs) -> ConViT:
"""Get ConViT small+ model
Refer to the base class "models.ConViT" for more details.
"""
default_cfg = default_cfgs['convit_small_plus']
model = ConViT(in_channels=in_channels, num_classes=num_classes,
num_heads=9, embed_dim=576, **kwargs)
if pretrained:
load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)
return model
[文档]@register_model
def convit_base(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs) -> ConViT:
"""Get ConViT base model
Refer to the base class "models.ConViT" for more details.
"""
default_cfg = default_cfgs['convit_base']
model = ConViT(in_channels=in_channels, num_classes=num_classes,
num_heads=16, embed_dim=768, **kwargs)
if pretrained:
load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)
return model
[文档]@register_model
def convit_base_plus(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs) -> ConViT:
"""Get ConViT base+ model
Refer to the base class "models.ConViT" for more details.
"""
default_cfg = default_cfgs['convit_base_plus']
model = ConViT(in_channels=in_channels, num_classes=num_classes,
num_heads=16, embed_dim=1024, **kwargs)
if pretrained:
load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)
return model