329 lines
12 KiB
Python
329 lines
12 KiB
Python
|
# -*- encoding: utf-8 -*-
|
||
|
# File: modeling_megrezo.py
|
||
|
# Description: This file contains the implementation of the Megrez-Omni model.
|
||
|
|
||
|
|
||
|
import torch
|
||
|
from torch.nn.utils.rnn import pad_sequence
|
||
|
from transformers import AutoProcessor
|
||
|
from transformers import LlamaForCausalLM
|
||
|
from transformers.modeling_utils import PreTrainedModel
|
||
|
from transformers.utils import add_start_docstrings
|
||
|
from transformers.utils import add_start_docstrings_to_model_forward
|
||
|
from transformers.utils import is_flash_attn_2_available
|
||
|
from transformers.utils import is_flash_attn_greater_or_equal_2_10
|
||
|
from transformers.utils import logging
|
||
|
from transformers.utils import replace_return_docstrings
|
||
|
|
||
|
from .audio import AudioEncoder
|
||
|
from .configuration_megrezo import MegrezOConfig
|
||
|
from .modeling_navit_siglip import SiglipVisionTransformer
|
||
|
from .resampler import Resampler
|
||
|
|
||
|
|
||
|
def insert_audio_embeddings(text_embeddings, inserted_embeddings, inserted_bounds):
|
||
|
|
||
|
inserted_bounds = inserted_bounds.long()
|
||
|
|
||
|
for idx in range(len(inserted_embeddings)):
|
||
|
bid = inserted_bounds[idx][0]
|
||
|
start_id = inserted_bounds[idx][1]
|
||
|
end_id = inserted_bounds[idx][2]
|
||
|
embedding = inserted_embeddings[idx]
|
||
|
text_embeddings[bid, start_id + 1 : end_id] = embedding
|
||
|
|
||
|
return text_embeddings
|
||
|
|
||
|
|
||
|
def insert_image_embeddings(text_embeddings, inserted_embeddings, inserted_bounds):
|
||
|
|
||
|
inserted_bounds = inserted_bounds.long()
|
||
|
for idx in range(len(inserted_embeddings)):
|
||
|
bid = inserted_bounds[idx][0]
|
||
|
start_id = inserted_bounds[idx][1]
|
||
|
end_id = inserted_bounds[idx][2]
|
||
|
embedding = inserted_embeddings[idx]
|
||
|
text_embeddings[bid, start_id:end_id] = embedding
|
||
|
|
||
|
return text_embeddings
|
||
|
|
||
|
|
||
|
MegrezO_START_DOCSTRING = r"""
|
||
|
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||
|
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||
|
etc.)
|
||
|
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||
|
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||
|
and behavior.
|
||
|
Parameters:
|
||
|
config ([`MegrezOConfig`]):
|
||
|
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
||
|
load the weights associated with the model, only the configuration. Check out the
|
||
|
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||
|
"""
|
||
|
|
||
|
|
||
|
@add_start_docstrings(
|
||
|
"The bare MegrezO Model outputting raw hidden-states without any specific head on top.",
|
||
|
MegrezO_START_DOCSTRING,
|
||
|
)
|
||
|
class MegrezOPreTrainedModel(PreTrainedModel):
|
||
|
base_model_prefix = "model"
|
||
|
supports_gradient_checkpointing = True
|
||
|
config_class = MegrezOConfig
|
||
|
_skip_keys_device_placement = "past_key_values"
|
||
|
_supports_flash_attn_2 = True
|
||
|
|
||
|
|
||
|
class AudioModel(torch.nn.Module):
|
||
|
|
||
|
def __init__(self, config: MegrezOConfig):
|
||
|
super(AudioModel, self).__init__()
|
||
|
self.config = config
|
||
|
self.audio = AudioEncoder(**config.audio_config.to_dict())
|
||
|
|
||
|
def forward(self, audio_info):
|
||
|
audios = audio_info["input_audios"]
|
||
|
input_audio_lengths = audio_info["input_audio_lengths"]
|
||
|
audio_span_tokens = audio_info["audio_span_tokens"]
|
||
|
audios_features = self.audio.encode(audios, input_audio_lengths, audio_span_tokens)
|
||
|
return audios_features
|
||
|
|
||
|
|
||
|
class VisionModel(torch.nn.Module):
|
||
|
|
||
|
def __init__(self, config: MegrezOConfig):
|
||
|
super(VisionModel, self).__init__()
|
||
|
self.config = config
|
||
|
self.vpm = self.init_vision_module()
|
||
|
self.resampler = self.init_resampler(self.config.hidden_size, self.vpm.embed_dim)
|
||
|
|
||
|
def init_vision_module(self):
|
||
|
if self.config._attn_implementation == "flash_attention_2":
|
||
|
self.config.vision_config._attn_implementation = "flash_attention_2"
|
||
|
else:
|
||
|
# not suport sdpa
|
||
|
self.config.vision_config._attn_implementation = "eager"
|
||
|
model = SiglipVisionTransformer(self.config.vision_config)
|
||
|
if self.config.drop_vision_last_layer:
|
||
|
model.encoder.layers = model.encoder.layers[:-1]
|
||
|
|
||
|
setattr(model, "embed_dim", model.embeddings.embed_dim)
|
||
|
setattr(model, "patch_size", model.embeddings.patch_size)
|
||
|
|
||
|
return model
|
||
|
|
||
|
def init_resampler(self, embed_dim, vision_dim):
|
||
|
return Resampler(
|
||
|
num_queries=self.config.query_num,
|
||
|
embed_dim=embed_dim,
|
||
|
num_heads=embed_dim // 128,
|
||
|
kv_dim=vision_dim,
|
||
|
adaptive=True,
|
||
|
)
|
||
|
|
||
|
def get_vision_embedding(
|
||
|
self,
|
||
|
all_pixel_values: torch.Tensor,
|
||
|
patch_attention_mask: torch.Tensor,
|
||
|
tgt_sizes: torch.Tensor,
|
||
|
):
|
||
|
B = all_pixel_values.size(0)
|
||
|
vision_batch_size = self.config.vision_batch_size
|
||
|
if B > vision_batch_size:
|
||
|
hs = []
|
||
|
for i in range(0, B, vision_batch_size):
|
||
|
start_idx = i
|
||
|
end_idx = i + vision_batch_size
|
||
|
tmp_hs = self.vpm(
|
||
|
all_pixel_values[start_idx:end_idx],
|
||
|
patch_attention_mask=patch_attention_mask[start_idx:end_idx],
|
||
|
tgt_sizes=tgt_sizes[start_idx:end_idx],
|
||
|
).last_hidden_state
|
||
|
hs.append(tmp_hs)
|
||
|
vision_embedding = torch.cat(hs, dim=0)
|
||
|
else:
|
||
|
vision_embedding = self.vpm(
|
||
|
all_pixel_values,
|
||
|
patch_attention_mask=patch_attention_mask,
|
||
|
tgt_sizes=tgt_sizes,
|
||
|
).last_hidden_state
|
||
|
|
||
|
return vision_embedding
|
||
|
|
||
|
def _prepare_vision_input(self, images, patch_attention_mask, tgt_sizes):
|
||
|
# (TODO) Move to processor
|
||
|
device = self.vpm.device
|
||
|
dtype = self.vpm.dtype
|
||
|
|
||
|
pixel_values = torch.stack([(image.to(device) - 127.5) / 127.5 for image in images]).type(dtype)
|
||
|
patch_attention_mask = patch_attention_mask.to(device)
|
||
|
return pixel_values, patch_attention_mask, tgt_sizes
|
||
|
|
||
|
def forward(self, images, tgt_sizes, patch_attention_mask):
|
||
|
pixel_values, patch_attention_mask, tgt_sizes = self._prepare_vision_input(
|
||
|
images, patch_attention_mask, tgt_sizes
|
||
|
)
|
||
|
embedding = self.get_vision_embedding(pixel_values, patch_attention_mask, tgt_sizes)
|
||
|
embedding = self.resampler(embedding, tgt_sizes)
|
||
|
return embedding
|
||
|
|
||
|
|
||
|
class MegrezO(MegrezOPreTrainedModel):
|
||
|
|
||
|
def __init__(self, config):
|
||
|
super().__init__(config)
|
||
|
self.llm = LlamaForCausalLM(config)
|
||
|
self.vision = VisionModel(config)
|
||
|
self.audio = AudioModel(config)
|
||
|
self.post_init()
|
||
|
self.processor = None
|
||
|
|
||
|
# Will be set in the training script
|
||
|
self.tune_vision = False
|
||
|
self.tune_audio = False
|
||
|
|
||
|
def _get_or_init_processor(self):
|
||
|
|
||
|
if self.processor is None:
|
||
|
self.processor = AutoProcessor.from_pretrained(
|
||
|
self.config._name_or_path,
|
||
|
trust_remote_code=True,
|
||
|
)
|
||
|
|
||
|
return self.processor
|
||
|
|
||
|
def convert_to_device(self, mini_batch):
|
||
|
for key in mini_batch:
|
||
|
if isinstance(mini_batch[key], torch.Tensor):
|
||
|
mini_batch[key] = mini_batch[key].to(self.device)
|
||
|
if isinstance(mini_batch[key], list):
|
||
|
return_value = []
|
||
|
for value in mini_batch[key]:
|
||
|
if isinstance(value, torch.Tensor):
|
||
|
value = value.to(self.device)
|
||
|
return_value.append(value)
|
||
|
mini_batch[key] = return_value
|
||
|
|
||
|
return mini_batch
|
||
|
|
||
|
def compose_embeddings(self, mini_batch):
|
||
|
position_ids = mini_batch["position_ids"]
|
||
|
input_ids = mini_batch["input_ids"]
|
||
|
image_encoding = mini_batch.get("image_encoding")
|
||
|
audio_encoding = mini_batch.get("audio_encoding")
|
||
|
if position_ids.dtype != torch.int64:
|
||
|
position_ids = position_ids.long()
|
||
|
|
||
|
embeddings_text = self.llm.model.embed_tokens(input_ids)
|
||
|
input_embeds = embeddings_text
|
||
|
if image_encoding:
|
||
|
pixel_values = image_encoding["pixel_values"]
|
||
|
tgt_sizes = image_encoding["tgt_sizes"]
|
||
|
patch_attention_mask = image_encoding["patch_attention_mask"]
|
||
|
bounds_image = image_encoding["image_bounds"]
|
||
|
embeddings_image = self.vision(pixel_values, tgt_sizes, patch_attention_mask=patch_attention_mask)
|
||
|
input_embeds = insert_image_embeddings(embeddings_text, embeddings_image, bounds_image)
|
||
|
elif self.training and self.tune_vision:
|
||
|
pixel_values = torch.zeros((3, 14, 3584), dtype=torch.float32)
|
||
|
tgt_sizes = torch.tensor([[16, 16]], dtype=torch.int64)
|
||
|
patch_attention_mask = torch.ones((3, 14), dtype=torch.float32)
|
||
|
embeddings_image = self.vision(pixel_values, tgt_sizes, patch_attention_mask=patch_attention_mask)
|
||
|
input_embeds += embeddings_image[0].sum() * 0.0
|
||
|
|
||
|
if audio_encoding:
|
||
|
embeddings_audio = self.audio(audio_encoding)
|
||
|
bounds_audio = audio_encoding["audio_bounds"]
|
||
|
input_embeds = insert_audio_embeddings(embeddings_text, embeddings_audio, bounds_audio)
|
||
|
elif self.training and self.tune_audio:
|
||
|
dummy_audio = torch.zeros((1, 128, 3000), dtype=torch.float32)
|
||
|
dummy_audio_lengths = torch.tensor([[125, 62]], dtype=torch.int32)
|
||
|
dummy_span_tokens = [64]
|
||
|
dummy_audio_encoding = [
|
||
|
{
|
||
|
"input_audios": dummy_audio,
|
||
|
"input_audio_lengths": dummy_audio_lengths,
|
||
|
"audio_span_tokens": dummy_span_tokens,
|
||
|
}
|
||
|
]
|
||
|
embeddings_audio = self.audio(dummy_audio_encoding)
|
||
|
input_embeds += embeddings_audio[0].sum() * 0.0
|
||
|
|
||
|
return input_ids, input_embeds, position_ids
|
||
|
|
||
|
def forward(self, data, **kwargs):
|
||
|
if self.training:
|
||
|
_, input_embeds, position_ids = self.compose_embeddings(data)
|
||
|
return self.llm.forward(
|
||
|
input_ids=None,
|
||
|
position_ids=position_ids,
|
||
|
inputs_embeds=input_embeds,
|
||
|
**kwargs,
|
||
|
)
|
||
|
return self.llm.forward(**kwargs)
|
||
|
|
||
|
def generate(
|
||
|
self,
|
||
|
input_ids,
|
||
|
position_ids,
|
||
|
attention_mask,
|
||
|
image_encoding=None,
|
||
|
audio_encoding=None,
|
||
|
**kwargs,
|
||
|
):
|
||
|
tokenizer = self._get_or_init_processor().tokenizer
|
||
|
data = {
|
||
|
"input_ids": input_ids,
|
||
|
"position_ids": position_ids,
|
||
|
"attention_mask": attention_mask,
|
||
|
"image_encoding": image_encoding,
|
||
|
"audio_encoding": audio_encoding,
|
||
|
}
|
||
|
data = self.convert_to_device(data)
|
||
|
input_ids, input_embeds, position_ids = self.compose_embeddings(data)
|
||
|
|
||
|
output = self.llm.generate(
|
||
|
inputs_embeds=input_embeds,
|
||
|
pad_token_id=tokenizer.pad_token_id,
|
||
|
eos_token_id=tokenizer.eos_token_id,
|
||
|
**kwargs,
|
||
|
)
|
||
|
return output
|
||
|
|
||
|
def trim_stop_words(self, response, stop_words):
|
||
|
if stop_words:
|
||
|
for stop in stop_words:
|
||
|
idx = response.find(stop)
|
||
|
if idx != -1:
|
||
|
response = response[:idx]
|
||
|
return response
|
||
|
|
||
|
@torch.inference_mode()
|
||
|
def chat(self, input_msgs, processor=None, sampling=False, **kwargs):
|
||
|
if processor is None:
|
||
|
processor = self._get_or_init_processor()
|
||
|
|
||
|
if sampling:
|
||
|
generation_config = {
|
||
|
"top_p": 0.8,
|
||
|
"top_k": 100,
|
||
|
"temperature": 0.7,
|
||
|
"do_sample": True,
|
||
|
"repetition_penalty": 1.05,
|
||
|
}
|
||
|
else:
|
||
|
generation_config = {
|
||
|
"num_beams": 3,
|
||
|
"repetition_penalty": 1.2,
|
||
|
}
|
||
|
|
||
|
generation_config.update(kwargs)
|
||
|
if generation_config.get("temperature") == 0:
|
||
|
generation_config["do_sample"] = False
|
||
|
|
||
|
data = processor(input_msgs)
|
||
|
output_ids = self.generate(**data, **generation_config)
|
||
|
tokenizer = processor.tokenizer
|
||
|
answer = tokenizer.decode(output_ids[0])
|
||
|
return answer
|