87 lines
3.2 KiB
Python
87 lines
3.2 KiB
Python
|
|
||
|
from typing import List, Optional, Tuple, Union
|
||
|
import torch, math
|
||
|
import torch.utils.checkpoint
|
||
|
from torch import nn
|
||
|
import transformers
|
||
|
from flash_attn import flash_attn_varlen_func
|
||
|
from transformers.activations import ACT2FN
|
||
|
from PIL import Image
|
||
|
import io, fire
|
||
|
from torch.nn import functional as F
|
||
|
|
||
|
class OmniVisualEncoder(transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VisionTransformerPretrainedModel):
|
||
|
def __init__(self, config):
|
||
|
super().__init__(config)
|
||
|
self.config_attn_implementation = 'flash_attention_2'
|
||
|
self.gradient_checkpointing = True # 强制开启
|
||
|
self._gradient_checkpointing_func = torch.utils.checkpoint.checkpoint
|
||
|
self.merge_size = config.merge_size if hasattr(config, 'merge_size') else 2
|
||
|
del self.merger
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
pixel_values: torch.Tensor,
|
||
|
grid_thw: torch.Tensor,
|
||
|
):
|
||
|
hidden_states = pixel_values.to(self.get_dtype())
|
||
|
grid_thw = grid_thw.to(pixel_values.device)
|
||
|
|
||
|
hidden_states = self.patch_embed(hidden_states)
|
||
|
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
||
|
|
||
|
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
|
||
|
dim=0, dtype=torch.int32
|
||
|
)
|
||
|
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||
|
|
||
|
for blk in self.blocks:
|
||
|
if self.gradient_checkpointing and self.training:
|
||
|
hidden_states = self._gradient_checkpointing_func(blk.__call__, hidden_states, cu_seqlens, rotary_pos_emb)
|
||
|
else:
|
||
|
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
|
||
|
|
||
|
return hidden_states
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def fake_input(self, device):
|
||
|
merge_size = max(self.merge_size, self.config.spatial_merge_size)
|
||
|
fake_image = torch.zeros([
|
||
|
1,
|
||
|
self.config.temporal_patch_size,
|
||
|
3,
|
||
|
merge_size // self.config.spatial_merge_size,
|
||
|
self.config.spatial_merge_size,
|
||
|
self.config.patch_size,
|
||
|
merge_size // self.config.spatial_merge_size,
|
||
|
self.config.spatial_merge_size,
|
||
|
self.config.patch_size,
|
||
|
], dtype=torch.float32, device=device)
|
||
|
patches = fake_image.permute(0, 3, 6, 4, 7, 2, 1, 5, 8)
|
||
|
flatten_patches = patches.reshape(
|
||
|
merge_size * merge_size, 3 * self.config.temporal_patch_size * self.config.patch_size * self.config.patch_size
|
||
|
)
|
||
|
return [flatten_patches], [(1, merge_size, merge_size)], [1]
|
||
|
|
||
|
|
||
|
class OmniVisualBridge(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
super().__init__()
|
||
|
self.config = config
|
||
|
self.merge_size = self.config.merge_size if hasattr(self.config, 'merge_size') else 2
|
||
|
self.hidden_size = config.embed_dim * (self.merge_size**2)
|
||
|
self.ln_q = nn.LayerNorm(config.embed_dim, eps=1e-6)
|
||
|
self.mlp = nn.Sequential(
|
||
|
nn.Linear(self.hidden_size, self.hidden_size),
|
||
|
nn.GELU(),
|
||
|
nn.Linear(self.hidden_size, config.hidden_size),
|
||
|
)
|
||
|
|
||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
|
x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
|
||
|
return x
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
fire.Fire()
|
||
|
|