Megrez-3B-Omni_a13954817325.../configuration_megrezo.py

88 lines
2.5 KiB
Python

"""MegrezO model configuration"""
from typing import Optional
from transformers.configuration_utils import PretrainedConfig
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.utils import logging
from .modeling_navit_siglip import SiglipVisionConfig
logger = logging.get_logger(__name__)
class AudioConfig(PretrainedConfig):
model_type = "megrezo"
def __init__(
self,
n_mels: int = 128,
n_ctx: int = 1500,
n_state: int = 1280,
n_head: int = 20,
n_layer: int = 32,
output_dim: int = 2560,
avg_pool: bool = True,
add_audio_bos_eos_token: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self.n_mels = n_mels
self.n_ctx = n_ctx
self.n_state = n_state
self.n_head = n_head
self.n_layer = n_layer
self.output_dim = output_dim
self.avg_pool = avg_pool
self.add_audio_bos_eos_token = add_audio_bos_eos_token
class MegrezOConfig(LlamaConfig):
model_type = "megrezo"
keys_to_ignore_at_inference = ["past_key_values"]
is_composition = True
_default_audio_config = {
"n_mels": 128,
"n_ctx": 1500,
"n_state": 1280,
"n_head": 20,
"n_layer": 32,
"output_dim": 2560,
"avg_pool": True,
"add_audio_bos_eos_token": True,
}
_default_vision_config = {
"intermediate_size": 4304,
"num_hidden_layers": 27,
"num_attention_heads": 16,
"image_size": 980,
"hidden_size": 1152,
"patch_size": 16,
"model_type": "siglip_vision_model",
}
def __init__(
self,
audio_config: Optional[AudioConfig] = None,
vision_config: Optional[SiglipVisionConfig] = None,
**kwargs,
):
super().__init__(**kwargs)
if audio_config is None:
self.audio_config = AudioConfig(**self._default_audio_config)
elif isinstance(audio_config, dict):
self.audio_config = AudioConfig(**audio_config)
elif isinstance(audio_config, AudioConfig):
self.audio_config = audio_config
if vision_config is None:
self.vision_config = SiglipVisionConfig(**self._default_vision_config)
elif isinstance(vision_config, dict):
self.vision_config = SiglipVisionConfig(**vision_config)
elif isinstance(vision_config, SiglipVisionConfig):
self.vision_config = vision_config