182 lines
5.2 KiB
Python
182 lines
5.2 KiB
Python
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch
|
|
from torch import Tensor
|
|
from typing import Optional
|
|
|
|
class Attention(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
num_heads: int,
|
|
dropout_prob: float = 0
|
|
):
|
|
super().__init__()
|
|
|
|
self.use_sdp = int(torch.__version__[0]) > 1
|
|
|
|
self.query = nn.Linear(dim, dim)
|
|
self.key = nn.Linear(dim, dim)
|
|
self.value = nn.Linear(dim, dim)
|
|
self.out = nn.Linear(dim, dim)
|
|
|
|
self.dropout_prob = dropout_prob
|
|
self.num_heads = num_heads
|
|
self.head_dim = dim // num_heads
|
|
self.scale = self.head_dim**-0.5
|
|
|
|
def forward(
|
|
self,
|
|
x: Tensor,
|
|
attn_mask: Optional[Tensor] = None,
|
|
context: Optional[Tensor] = None,
|
|
is_causal: bool = False,
|
|
) -> Tensor:
|
|
|
|
query = self.reshape(self.query(x))
|
|
key = self.reshape(self.key(x if context is None else context))
|
|
value = self.reshape(self.value(x if context is None else context))
|
|
|
|
if self.use_sdp:
|
|
x = F.scaled_dot_product_attention(
|
|
query,
|
|
key,
|
|
value,
|
|
attn_mask,
|
|
dropout_p=self.dropout_prob if self.training else 0,
|
|
is_causal=is_causal,
|
|
)
|
|
else:
|
|
attn = query @ key.transpose(-2, -1) * self.scale
|
|
if attn_mask is not None:
|
|
attn += attn_mask
|
|
|
|
attn = attn.softmax(dim=-1)
|
|
x = attn @ value
|
|
|
|
return self.out(x.transpose(2, 1).flatten(2))
|
|
|
|
def reshape(self, x: Tensor) -> Tensor:
|
|
batch_size, seq_len, _ = x.shape
|
|
x = x.view(batch_size, seq_len, self.num_heads, self.head_dim)
|
|
return x.transpose(2, 1)
|
|
|
|
|
|
class MLP(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
dim_expand_factor: int = 4
|
|
):
|
|
super().__init__()
|
|
|
|
self.hidden_layer = nn.Linear(dim, dim * dim_expand_factor)
|
|
self.output_layer = nn.Linear(dim * dim_expand_factor, dim)
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
x = F.gelu(self.hidden_layer(x))
|
|
return self.output_layer(x)
|
|
|
|
|
|
class LayerScale(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
init_values: float = 1e-5,
|
|
inplace: bool = False
|
|
):
|
|
super().__init__()
|
|
self.weight = nn.Parameter(init_values * torch.ones(dim))
|
|
self.inplace = inplace
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
return x.mul_(self.weight) if self.inplace else x * self.weight
|
|
|
|
|
|
class VisionEncoderBlock(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
num_heads: int
|
|
):
|
|
super().__init__()
|
|
self.norm1 = nn.LayerNorm(dim, eps=1e-6)
|
|
self.attn = Attention(dim, num_heads)
|
|
self.ls1 = LayerScale(dim)
|
|
|
|
self.norm2 = nn.LayerNorm(dim, eps=1e-6)
|
|
self.mlp = MLP(dim)
|
|
self.ls2 = LayerScale(dim)
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
x = x + self.ls1(self.attn(self.norm1(x)))
|
|
x = x + self.ls2(self.mlp(self.norm2(x)))
|
|
return x
|
|
|
|
|
|
class VisionEncoder(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
patch_size: int,
|
|
num_layers: int,
|
|
num_heads: int,
|
|
):
|
|
super().__init__()
|
|
|
|
self.n_patch = 224 // patch_size
|
|
self.seq_len = self.n_patch ** 2
|
|
self.patch_size = patch_size
|
|
|
|
self.patch_embed = nn.Conv2d(3, dim, patch_size, patch_size)
|
|
self.pos_embed = nn.Parameter(torch.randn(1, self.seq_len, dim) * 0.02)
|
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, dim))
|
|
self.interpolate_offset = 0.1
|
|
self.interpolate_antialias = False
|
|
|
|
self.blocks = nn.Sequential(
|
|
*[
|
|
VisionEncoderBlock(dim, num_heads)
|
|
for _ in range(num_layers)
|
|
]
|
|
)
|
|
|
|
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
|
|
|
def interpolate_pos_encoding(self, x, h, w):
|
|
previous_dtype = x.dtype
|
|
|
|
if x.shape[1] == self.seq_len and w == h:
|
|
return self.pos_embed
|
|
|
|
pos_embed = self.pos_embed.float()
|
|
|
|
dim = x.shape[-1]
|
|
w0 = w // self.patch_size
|
|
h0 = h // self.patch_size
|
|
# we add a small number to avoid floating point error in the interpolation
|
|
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
|
w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
|
|
sx, sy = float(w0) / self.n_patch, float(h0) / self.n_patch
|
|
|
|
pos_embed = nn.functional.interpolate(
|
|
pos_embed.reshape(1, self.n_patch, self.n_patch, dim).permute(0, 3, 1, 2),
|
|
scale_factor=(sy, sx),
|
|
mode="bicubic",
|
|
antialias=self.interpolate_antialias,
|
|
)
|
|
|
|
return pos_embed.to(previous_dtype).flatten(start_dim=2).transpose(2, 1)
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
h, w = x.shape[2:]
|
|
x = self.patch_embed(x).flatten(start_dim=2).transpose(2, 1)
|
|
x = x + self.interpolate_pos_encoding(x, h, w)
|
|
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
|
x = self.blocks(x)
|
|
return self.norm(x) |