first commit
This commit is contained in:
parent
8a151fcb32
commit
5de8f3e15b
|
@ -0,0 +1,63 @@
|
||||||
|
Copyright (c) 2024, NVIDIA Corporation. All rights reserved.
|
||||||
|
|
||||||
|
Nvidia Source Code License-NC
|
||||||
|
|
||||||
|
1. Definitions
|
||||||
|
|
||||||
|
“Licensor” means any person or entity that distributes its Work.
|
||||||
|
|
||||||
|
“Work” means (a) the original work of authorship made available under this license, which may include software, documentation,
|
||||||
|
or other files, and (b) any additions to or derivative works thereof that are made available under this license.
|
||||||
|
|
||||||
|
The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S.
|
||||||
|
copyright law; provided, however, that for the purposes of this license, derivative works shall not include works that
|
||||||
|
remain separable from, or merely link (or bind by name) to the interfaces of, the Work.
|
||||||
|
|
||||||
|
Works are “made available” under this license by including in or with the Work either (a) a copyright notice referencing
|
||||||
|
the applicability of this license to the Work, or (b) a copy of this license.
|
||||||
|
|
||||||
|
2. License Grant
|
||||||
|
|
||||||
|
2.1 Copyright Grant. Subject to the terms and conditions of this license, each Licensor grants to you a perpetual,
|
||||||
|
worldwide, non-exclusive, royalty-free, copyright license to use, reproduce, prepare derivative works of, publicly
|
||||||
|
display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form.
|
||||||
|
|
||||||
|
3. Limitations
|
||||||
|
|
||||||
|
3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this license, (b) you include a
|
||||||
|
complete copy of this license with your distribution, and (c) you retain without modification any copyright, patent,
|
||||||
|
trademark, or attribution notices that are present in the Work.
|
||||||
|
|
||||||
|
3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution
|
||||||
|
of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation in Section 3.3
|
||||||
|
applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms.
|
||||||
|
Notwithstanding Your Terms, this license (including the redistribution requirements in Section 3.1) will continue to apply
|
||||||
|
to the Work itself.
|
||||||
|
|
||||||
|
3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially.
|
||||||
|
Notwithstanding the foregoing, NVIDIA Corporation and its affiliates may use the Work and any derivative works commercially.
|
||||||
|
As used herein, “non-commercially” means for research or evaluation purposes only.
|
||||||
|
|
||||||
|
3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim
|
||||||
|
or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under
|
||||||
|
this license from such Licensor (including the grant in Section 2.1) will terminate immediately.
|
||||||
|
|
||||||
|
3.5 Trademarks. This license does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks,
|
||||||
|
except as necessary to reproduce the notices described in this license.
|
||||||
|
|
||||||
|
3.6 Termination. If you violate any term of this license, then your rights under this license (including the grant in Section 2.1)
|
||||||
|
will terminate immediately.
|
||||||
|
|
||||||
|
4. Disclaimer of Warranty.
|
||||||
|
|
||||||
|
THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES
|
||||||
|
OR CONDITIONS OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING
|
||||||
|
ANY ACTIVITIES UNDER THIS LICENSE.
|
||||||
|
|
||||||
|
5. Limitation of Liability.
|
||||||
|
|
||||||
|
EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT,
|
||||||
|
OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL
|
||||||
|
DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL,
|
||||||
|
BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER DAMAGES OR LOSSES), EVEN IF THE LICENSOR
|
||||||
|
HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,28 @@
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
class MambaVisionConfig(PretrainedConfig):
|
||||||
|
model_type = "mambavision"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
depths=[3, 3, 12, 5],
|
||||||
|
num_heads=[4, 8, 16, 32],
|
||||||
|
window_size=[8, 8, 14, 7],
|
||||||
|
dim=196,
|
||||||
|
in_dim=64,
|
||||||
|
mlp_ratio=4,
|
||||||
|
drop_path_rate=0.3,
|
||||||
|
layer_scale=1e-5,
|
||||||
|
layer_scale_conv=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.depths = depths
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.window_size = window_size
|
||||||
|
self.dim = dim
|
||||||
|
self.in_dim = in_dim
|
||||||
|
self.mlp_ratio = mlp_ratio
|
||||||
|
self.drop_path_rate = drop_path_rate
|
||||||
|
self.layer_scale=layer_scale
|
||||||
|
self.layer_scale_conv=layer_scale_conv
|
||||||
|
super().__init__(**kwargs)
|
|
@ -0,0 +1,865 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||||
|
# and proprietary rights in and to this software, related documentation
|
||||||
|
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||||
|
# distribution of this software and related documentation without an express
|
||||||
|
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from timm.models.registry import register_model
|
||||||
|
import math
|
||||||
|
from timm.models.layers import trunc_normal_, DropPath, LayerNorm2d
|
||||||
|
from timm.models._builder import resolve_pretrained_cfg
|
||||||
|
try:
|
||||||
|
from timm.models._builder import _update_default_kwargs as update_args
|
||||||
|
except:
|
||||||
|
from timm.models._builder import _update_default_model_kwargs as update_args
|
||||||
|
from timm.models.vision_transformer import Mlp, PatchEmbed
|
||||||
|
from timm.models.layers import DropPath, trunc_normal_
|
||||||
|
from timm.models.registry import register_model
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from pathlib import Path
|
||||||
|
from huggingface_hub import PyTorchModelHubMixin
|
||||||
|
|
||||||
|
|
||||||
|
def _cfg(url='', **kwargs):
|
||||||
|
return {'url': url,
|
||||||
|
'num_classes': 1000,
|
||||||
|
'input_size': (3, 224, 224),
|
||||||
|
'pool_size': None,
|
||||||
|
'crop_pct': 0.875,
|
||||||
|
'interpolation': 'bicubic',
|
||||||
|
'fixed_input_size': True,
|
||||||
|
'mean': (0.485, 0.456, 0.406),
|
||||||
|
'std': (0.229, 0.224, 0.225),
|
||||||
|
**kwargs
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
default_cfgs = {
|
||||||
|
'mamba_vision_T': _cfg(url='https://huggingface.co/nvidia/MambaVision-T-1K/resolve/main/mambavision_tiny_1k.pth.tar',
|
||||||
|
crop_pct=1.0,
|
||||||
|
input_size=(3, 224, 224),
|
||||||
|
crop_mode='center'),
|
||||||
|
'mamba_vision_T2': _cfg(url='https://huggingface.co/nvidia/MambaVision-T2-1K/resolve/main/mambavision_tiny2_1k.pth.tar',
|
||||||
|
crop_pct=0.98,
|
||||||
|
input_size=(3, 224, 224),
|
||||||
|
crop_mode='center'),
|
||||||
|
'mamba_vision_S': _cfg(url='https://huggingface.co/nvidia/MambaVision-S-1K/resolve/main/mambavision_small_1k.pth.tar',
|
||||||
|
crop_pct=0.93,
|
||||||
|
input_size=(3, 224, 224),
|
||||||
|
crop_mode='center'),
|
||||||
|
'mamba_vision_B': _cfg(url='https://huggingface.co/nvidia/MambaVision-B-1K/resolve/main/mambavision_base_1k.pth.tar',
|
||||||
|
crop_pct=1.0,
|
||||||
|
input_size=(3, 224, 224),
|
||||||
|
crop_mode='center'),
|
||||||
|
'mamba_vision_L': _cfg(url='https://huggingface.co/nvidia/MambaVision-L-1K/resolve/main/mambavision_large_1k.pth.tar',
|
||||||
|
crop_pct=1.0,
|
||||||
|
input_size=(3, 224, 224),
|
||||||
|
crop_mode='center'),
|
||||||
|
'mamba_vision_L2': _cfg(url='https://huggingface.co/nvidia/MambaVision-L2-1K/resolve/main/mambavision_large2_1k.pth.tar',
|
||||||
|
crop_pct=1.0,
|
||||||
|
input_size=(3, 224, 224),
|
||||||
|
crop_mode='center')
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def window_partition(x, window_size):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: (B, C, H, W)
|
||||||
|
window_size: window size
|
||||||
|
h_w: Height of window
|
||||||
|
w_w: Width of window
|
||||||
|
Returns:
|
||||||
|
local window features (num_windows*B, window_size*window_size, C)
|
||||||
|
"""
|
||||||
|
B, C, H, W = x.shape
|
||||||
|
x = x.view(B, C, H // window_size, window_size, W // window_size, window_size)
|
||||||
|
windows = x.permute(0, 2, 4, 3, 5, 1).reshape(-1, window_size*window_size, C)
|
||||||
|
return windows
|
||||||
|
|
||||||
|
|
||||||
|
def window_reverse(windows, window_size, H, W):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
windows: local window features (num_windows*B, window_size, window_size, C)
|
||||||
|
window_size: Window size
|
||||||
|
H: Height of image
|
||||||
|
W: Width of image
|
||||||
|
Returns:
|
||||||
|
x: (B, C, H, W)
|
||||||
|
"""
|
||||||
|
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
||||||
|
x = windows.reshape(B, H // window_size, W // window_size, window_size, window_size, -1)
|
||||||
|
x = x.permute(0, 5, 1, 3, 2, 4).reshape(B,windows.shape[2], H, W)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def _load_state_dict(module, state_dict, strict=False, logger=None):
|
||||||
|
"""Load state_dict to a module.
|
||||||
|
|
||||||
|
This method is modified from :meth:`torch.nn.Module.load_state_dict`.
|
||||||
|
Default value for ``strict`` is set to ``False`` and the message for
|
||||||
|
param mismatch will be shown even if strict is False.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (Module): Module that receives the state_dict.
|
||||||
|
state_dict (OrderedDict): Weights.
|
||||||
|
strict (bool): whether to strictly enforce that the keys
|
||||||
|
in :attr:`state_dict` match the keys returned by this module's
|
||||||
|
:meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
|
||||||
|
logger (:obj:`logging.Logger`, optional): Logger to log the error
|
||||||
|
message. If not specified, print function will be used.
|
||||||
|
"""
|
||||||
|
unexpected_keys = []
|
||||||
|
all_missing_keys = []
|
||||||
|
err_msg = []
|
||||||
|
|
||||||
|
metadata = getattr(state_dict, '_metadata', None)
|
||||||
|
state_dict = state_dict.copy()
|
||||||
|
if metadata is not None:
|
||||||
|
state_dict._metadata = metadata
|
||||||
|
|
||||||
|
def load(module, prefix=''):
|
||||||
|
local_metadata = {} if metadata is None else metadata.get(
|
||||||
|
prefix[:-1], {})
|
||||||
|
module._load_from_state_dict(state_dict, prefix, local_metadata, True,
|
||||||
|
all_missing_keys, unexpected_keys,
|
||||||
|
err_msg)
|
||||||
|
for name, child in module._modules.items():
|
||||||
|
if child is not None:
|
||||||
|
load(child, prefix + name + '.')
|
||||||
|
|
||||||
|
load(module)
|
||||||
|
load = None
|
||||||
|
missing_keys = [
|
||||||
|
key for key in all_missing_keys if 'num_batches_tracked' not in key
|
||||||
|
]
|
||||||
|
|
||||||
|
if unexpected_keys:
|
||||||
|
err_msg.append('unexpected key in source '
|
||||||
|
f'state_dict: {", ".join(unexpected_keys)}\n')
|
||||||
|
if missing_keys:
|
||||||
|
err_msg.append(
|
||||||
|
f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
|
||||||
|
|
||||||
|
|
||||||
|
if len(err_msg) > 0:
|
||||||
|
err_msg.insert(
|
||||||
|
0, 'The model and loaded state dict do not match exactly\n')
|
||||||
|
err_msg = '\n'.join(err_msg)
|
||||||
|
if strict:
|
||||||
|
raise RuntimeError(err_msg)
|
||||||
|
elif logger is not None:
|
||||||
|
logger.warning(err_msg)
|
||||||
|
else:
|
||||||
|
print(err_msg)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_checkpoint(model,
|
||||||
|
filename,
|
||||||
|
map_location='cpu',
|
||||||
|
strict=False,
|
||||||
|
logger=None):
|
||||||
|
"""Load checkpoint from a file or URI.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (Module): Module to load checkpoint.
|
||||||
|
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
|
||||||
|
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
|
||||||
|
details.
|
||||||
|
map_location (str): Same as :func:`torch.load`.
|
||||||
|
strict (bool): Whether to allow different params for the model and
|
||||||
|
checkpoint.
|
||||||
|
logger (:mod:`logging.Logger` or None): The logger for error message.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict or OrderedDict: The loaded checkpoint.
|
||||||
|
"""
|
||||||
|
checkpoint = torch.load(filename, map_location=map_location)
|
||||||
|
if not isinstance(checkpoint, dict):
|
||||||
|
raise RuntimeError(
|
||||||
|
f'No state_dict found in checkpoint file {filename}')
|
||||||
|
if 'state_dict' in checkpoint:
|
||||||
|
state_dict = checkpoint['state_dict']
|
||||||
|
elif 'model' in checkpoint:
|
||||||
|
state_dict = checkpoint['model']
|
||||||
|
else:
|
||||||
|
state_dict = checkpoint
|
||||||
|
if list(state_dict.keys())[0].startswith('module.'):
|
||||||
|
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
||||||
|
|
||||||
|
if sorted(list(state_dict.keys()))[0].startswith('encoder'):
|
||||||
|
state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}
|
||||||
|
|
||||||
|
_load_state_dict(model, state_dict, strict, logger)
|
||||||
|
return checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
class Downsample(nn.Module):
|
||||||
|
"""
|
||||||
|
Down-sampling block"
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
dim,
|
||||||
|
keep_dim=False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
dim: feature size dimension.
|
||||||
|
norm_layer: normalization layer.
|
||||||
|
keep_dim: bool argument for maintaining the resolution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
if keep_dim:
|
||||||
|
dim_out = dim
|
||||||
|
else:
|
||||||
|
dim_out = 2 * dim
|
||||||
|
self.reduction = nn.Sequential(
|
||||||
|
nn.Conv2d(dim, dim_out, 3, 2, 1, bias=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.reduction(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class PatchEmbed(nn.Module):
|
||||||
|
"""
|
||||||
|
Patch embedding block"
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_chans=3, in_dim=64, dim=96):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
in_chans: number of input channels.
|
||||||
|
dim: feature size dimension.
|
||||||
|
"""
|
||||||
|
# in_dim = 1
|
||||||
|
super().__init__()
|
||||||
|
self.proj = nn.Identity()
|
||||||
|
self.conv_down = nn.Sequential(
|
||||||
|
nn.Conv2d(in_chans, in_dim, 3, 2, 1, bias=False),
|
||||||
|
nn.BatchNorm2d(in_dim, eps=1e-4),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(in_dim, dim, 3, 2, 1, bias=False),
|
||||||
|
nn.BatchNorm2d(dim, eps=1e-4),
|
||||||
|
nn.ReLU()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.proj(x)
|
||||||
|
x = self.conv_down(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ConvBlock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim,
|
||||||
|
drop_path=0.,
|
||||||
|
layer_scale=None,
|
||||||
|
kernel_size=3):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, padding=1)
|
||||||
|
self.norm1 = nn.BatchNorm2d(dim, eps=1e-5)
|
||||||
|
self.act1 = nn.GELU(approximate= 'tanh')
|
||||||
|
self.conv2 = nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, padding=1)
|
||||||
|
self.norm2 = nn.BatchNorm2d(dim, eps=1e-5)
|
||||||
|
self.layer_scale = layer_scale
|
||||||
|
if layer_scale is not None and type(layer_scale) in [int, float]:
|
||||||
|
self.gamma = nn.Parameter(layer_scale * torch.ones(dim))
|
||||||
|
self.layer_scale = True
|
||||||
|
else:
|
||||||
|
self.layer_scale = False
|
||||||
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
input = x
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.norm1(x)
|
||||||
|
x = self.act1(x)
|
||||||
|
x = self.conv2(x)
|
||||||
|
x = self.norm2(x)
|
||||||
|
if self.layer_scale:
|
||||||
|
x = x * self.gamma.view(1, -1, 1, 1)
|
||||||
|
x = input + self.drop_path(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MambaVisionMixer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model,
|
||||||
|
d_state=16,
|
||||||
|
d_conv=4,
|
||||||
|
expand=2,
|
||||||
|
dt_rank="auto",
|
||||||
|
dt_min=0.001,
|
||||||
|
dt_max=0.1,
|
||||||
|
dt_init="random",
|
||||||
|
dt_scale=1.0,
|
||||||
|
dt_init_floor=1e-4,
|
||||||
|
conv_bias=True,
|
||||||
|
bias=False,
|
||||||
|
use_fast_path=True,
|
||||||
|
layer_idx=None,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
):
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
super().__init__()
|
||||||
|
self.d_model = d_model
|
||||||
|
self.d_state = d_state
|
||||||
|
self.d_conv = d_conv
|
||||||
|
self.expand = expand
|
||||||
|
self.d_inner = int(self.expand * self.d_model)
|
||||||
|
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
|
||||||
|
self.use_fast_path = use_fast_path
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
self.in_proj = nn.Linear(self.d_model, self.d_inner, bias=bias, **factory_kwargs)
|
||||||
|
self.x_proj = nn.Linear(
|
||||||
|
self.d_inner//2, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
|
||||||
|
)
|
||||||
|
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner//2, bias=True, **factory_kwargs)
|
||||||
|
dt_init_std = self.dt_rank**-0.5 * dt_scale
|
||||||
|
if dt_init == "constant":
|
||||||
|
nn.init.constant_(self.dt_proj.weight, dt_init_std)
|
||||||
|
elif dt_init == "random":
|
||||||
|
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
dt = torch.exp(
|
||||||
|
torch.rand(self.d_inner//2, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
|
||||||
|
+ math.log(dt_min)
|
||||||
|
).clamp(min=dt_init_floor)
|
||||||
|
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
||||||
|
with torch.no_grad():
|
||||||
|
self.dt_proj.bias.copy_(inv_dt)
|
||||||
|
self.dt_proj.bias._no_reinit = True
|
||||||
|
A = repeat(
|
||||||
|
torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
|
||||||
|
"n -> d n",
|
||||||
|
d=self.d_inner//2,
|
||||||
|
).contiguous()
|
||||||
|
A_log = torch.log(A)
|
||||||
|
self.A_log = nn.Parameter(A_log)
|
||||||
|
self.A_log._no_weight_decay = True
|
||||||
|
self.D = nn.Parameter(torch.ones(self.d_inner//2, device=device))
|
||||||
|
self.D._no_weight_decay = True
|
||||||
|
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
|
||||||
|
self.conv1d_x = nn.Conv1d(
|
||||||
|
in_channels=self.d_inner//2,
|
||||||
|
out_channels=self.d_inner//2,
|
||||||
|
bias=conv_bias//2,
|
||||||
|
kernel_size=d_conv,
|
||||||
|
groups=self.d_inner//2,
|
||||||
|
**factory_kwargs,
|
||||||
|
)
|
||||||
|
self.conv1d_z = nn.Conv1d(
|
||||||
|
in_channels=self.d_inner//2,
|
||||||
|
out_channels=self.d_inner//2,
|
||||||
|
bias=conv_bias//2,
|
||||||
|
kernel_size=d_conv,
|
||||||
|
groups=self.d_inner//2,
|
||||||
|
**factory_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
"""
|
||||||
|
hidden_states: (B, L, D)
|
||||||
|
Returns: same shape as hidden_states
|
||||||
|
"""
|
||||||
|
_, seqlen, _ = hidden_states.shape
|
||||||
|
xz = self.in_proj(hidden_states)
|
||||||
|
xz = rearrange(xz, "b l d -> b d l")
|
||||||
|
x, z = xz.chunk(2, dim=1)
|
||||||
|
A = -torch.exp(self.A_log.float())
|
||||||
|
x = F.silu(F.conv1d(input=x, weight=self.conv1d_x.weight, bias=self.conv1d_x.bias, padding='same', groups=self.d_inner//2))
|
||||||
|
z = F.silu(F.conv1d(input=z, weight=self.conv1d_z.weight, bias=self.conv1d_z.bias, padding='same', groups=self.d_inner//2))
|
||||||
|
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d"))
|
||||||
|
dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
||||||
|
dt = rearrange(self.dt_proj(dt), "(b l) d -> b d l", l=seqlen)
|
||||||
|
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
||||||
|
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
||||||
|
y = selective_scan_fn(x,
|
||||||
|
dt,
|
||||||
|
A,
|
||||||
|
B,
|
||||||
|
C,
|
||||||
|
self.D.float(),
|
||||||
|
z=None,
|
||||||
|
delta_bias=self.dt_proj.bias.float(),
|
||||||
|
delta_softplus=True,
|
||||||
|
return_last_state=None)
|
||||||
|
|
||||||
|
y = torch.cat([y, z], dim=1)
|
||||||
|
y = rearrange(y, "b d l -> b l d")
|
||||||
|
out = self.out_proj(y)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
num_heads=8,
|
||||||
|
qkv_bias=False,
|
||||||
|
qk_norm=False,
|
||||||
|
attn_drop=0.,
|
||||||
|
proj_drop=0.,
|
||||||
|
norm_layer=nn.LayerNorm,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
assert dim % num_heads == 0
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = dim // num_heads
|
||||||
|
self.scale = self.head_dim ** -0.5
|
||||||
|
self.fused_attn = True
|
||||||
|
|
||||||
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||||
|
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||||
|
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||||
|
self.attn_drop = nn.Dropout(attn_drop)
|
||||||
|
self.proj = nn.Linear(dim, dim)
|
||||||
|
self.proj_drop = nn.Dropout(proj_drop)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
B, N, C = x.shape
|
||||||
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||||
|
q, k, v = qkv.unbind(0)
|
||||||
|
q, k = self.q_norm(q), self.k_norm(k)
|
||||||
|
|
||||||
|
if self.fused_attn:
|
||||||
|
x = F.scaled_dot_product_attention(
|
||||||
|
q, k, v,
|
||||||
|
dropout_p=self.attn_drop.p,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
q = q * self.scale
|
||||||
|
attn = q @ k.transpose(-2, -1)
|
||||||
|
attn = attn.softmax(dim=-1)
|
||||||
|
attn = self.attn_drop(attn)
|
||||||
|
x = attn @ v
|
||||||
|
|
||||||
|
x = x.transpose(1, 2).reshape(B, N, C)
|
||||||
|
x = self.proj(x)
|
||||||
|
x = self.proj_drop(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Block(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
dim,
|
||||||
|
num_heads,
|
||||||
|
counter,
|
||||||
|
transformer_blocks,
|
||||||
|
mlp_ratio=4.,
|
||||||
|
qkv_bias=False,
|
||||||
|
qk_scale=False,
|
||||||
|
drop=0.,
|
||||||
|
attn_drop=0.,
|
||||||
|
drop_path=0.,
|
||||||
|
act_layer=nn.GELU,
|
||||||
|
norm_layer=nn.LayerNorm,
|
||||||
|
Mlp_block=Mlp,
|
||||||
|
layer_scale=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.norm1 = norm_layer(dim)
|
||||||
|
if counter in transformer_blocks:
|
||||||
|
self.mixer = Attention(
|
||||||
|
dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
qk_norm=qk_scale,
|
||||||
|
attn_drop=attn_drop,
|
||||||
|
proj_drop=drop,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.mixer = MambaVisionMixer(d_model=dim,
|
||||||
|
d_state=8,
|
||||||
|
d_conv=3,
|
||||||
|
expand=1
|
||||||
|
)
|
||||||
|
|
||||||
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||||
|
self.norm2 = norm_layer(dim)
|
||||||
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||||
|
self.mlp = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||||
|
use_layer_scale = layer_scale is not None and type(layer_scale) in [int, float]
|
||||||
|
self.gamma_1 = nn.Parameter(layer_scale * torch.ones(dim)) if use_layer_scale else 1
|
||||||
|
self.gamma_2 = nn.Parameter(layer_scale * torch.ones(dim)) if use_layer_scale else 1
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x + self.drop_path(self.gamma_1 * self.mixer(self.norm1(x)))
|
||||||
|
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MambaVisionLayer(nn.Module):
|
||||||
|
"""
|
||||||
|
MambaVision layer"
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
dim,
|
||||||
|
depth,
|
||||||
|
num_heads,
|
||||||
|
window_size,
|
||||||
|
conv=False,
|
||||||
|
downsample=True,
|
||||||
|
mlp_ratio=4.,
|
||||||
|
qkv_bias=True,
|
||||||
|
qk_scale=None,
|
||||||
|
drop=0.,
|
||||||
|
attn_drop=0.,
|
||||||
|
drop_path=0.,
|
||||||
|
layer_scale=None,
|
||||||
|
layer_scale_conv=None,
|
||||||
|
transformer_blocks = [],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
dim: feature size dimension.
|
||||||
|
depth: number of layers in each stage.
|
||||||
|
window_size: window size in each stage.
|
||||||
|
conv: bool argument for conv stage flag.
|
||||||
|
downsample: bool argument for down-sampling.
|
||||||
|
mlp_ratio: MLP ratio.
|
||||||
|
num_heads: number of heads in each stage.
|
||||||
|
qkv_bias: bool argument for query, key, value learnable bias.
|
||||||
|
qk_scale: bool argument to scaling query, key.
|
||||||
|
drop: dropout rate.
|
||||||
|
attn_drop: attention dropout rate.
|
||||||
|
drop_path: drop path rate.
|
||||||
|
norm_layer: normalization layer.
|
||||||
|
layer_scale: layer scaling coefficient.
|
||||||
|
layer_scale_conv: conv layer scaling coefficient.
|
||||||
|
transformer_blocks: list of transformer blocks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.conv = conv
|
||||||
|
self.transformer_block = False
|
||||||
|
if conv:
|
||||||
|
self.blocks = nn.ModuleList([ConvBlock(dim=dim,
|
||||||
|
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
||||||
|
layer_scale=layer_scale_conv)
|
||||||
|
for i in range(depth)])
|
||||||
|
self.transformer_block = False
|
||||||
|
else:
|
||||||
|
self.transformer_block = True
|
||||||
|
self.blocks = nn.ModuleList([Block(dim=dim,
|
||||||
|
counter=i,
|
||||||
|
transformer_blocks=transformer_blocks,
|
||||||
|
num_heads=num_heads,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
qk_scale=qk_scale,
|
||||||
|
drop=drop,
|
||||||
|
attn_drop=attn_drop,
|
||||||
|
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
||||||
|
layer_scale=layer_scale)
|
||||||
|
for i in range(depth)])
|
||||||
|
self.transformer_block = True
|
||||||
|
|
||||||
|
self.downsample = None if not downsample else Downsample(dim=dim)
|
||||||
|
self.do_gt = False
|
||||||
|
self.window_size = window_size
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
_, _, H, W = x.shape
|
||||||
|
|
||||||
|
if self.transformer_block:
|
||||||
|
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
||||||
|
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
||||||
|
if pad_r > 0 or pad_b > 0:
|
||||||
|
x = torch.nn.functional.pad(x, (0,pad_r,0,pad_b))
|
||||||
|
_, _, Hp, Wp = x.shape
|
||||||
|
else:
|
||||||
|
Hp, Wp = H, W
|
||||||
|
x = window_partition(x, self.window_size)
|
||||||
|
|
||||||
|
for _, blk in enumerate(self.blocks):
|
||||||
|
x = blk(x)
|
||||||
|
if self.transformer_block:
|
||||||
|
x = window_reverse(x, self.window_size, Hp, Wp)
|
||||||
|
if pad_r > 0 or pad_b > 0:
|
||||||
|
x = x[:, :, :H, :W].contiguous()
|
||||||
|
if self.downsample is None:
|
||||||
|
return x
|
||||||
|
return self.downsample(x)
|
||||||
|
|
||||||
|
|
||||||
|
class MambaVision(nn.Module, PyTorchModelHubMixin):
|
||||||
|
"""
|
||||||
|
MambaVision,
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
dim,
|
||||||
|
in_dim,
|
||||||
|
depths,
|
||||||
|
window_size,
|
||||||
|
mlp_ratio,
|
||||||
|
num_heads,
|
||||||
|
drop_path_rate=0.2,
|
||||||
|
in_chans=3,
|
||||||
|
num_classes=1000,
|
||||||
|
qkv_bias=True,
|
||||||
|
qk_scale=None,
|
||||||
|
drop_rate=0.,
|
||||||
|
attn_drop_rate=0.,
|
||||||
|
layer_scale=None,
|
||||||
|
layer_scale_conv=None,
|
||||||
|
**kwargs):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
dim: feature size dimension.
|
||||||
|
depths: number of layers in each stage.
|
||||||
|
window_size: window size in each stage.
|
||||||
|
mlp_ratio: MLP ratio.
|
||||||
|
num_heads: number of heads in each stage.
|
||||||
|
drop_path_rate: drop path rate.
|
||||||
|
in_chans: number of input channels.
|
||||||
|
num_classes: number of classes.
|
||||||
|
qkv_bias: bool argument for query, key, value learnable bias.
|
||||||
|
qk_scale: bool argument to scaling query, key.
|
||||||
|
drop_rate: dropout rate.
|
||||||
|
attn_drop_rate: attention dropout rate.
|
||||||
|
norm_layer: normalization layer.
|
||||||
|
layer_scale: layer scaling coefficient.
|
||||||
|
layer_scale_conv: conv layer scaling coefficient.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
num_features = int(dim * 2 ** (len(depths) - 1))
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.patch_embed = PatchEmbed(in_chans=in_chans, in_dim=in_dim, dim=dim)
|
||||||
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
||||||
|
self.levels = nn.ModuleList()
|
||||||
|
for i in range(len(depths)):
|
||||||
|
conv = True if (i == 0 or i == 1) else False
|
||||||
|
level = MambaVisionLayer(dim=int(dim * 2 ** i),
|
||||||
|
depth=depths[i],
|
||||||
|
num_heads=num_heads[i],
|
||||||
|
window_size=window_size[i],
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
qk_scale=qk_scale,
|
||||||
|
conv=conv,
|
||||||
|
drop=drop_rate,
|
||||||
|
attn_drop=attn_drop_rate,
|
||||||
|
drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])],
|
||||||
|
downsample=(i < 3),
|
||||||
|
layer_scale=layer_scale,
|
||||||
|
layer_scale_conv=layer_scale_conv,
|
||||||
|
transformer_blocks=list(range(depths[i]//2+1, depths[i])) if depths[i]%2!=0 else list(range(depths[i]//2, depths[i])),
|
||||||
|
)
|
||||||
|
self.levels.append(level)
|
||||||
|
self.norm = nn.BatchNorm2d(num_features)
|
||||||
|
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
||||||
|
self.head = nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||||
|
self.apply(self._init_weights)
|
||||||
|
|
||||||
|
def _init_weights(self, m):
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
trunc_normal_(m.weight, std=.02)
|
||||||
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.LayerNorm):
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
nn.init.constant_(m.weight, 1.0)
|
||||||
|
elif isinstance(m, LayerNorm2d):
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
nn.init.constant_(m.weight, 1.0)
|
||||||
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
|
nn.init.ones_(m.weight)
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def no_weight_decay_keywords(self):
|
||||||
|
return {'rpb'}
|
||||||
|
|
||||||
|
def forward_features(self, x):
|
||||||
|
x = self.patch_embed(x)
|
||||||
|
for level in self.levels:
|
||||||
|
x = level(x)
|
||||||
|
x = self.norm(x)
|
||||||
|
x = self.avgpool(x)
|
||||||
|
x = torch.flatten(x, 1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.forward_features(x)
|
||||||
|
x = self.head(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _load_state_dict(self,
|
||||||
|
pretrained,
|
||||||
|
strict: bool = False):
|
||||||
|
_load_checkpoint(self,
|
||||||
|
pretrained,
|
||||||
|
strict=strict)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def mamba_vision_T(pretrained=False, **kwargs):
|
||||||
|
model_path = kwargs.pop("model_path", "/tmp/mamba_vision_T.pth.tar")
|
||||||
|
pretrained_cfg = resolve_pretrained_cfg('mamba_vision_T').to_dict()
|
||||||
|
update_args(pretrained_cfg, kwargs, kwargs_filter=None)
|
||||||
|
model = MambaVision(depths=[1, 3, 8, 4],
|
||||||
|
num_heads=[2, 4, 8, 16],
|
||||||
|
window_size=[8, 8, 14, 7],
|
||||||
|
dim=80,
|
||||||
|
in_dim=32,
|
||||||
|
mlp_ratio=4,
|
||||||
|
resolution=224,
|
||||||
|
drop_path_rate=0.2,
|
||||||
|
**kwargs)
|
||||||
|
model.pretrained_cfg = pretrained_cfg
|
||||||
|
model.default_cfg = model.pretrained_cfg
|
||||||
|
if pretrained:
|
||||||
|
if not Path(model_path).is_file():
|
||||||
|
url = model.default_cfg['url']
|
||||||
|
torch.hub.download_url_to_file(url=url, dst=model_path)
|
||||||
|
model._load_state_dict(model_path)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def mamba_vision_T2(pretrained=False, **kwargs):
|
||||||
|
model_path = kwargs.pop("model_path", "/tmp/mamba_vision_T2.pth.tar")
|
||||||
|
pretrained_cfg = resolve_pretrained_cfg('mamba_vision_T2').to_dict()
|
||||||
|
update_args(pretrained_cfg, kwargs, kwargs_filter=None)
|
||||||
|
model = MambaVision(depths=[1, 3, 11, 4],
|
||||||
|
num_heads=[2, 4, 8, 16],
|
||||||
|
window_size=[8, 8, 14, 7],
|
||||||
|
dim=80,
|
||||||
|
in_dim=32,
|
||||||
|
mlp_ratio=4,
|
||||||
|
resolution=224,
|
||||||
|
drop_path_rate=0.2,
|
||||||
|
**kwargs)
|
||||||
|
model.pretrained_cfg = pretrained_cfg
|
||||||
|
model.default_cfg = model.pretrained_cfg
|
||||||
|
if pretrained:
|
||||||
|
if not Path(model_path).is_file():
|
||||||
|
url = model.default_cfg['url']
|
||||||
|
torch.hub.download_url_to_file(url=url, dst=model_path)
|
||||||
|
model._load_state_dict(model_path)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def mamba_vision_S(pretrained=False, **kwargs):
|
||||||
|
model_path = kwargs.pop("model_path", "/tmp/mamba_vision_S.pth.tar")
|
||||||
|
pretrained_cfg = resolve_pretrained_cfg('mamba_vision_S').to_dict()
|
||||||
|
update_args(pretrained_cfg, kwargs, kwargs_filter=None)
|
||||||
|
model = MambaVision(depths=[3, 3, 7, 5],
|
||||||
|
num_heads=[2, 4, 8, 16],
|
||||||
|
window_size=[8, 8, 14, 7],
|
||||||
|
dim=96,
|
||||||
|
in_dim=64,
|
||||||
|
mlp_ratio=4,
|
||||||
|
resolution=224,
|
||||||
|
drop_path_rate=0.2,
|
||||||
|
**kwargs)
|
||||||
|
model.pretrained_cfg = pretrained_cfg
|
||||||
|
model.default_cfg = model.pretrained_cfg
|
||||||
|
if pretrained:
|
||||||
|
if not Path(model_path).is_file():
|
||||||
|
url = model.default_cfg['url']
|
||||||
|
torch.hub.download_url_to_file(url=url, dst=model_path)
|
||||||
|
model._load_state_dict(model_path)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def mamba_vision_B(pretrained=False, **kwargs):
|
||||||
|
model_path = kwargs.pop("model_path", "/tmp/mamba_vision_B.pth.tar")
|
||||||
|
pretrained_cfg = resolve_pretrained_cfg('mamba_vision_B').to_dict()
|
||||||
|
update_args(pretrained_cfg, kwargs, kwargs_filter=None)
|
||||||
|
model = MambaVision(depths=[3, 3, 10, 5],
|
||||||
|
num_heads=[2, 4, 8, 16],
|
||||||
|
window_size=[8, 8, 14, 7],
|
||||||
|
dim=128,
|
||||||
|
in_dim=64,
|
||||||
|
mlp_ratio=4,
|
||||||
|
resolution=224,
|
||||||
|
drop_path_rate=0.3,
|
||||||
|
layer_scale=1e-5,
|
||||||
|
layer_scale_conv=None,
|
||||||
|
**kwargs)
|
||||||
|
model.pretrained_cfg = pretrained_cfg
|
||||||
|
model.default_cfg = model.pretrained_cfg
|
||||||
|
if pretrained:
|
||||||
|
if not Path(model_path).is_file():
|
||||||
|
url = model.default_cfg['url']
|
||||||
|
torch.hub.download_url_to_file(url=url, dst=model_path)
|
||||||
|
model._load_state_dict(model_path)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def mamba_vision_L(pretrained=False, **kwargs):
|
||||||
|
model_path = kwargs.pop("model_path", "/tmp/mamba_vision_L.pth.tar")
|
||||||
|
pretrained_cfg = resolve_pretrained_cfg('mamba_vision_L').to_dict()
|
||||||
|
update_args(pretrained_cfg, kwargs, kwargs_filter=None)
|
||||||
|
model = MambaVision(depths=[3, 3, 10, 5],
|
||||||
|
num_heads=[4, 8, 16, 32],
|
||||||
|
window_size=[8, 8, 14, 7],
|
||||||
|
dim=196,
|
||||||
|
in_dim=64,
|
||||||
|
mlp_ratio=4,
|
||||||
|
resolution=224,
|
||||||
|
drop_path_rate=0.3,
|
||||||
|
layer_scale=1e-5,
|
||||||
|
layer_scale_conv=None,
|
||||||
|
**kwargs)
|
||||||
|
model.pretrained_cfg = pretrained_cfg
|
||||||
|
model.default_cfg = model.pretrained_cfg
|
||||||
|
if pretrained:
|
||||||
|
if not Path(model_path).is_file():
|
||||||
|
url = model.default_cfg['url']
|
||||||
|
torch.hub.download_url_to_file(url=url, dst=model_path)
|
||||||
|
model._load_state_dict(model_path)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def mamba_vision_L2(pretrained=False, **kwargs):
|
||||||
|
model_path = kwargs.pop("model_path", "/tmp/mamba_vision_L2.pth.tar")
|
||||||
|
pretrained_cfg = resolve_pretrained_cfg('mamba_vision_L2').to_dict()
|
||||||
|
update_args(pretrained_cfg, kwargs, kwargs_filter=None)
|
||||||
|
model = MambaVision(depths=[3, 3, 12, 5],
|
||||||
|
num_heads=[4, 8, 16, 32],
|
||||||
|
window_size=[8, 8, 14, 7],
|
||||||
|
dim=196,
|
||||||
|
in_dim=64,
|
||||||
|
mlp_ratio=4,
|
||||||
|
resolution=224,
|
||||||
|
drop_path_rate=0.3,
|
||||||
|
layer_scale=1e-5,
|
||||||
|
layer_scale_conv=None,
|
||||||
|
**kwargs)
|
||||||
|
model.pretrained_cfg = pretrained_cfg
|
||||||
|
model.default_cfg = model.pretrained_cfg
|
||||||
|
if pretrained:
|
||||||
|
if not Path(model_path).is_file():
|
||||||
|
url = model.default_cfg['url']
|
||||||
|
torch.hub.download_url_to_file(url=url, dst=model_path)
|
||||||
|
model._load_state_dict(model_path)
|
||||||
|
return model
|
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,764 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||||
|
# and proprietary rights in and to this software, related documentation
|
||||||
|
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||||
|
# distribution of this software and related documentation without an express
|
||||||
|
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from timm.models.registry import register_model
|
||||||
|
import math
|
||||||
|
from timm.models.layers import trunc_normal_, DropPath, LayerNorm2d
|
||||||
|
from timm.models._builder import resolve_pretrained_cfg
|
||||||
|
try:
|
||||||
|
from timm.models._builder import _update_default_kwargs as update_args
|
||||||
|
except:
|
||||||
|
from timm.models._builder import _update_default_model_kwargs as update_args
|
||||||
|
from timm.models.vision_transformer import Mlp, PatchEmbed
|
||||||
|
from timm.models.layers import DropPath, trunc_normal_
|
||||||
|
from timm.models.registry import register_model
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
|
from transformers import PreTrainedModel
|
||||||
|
|
||||||
|
from configuration_mambavision import MambaVisionConfig
|
||||||
|
|
||||||
|
|
||||||
|
def _cfg(url='', **kwargs):
|
||||||
|
return {'url': url,
|
||||||
|
'num_classes': 1000,
|
||||||
|
'input_size': (3, 224, 224),
|
||||||
|
'pool_size': None,
|
||||||
|
'crop_pct': 0.875,
|
||||||
|
'interpolation': 'bicubic',
|
||||||
|
'fixed_input_size': True,
|
||||||
|
'mean': (0.485, 0.456, 0.406),
|
||||||
|
'std': (0.229, 0.224, 0.225),
|
||||||
|
**kwargs
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
default_cfgs = {
|
||||||
|
'mamba_vision_T': _cfg(url='https://huggingface.co/nvidia/MambaVision-T-1K/resolve/main/mambavision_tiny_1k.pth.tar',
|
||||||
|
crop_pct=1.0,
|
||||||
|
input_size=(3, 224, 224),
|
||||||
|
crop_mode='center'),
|
||||||
|
'mamba_vision_T2': _cfg(url='https://huggingface.co/nvidia/MambaVision-T2-1K/resolve/main/mambavision_tiny2_1k.pth.tar',
|
||||||
|
crop_pct=0.98,
|
||||||
|
input_size=(3, 224, 224),
|
||||||
|
crop_mode='center'),
|
||||||
|
'mamba_vision_S': _cfg(url='https://huggingface.co/nvidia/MambaVision-S-1K/resolve/main/mambavision_small_1k.pth.tar',
|
||||||
|
crop_pct=0.93,
|
||||||
|
input_size=(3, 224, 224),
|
||||||
|
crop_mode='center'),
|
||||||
|
'mamba_vision_B': _cfg(url='https://huggingface.co/nvidia/MambaVision-B-1K/resolve/main/mambavision_base_1k.pth.tar',
|
||||||
|
crop_pct=1.0,
|
||||||
|
input_size=(3, 224, 224),
|
||||||
|
crop_mode='center'),
|
||||||
|
'mamba_vision_L': _cfg(url='https://huggingface.co/nvidia/MambaVision-L-1K/resolve/main/mambavision_large_1k.pth.tar',
|
||||||
|
crop_pct=1.0,
|
||||||
|
input_size=(3, 224, 224),
|
||||||
|
crop_mode='center'),
|
||||||
|
'mamba_vision_L2': _cfg(url='https://huggingface.co/nvidia/MambaVision-L2-1K/resolve/main/mambavision_large2_1k.pth.tar',
|
||||||
|
crop_pct=1.0,
|
||||||
|
input_size=(3, 224, 224),
|
||||||
|
crop_mode='center')
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def window_partition(x, window_size):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: (B, C, H, W)
|
||||||
|
window_size: window size
|
||||||
|
h_w: Height of window
|
||||||
|
w_w: Width of window
|
||||||
|
Returns:
|
||||||
|
local window features (num_windows*B, window_size*window_size, C)
|
||||||
|
"""
|
||||||
|
B, C, H, W = x.shape
|
||||||
|
x = x.view(B, C, H // window_size, window_size, W // window_size, window_size)
|
||||||
|
windows = x.permute(0, 2, 4, 3, 5, 1).reshape(-1, window_size*window_size, C)
|
||||||
|
return windows
|
||||||
|
|
||||||
|
|
||||||
|
def window_reverse(windows, window_size, H, W):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
windows: local window features (num_windows*B, window_size, window_size, C)
|
||||||
|
window_size: Window size
|
||||||
|
H: Height of image
|
||||||
|
W: Width of image
|
||||||
|
Returns:
|
||||||
|
x: (B, C, H, W)
|
||||||
|
"""
|
||||||
|
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
||||||
|
x = windows.reshape(B, H // window_size, W // window_size, window_size, window_size, -1)
|
||||||
|
x = x.permute(0, 5, 1, 3, 2, 4).reshape(B,windows.shape[2], H, W)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def _load_state_dict(module, state_dict, strict=False, logger=None):
|
||||||
|
"""Load state_dict to a module.
|
||||||
|
|
||||||
|
This method is modified from :meth:`torch.nn.Module.load_state_dict`.
|
||||||
|
Default value for ``strict`` is set to ``False`` and the message for
|
||||||
|
param mismatch will be shown even if strict is False.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (Module): Module that receives the state_dict.
|
||||||
|
state_dict (OrderedDict): Weights.
|
||||||
|
strict (bool): whether to strictly enforce that the keys
|
||||||
|
in :attr:`state_dict` match the keys returned by this module's
|
||||||
|
:meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
|
||||||
|
logger (:obj:`logging.Logger`, optional): Logger to log the error
|
||||||
|
message. If not specified, print function will be used.
|
||||||
|
"""
|
||||||
|
unexpected_keys = []
|
||||||
|
all_missing_keys = []
|
||||||
|
err_msg = []
|
||||||
|
|
||||||
|
metadata = getattr(state_dict, '_metadata', None)
|
||||||
|
state_dict = state_dict.copy()
|
||||||
|
if metadata is not None:
|
||||||
|
state_dict._metadata = metadata
|
||||||
|
|
||||||
|
def load(module, prefix=''):
|
||||||
|
local_metadata = {} if metadata is None else metadata.get(
|
||||||
|
prefix[:-1], {})
|
||||||
|
module._load_from_state_dict(state_dict, prefix, local_metadata, True,
|
||||||
|
all_missing_keys, unexpected_keys,
|
||||||
|
err_msg)
|
||||||
|
for name, child in module._modules.items():
|
||||||
|
if child is not None:
|
||||||
|
load(child, prefix + name + '.')
|
||||||
|
|
||||||
|
load(module)
|
||||||
|
load = None
|
||||||
|
missing_keys = [
|
||||||
|
key for key in all_missing_keys if 'num_batches_tracked' not in key
|
||||||
|
]
|
||||||
|
|
||||||
|
if unexpected_keys:
|
||||||
|
err_msg.append('unexpected key in source '
|
||||||
|
f'state_dict: {", ".join(unexpected_keys)}\n')
|
||||||
|
if missing_keys:
|
||||||
|
err_msg.append(
|
||||||
|
f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
|
||||||
|
|
||||||
|
|
||||||
|
if len(err_msg) > 0:
|
||||||
|
err_msg.insert(
|
||||||
|
0, 'The model and loaded state dict do not match exactly\n')
|
||||||
|
err_msg = '\n'.join(err_msg)
|
||||||
|
if strict:
|
||||||
|
raise RuntimeError(err_msg)
|
||||||
|
elif logger is not None:
|
||||||
|
logger.warning(err_msg)
|
||||||
|
else:
|
||||||
|
print(err_msg)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_checkpoint(model,
|
||||||
|
filename,
|
||||||
|
map_location='cpu',
|
||||||
|
strict=False,
|
||||||
|
logger=None):
|
||||||
|
"""Load checkpoint from a file or URI.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (Module): Module to load checkpoint.
|
||||||
|
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
|
||||||
|
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
|
||||||
|
details.
|
||||||
|
map_location (str): Same as :func:`torch.load`.
|
||||||
|
strict (bool): Whether to allow different params for the model and
|
||||||
|
checkpoint.
|
||||||
|
logger (:mod:`logging.Logger` or None): The logger for error message.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict or OrderedDict: The loaded checkpoint.
|
||||||
|
"""
|
||||||
|
checkpoint = torch.load(filename, map_location=map_location)
|
||||||
|
if not isinstance(checkpoint, dict):
|
||||||
|
raise RuntimeError(
|
||||||
|
f'No state_dict found in checkpoint file {filename}')
|
||||||
|
if 'state_dict' in checkpoint:
|
||||||
|
state_dict = checkpoint['state_dict']
|
||||||
|
elif 'model' in checkpoint:
|
||||||
|
state_dict = checkpoint['model']
|
||||||
|
else:
|
||||||
|
state_dict = checkpoint
|
||||||
|
if list(state_dict.keys())[0].startswith('module.'):
|
||||||
|
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
||||||
|
|
||||||
|
if sorted(list(state_dict.keys()))[0].startswith('encoder'):
|
||||||
|
state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}
|
||||||
|
|
||||||
|
_load_state_dict(model, state_dict, strict, logger)
|
||||||
|
return checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
class Downsample(nn.Module):
|
||||||
|
"""
|
||||||
|
Down-sampling block"
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
dim,
|
||||||
|
keep_dim=False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
dim: feature size dimension.
|
||||||
|
norm_layer: normalization layer.
|
||||||
|
keep_dim: bool argument for maintaining the resolution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
if keep_dim:
|
||||||
|
dim_out = dim
|
||||||
|
else:
|
||||||
|
dim_out = 2 * dim
|
||||||
|
self.reduction = nn.Sequential(
|
||||||
|
nn.Conv2d(dim, dim_out, 3, 2, 1, bias=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.reduction(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class PatchEmbed(nn.Module):
|
||||||
|
"""
|
||||||
|
Patch embedding block"
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_chans=3, in_dim=64, dim=96):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
in_chans: number of input channels.
|
||||||
|
dim: feature size dimension.
|
||||||
|
"""
|
||||||
|
# in_dim = 1
|
||||||
|
super().__init__()
|
||||||
|
self.proj = nn.Identity()
|
||||||
|
self.conv_down = nn.Sequential(
|
||||||
|
nn.Conv2d(in_chans, in_dim, 3, 2, 1, bias=False),
|
||||||
|
nn.BatchNorm2d(in_dim, eps=1e-4),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(in_dim, dim, 3, 2, 1, bias=False),
|
||||||
|
nn.BatchNorm2d(dim, eps=1e-4),
|
||||||
|
nn.ReLU()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.proj(x)
|
||||||
|
x = self.conv_down(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ConvBlock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim,
|
||||||
|
drop_path=0.,
|
||||||
|
layer_scale=None,
|
||||||
|
kernel_size=3):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, padding=1)
|
||||||
|
self.norm1 = nn.BatchNorm2d(dim, eps=1e-5)
|
||||||
|
self.act1 = nn.GELU(approximate= 'tanh')
|
||||||
|
self.conv2 = nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, padding=1)
|
||||||
|
self.norm2 = nn.BatchNorm2d(dim, eps=1e-5)
|
||||||
|
self.layer_scale = layer_scale
|
||||||
|
if layer_scale is not None and type(layer_scale) in [int, float]:
|
||||||
|
self.g = nn.Parameter(layer_scale * torch.ones(dim))
|
||||||
|
self.layer_scale = True
|
||||||
|
else:
|
||||||
|
self.layer_scale = False
|
||||||
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
input = x
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.norm1(x)
|
||||||
|
x = self.act1(x)
|
||||||
|
x = self.conv2(x)
|
||||||
|
x = self.norm2(x)
|
||||||
|
if self.layer_scale:
|
||||||
|
x = x * self.g.view(1, -1, 1, 1)
|
||||||
|
x = input + self.drop_path(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MambaVisionMixer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model,
|
||||||
|
d_state=16,
|
||||||
|
d_conv=4,
|
||||||
|
expand=2,
|
||||||
|
dt_rank="auto",
|
||||||
|
dt_min=0.001,
|
||||||
|
dt_max=0.1,
|
||||||
|
dt_init="random",
|
||||||
|
dt_scale=1.0,
|
||||||
|
dt_init_floor=1e-4,
|
||||||
|
conv_bias=True,
|
||||||
|
bias=False,
|
||||||
|
use_fast_path=True,
|
||||||
|
layer_idx=None,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
):
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
super().__init__()
|
||||||
|
self.d_model = d_model
|
||||||
|
self.d_state = d_state
|
||||||
|
self.d_conv = d_conv
|
||||||
|
self.expand = expand
|
||||||
|
self.d_inner = int(self.expand * self.d_model)
|
||||||
|
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
|
||||||
|
self.use_fast_path = use_fast_path
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
self.in_proj = nn.Linear(self.d_model, self.d_inner, bias=bias, **factory_kwargs)
|
||||||
|
self.x_proj = nn.Linear(
|
||||||
|
self.d_inner//2, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
|
||||||
|
)
|
||||||
|
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner//2, bias=True, **factory_kwargs)
|
||||||
|
dt_init_std = self.dt_rank**-0.5 * dt_scale
|
||||||
|
if dt_init == "constant":
|
||||||
|
nn.init.constant_(self.dt_proj.weight, dt_init_std)
|
||||||
|
elif dt_init == "random":
|
||||||
|
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
dt = torch.exp(
|
||||||
|
torch.rand(self.d_inner//2, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
|
||||||
|
+ math.log(dt_min)
|
||||||
|
).clamp(min=dt_init_floor)
|
||||||
|
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
||||||
|
with torch.no_grad():
|
||||||
|
self.dt_proj.bias.copy_(inv_dt)
|
||||||
|
self.dt_proj.bias._no_reinit = True
|
||||||
|
A = repeat(
|
||||||
|
torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
|
||||||
|
"n -> d n",
|
||||||
|
d=self.d_inner//2,
|
||||||
|
).contiguous()
|
||||||
|
A_log = torch.log(A)
|
||||||
|
self.A_log = nn.Parameter(A_log)
|
||||||
|
self.A_log._no_weight_decay = True
|
||||||
|
self.D = nn.Parameter(torch.ones(self.d_inner//2, device=device))
|
||||||
|
self.D._no_weight_decay = True
|
||||||
|
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
|
||||||
|
self.conv1d_x = nn.Conv1d(
|
||||||
|
in_channels=self.d_inner//2,
|
||||||
|
out_channels=self.d_inner//2,
|
||||||
|
bias=conv_bias//2,
|
||||||
|
kernel_size=d_conv,
|
||||||
|
groups=self.d_inner//2,
|
||||||
|
**factory_kwargs,
|
||||||
|
)
|
||||||
|
self.conv1d_z = nn.Conv1d(
|
||||||
|
in_channels=self.d_inner//2,
|
||||||
|
out_channels=self.d_inner//2,
|
||||||
|
bias=conv_bias//2,
|
||||||
|
kernel_size=d_conv,
|
||||||
|
groups=self.d_inner//2,
|
||||||
|
**factory_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
"""
|
||||||
|
hidden_states: (B, L, D)
|
||||||
|
Returns: same shape as hidden_states
|
||||||
|
"""
|
||||||
|
_, seqlen, _ = hidden_states.shape
|
||||||
|
xz = self.in_proj(hidden_states)
|
||||||
|
xz = rearrange(xz, "b l d -> b d l")
|
||||||
|
x, z = xz.chunk(2, dim=1)
|
||||||
|
A = -torch.exp(self.A_log.float())
|
||||||
|
x = F.silu(F.conv1d(input=x, weight=self.conv1d_x.weight, bias=self.conv1d_x.bias, padding='same', groups=self.d_inner//2))
|
||||||
|
z = F.silu(F.conv1d(input=z, weight=self.conv1d_z.weight, bias=self.conv1d_z.bias, padding='same', groups=self.d_inner//2))
|
||||||
|
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d"))
|
||||||
|
dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
||||||
|
dt = rearrange(self.dt_proj(dt), "(b l) d -> b d l", l=seqlen)
|
||||||
|
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
||||||
|
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
||||||
|
y = selective_scan_fn(x,
|
||||||
|
dt,
|
||||||
|
A,
|
||||||
|
B,
|
||||||
|
C,
|
||||||
|
self.D.float(),
|
||||||
|
z=None,
|
||||||
|
delta_bias=self.dt_proj.bias.float(),
|
||||||
|
delta_softplus=True,
|
||||||
|
return_last_state=None)
|
||||||
|
|
||||||
|
y = torch.cat([y, z], dim=1)
|
||||||
|
y = rearrange(y, "b d l -> b l d")
|
||||||
|
out = self.out_proj(y)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
num_heads=8,
|
||||||
|
qkv_bias=False,
|
||||||
|
qk_norm=False,
|
||||||
|
attn_drop=0.,
|
||||||
|
proj_drop=0.,
|
||||||
|
norm_layer=nn.LayerNorm,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
assert dim % num_heads == 0
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = dim // num_heads
|
||||||
|
self.scale = self.head_dim ** -0.5
|
||||||
|
self.fused_attn = True
|
||||||
|
|
||||||
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||||
|
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||||
|
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||||
|
self.attn_drop = nn.Dropout(attn_drop)
|
||||||
|
self.proj = nn.Linear(dim, dim)
|
||||||
|
self.proj_drop = nn.Dropout(proj_drop)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
B, N, C = x.shape
|
||||||
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||||
|
q, k, v = qkv.unbind(0)
|
||||||
|
q, k = self.q_norm(q), self.k_norm(k)
|
||||||
|
|
||||||
|
if self.fused_attn:
|
||||||
|
x = F.scaled_dot_product_attention(
|
||||||
|
q, k, v,
|
||||||
|
dropout_p=self.attn_drop.p,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
q = q * self.scale
|
||||||
|
attn = q @ k.transpose(-2, -1)
|
||||||
|
attn = attn.softmax(dim=-1)
|
||||||
|
attn = self.attn_drop(attn)
|
||||||
|
x = attn @ v
|
||||||
|
|
||||||
|
x = x.transpose(1, 2).reshape(B, N, C)
|
||||||
|
x = self.proj(x)
|
||||||
|
x = self.proj_drop(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Block(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
dim,
|
||||||
|
num_heads,
|
||||||
|
counter,
|
||||||
|
transformer_blocks,
|
||||||
|
mlp_ratio=4.,
|
||||||
|
qkv_bias=False,
|
||||||
|
qk_scale=False,
|
||||||
|
drop=0.,
|
||||||
|
attn_drop=0.,
|
||||||
|
drop_path=0.,
|
||||||
|
act_layer=nn.GELU,
|
||||||
|
norm_layer=nn.LayerNorm,
|
||||||
|
Mlp_block=Mlp,
|
||||||
|
layer_scale=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.norm1 = norm_layer(dim)
|
||||||
|
if counter in transformer_blocks:
|
||||||
|
self.mixer = Attention(
|
||||||
|
dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
qk_norm=qk_scale,
|
||||||
|
attn_drop=attn_drop,
|
||||||
|
proj_drop=drop,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.mixer = MambaVisionMixer(d_model=dim,
|
||||||
|
d_state=8,
|
||||||
|
d_conv=3,
|
||||||
|
expand=1
|
||||||
|
)
|
||||||
|
|
||||||
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||||
|
self.norm2 = norm_layer(dim)
|
||||||
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||||
|
self.mlp = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||||
|
use_layer_scale = layer_scale is not None and type(layer_scale) in [int, float]
|
||||||
|
self.g_1 = nn.Parameter(layer_scale * torch.ones(dim)) if use_layer_scale else 1
|
||||||
|
self.g_2 = nn.Parameter(layer_scale * torch.ones(dim)) if use_layer_scale else 1
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x + self.drop_path(self.g_1 * self.mixer(self.norm1(x)))
|
||||||
|
x = x + self.drop_path(self.g_2 * self.mlp(self.norm2(x)))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MambaVisionLayer(nn.Module):
|
||||||
|
"""
|
||||||
|
MambaVision layer"
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
dim,
|
||||||
|
depth,
|
||||||
|
num_heads,
|
||||||
|
window_size,
|
||||||
|
conv=False,
|
||||||
|
downsample=True,
|
||||||
|
mlp_ratio=4.,
|
||||||
|
qkv_bias=True,
|
||||||
|
qk_scale=None,
|
||||||
|
drop=0.,
|
||||||
|
attn_drop=0.,
|
||||||
|
drop_path=0.,
|
||||||
|
layer_scale=None,
|
||||||
|
layer_scale_conv=None,
|
||||||
|
transformer_blocks = [],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
dim: feature size dimension.
|
||||||
|
depth: number of layers in each stage.
|
||||||
|
window_size: window size in each stage.
|
||||||
|
conv: bool argument for conv stage flag.
|
||||||
|
downsample: bool argument for down-sampling.
|
||||||
|
mlp_ratio: MLP ratio.
|
||||||
|
num_heads: number of heads in each stage.
|
||||||
|
qkv_bias: bool argument for query, key, value learnable bias.
|
||||||
|
qk_scale: bool argument to scaling query, key.
|
||||||
|
drop: dropout rate.
|
||||||
|
attn_drop: attention dropout rate.
|
||||||
|
drop_path: drop path rate.
|
||||||
|
norm_layer: normalization layer.
|
||||||
|
layer_scale: layer scaling coefficient.
|
||||||
|
layer_scale_conv: conv layer scaling coefficient.
|
||||||
|
transformer_blocks: list of transformer blocks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.conv = conv
|
||||||
|
self.transformer_block = False
|
||||||
|
if conv:
|
||||||
|
self.blocks = nn.ModuleList([ConvBlock(dim=dim,
|
||||||
|
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
||||||
|
layer_scale=layer_scale_conv)
|
||||||
|
for i in range(depth)])
|
||||||
|
self.transformer_block = False
|
||||||
|
else:
|
||||||
|
self.transformer_block = True
|
||||||
|
self.blocks = nn.ModuleList([Block(dim=dim,
|
||||||
|
counter=i,
|
||||||
|
transformer_blocks=transformer_blocks,
|
||||||
|
num_heads=num_heads,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
qk_scale=qk_scale,
|
||||||
|
drop=drop,
|
||||||
|
attn_drop=attn_drop,
|
||||||
|
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
||||||
|
layer_scale=layer_scale)
|
||||||
|
for i in range(depth)])
|
||||||
|
self.transformer_block = True
|
||||||
|
|
||||||
|
self.downsample = None if not downsample else Downsample(dim=dim)
|
||||||
|
self.do_gt = False
|
||||||
|
self.window_size = window_size
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
_, _, H, W = x.shape
|
||||||
|
|
||||||
|
if self.transformer_block:
|
||||||
|
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
||||||
|
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
||||||
|
if pad_r > 0 or pad_b > 0:
|
||||||
|
x = torch.nn.functional.pad(x, (0,pad_r,0,pad_b))
|
||||||
|
_, _, Hp, Wp = x.shape
|
||||||
|
else:
|
||||||
|
Hp, Wp = H, W
|
||||||
|
x = window_partition(x, self.window_size)
|
||||||
|
|
||||||
|
for _, blk in enumerate(self.blocks):
|
||||||
|
x = blk(x)
|
||||||
|
if self.transformer_block:
|
||||||
|
x = window_reverse(x, self.window_size, Hp, Wp)
|
||||||
|
if pad_r > 0 or pad_b > 0:
|
||||||
|
x = x[:, :, :H, :W].contiguous()
|
||||||
|
if self.downsample is None:
|
||||||
|
return x, x
|
||||||
|
return self.downsample(x), x
|
||||||
|
|
||||||
|
|
||||||
|
class MambaVision(nn.Module):
|
||||||
|
"""
|
||||||
|
MambaVision,
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
dim,
|
||||||
|
in_dim,
|
||||||
|
depths,
|
||||||
|
window_size,
|
||||||
|
mlp_ratio,
|
||||||
|
num_heads,
|
||||||
|
drop_path_rate=0.2,
|
||||||
|
in_chans=3,
|
||||||
|
num_classes=1000,
|
||||||
|
qkv_bias=True,
|
||||||
|
qk_scale=None,
|
||||||
|
drop_rate=0.,
|
||||||
|
attn_drop_rate=0.,
|
||||||
|
layer_scale=None,
|
||||||
|
layer_scale_conv=None,
|
||||||
|
**kwargs):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
dim: feature size dimension.
|
||||||
|
depths: number of layers in each stage.
|
||||||
|
window_size: window size in each stage.
|
||||||
|
mlp_ratio: MLP ratio.
|
||||||
|
num_heads: number of heads in each stage.
|
||||||
|
drop_path_rate: drop path rate.
|
||||||
|
in_chans: number of input channels.
|
||||||
|
num_classes: number of classes.
|
||||||
|
qkv_bias: bool argument for query, key, value learnable bias.
|
||||||
|
qk_scale: bool argument to scaling query, key.
|
||||||
|
drop_rate: dropout rate.
|
||||||
|
attn_drop_rate: attention dropout rate.
|
||||||
|
norm_layer: normalization layer.
|
||||||
|
layer_scale: layer scaling coefficient.
|
||||||
|
layer_scale_conv: conv layer scaling coefficient.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
num_features = int(dim * 2 ** (len(depths) - 1))
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.patch_embed = PatchEmbed(in_chans=in_chans, in_dim=in_dim, dim=dim)
|
||||||
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
||||||
|
self.levels = nn.ModuleList()
|
||||||
|
for i in range(len(depths)):
|
||||||
|
conv = True if (i == 0 or i == 1) else False
|
||||||
|
level = MambaVisionLayer(dim=int(dim * 2 ** i),
|
||||||
|
depth=depths[i],
|
||||||
|
num_heads=num_heads[i],
|
||||||
|
window_size=window_size[i],
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
qk_scale=qk_scale,
|
||||||
|
conv=conv,
|
||||||
|
drop=drop_rate,
|
||||||
|
attn_drop=attn_drop_rate,
|
||||||
|
drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])],
|
||||||
|
downsample=(i < 3),
|
||||||
|
layer_scale=layer_scale,
|
||||||
|
layer_scale_conv=layer_scale_conv,
|
||||||
|
transformer_blocks=list(range(depths[i]//2+1, depths[i])) if depths[i]%2!=0 else list(range(depths[i]//2, depths[i])),
|
||||||
|
)
|
||||||
|
self.levels.append(level)
|
||||||
|
self.norm = nn.BatchNorm2d(num_features)
|
||||||
|
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
||||||
|
self.head = nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||||
|
self.apply(self._init_weights)
|
||||||
|
|
||||||
|
def _init_weights(self, m):
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
trunc_normal_(m.weight, std=.02)
|
||||||
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.LayerNorm):
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
nn.init.constant_(m.weight, 1.0)
|
||||||
|
elif isinstance(m, LayerNorm2d):
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
nn.init.constant_(m.weight, 1.0)
|
||||||
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
|
nn.init.ones_(m.weight)
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def no_weight_decay_keywords(self):
|
||||||
|
return {'rpb'}
|
||||||
|
|
||||||
|
def forward_features(self, x):
|
||||||
|
x = self.patch_embed(x)
|
||||||
|
outs = []
|
||||||
|
for level in self.levels:
|
||||||
|
x, xo = level(x)
|
||||||
|
outs.append(xo)
|
||||||
|
x = self.norm(x)
|
||||||
|
x = self.avgpool(x)
|
||||||
|
x = torch.flatten(x, 1)
|
||||||
|
return x, outs
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x, outs = self.forward_features(x)
|
||||||
|
x = self.head(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _load_state_dict(self,
|
||||||
|
pretrained,
|
||||||
|
strict: bool = False):
|
||||||
|
_load_checkpoint(self,
|
||||||
|
pretrained,
|
||||||
|
strict=strict)
|
||||||
|
|
||||||
|
|
||||||
|
class MambaVisionModel(PreTrainedModel):
|
||||||
|
config_class = MambaVisionConfig
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.model = MambaVision(
|
||||||
|
depths=config.depths,
|
||||||
|
num_heads=config.num_heads,
|
||||||
|
window_size=config.window_size,
|
||||||
|
dim=config.dim,
|
||||||
|
in_dim=config.in_dim,
|
||||||
|
mlp_ratio=config.mlp_ratio,
|
||||||
|
layer_scale=config.layer_scale,
|
||||||
|
layer_scale_conv=config.layer_scale_conv
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, tensor):
|
||||||
|
return self.model.forward_features(tensor)
|
||||||
|
|
||||||
|
|
||||||
|
class MambaVisionModelForImageClassification(PreTrainedModel):
|
||||||
|
config_class = MambaVisionConfig
|
||||||
|
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.model = MambaVision(
|
||||||
|
depths=config.depths,
|
||||||
|
num_heads=config.num_heads,
|
||||||
|
window_size=config.window_size,
|
||||||
|
dim=config.dim,
|
||||||
|
in_dim=config.in_dim,
|
||||||
|
mlp_ratio=config.mlp_ratio,
|
||||||
|
layer_scale=config.layer_scale,
|
||||||
|
layer_scale_conv=config.layer_scale_conv
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, tensor, labels=None):
|
||||||
|
logits = self.model(tensor)
|
||||||
|
if labels is not None:
|
||||||
|
loss = torch.nn.cross_entropy(logits, labels)
|
||||||
|
return {"loss": loss, "logits": logits}
|
||||||
|
return {"logits": logits}
|
Loading…
Reference in New Issue