first commit

This commit is contained in:
YYJ-aaaa 2024-11-13 11:19:16 +08:00
parent 8a151fcb32
commit 5de8f3e15b
7 changed files with 2777 additions and 0 deletions

63
LICENSE Normal file
View File

@ -0,0 +1,63 @@
Copyright (c) 2024, NVIDIA Corporation. All rights reserved.
Nvidia Source Code License-NC
1. Definitions
“Licensor” means any person or entity that distributes its Work.
“Work” means (a) the original work of authorship made available under this license, which may include software, documentation,
or other files, and (b) any additions to or derivative works thereof that are made available under this license.
The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S.
copyright law; provided, however, that for the purposes of this license, derivative works shall not include works that
remain separable from, or merely link (or bind by name) to the interfaces of, the Work.
Works are “made available” under this license by including in or with the Work either (a) a copyright notice referencing
the applicability of this license to the Work, or (b) a copy of this license.
2. License Grant
2.1 Copyright Grant. Subject to the terms and conditions of this license, each Licensor grants to you a perpetual,
worldwide, non-exclusive, royalty-free, copyright license to use, reproduce, prepare derivative works of, publicly
display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form.
3. Limitations
3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this license, (b) you include a
complete copy of this license with your distribution, and (c) you retain without modification any copyright, patent,
trademark, or attribution notices that are present in the Work.
3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution
of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation in Section 3.3
applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms.
Notwithstanding Your Terms, this license (including the redistribution requirements in Section 3.1) will continue to apply
to the Work itself.
3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially.
Notwithstanding the foregoing, NVIDIA Corporation and its affiliates may use the Work and any derivative works commercially.
As used herein, “non-commercially” means for research or evaluation purposes only.
3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim
or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under
this license from such Licensor (including the grant in Section 2.1) will terminate immediately.
3.5 Trademarks. This license does not grant any rights to use any Licensors or its affiliates names, logos, or trademarks,
except as necessary to reproduce the notices described in this license.
3.6 Termination. If you violate any term of this license, then your rights under this license (including the grant in Section 2.1)
will terminate immediately.
4. Disclaimer of Warranty.
THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES
OR CONDITIONS OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING
ANY ACTIVITIES UNDER THIS LICENSE.
5. Limitation of Liability.
EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT,
OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL
DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL,
BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER DAMAGES OR LOSSES), EVEN IF THE LICENSOR
HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.

1051
config.json Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,28 @@
from transformers import PretrainedConfig
class MambaVisionConfig(PretrainedConfig):
model_type = "mambavision"
def __init__(
self,
depths=[3, 3, 12, 5],
num_heads=[4, 8, 16, 32],
window_size=[8, 8, 14, 7],
dim=196,
in_dim=64,
mlp_ratio=4,
drop_path_rate=0.3,
layer_scale=1e-5,
layer_scale_conv=None,
**kwargs,
):
self.depths = depths
self.num_heads = num_heads
self.window_size = window_size
self.dim = dim
self.in_dim = in_dim
self.mlp_ratio = mlp_ratio
self.drop_path_rate = drop_path_rate
self.layer_scale=layer_scale
self.layer_scale_conv=layer_scale_conv
super().__init__(**kwargs)

865
mamba_vision.py Normal file
View File

@ -0,0 +1,865 @@
#!/usr/bin/env python3
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import torch
import torch.nn as nn
from timm.models.registry import register_model
import math
from timm.models.layers import trunc_normal_, DropPath, LayerNorm2d
from timm.models._builder import resolve_pretrained_cfg
try:
from timm.models._builder import _update_default_kwargs as update_args
except:
from timm.models._builder import _update_default_model_kwargs as update_args
from timm.models.vision_transformer import Mlp, PatchEmbed
from timm.models.layers import DropPath, trunc_normal_
from timm.models.registry import register_model
import torch.nn.functional as F
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
from einops import rearrange, repeat
from pathlib import Path
from huggingface_hub import PyTorchModelHubMixin
def _cfg(url='', **kwargs):
return {'url': url,
'num_classes': 1000,
'input_size': (3, 224, 224),
'pool_size': None,
'crop_pct': 0.875,
'interpolation': 'bicubic',
'fixed_input_size': True,
'mean': (0.485, 0.456, 0.406),
'std': (0.229, 0.224, 0.225),
**kwargs
}
default_cfgs = {
'mamba_vision_T': _cfg(url='https://huggingface.co/nvidia/MambaVision-T-1K/resolve/main/mambavision_tiny_1k.pth.tar',
crop_pct=1.0,
input_size=(3, 224, 224),
crop_mode='center'),
'mamba_vision_T2': _cfg(url='https://huggingface.co/nvidia/MambaVision-T2-1K/resolve/main/mambavision_tiny2_1k.pth.tar',
crop_pct=0.98,
input_size=(3, 224, 224),
crop_mode='center'),
'mamba_vision_S': _cfg(url='https://huggingface.co/nvidia/MambaVision-S-1K/resolve/main/mambavision_small_1k.pth.tar',
crop_pct=0.93,
input_size=(3, 224, 224),
crop_mode='center'),
'mamba_vision_B': _cfg(url='https://huggingface.co/nvidia/MambaVision-B-1K/resolve/main/mambavision_base_1k.pth.tar',
crop_pct=1.0,
input_size=(3, 224, 224),
crop_mode='center'),
'mamba_vision_L': _cfg(url='https://huggingface.co/nvidia/MambaVision-L-1K/resolve/main/mambavision_large_1k.pth.tar',
crop_pct=1.0,
input_size=(3, 224, 224),
crop_mode='center'),
'mamba_vision_L2': _cfg(url='https://huggingface.co/nvidia/MambaVision-L2-1K/resolve/main/mambavision_large2_1k.pth.tar',
crop_pct=1.0,
input_size=(3, 224, 224),
crop_mode='center')
}
def window_partition(x, window_size):
"""
Args:
x: (B, C, H, W)
window_size: window size
h_w: Height of window
w_w: Width of window
Returns:
local window features (num_windows*B, window_size*window_size, C)
"""
B, C, H, W = x.shape
x = x.view(B, C, H // window_size, window_size, W // window_size, window_size)
windows = x.permute(0, 2, 4, 3, 5, 1).reshape(-1, window_size*window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: local window features (num_windows*B, window_size, window_size, C)
window_size: Window size
H: Height of image
W: Width of image
Returns:
x: (B, C, H, W)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.reshape(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 5, 1, 3, 2, 4).reshape(B,windows.shape[2], H, W)
return x
def _load_state_dict(module, state_dict, strict=False, logger=None):
"""Load state_dict to a module.
This method is modified from :meth:`torch.nn.Module.load_state_dict`.
Default value for ``strict`` is set to ``False`` and the message for
param mismatch will be shown even if strict is False.
Args:
module (Module): Module that receives the state_dict.
state_dict (OrderedDict): Weights.
strict (bool): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
logger (:obj:`logging.Logger`, optional): Logger to log the error
message. If not specified, print function will be used.
"""
unexpected_keys = []
all_missing_keys = []
err_msg = []
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(
prefix[:-1], {})
module._load_from_state_dict(state_dict, prefix, local_metadata, True,
all_missing_keys, unexpected_keys,
err_msg)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
load(module)
load = None
missing_keys = [
key for key in all_missing_keys if 'num_batches_tracked' not in key
]
if unexpected_keys:
err_msg.append('unexpected key in source '
f'state_dict: {", ".join(unexpected_keys)}\n')
if missing_keys:
err_msg.append(
f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
if len(err_msg) > 0:
err_msg.insert(
0, 'The model and loaded state dict do not match exactly\n')
err_msg = '\n'.join(err_msg)
if strict:
raise RuntimeError(err_msg)
elif logger is not None:
logger.warning(err_msg)
else:
print(err_msg)
def _load_checkpoint(model,
filename,
map_location='cpu',
strict=False,
logger=None):
"""Load checkpoint from a file or URI.
Args:
model (Module): Module to load checkpoint.
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
details.
map_location (str): Same as :func:`torch.load`.
strict (bool): Whether to allow different params for the model and
checkpoint.
logger (:mod:`logging.Logger` or None): The logger for error message.
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
checkpoint = torch.load(filename, map_location=map_location)
if not isinstance(checkpoint, dict):
raise RuntimeError(
f'No state_dict found in checkpoint file {filename}')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
state_dict = checkpoint['model']
else:
state_dict = checkpoint
if list(state_dict.keys())[0].startswith('module.'):
state_dict = {k[7:]: v for k, v in state_dict.items()}
if sorted(list(state_dict.keys()))[0].startswith('encoder'):
state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}
_load_state_dict(model, state_dict, strict, logger)
return checkpoint
class Downsample(nn.Module):
"""
Down-sampling block"
"""
def __init__(self,
dim,
keep_dim=False,
):
"""
Args:
dim: feature size dimension.
norm_layer: normalization layer.
keep_dim: bool argument for maintaining the resolution.
"""
super().__init__()
if keep_dim:
dim_out = dim
else:
dim_out = 2 * dim
self.reduction = nn.Sequential(
nn.Conv2d(dim, dim_out, 3, 2, 1, bias=False),
)
def forward(self, x):
x = self.reduction(x)
return x
class PatchEmbed(nn.Module):
"""
Patch embedding block"
"""
def __init__(self, in_chans=3, in_dim=64, dim=96):
"""
Args:
in_chans: number of input channels.
dim: feature size dimension.
"""
# in_dim = 1
super().__init__()
self.proj = nn.Identity()
self.conv_down = nn.Sequential(
nn.Conv2d(in_chans, in_dim, 3, 2, 1, bias=False),
nn.BatchNorm2d(in_dim, eps=1e-4),
nn.ReLU(),
nn.Conv2d(in_dim, dim, 3, 2, 1, bias=False),
nn.BatchNorm2d(dim, eps=1e-4),
nn.ReLU()
)
def forward(self, x):
x = self.proj(x)
x = self.conv_down(x)
return x
class ConvBlock(nn.Module):
def __init__(self, dim,
drop_path=0.,
layer_scale=None,
kernel_size=3):
super().__init__()
self.conv1 = nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, padding=1)
self.norm1 = nn.BatchNorm2d(dim, eps=1e-5)
self.act1 = nn.GELU(approximate= 'tanh')
self.conv2 = nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, padding=1)
self.norm2 = nn.BatchNorm2d(dim, eps=1e-5)
self.layer_scale = layer_scale
if layer_scale is not None and type(layer_scale) in [int, float]:
self.gamma = nn.Parameter(layer_scale * torch.ones(dim))
self.layer_scale = True
else:
self.layer_scale = False
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
input = x
x = self.conv1(x)
x = self.norm1(x)
x = self.act1(x)
x = self.conv2(x)
x = self.norm2(x)
if self.layer_scale:
x = x * self.gamma.view(1, -1, 1, 1)
x = input + self.drop_path(x)
return x
class MambaVisionMixer(nn.Module):
def __init__(
self,
d_model,
d_state=16,
d_conv=4,
expand=2,
dt_rank="auto",
dt_min=0.001,
dt_max=0.1,
dt_init="random",
dt_scale=1.0,
dt_init_floor=1e-4,
conv_bias=True,
bias=False,
use_fast_path=True,
layer_idx=None,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.d_model = d_model
self.d_state = d_state
self.d_conv = d_conv
self.expand = expand
self.d_inner = int(self.expand * self.d_model)
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
self.use_fast_path = use_fast_path
self.layer_idx = layer_idx
self.in_proj = nn.Linear(self.d_model, self.d_inner, bias=bias, **factory_kwargs)
self.x_proj = nn.Linear(
self.d_inner//2, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
)
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner//2, bias=True, **factory_kwargs)
dt_init_std = self.dt_rank**-0.5 * dt_scale
if dt_init == "constant":
nn.init.constant_(self.dt_proj.weight, dt_init_std)
elif dt_init == "random":
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
else:
raise NotImplementedError
dt = torch.exp(
torch.rand(self.d_inner//2, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
).clamp(min=dt_init_floor)
inv_dt = dt + torch.log(-torch.expm1(-dt))
with torch.no_grad():
self.dt_proj.bias.copy_(inv_dt)
self.dt_proj.bias._no_reinit = True
A = repeat(
torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
"n -> d n",
d=self.d_inner//2,
).contiguous()
A_log = torch.log(A)
self.A_log = nn.Parameter(A_log)
self.A_log._no_weight_decay = True
self.D = nn.Parameter(torch.ones(self.d_inner//2, device=device))
self.D._no_weight_decay = True
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
self.conv1d_x = nn.Conv1d(
in_channels=self.d_inner//2,
out_channels=self.d_inner//2,
bias=conv_bias//2,
kernel_size=d_conv,
groups=self.d_inner//2,
**factory_kwargs,
)
self.conv1d_z = nn.Conv1d(
in_channels=self.d_inner//2,
out_channels=self.d_inner//2,
bias=conv_bias//2,
kernel_size=d_conv,
groups=self.d_inner//2,
**factory_kwargs,
)
def forward(self, hidden_states):
"""
hidden_states: (B, L, D)
Returns: same shape as hidden_states
"""
_, seqlen, _ = hidden_states.shape
xz = self.in_proj(hidden_states)
xz = rearrange(xz, "b l d -> b d l")
x, z = xz.chunk(2, dim=1)
A = -torch.exp(self.A_log.float())
x = F.silu(F.conv1d(input=x, weight=self.conv1d_x.weight, bias=self.conv1d_x.bias, padding='same', groups=self.d_inner//2))
z = F.silu(F.conv1d(input=z, weight=self.conv1d_z.weight, bias=self.conv1d_z.bias, padding='same', groups=self.d_inner//2))
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d"))
dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
dt = rearrange(self.dt_proj(dt), "(b l) d -> b d l", l=seqlen)
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
y = selective_scan_fn(x,
dt,
A,
B,
C,
self.D.float(),
z=None,
delta_bias=self.dt_proj.bias.float(),
delta_softplus=True,
return_last_state=None)
y = torch.cat([y, z], dim=1)
y = rearrange(y, "b d l -> b l d")
out = self.out_proj(y)
return out
class Attention(nn.Module):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
qk_norm=False,
attn_drop=0.,
proj_drop=0.,
norm_layer=nn.LayerNorm,
):
super().__init__()
assert dim % num_heads == 0
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.fused_attn = True
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
if self.fused_attn:
x = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p,
)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self,
dim,
num_heads,
counter,
transformer_blocks,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=False,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
Mlp_block=Mlp,
layer_scale=None,
):
super().__init__()
self.norm1 = norm_layer(dim)
if counter in transformer_blocks:
self.mixer = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_norm=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
norm_layer=norm_layer,
)
else:
self.mixer = MambaVisionMixer(d_model=dim,
d_state=8,
d_conv=3,
expand=1
)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
use_layer_scale = layer_scale is not None and type(layer_scale) in [int, float]
self.gamma_1 = nn.Parameter(layer_scale * torch.ones(dim)) if use_layer_scale else 1
self.gamma_2 = nn.Parameter(layer_scale * torch.ones(dim)) if use_layer_scale else 1
def forward(self, x):
x = x + self.drop_path(self.gamma_1 * self.mixer(self.norm1(x)))
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
return x
class MambaVisionLayer(nn.Module):
"""
MambaVision layer"
"""
def __init__(self,
dim,
depth,
num_heads,
window_size,
conv=False,
downsample=True,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
layer_scale=None,
layer_scale_conv=None,
transformer_blocks = [],
):
"""
Args:
dim: feature size dimension.
depth: number of layers in each stage.
window_size: window size in each stage.
conv: bool argument for conv stage flag.
downsample: bool argument for down-sampling.
mlp_ratio: MLP ratio.
num_heads: number of heads in each stage.
qkv_bias: bool argument for query, key, value learnable bias.
qk_scale: bool argument to scaling query, key.
drop: dropout rate.
attn_drop: attention dropout rate.
drop_path: drop path rate.
norm_layer: normalization layer.
layer_scale: layer scaling coefficient.
layer_scale_conv: conv layer scaling coefficient.
transformer_blocks: list of transformer blocks.
"""
super().__init__()
self.conv = conv
self.transformer_block = False
if conv:
self.blocks = nn.ModuleList([ConvBlock(dim=dim,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
layer_scale=layer_scale_conv)
for i in range(depth)])
self.transformer_block = False
else:
self.transformer_block = True
self.blocks = nn.ModuleList([Block(dim=dim,
counter=i,
transformer_blocks=transformer_blocks,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
layer_scale=layer_scale)
for i in range(depth)])
self.transformer_block = True
self.downsample = None if not downsample else Downsample(dim=dim)
self.do_gt = False
self.window_size = window_size
def forward(self, x):
_, _, H, W = x.shape
if self.transformer_block:
pad_r = (self.window_size - W % self.window_size) % self.window_size
pad_b = (self.window_size - H % self.window_size) % self.window_size
if pad_r > 0 or pad_b > 0:
x = torch.nn.functional.pad(x, (0,pad_r,0,pad_b))
_, _, Hp, Wp = x.shape
else:
Hp, Wp = H, W
x = window_partition(x, self.window_size)
for _, blk in enumerate(self.blocks):
x = blk(x)
if self.transformer_block:
x = window_reverse(x, self.window_size, Hp, Wp)
if pad_r > 0 or pad_b > 0:
x = x[:, :, :H, :W].contiguous()
if self.downsample is None:
return x
return self.downsample(x)
class MambaVision(nn.Module, PyTorchModelHubMixin):
"""
MambaVision,
"""
def __init__(self,
dim,
in_dim,
depths,
window_size,
mlp_ratio,
num_heads,
drop_path_rate=0.2,
in_chans=3,
num_classes=1000,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
layer_scale=None,
layer_scale_conv=None,
**kwargs):
"""
Args:
dim: feature size dimension.
depths: number of layers in each stage.
window_size: window size in each stage.
mlp_ratio: MLP ratio.
num_heads: number of heads in each stage.
drop_path_rate: drop path rate.
in_chans: number of input channels.
num_classes: number of classes.
qkv_bias: bool argument for query, key, value learnable bias.
qk_scale: bool argument to scaling query, key.
drop_rate: dropout rate.
attn_drop_rate: attention dropout rate.
norm_layer: normalization layer.
layer_scale: layer scaling coefficient.
layer_scale_conv: conv layer scaling coefficient.
"""
super().__init__()
num_features = int(dim * 2 ** (len(depths) - 1))
self.num_classes = num_classes
self.patch_embed = PatchEmbed(in_chans=in_chans, in_dim=in_dim, dim=dim)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
self.levels = nn.ModuleList()
for i in range(len(depths)):
conv = True if (i == 0 or i == 1) else False
level = MambaVisionLayer(dim=int(dim * 2 ** i),
depth=depths[i],
num_heads=num_heads[i],
window_size=window_size[i],
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
conv=conv,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])],
downsample=(i < 3),
layer_scale=layer_scale,
layer_scale_conv=layer_scale_conv,
transformer_blocks=list(range(depths[i]//2+1, depths[i])) if depths[i]%2!=0 else list(range(depths[i]//2, depths[i])),
)
self.levels.append(level)
self.norm = nn.BatchNorm2d(num_features)
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.head = nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, LayerNorm2d):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {'rpb'}
def forward_features(self, x):
x = self.patch_embed(x)
for level in self.levels:
x = level(x)
x = self.norm(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
def _load_state_dict(self,
pretrained,
strict: bool = False):
_load_checkpoint(self,
pretrained,
strict=strict)
@register_model
def mamba_vision_T(pretrained=False, **kwargs):
model_path = kwargs.pop("model_path", "/tmp/mamba_vision_T.pth.tar")
pretrained_cfg = resolve_pretrained_cfg('mamba_vision_T').to_dict()
update_args(pretrained_cfg, kwargs, kwargs_filter=None)
model = MambaVision(depths=[1, 3, 8, 4],
num_heads=[2, 4, 8, 16],
window_size=[8, 8, 14, 7],
dim=80,
in_dim=32,
mlp_ratio=4,
resolution=224,
drop_path_rate=0.2,
**kwargs)
model.pretrained_cfg = pretrained_cfg
model.default_cfg = model.pretrained_cfg
if pretrained:
if not Path(model_path).is_file():
url = model.default_cfg['url']
torch.hub.download_url_to_file(url=url, dst=model_path)
model._load_state_dict(model_path)
return model
@register_model
def mamba_vision_T2(pretrained=False, **kwargs):
model_path = kwargs.pop("model_path", "/tmp/mamba_vision_T2.pth.tar")
pretrained_cfg = resolve_pretrained_cfg('mamba_vision_T2').to_dict()
update_args(pretrained_cfg, kwargs, kwargs_filter=None)
model = MambaVision(depths=[1, 3, 11, 4],
num_heads=[2, 4, 8, 16],
window_size=[8, 8, 14, 7],
dim=80,
in_dim=32,
mlp_ratio=4,
resolution=224,
drop_path_rate=0.2,
**kwargs)
model.pretrained_cfg = pretrained_cfg
model.default_cfg = model.pretrained_cfg
if pretrained:
if not Path(model_path).is_file():
url = model.default_cfg['url']
torch.hub.download_url_to_file(url=url, dst=model_path)
model._load_state_dict(model_path)
return model
@register_model
def mamba_vision_S(pretrained=False, **kwargs):
model_path = kwargs.pop("model_path", "/tmp/mamba_vision_S.pth.tar")
pretrained_cfg = resolve_pretrained_cfg('mamba_vision_S').to_dict()
update_args(pretrained_cfg, kwargs, kwargs_filter=None)
model = MambaVision(depths=[3, 3, 7, 5],
num_heads=[2, 4, 8, 16],
window_size=[8, 8, 14, 7],
dim=96,
in_dim=64,
mlp_ratio=4,
resolution=224,
drop_path_rate=0.2,
**kwargs)
model.pretrained_cfg = pretrained_cfg
model.default_cfg = model.pretrained_cfg
if pretrained:
if not Path(model_path).is_file():
url = model.default_cfg['url']
torch.hub.download_url_to_file(url=url, dst=model_path)
model._load_state_dict(model_path)
return model
@register_model
def mamba_vision_B(pretrained=False, **kwargs):
model_path = kwargs.pop("model_path", "/tmp/mamba_vision_B.pth.tar")
pretrained_cfg = resolve_pretrained_cfg('mamba_vision_B').to_dict()
update_args(pretrained_cfg, kwargs, kwargs_filter=None)
model = MambaVision(depths=[3, 3, 10, 5],
num_heads=[2, 4, 8, 16],
window_size=[8, 8, 14, 7],
dim=128,
in_dim=64,
mlp_ratio=4,
resolution=224,
drop_path_rate=0.3,
layer_scale=1e-5,
layer_scale_conv=None,
**kwargs)
model.pretrained_cfg = pretrained_cfg
model.default_cfg = model.pretrained_cfg
if pretrained:
if not Path(model_path).is_file():
url = model.default_cfg['url']
torch.hub.download_url_to_file(url=url, dst=model_path)
model._load_state_dict(model_path)
return model
@register_model
def mamba_vision_L(pretrained=False, **kwargs):
model_path = kwargs.pop("model_path", "/tmp/mamba_vision_L.pth.tar")
pretrained_cfg = resolve_pretrained_cfg('mamba_vision_L').to_dict()
update_args(pretrained_cfg, kwargs, kwargs_filter=None)
model = MambaVision(depths=[3, 3, 10, 5],
num_heads=[4, 8, 16, 32],
window_size=[8, 8, 14, 7],
dim=196,
in_dim=64,
mlp_ratio=4,
resolution=224,
drop_path_rate=0.3,
layer_scale=1e-5,
layer_scale_conv=None,
**kwargs)
model.pretrained_cfg = pretrained_cfg
model.default_cfg = model.pretrained_cfg
if pretrained:
if not Path(model_path).is_file():
url = model.default_cfg['url']
torch.hub.download_url_to_file(url=url, dst=model_path)
model._load_state_dict(model_path)
return model
@register_model
def mamba_vision_L2(pretrained=False, **kwargs):
model_path = kwargs.pop("model_path", "/tmp/mamba_vision_L2.pth.tar")
pretrained_cfg = resolve_pretrained_cfg('mamba_vision_L2').to_dict()
update_args(pretrained_cfg, kwargs, kwargs_filter=None)
model = MambaVision(depths=[3, 3, 12, 5],
num_heads=[4, 8, 16, 32],
window_size=[8, 8, 14, 7],
dim=196,
in_dim=64,
mlp_ratio=4,
resolution=224,
drop_path_rate=0.3,
layer_scale=1e-5,
layer_scale_conv=None,
**kwargs)
model.pretrained_cfg = pretrained_cfg
model.default_cfg = model.pretrained_cfg
if pretrained:
if not Path(model_path).is_file():
url = model.default_cfg['url']
torch.hub.download_url_to_file(url=url, dst=model_path)
model._load_state_dict(model_path)
return model

BIN
mambavision_large2_1k.pth.tar (Stored with Git LFS) Normal file

Binary file not shown.

BIN
model.safetensors (Stored with Git LFS) Normal file

Binary file not shown.

764
modeling_mambavision.py Normal file
View File

@ -0,0 +1,764 @@
#!/usr/bin/env python3
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import torch
import torch.nn as nn
from timm.models.registry import register_model
import math
from timm.models.layers import trunc_normal_, DropPath, LayerNorm2d
from timm.models._builder import resolve_pretrained_cfg
try:
from timm.models._builder import _update_default_kwargs as update_args
except:
from timm.models._builder import _update_default_model_kwargs as update_args
from timm.models.vision_transformer import Mlp, PatchEmbed
from timm.models.layers import DropPath, trunc_normal_
from timm.models.registry import register_model
import torch.nn.functional as F
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
from einops import rearrange, repeat
from transformers import PreTrainedModel
from configuration_mambavision import MambaVisionConfig
def _cfg(url='', **kwargs):
return {'url': url,
'num_classes': 1000,
'input_size': (3, 224, 224),
'pool_size': None,
'crop_pct': 0.875,
'interpolation': 'bicubic',
'fixed_input_size': True,
'mean': (0.485, 0.456, 0.406),
'std': (0.229, 0.224, 0.225),
**kwargs
}
default_cfgs = {
'mamba_vision_T': _cfg(url='https://huggingface.co/nvidia/MambaVision-T-1K/resolve/main/mambavision_tiny_1k.pth.tar',
crop_pct=1.0,
input_size=(3, 224, 224),
crop_mode='center'),
'mamba_vision_T2': _cfg(url='https://huggingface.co/nvidia/MambaVision-T2-1K/resolve/main/mambavision_tiny2_1k.pth.tar',
crop_pct=0.98,
input_size=(3, 224, 224),
crop_mode='center'),
'mamba_vision_S': _cfg(url='https://huggingface.co/nvidia/MambaVision-S-1K/resolve/main/mambavision_small_1k.pth.tar',
crop_pct=0.93,
input_size=(3, 224, 224),
crop_mode='center'),
'mamba_vision_B': _cfg(url='https://huggingface.co/nvidia/MambaVision-B-1K/resolve/main/mambavision_base_1k.pth.tar',
crop_pct=1.0,
input_size=(3, 224, 224),
crop_mode='center'),
'mamba_vision_L': _cfg(url='https://huggingface.co/nvidia/MambaVision-L-1K/resolve/main/mambavision_large_1k.pth.tar',
crop_pct=1.0,
input_size=(3, 224, 224),
crop_mode='center'),
'mamba_vision_L2': _cfg(url='https://huggingface.co/nvidia/MambaVision-L2-1K/resolve/main/mambavision_large2_1k.pth.tar',
crop_pct=1.0,
input_size=(3, 224, 224),
crop_mode='center')
}
def window_partition(x, window_size):
"""
Args:
x: (B, C, H, W)
window_size: window size
h_w: Height of window
w_w: Width of window
Returns:
local window features (num_windows*B, window_size*window_size, C)
"""
B, C, H, W = x.shape
x = x.view(B, C, H // window_size, window_size, W // window_size, window_size)
windows = x.permute(0, 2, 4, 3, 5, 1).reshape(-1, window_size*window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: local window features (num_windows*B, window_size, window_size, C)
window_size: Window size
H: Height of image
W: Width of image
Returns:
x: (B, C, H, W)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.reshape(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 5, 1, 3, 2, 4).reshape(B,windows.shape[2], H, W)
return x
def _load_state_dict(module, state_dict, strict=False, logger=None):
"""Load state_dict to a module.
This method is modified from :meth:`torch.nn.Module.load_state_dict`.
Default value for ``strict`` is set to ``False`` and the message for
param mismatch will be shown even if strict is False.
Args:
module (Module): Module that receives the state_dict.
state_dict (OrderedDict): Weights.
strict (bool): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
logger (:obj:`logging.Logger`, optional): Logger to log the error
message. If not specified, print function will be used.
"""
unexpected_keys = []
all_missing_keys = []
err_msg = []
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(
prefix[:-1], {})
module._load_from_state_dict(state_dict, prefix, local_metadata, True,
all_missing_keys, unexpected_keys,
err_msg)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
load(module)
load = None
missing_keys = [
key for key in all_missing_keys if 'num_batches_tracked' not in key
]
if unexpected_keys:
err_msg.append('unexpected key in source '
f'state_dict: {", ".join(unexpected_keys)}\n')
if missing_keys:
err_msg.append(
f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
if len(err_msg) > 0:
err_msg.insert(
0, 'The model and loaded state dict do not match exactly\n')
err_msg = '\n'.join(err_msg)
if strict:
raise RuntimeError(err_msg)
elif logger is not None:
logger.warning(err_msg)
else:
print(err_msg)
def _load_checkpoint(model,
filename,
map_location='cpu',
strict=False,
logger=None):
"""Load checkpoint from a file or URI.
Args:
model (Module): Module to load checkpoint.
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
details.
map_location (str): Same as :func:`torch.load`.
strict (bool): Whether to allow different params for the model and
checkpoint.
logger (:mod:`logging.Logger` or None): The logger for error message.
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
checkpoint = torch.load(filename, map_location=map_location)
if not isinstance(checkpoint, dict):
raise RuntimeError(
f'No state_dict found in checkpoint file {filename}')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
state_dict = checkpoint['model']
else:
state_dict = checkpoint
if list(state_dict.keys())[0].startswith('module.'):
state_dict = {k[7:]: v for k, v in state_dict.items()}
if sorted(list(state_dict.keys()))[0].startswith('encoder'):
state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}
_load_state_dict(model, state_dict, strict, logger)
return checkpoint
class Downsample(nn.Module):
"""
Down-sampling block"
"""
def __init__(self,
dim,
keep_dim=False,
):
"""
Args:
dim: feature size dimension.
norm_layer: normalization layer.
keep_dim: bool argument for maintaining the resolution.
"""
super().__init__()
if keep_dim:
dim_out = dim
else:
dim_out = 2 * dim
self.reduction = nn.Sequential(
nn.Conv2d(dim, dim_out, 3, 2, 1, bias=False),
)
def forward(self, x):
x = self.reduction(x)
return x
class PatchEmbed(nn.Module):
"""
Patch embedding block"
"""
def __init__(self, in_chans=3, in_dim=64, dim=96):
"""
Args:
in_chans: number of input channels.
dim: feature size dimension.
"""
# in_dim = 1
super().__init__()
self.proj = nn.Identity()
self.conv_down = nn.Sequential(
nn.Conv2d(in_chans, in_dim, 3, 2, 1, bias=False),
nn.BatchNorm2d(in_dim, eps=1e-4),
nn.ReLU(),
nn.Conv2d(in_dim, dim, 3, 2, 1, bias=False),
nn.BatchNorm2d(dim, eps=1e-4),
nn.ReLU()
)
def forward(self, x):
x = self.proj(x)
x = self.conv_down(x)
return x
class ConvBlock(nn.Module):
def __init__(self, dim,
drop_path=0.,
layer_scale=None,
kernel_size=3):
super().__init__()
self.conv1 = nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, padding=1)
self.norm1 = nn.BatchNorm2d(dim, eps=1e-5)
self.act1 = nn.GELU(approximate= 'tanh')
self.conv2 = nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, padding=1)
self.norm2 = nn.BatchNorm2d(dim, eps=1e-5)
self.layer_scale = layer_scale
if layer_scale is not None and type(layer_scale) in [int, float]:
self.g = nn.Parameter(layer_scale * torch.ones(dim))
self.layer_scale = True
else:
self.layer_scale = False
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
input = x
x = self.conv1(x)
x = self.norm1(x)
x = self.act1(x)
x = self.conv2(x)
x = self.norm2(x)
if self.layer_scale:
x = x * self.g.view(1, -1, 1, 1)
x = input + self.drop_path(x)
return x
class MambaVisionMixer(nn.Module):
def __init__(
self,
d_model,
d_state=16,
d_conv=4,
expand=2,
dt_rank="auto",
dt_min=0.001,
dt_max=0.1,
dt_init="random",
dt_scale=1.0,
dt_init_floor=1e-4,
conv_bias=True,
bias=False,
use_fast_path=True,
layer_idx=None,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.d_model = d_model
self.d_state = d_state
self.d_conv = d_conv
self.expand = expand
self.d_inner = int(self.expand * self.d_model)
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
self.use_fast_path = use_fast_path
self.layer_idx = layer_idx
self.in_proj = nn.Linear(self.d_model, self.d_inner, bias=bias, **factory_kwargs)
self.x_proj = nn.Linear(
self.d_inner//2, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
)
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner//2, bias=True, **factory_kwargs)
dt_init_std = self.dt_rank**-0.5 * dt_scale
if dt_init == "constant":
nn.init.constant_(self.dt_proj.weight, dt_init_std)
elif dt_init == "random":
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
else:
raise NotImplementedError
dt = torch.exp(
torch.rand(self.d_inner//2, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
).clamp(min=dt_init_floor)
inv_dt = dt + torch.log(-torch.expm1(-dt))
with torch.no_grad():
self.dt_proj.bias.copy_(inv_dt)
self.dt_proj.bias._no_reinit = True
A = repeat(
torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
"n -> d n",
d=self.d_inner//2,
).contiguous()
A_log = torch.log(A)
self.A_log = nn.Parameter(A_log)
self.A_log._no_weight_decay = True
self.D = nn.Parameter(torch.ones(self.d_inner//2, device=device))
self.D._no_weight_decay = True
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
self.conv1d_x = nn.Conv1d(
in_channels=self.d_inner//2,
out_channels=self.d_inner//2,
bias=conv_bias//2,
kernel_size=d_conv,
groups=self.d_inner//2,
**factory_kwargs,
)
self.conv1d_z = nn.Conv1d(
in_channels=self.d_inner//2,
out_channels=self.d_inner//2,
bias=conv_bias//2,
kernel_size=d_conv,
groups=self.d_inner//2,
**factory_kwargs,
)
def forward(self, hidden_states):
"""
hidden_states: (B, L, D)
Returns: same shape as hidden_states
"""
_, seqlen, _ = hidden_states.shape
xz = self.in_proj(hidden_states)
xz = rearrange(xz, "b l d -> b d l")
x, z = xz.chunk(2, dim=1)
A = -torch.exp(self.A_log.float())
x = F.silu(F.conv1d(input=x, weight=self.conv1d_x.weight, bias=self.conv1d_x.bias, padding='same', groups=self.d_inner//2))
z = F.silu(F.conv1d(input=z, weight=self.conv1d_z.weight, bias=self.conv1d_z.bias, padding='same', groups=self.d_inner//2))
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d"))
dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
dt = rearrange(self.dt_proj(dt), "(b l) d -> b d l", l=seqlen)
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
y = selective_scan_fn(x,
dt,
A,
B,
C,
self.D.float(),
z=None,
delta_bias=self.dt_proj.bias.float(),
delta_softplus=True,
return_last_state=None)
y = torch.cat([y, z], dim=1)
y = rearrange(y, "b d l -> b l d")
out = self.out_proj(y)
return out
class Attention(nn.Module):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
qk_norm=False,
attn_drop=0.,
proj_drop=0.,
norm_layer=nn.LayerNorm,
):
super().__init__()
assert dim % num_heads == 0
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.fused_attn = True
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
if self.fused_attn:
x = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p,
)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self,
dim,
num_heads,
counter,
transformer_blocks,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=False,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
Mlp_block=Mlp,
layer_scale=None,
):
super().__init__()
self.norm1 = norm_layer(dim)
if counter in transformer_blocks:
self.mixer = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_norm=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
norm_layer=norm_layer,
)
else:
self.mixer = MambaVisionMixer(d_model=dim,
d_state=8,
d_conv=3,
expand=1
)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
use_layer_scale = layer_scale is not None and type(layer_scale) in [int, float]
self.g_1 = nn.Parameter(layer_scale * torch.ones(dim)) if use_layer_scale else 1
self.g_2 = nn.Parameter(layer_scale * torch.ones(dim)) if use_layer_scale else 1
def forward(self, x):
x = x + self.drop_path(self.g_1 * self.mixer(self.norm1(x)))
x = x + self.drop_path(self.g_2 * self.mlp(self.norm2(x)))
return x
class MambaVisionLayer(nn.Module):
"""
MambaVision layer"
"""
def __init__(self,
dim,
depth,
num_heads,
window_size,
conv=False,
downsample=True,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
layer_scale=None,
layer_scale_conv=None,
transformer_blocks = [],
):
"""
Args:
dim: feature size dimension.
depth: number of layers in each stage.
window_size: window size in each stage.
conv: bool argument for conv stage flag.
downsample: bool argument for down-sampling.
mlp_ratio: MLP ratio.
num_heads: number of heads in each stage.
qkv_bias: bool argument for query, key, value learnable bias.
qk_scale: bool argument to scaling query, key.
drop: dropout rate.
attn_drop: attention dropout rate.
drop_path: drop path rate.
norm_layer: normalization layer.
layer_scale: layer scaling coefficient.
layer_scale_conv: conv layer scaling coefficient.
transformer_blocks: list of transformer blocks.
"""
super().__init__()
self.conv = conv
self.transformer_block = False
if conv:
self.blocks = nn.ModuleList([ConvBlock(dim=dim,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
layer_scale=layer_scale_conv)
for i in range(depth)])
self.transformer_block = False
else:
self.transformer_block = True
self.blocks = nn.ModuleList([Block(dim=dim,
counter=i,
transformer_blocks=transformer_blocks,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
layer_scale=layer_scale)
for i in range(depth)])
self.transformer_block = True
self.downsample = None if not downsample else Downsample(dim=dim)
self.do_gt = False
self.window_size = window_size
def forward(self, x):
_, _, H, W = x.shape
if self.transformer_block:
pad_r = (self.window_size - W % self.window_size) % self.window_size
pad_b = (self.window_size - H % self.window_size) % self.window_size
if pad_r > 0 or pad_b > 0:
x = torch.nn.functional.pad(x, (0,pad_r,0,pad_b))
_, _, Hp, Wp = x.shape
else:
Hp, Wp = H, W
x = window_partition(x, self.window_size)
for _, blk in enumerate(self.blocks):
x = blk(x)
if self.transformer_block:
x = window_reverse(x, self.window_size, Hp, Wp)
if pad_r > 0 or pad_b > 0:
x = x[:, :, :H, :W].contiguous()
if self.downsample is None:
return x, x
return self.downsample(x), x
class MambaVision(nn.Module):
"""
MambaVision,
"""
def __init__(self,
dim,
in_dim,
depths,
window_size,
mlp_ratio,
num_heads,
drop_path_rate=0.2,
in_chans=3,
num_classes=1000,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
layer_scale=None,
layer_scale_conv=None,
**kwargs):
"""
Args:
dim: feature size dimension.
depths: number of layers in each stage.
window_size: window size in each stage.
mlp_ratio: MLP ratio.
num_heads: number of heads in each stage.
drop_path_rate: drop path rate.
in_chans: number of input channels.
num_classes: number of classes.
qkv_bias: bool argument for query, key, value learnable bias.
qk_scale: bool argument to scaling query, key.
drop_rate: dropout rate.
attn_drop_rate: attention dropout rate.
norm_layer: normalization layer.
layer_scale: layer scaling coefficient.
layer_scale_conv: conv layer scaling coefficient.
"""
super().__init__()
num_features = int(dim * 2 ** (len(depths) - 1))
self.num_classes = num_classes
self.patch_embed = PatchEmbed(in_chans=in_chans, in_dim=in_dim, dim=dim)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
self.levels = nn.ModuleList()
for i in range(len(depths)):
conv = True if (i == 0 or i == 1) else False
level = MambaVisionLayer(dim=int(dim * 2 ** i),
depth=depths[i],
num_heads=num_heads[i],
window_size=window_size[i],
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
conv=conv,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])],
downsample=(i < 3),
layer_scale=layer_scale,
layer_scale_conv=layer_scale_conv,
transformer_blocks=list(range(depths[i]//2+1, depths[i])) if depths[i]%2!=0 else list(range(depths[i]//2, depths[i])),
)
self.levels.append(level)
self.norm = nn.BatchNorm2d(num_features)
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.head = nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, LayerNorm2d):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {'rpb'}
def forward_features(self, x):
x = self.patch_embed(x)
outs = []
for level in self.levels:
x, xo = level(x)
outs.append(xo)
x = self.norm(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
return x, outs
def forward(self, x):
x, outs = self.forward_features(x)
x = self.head(x)
return x
def _load_state_dict(self,
pretrained,
strict: bool = False):
_load_checkpoint(self,
pretrained,
strict=strict)
class MambaVisionModel(PreTrainedModel):
config_class = MambaVisionConfig
def __init__(self, config):
super().__init__(config)
self.model = MambaVision(
depths=config.depths,
num_heads=config.num_heads,
window_size=config.window_size,
dim=config.dim,
in_dim=config.in_dim,
mlp_ratio=config.mlp_ratio,
layer_scale=config.layer_scale,
layer_scale_conv=config.layer_scale_conv
)
def forward(self, tensor):
return self.model.forward_features(tensor)
class MambaVisionModelForImageClassification(PreTrainedModel):
config_class = MambaVisionConfig
def __init__(self, config):
super().__init__(config)
self.model = MambaVision(
depths=config.depths,
num_heads=config.num_heads,
window_size=config.window_size,
dim=config.dim,
in_dim=config.in_dim,
mlp_ratio=config.mlp_ratio,
layer_scale=config.layer_scale,
layer_scale_conv=config.layer_scale_conv
)
def forward(self, tensor, labels=None):
logits = self.model(tensor)
if labels is not None:
loss = torch.nn.cross_entropy(logits, labels)
return {"loss": loss, "logits": logits}
return {"logits": logits}