88 lines
2.5 KiB
Python
88 lines
2.5 KiB
Python
"""MegrezO model configuration"""
|
|
|
|
from typing import Optional
|
|
|
|
from transformers.configuration_utils import PretrainedConfig
|
|
from transformers.models.llama.configuration_llama import LlamaConfig
|
|
from transformers.utils import logging
|
|
|
|
from .modeling_navit_siglip import SiglipVisionConfig
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
class AudioConfig(PretrainedConfig):
|
|
model_type = "megrezo"
|
|
|
|
def __init__(
|
|
self,
|
|
n_mels: int = 128,
|
|
n_ctx: int = 1500,
|
|
n_state: int = 1280,
|
|
n_head: int = 20,
|
|
n_layer: int = 32,
|
|
output_dim: int = 2560,
|
|
avg_pool: bool = True,
|
|
add_audio_bos_eos_token: bool = True,
|
|
**kwargs,
|
|
):
|
|
super().__init__(**kwargs)
|
|
|
|
self.n_mels = n_mels
|
|
self.n_ctx = n_ctx
|
|
self.n_state = n_state
|
|
self.n_head = n_head
|
|
self.n_layer = n_layer
|
|
self.output_dim = output_dim
|
|
self.avg_pool = avg_pool
|
|
self.add_audio_bos_eos_token = add_audio_bos_eos_token
|
|
|
|
|
|
class MegrezOConfig(LlamaConfig):
|
|
model_type = "megrezo"
|
|
keys_to_ignore_at_inference = ["past_key_values"]
|
|
is_composition = True
|
|
|
|
_default_audio_config = {
|
|
"n_mels": 128,
|
|
"n_ctx": 1500,
|
|
"n_state": 1280,
|
|
"n_head": 20,
|
|
"n_layer": 32,
|
|
"output_dim": 2560,
|
|
"avg_pool": True,
|
|
"add_audio_bos_eos_token": True,
|
|
}
|
|
|
|
_default_vision_config = {
|
|
"intermediate_size": 4304,
|
|
"num_hidden_layers": 27,
|
|
"num_attention_heads": 16,
|
|
"image_size": 980,
|
|
"hidden_size": 1152,
|
|
"patch_size": 16,
|
|
"model_type": "siglip_vision_model",
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
audio_config: Optional[AudioConfig] = None,
|
|
vision_config: Optional[SiglipVisionConfig] = None,
|
|
**kwargs,
|
|
):
|
|
super().__init__(**kwargs)
|
|
|
|
if audio_config is None:
|
|
self.audio_config = AudioConfig(**self._default_audio_config)
|
|
elif isinstance(audio_config, dict):
|
|
self.audio_config = AudioConfig(**audio_config)
|
|
elif isinstance(audio_config, AudioConfig):
|
|
self.audio_config = audio_config
|
|
|
|
if vision_config is None:
|
|
self.vision_config = SiglipVisionConfig(**self._default_vision_config)
|
|
elif isinstance(vision_config, dict):
|
|
self.vision_config = SiglipVisionConfig(**vision_config)
|
|
elif isinstance(vision_config, SiglipVisionConfig):
|
|
self.vision_config = vision_config
|