# -*- encoding: utf-8 -*- # File: modeling_megrezo.py # Description: This file contains the implementation of the Megrez-Omni model. import torch from torch.nn.utils.rnn import pad_sequence from transformers import AutoProcessor from transformers import LlamaForCausalLM from transformers.modeling_utils import PreTrainedModel from transformers.utils import add_start_docstrings from transformers.utils import add_start_docstrings_to_model_forward from transformers.utils import is_flash_attn_2_available from transformers.utils import is_flash_attn_greater_or_equal_2_10 from transformers.utils import logging from transformers.utils import replace_return_docstrings from .audio import AudioEncoder from .configuration_megrezo import MegrezOConfig from .modeling_navit_siglip import SiglipVisionTransformer from .resampler import Resampler def insert_audio_embeddings(text_embeddings, inserted_embeddings, inserted_bounds): inserted_bounds = inserted_bounds.long() for idx in range(len(inserted_embeddings)): bid = inserted_bounds[idx][0] start_id = inserted_bounds[idx][1] end_id = inserted_bounds[idx][2] embedding = inserted_embeddings[idx] text_embeddings[bid, start_id + 1 : end_id] = embedding return text_embeddings def insert_image_embeddings(text_embeddings, inserted_embeddings, inserted_bounds): inserted_bounds = inserted_bounds.long() for idx in range(len(inserted_embeddings)): bid = inserted_bounds[idx][0] start_id = inserted_bounds[idx][1] end_id = inserted_bounds[idx][2] embedding = inserted_embeddings[idx] text_embeddings[bid, start_id:end_id] = embedding return text_embeddings MegrezO_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Parameters: config ([`MegrezOConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ @add_start_docstrings( "The bare MegrezO Model outputting raw hidden-states without any specific head on top.", MegrezO_START_DOCSTRING, ) class MegrezOPreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True config_class = MegrezOConfig _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True class AudioModel(torch.nn.Module): def __init__(self, config: MegrezOConfig): super(AudioModel, self).__init__() self.config = config self.audio = AudioEncoder(**config.audio_config.to_dict()) def forward(self, audio_info): audios = audio_info["input_audios"] input_audio_lengths = audio_info["input_audio_lengths"] audio_span_tokens = audio_info["audio_span_tokens"] audios_features = self.audio.encode(audios, input_audio_lengths, audio_span_tokens) return audios_features class VisionModel(torch.nn.Module): def __init__(self, config: MegrezOConfig): super(VisionModel, self).__init__() self.config = config self.vpm = self.init_vision_module() self.resampler = self.init_resampler(self.config.hidden_size, self.vpm.embed_dim) def init_vision_module(self): if self.config._attn_implementation == "flash_attention_2": self.config.vision_config._attn_implementation = "flash_attention_2" else: # not suport sdpa self.config.vision_config._attn_implementation = "eager" model = SiglipVisionTransformer(self.config.vision_config) if self.config.drop_vision_last_layer: model.encoder.layers = model.encoder.layers[:-1] setattr(model, "embed_dim", model.embeddings.embed_dim) setattr(model, "patch_size", model.embeddings.patch_size) return model def init_resampler(self, embed_dim, vision_dim): return Resampler( num_queries=self.config.query_num, embed_dim=embed_dim, num_heads=embed_dim // 128, kv_dim=vision_dim, adaptive=True, ) def get_vision_embedding( self, all_pixel_values: torch.Tensor, patch_attention_mask: torch.Tensor, tgt_sizes: torch.Tensor, ): B = all_pixel_values.size(0) vision_batch_size = self.config.vision_batch_size if B > vision_batch_size: hs = [] for i in range(0, B, vision_batch_size): start_idx = i end_idx = i + vision_batch_size tmp_hs = self.vpm( all_pixel_values[start_idx:end_idx], patch_attention_mask=patch_attention_mask[start_idx:end_idx], tgt_sizes=tgt_sizes[start_idx:end_idx], ).last_hidden_state hs.append(tmp_hs) vision_embedding = torch.cat(hs, dim=0) else: vision_embedding = self.vpm( all_pixel_values, patch_attention_mask=patch_attention_mask, tgt_sizes=tgt_sizes, ).last_hidden_state return vision_embedding def _prepare_vision_input(self, images, patch_attention_mask, tgt_sizes): # (TODO) Move to processor device = self.vpm.device dtype = self.vpm.dtype pixel_values = torch.stack([(image.to(device) - 127.5) / 127.5 for image in images]).type(dtype) patch_attention_mask = patch_attention_mask.to(device) return pixel_values, patch_attention_mask, tgt_sizes def forward(self, images, tgt_sizes, patch_attention_mask): pixel_values, patch_attention_mask, tgt_sizes = self._prepare_vision_input( images, patch_attention_mask, tgt_sizes ) embedding = self.get_vision_embedding(pixel_values, patch_attention_mask, tgt_sizes) embedding = self.resampler(embedding, tgt_sizes) return embedding class MegrezO(MegrezOPreTrainedModel): def __init__(self, config): super().__init__(config) self.llm = LlamaForCausalLM(config) self.vision = VisionModel(config) self.audio = AudioModel(config) self.post_init() self.processor = None # Will be set in the training script self.tune_vision = False self.tune_audio = False def _get_or_init_processor(self): if self.processor is None: self.processor = AutoProcessor.from_pretrained( self.config._name_or_path, trust_remote_code=True, ) return self.processor def convert_to_device(self, mini_batch): for key in mini_batch: if isinstance(mini_batch[key], torch.Tensor): mini_batch[key] = mini_batch[key].to(self.device) if isinstance(mini_batch[key], list): return_value = [] for value in mini_batch[key]: if isinstance(value, torch.Tensor): value = value.to(self.device) return_value.append(value) mini_batch[key] = return_value return mini_batch def compose_embeddings(self, mini_batch): position_ids = mini_batch["position_ids"] input_ids = mini_batch["input_ids"] image_encoding = mini_batch.get("image_encoding") audio_encoding = mini_batch.get("audio_encoding") if position_ids.dtype != torch.int64: position_ids = position_ids.long() embeddings_text = self.llm.model.embed_tokens(input_ids) input_embeds = embeddings_text if image_encoding: pixel_values = image_encoding["pixel_values"] tgt_sizes = image_encoding["tgt_sizes"] patch_attention_mask = image_encoding["patch_attention_mask"] bounds_image = image_encoding["image_bounds"] embeddings_image = self.vision(pixel_values, tgt_sizes, patch_attention_mask=patch_attention_mask) input_embeds = insert_image_embeddings(embeddings_text, embeddings_image, bounds_image) elif self.training and self.tune_vision: pixel_values = torch.zeros((3, 14, 3584), dtype=torch.float32) tgt_sizes = torch.tensor([[16, 16]], dtype=torch.int64) patch_attention_mask = torch.ones((3, 14), dtype=torch.float32) embeddings_image = self.vision(pixel_values, tgt_sizes, patch_attention_mask=patch_attention_mask) input_embeds += embeddings_image[0].sum() * 0.0 if audio_encoding: embeddings_audio = self.audio(audio_encoding) bounds_audio = audio_encoding["audio_bounds"] input_embeds = insert_audio_embeddings(embeddings_text, embeddings_audio, bounds_audio) elif self.training and self.tune_audio: dummy_audio = torch.zeros((1, 128, 3000), dtype=torch.float32) dummy_audio_lengths = torch.tensor([[125, 62]], dtype=torch.int32) dummy_span_tokens = [64] dummy_audio_encoding = [ { "input_audios": dummy_audio, "input_audio_lengths": dummy_audio_lengths, "audio_span_tokens": dummy_span_tokens, } ] embeddings_audio = self.audio(dummy_audio_encoding) input_embeds += embeddings_audio[0].sum() * 0.0 return input_ids, input_embeds, position_ids def forward(self, data, **kwargs): if self.training: _, input_embeds, position_ids = self.compose_embeddings(data) return self.llm.forward( input_ids=None, position_ids=position_ids, inputs_embeds=input_embeds, **kwargs, ) return self.llm.forward(**kwargs) def generate( self, input_ids, position_ids, attention_mask, image_encoding=None, audio_encoding=None, **kwargs, ): tokenizer = self._get_or_init_processor().tokenizer data = { "input_ids": input_ids, "position_ids": position_ids, "attention_mask": attention_mask, "image_encoding": image_encoding, "audio_encoding": audio_encoding, } data = self.convert_to_device(data) input_ids, input_embeds, position_ids = self.compose_embeddings(data) output = self.llm.generate( inputs_embeds=input_embeds, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, **kwargs, ) return output def trim_stop_words(self, response, stop_words): if stop_words: for stop in stop_words: idx = response.find(stop) if idx != -1: response = response[:idx] return response @torch.inference_mode() def chat(self, input_msgs, processor=None, sampling=False, **kwargs): if processor is None: processor = self._get_or_init_processor() if sampling: generation_config = { "top_p": 0.8, "top_k": 100, "temperature": 0.7, "do_sample": True, "repetition_penalty": 1.05, } else: generation_config = { "num_beams": 3, "repetition_penalty": 1.2, } generation_config.update(kwargs) if generation_config.get("temperature") == 0: generation_config["do_sample"] = False data = processor(input_msgs) output_ids = self.generate(**data, **generation_config) tokenizer = processor.tokenizer answer = tokenizer.decode(output_ids[0]) return answer