1012 lines
44 KiB
Python
1012 lines
44 KiB
Python
# Copyright 2023 Baichuan Inc. All Rights Reserved.
|
||
#
|
||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||
#
|
||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||
# and OPT implementations in this library. It has been modified from its
|
||
# original forms to accommodate minor architectural differences compared
|
||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
""" PyTorch omni model."""
|
||
import os
|
||
import time
|
||
import json
|
||
import math
|
||
import numpy as np
|
||
from typing import List, Optional, Tuple, Union, Any
|
||
from threading import Thread
|
||
from easydict import EasyDict
|
||
|
||
import torch
|
||
import torch.distributed
|
||
import torch.utils.checkpoint
|
||
from torch import nn
|
||
from torch.nn import CrossEntropyLoss
|
||
from torch.nn import functional as F
|
||
import torch.distributed as dist
|
||
from transformers import PreTrainedModel
|
||
from transformers.activations import ACT2FN
|
||
from dataclasses import dataclass
|
||
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
|
||
from transformers.generation.utils import GenerationConfig
|
||
from transformers.utils import logging
|
||
# import for dynamic import not used in this file
|
||
from .vector_quantize import VectorQuantize, EuclideanCodebook
|
||
from .matcha_components import (
|
||
SinusoidalPosEmb,
|
||
Block1D,
|
||
ResnetBlock1D,
|
||
Downsample1D,
|
||
TimestepEmbedding,
|
||
Upsample1D,
|
||
)
|
||
from .matcha_transformer import BasicTransformerBlock
|
||
from .flow_matching import ConditionalDecoder, ConditionalCFM
|
||
|
||
from .configuration_omni import OmniConfig
|
||
from .audio_modeling_omni import (RMSNorm,
|
||
OmniAudioEncoder,
|
||
OmniAudioDecoder,
|
||
OmniAudioVQBridgeTokenizer,
|
||
OmniAudioFlowMatchingDecoder)
|
||
from .visual_modeling_omni import OmniVisualEncoder, OmniVisualBridge
|
||
from .processor_omni import OmniMMProcessor
|
||
|
||
# support model path contain point(.)
|
||
try:
|
||
# step1: copy relative imports to transformers_modules
|
||
from .generation_utils import build_chat_input, TextIterStreamer
|
||
from .sequence_parallel_utils import (
|
||
create_attention_layer,
|
||
get_sequence_parallel_size,
|
||
get_sequence_parallel_chunk,
|
||
)
|
||
except ModuleNotFoundError:
|
||
# step2: direct import from transformers_modules
|
||
try: # bypass check_imports failure
|
||
import sys
|
||
sys.path.append(os.path.dirname(__file__))
|
||
from generation_utils import build_chat_input, TextIterStreamer
|
||
from sequence_parallel_utils import (
|
||
create_attention_layer,
|
||
get_sequence_parallel_size,
|
||
get_sequence_parallel_chunk,
|
||
)
|
||
except Exception:
|
||
raise
|
||
|
||
logger = logging.get_logger(__name__)
|
||
|
||
def get_slopes(n):
|
||
def get_slopes_power_of_2(n):
|
||
start = (2 ** (-2 ** -(math.log2(n) - 3)))
|
||
ratio = start
|
||
return [start * ratio ** i for i in range(n)]
|
||
|
||
if math.log2(n).is_integer():
|
||
return get_slopes_power_of_2(
|
||
n) # In the paper, we only train models that have 2^a heads for some a. This function has
|
||
else: # some good properties that only occur when the input is a power of 2. To maintain that even
|
||
closest_power_of_2 = 2 ** math.floor(
|
||
math.log2(n)) # when the number of heads is not a power of 2, we use this workaround.
|
||
return get_slopes_power_of_2(closest_power_of_2) + get_slopes(2 * closest_power_of_2)[0::2][
|
||
:n - closest_power_of_2]
|
||
|
||
|
||
class RotaryEmbedding(torch.nn.Module):
|
||
def __init__(self, dim, max_position_embeddings=2048, base=5e6, device=None):
|
||
super().__init__()
|
||
# 修复RePE初始化精度问题 https://zhuanlan.zhihu.com/p/678963442
|
||
# DeepSpeed 会 Hack torch.arange 强制在 GPU 上运行,这里使用原生的 torch.arange
|
||
try:
|
||
import deepspeed
|
||
self.arange = deepspeed.runtime.zero.partition_parameters._orig_torch_arange
|
||
except:
|
||
self.arange = torch.arange
|
||
|
||
self.inv_freq = 1.0 / (base ** (self.arange(0, dim, 2).float().to(device) / dim))
|
||
self.max_seq_len_cached = max_position_embeddings
|
||
t = self.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
|
||
freqs = torch.outer(t, self.inv_freq)
|
||
emb = torch.cat((freqs, freqs), dim=-1)
|
||
self.cos_cached = emb.cos()[None, None, :, :].to(torch.float32)
|
||
self.sin_cached = emb.sin()[None, None, :, :].to(torch.float32)
|
||
|
||
def forward(self, x, seq_len=None):
|
||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
|
||
if seq_len > self.max_seq_len_cached:
|
||
self.max_seq_len_cached = seq_len
|
||
t = self.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
|
||
freqs = torch.outer(t, self.inv_freq)
|
||
emb = torch.cat((freqs, freqs), dim=-1)
|
||
self.cos_cached = emb.cos()[None, None, :, :].to(torch.float32).to(x.device)
|
||
self.sin_cached = emb.sin()[None, None, :, :].to(torch.float32).to(x.device)
|
||
return (
|
||
self.cos_cached[:, :, :seq_len, ...].to(torch.float32).to(x.device),
|
||
self.sin_cached[:, :, :seq_len, ...].to(torch.float32).to(x.device),
|
||
)
|
||
|
||
|
||
def rotate_half(x):
|
||
"""Rotates half the hidden dims of the input."""
|
||
x1 = x[..., : x.shape[-1] // 2]
|
||
x2 = x[..., x.shape[-1] // 2:]
|
||
return torch.cat((-x2, x1), dim=-1)
|
||
|
||
|
||
def apply_rotary_pos_emb(q, k, cos_, sin_, position_ids):
|
||
cos = cos_.squeeze(1).squeeze(0) # [seq_len, dim]
|
||
sin = sin_.squeeze(1).squeeze(0) # [seq_len, dim]
|
||
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
||
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
||
q_embed = (q.float() * cos) + (rotate_half(q.float()) * sin)
|
||
k_embed = (k.float() * cos) + (rotate_half(k.float()) * sin)
|
||
return q_embed.to(q.dtype), k_embed.to(k.dtype)
|
||
|
||
|
||
class MLP(nn.Module):
|
||
def __init__(
|
||
self,
|
||
hidden_size: int,
|
||
intermediate_size: int,
|
||
hidden_act: str,
|
||
):
|
||
super().__init__()
|
||
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
||
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
|
||
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
||
self.act_fn = ACT2FN[hidden_act]
|
||
|
||
def forward(self, x):
|
||
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||
|
||
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||
"""
|
||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||
"""
|
||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||
if n_rep == 1:
|
||
return hidden_states
|
||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||
|
||
|
||
class Attention(nn.Module):
|
||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||
def __init__(self, config: OmniConfig, is_sparse=False):
|
||
super().__init__()
|
||
self.config = config
|
||
self.position_embedding_type = config.position_embedding_type.lower()
|
||
self.num_kv_heads = config.num_key_value_heads
|
||
self.head_dim = config.head_dim
|
||
self.hidden_size = config.num_attention_heads * self.head_dim
|
||
self.hidden_kv_size = self.num_kv_heads * self.head_dim
|
||
|
||
if is_sparse:
|
||
self.num_heads = config.sparse_attention_heads
|
||
assert self.num_kv_heads == config.num_attention_heads
|
||
self.W_pack = nn.Linear(self.hidden_size, 3 * self.num_heads * self.head_dim, bias=config.attention_qkv_bias)
|
||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
||
else:
|
||
self.num_heads = config.num_attention_heads
|
||
if self.config.attention_qkv_pack:
|
||
self.W_pack = nn.Linear(config.hidden_size, self.hidden_size + self.hidden_kv_size * 2, bias=config.attention_qkv_bias)
|
||
else:
|
||
self.q_proj = nn.Linear(config.hidden_size, self.hidden_size, bias=config.attention_qkv_bias)
|
||
self.k_proj = nn.Linear(config.hidden_size, self.hidden_kv_size, bias=config.attention_qkv_bias)
|
||
self.v_proj = nn.Linear(config.hidden_size, self.hidden_kv_size, bias=config.attention_qkv_bias)
|
||
|
||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
|
||
|
||
if self.position_embedding_type == 'rope':
|
||
self.rotary_emb = RotaryEmbedding(
|
||
dim=self.head_dim,
|
||
max_position_embeddings=config.max_position_embeddings,
|
||
base=config.get_rotary_base()
|
||
)
|
||
elif self.position_embedding_type == 'alibi':
|
||
self.alibi_slopes = get_slopes(self.num_heads)
|
||
self.attention = create_attention_layer(self.hidden_size, self.num_heads, self.head_dim)
|
||
|
||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||
|
||
def _repeat_kv(self, hidden_states: torch.Tensor, num_heads: int) -> torch.Tensor:
|
||
assert hidden_states.size(1) <= num_heads and num_heads % hidden_states.size(1) == 0
|
||
return repeat_kv(hidden_states, num_heads // hidden_states.size(1))
|
||
|
||
def forward(
|
||
self,
|
||
hidden_states: torch.Tensor,
|
||
attention_mask: Optional[torch.Tensor] = None,
|
||
position_ids: Optional[torch.LongTensor] = None,
|
||
seqlens: Optional[torch.IntTensor] = None,
|
||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||
output_attentions: bool = False,
|
||
use_cache: bool = False,
|
||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||
bsz, q_len = hidden_states.shape[:2]
|
||
|
||
if self.config.attention_qkv_pack:
|
||
proj = self.W_pack(hidden_states)
|
||
query_states, key_states, value_states = proj.split([self.hidden_size, self.hidden_kv_size, self.hidden_kv_size], dim=-1)
|
||
else:
|
||
query_states = self.q_proj(hidden_states)
|
||
key_states = self.k_proj(hidden_states)
|
||
value_states = self.v_proj(hidden_states)
|
||
|
||
# (B, S, hidden_size) -> (B, num_heads, S, head_size)
|
||
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||
# (B, S, hidden_size) -> (B, num_kv_heads, S, head_size)
|
||
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||
|
||
kv_seq_len = key_states.shape[-2]
|
||
if past_key_value is not None:
|
||
kv_seq_len += past_key_value[0].shape[-2]
|
||
if self.position_embedding_type == 'rope':
|
||
max_position = position_ids.max().item()+1 if position_ids is not None else kv_seq_len * get_sequence_parallel_size()
|
||
cos, sin = self.rotary_emb(value_states, seq_len=max_position)
|
||
query_states, key_states = apply_rotary_pos_emb(
|
||
query_states, key_states, cos, sin,
|
||
get_sequence_parallel_chunk(position_ids)
|
||
)
|
||
|
||
if past_key_value is not None:
|
||
# reuse k, v, self_attention
|
||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||
past_key_value = (key_states, value_states) if use_cache else None
|
||
|
||
# repeat k/v heads if n_kv_heads < n_heads
|
||
key_states = self._repeat_kv(key_states, query_states.size(1))
|
||
value_states = self._repeat_kv(value_states, query_states.size(1))
|
||
|
||
if seqlens is not None:
|
||
seqlens = seqlens.to(dtype=torch.int32)
|
||
max_seqlen = (seqlens[1:] - seqlens[:-1]).max().item()
|
||
if self.position_embedding_type == 'alibi':
|
||
alibi_slopes = torch.tensor(self.alibi_slopes, dtype=torch.float32).to(query_states.device)
|
||
else:
|
||
alibi_slopes = None
|
||
attn_output = self.attention(
|
||
query_states, key_states, value_states, seqlens, seqlens,
|
||
max_seqlen, max_seqlen, causal=True, alibi_slopes=alibi_slopes, use_flash=True)
|
||
else:
|
||
attn_output = self.attention(
|
||
query_states, key_states, value_states, attn_mask=attention_mask, use_flash=False)
|
||
|
||
attn_output = attn_output.reshape(bsz, q_len, -1)
|
||
attn_output = self.o_proj(attn_output)
|
||
|
||
return attn_output, None, past_key_value
|
||
|
||
|
||
class DecoderLayer(nn.Module):
|
||
def __init__(self, config: OmniConfig, is_sparse=False):
|
||
super().__init__()
|
||
self.hidden_size = config.hidden_size
|
||
self.self_attn = Attention(config=config, is_sparse=is_sparse)
|
||
self.mlp = MLP(
|
||
hidden_size=self.hidden_size,
|
||
intermediate_size=config.intermediate_size,
|
||
hidden_act=config.hidden_act,
|
||
)
|
||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||
|
||
def forward(
|
||
self,
|
||
hidden_states: torch.Tensor,
|
||
attention_mask: Optional[torch.Tensor] = None,
|
||
position_ids: Optional[torch.LongTensor] = None,
|
||
seqlens: Optional[torch.IntTensor] = None,
|
||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||
output_attentions: Optional[bool] = False,
|
||
use_cache: Optional[bool] = False,
|
||
group_index=None,
|
||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||
|
||
residual = hidden_states
|
||
|
||
hidden_states = self.input_layernorm(hidden_states)
|
||
|
||
# Self Attention
|
||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||
hidden_states=hidden_states,
|
||
attention_mask=attention_mask,
|
||
position_ids=position_ids,
|
||
seqlens=seqlens,
|
||
past_key_value=past_key_value,
|
||
output_attentions=output_attentions,
|
||
use_cache=use_cache,
|
||
)
|
||
hidden_states = residual + hidden_states
|
||
|
||
# Fully Connected
|
||
residual = hidden_states
|
||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||
hidden_states = self.mlp(hidden_states)
|
||
hidden_states = residual + hidden_states
|
||
|
||
outputs = (hidden_states,)
|
||
|
||
if output_attentions:
|
||
outputs += (self_attn_weights,)
|
||
|
||
if use_cache:
|
||
outputs += (present_key_value,)
|
||
|
||
return outputs
|
||
|
||
|
||
class OmniPreTrainedModel(PreTrainedModel):
|
||
config_class = OmniConfig
|
||
base_model_prefix = "model"
|
||
supports_gradient_checkpointing = True
|
||
_no_split_modules = ["DecoderLayer"]
|
||
_keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
|
||
|
||
def _init_weights(self, module):
|
||
std = self.config.initializer_range
|
||
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv1d) or isinstance(module, nn.ConvTranspose1d):
|
||
module.weight.data.normal_(mean=0.0, std=std)
|
||
if module.bias is not None:
|
||
module.bias.data.zero_()
|
||
elif isinstance(module, nn.Embedding):
|
||
module.weight.data.normal_(mean=0.0, std=std)
|
||
if module.padding_idx is not None:
|
||
module.weight.data[module.padding_idx].zero_()
|
||
elif isinstance(module, nn.LayerNorm) or isinstance(module, nn.GroupNorm):
|
||
module.weight.data.fill_(1.0)
|
||
module.bias.data.zero_()
|
||
elif isinstance(module, RMSNorm):
|
||
module.weight.data.fill_(1.0)
|
||
|
||
def _set_gradient_checkpointing(self, module, value=False):
|
||
if isinstance(module, OmniModel):
|
||
module.gradient_checkpointing = value
|
||
|
||
@dataclass
|
||
class OmniModelOutputWithPast(BaseModelOutputWithPast):
|
||
audio_encoder_ret: Optional[Any] = None
|
||
audio_decoder_ret: Optional[Any] = None
|
||
|
||
class OmniModel(OmniPreTrainedModel):
|
||
def __init__(self, config: OmniConfig):
|
||
super().__init__(config)
|
||
self.padding_idx = config.pad_token_id
|
||
self.vocab_size = config.vocab_size
|
||
|
||
if config.visual_config.enable:
|
||
self.visual_model = OmniVisualEncoder(config.visual_config)
|
||
self.visual_bridge_model = OmniVisualBridge(config.visual_config)
|
||
if config.video_config.enable and not config.visual_config.enable: # in case 没有visual_config而只有video_config
|
||
self.visual_model = OmniVisualEncoder(config.video_config)
|
||
self.visual_bridge_model = OmniVisualBridge(config.video_config)
|
||
|
||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||
self.layers = nn.ModuleList([
|
||
DecoderLayer(config, is_sparse=layer_idx in config.sparse_attention_layers)
|
||
for layer_idx in range(config.num_hidden_layers)
|
||
])
|
||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||
|
||
self.audio_embed_layers = nn.ModuleList([
|
||
nn.Embedding(codedim + 1, config.hidden_size)
|
||
for i, codedim in enumerate(config.audio_config.vq_config.codebook_sizes)
|
||
])
|
||
|
||
self.gradient_checkpointing = True
|
||
# Initialize weights and apply final processing
|
||
self.post_init()
|
||
|
||
def get_input_embeddings(self):
|
||
return self.embed_tokens
|
||
|
||
def set_input_embeddings(self, value):
|
||
self.embed_tokens = value
|
||
|
||
@torch.no_grad()
|
||
def get_multimodal_mask(self, input_ids, pad_token_id, special_token_list):
|
||
'''
|
||
获取任意模态的特殊mask,包含以下
|
||
1. pad mask 表示文本中图像/语音/视频模态提前留出的token位置
|
||
2. special token mask 特殊token 例如对理解模型<start> <end> 不需要next token prediction
|
||
3. embedding mask / lm_head mask 标记出特殊token在embedding中的mask
|
||
'''
|
||
pad_mask = torch.eq(input_ids, pad_token_id)
|
||
sp_mask = torch.zeros_like(input_ids, dtype=torch.bool)
|
||
lm_head_mask = torch.zeros([self.config.vocab_size, 1], dtype=torch.bool)
|
||
for sp_id in special_token_list:
|
||
sp_mask = torch.logical_or(sp_mask, torch.eq(input_ids, sp_id))
|
||
lm_head_mask[sp_id, 0] = True
|
||
return pad_mask, sp_mask, lm_head_mask
|
||
|
||
def get_multimodal_embed(
|
||
self,
|
||
input_ids,
|
||
text_embedding, # 1. self.embed_tokens(input_ids) 2. 其他模态结果
|
||
multimodal_embed,
|
||
pad_token_id,
|
||
fake_input,
|
||
group_index=None, # 某种模态的编号
|
||
):
|
||
pad_mask, sp_mask, _ = self.get_multimodal_mask(input_ids, pad_token_id, self.config.multimodal_special_token_list)
|
||
if not self.training: # 推理支持auto map 把多模态模块输出和input_ids 统一到一个device
|
||
multimodal_embed = multimodal_embed.to(input_ids.device)
|
||
if not fake_input: # 检查多模态token 和 pad mask数量一致 (不正确的截断会导致该问题)
|
||
assert pad_mask.sum() == multimodal_embed.shape[0]
|
||
else:
|
||
assert pad_mask.sum() <= 0
|
||
|
||
# 合并 当前模态embeddings 和text embeddings
|
||
input_ids = torch.where(pad_mask, torch.cumsum(pad_mask.view(-1).to(input_ids), dim=0).view(input_ids.shape)-1, input_ids)
|
||
text_embedding = (1 - pad_mask.to(text_embedding)).unsqueeze(-1) * text_embedding # pad token位置填0
|
||
multimodal_embedding = torch.embedding(multimodal_embed, input_ids * pad_mask) # 非 pad token 位置填idx=0位置结果
|
||
multimodal_embedding = pad_mask.to(multimodal_embedding).unsqueeze(-1) * multimodal_embedding # 非pad token 位置填0
|
||
final_embedding = multimodal_embedding.to(text_embedding) + text_embedding
|
||
|
||
if group_index is None:
|
||
group_index = pad_mask.to(torch.int32)
|
||
else:
|
||
current_index = torch.max(group_index) + 1
|
||
group_index += pad_mask.to(torch.int32) * current_index # 假设模态无重叠
|
||
|
||
return final_embedding, group_index
|
||
|
||
def get_visual_embed(
|
||
self,
|
||
input_ids,
|
||
text_embedding, # 1. self.embed_tokens(input_ids) 2. 其他模态结果
|
||
images = None,
|
||
patch_nums = None,
|
||
images_grid = None,
|
||
videos = None,
|
||
videos_patch_nums = None,
|
||
videos_grid = None,
|
||
group_index = None, # 某种模态的编号
|
||
):
|
||
if images is None or len(images) <= 0:
|
||
images, images_grid, patch_nums = self.visual_model.fake_input(input_ids.device)
|
||
image_fake_input = True
|
||
else:
|
||
image_fake_input = False
|
||
|
||
if videos is None or len(videos) <= 0 :
|
||
videos, videos_grid, videos_patch_nums = self.visual_model.fake_input(input_ids.device)
|
||
video_fake_input = True
|
||
else:
|
||
video_fake_input = False
|
||
|
||
visual_input = images + videos
|
||
visual_grid = images_grid + videos_grid
|
||
|
||
visual_input = torch.cat(visual_input, dim=0)
|
||
visual_grid = torch.tensor(np.array(visual_grid))
|
||
|
||
visual_embed = self.visual_model(visual_input, grid_thw=visual_grid)
|
||
visual_embed = self.visual_bridge_model(visual_embed)
|
||
|
||
assert sum(patch_nums) + sum(videos_patch_nums) == visual_embed.shape[0]
|
||
images_embed = visual_embed[:sum(patch_nums)]
|
||
videos_embed = visual_embed[sum(patch_nums):]
|
||
|
||
final_embedding, group_index = self.get_multimodal_embed(input_ids, text_embedding, images_embed, self.config.visual_config.image_pad_token_id, image_fake_input, group_index=group_index)
|
||
final_embedding, group_index = self.get_multimodal_embed(input_ids, final_embedding, videos_embed, self.config.video_config.video_place_token_id, video_fake_input, group_index=group_index)
|
||
return final_embedding, group_index
|
||
|
||
|
||
@torch.no_grad()
|
||
def audio_fake_input(self, device):
|
||
return torch.zeros(5, len(self.config.audio_config.vq_config.codebook_sizes), dtype=torch.int32, device=device)
|
||
|
||
def forward(
|
||
self,
|
||
input_ids: torch.LongTensor = None,
|
||
attention_mask: Optional[torch.Tensor] = None,
|
||
position_ids: Optional[torch.LongTensor] = None,
|
||
seqlens: Optional[torch.IntTensor] = None,
|
||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||
audios_tokens: Optional[List|torch.Tensor] = None, # 音频token bs*seqlen*vq_num
|
||
images: Optional[List|torch.Tensor] = None,
|
||
patch_nums: Optional[torch.Tensor] = None,
|
||
images_grid: Optional[List|torch.Tensor] = None,
|
||
videos: Optional[List|torch.Tensor] = None,
|
||
videos_patch_nums: Optional[torch.Tensor] = None,
|
||
videos_grid: Optional[List|torch.Tensor] = None,
|
||
use_cache: Optional[bool] = None,
|
||
output_attentions: Optional[bool] = None,
|
||
output_hidden_states: Optional[bool] = None,
|
||
return_dict: Optional[bool] = None,
|
||
) -> Union[Tuple, OmniModelOutputWithPast]:
|
||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||
output_hidden_states = (
|
||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||
)
|
||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||
return_dict = True if (return_dict is not None or self.training) else self.config.use_return_dict
|
||
|
||
# retrieve input_ids and inputs_embeds
|
||
if input_ids is not None and inputs_embeds is not None:
|
||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||
elif input_ids is not None:
|
||
batch_size, seq_length = input_ids.shape
|
||
elif inputs_embeds is not None:
|
||
batch_size, seq_length, _ = inputs_embeds.shape
|
||
else:
|
||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||
|
||
seq_length_with_past = seq_length
|
||
past_key_values_length = 0
|
||
|
||
if past_key_values is not None:
|
||
past_key_values_length = past_key_values[0][0].shape[2]
|
||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||
|
||
if position_ids is None:
|
||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||
position_ids = torch.arange(
|
||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||
)
|
||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||
else:
|
||
position_ids = position_ids.view(-1, seq_length).long()
|
||
|
||
group_index, audio_decoder_ret = None, None
|
||
if inputs_embeds is None:
|
||
sp_input_ids = get_sequence_parallel_chunk(input_ids)
|
||
inputs_embeds = self.embed_tokens(sp_input_ids)
|
||
if audios_tokens is None or len(audios_tokens) <= 0 :
|
||
audios_tokens = torch.zeros(5, len(self.config.audio_config.vq_config.codebook_sizes), dtype=torch.int32, device=input_ids.device) # a fake input
|
||
fake_input = True
|
||
else:
|
||
fake_input = False
|
||
for i, audio_emb_layer in enumerate(self.audio_embed_layers):
|
||
if i==0:
|
||
audio_embs = audio_emb_layer(audios_tokens[..., i])
|
||
else:
|
||
audio_embs += audio_emb_layer(audios_tokens[..., i])
|
||
inputs_embeds, group_index = self.get_multimodal_embed(sp_input_ids, inputs_embeds, audio_embs, self.config.audio_config.audio_pad_token_id, fake_input, group_index=group_index)
|
||
|
||
if self.config.visual_config.enable or self.config.video_config.enable:
|
||
inputs_embeds, group_index = self.get_visual_embed(sp_input_ids, inputs_embeds, images, patch_nums, images_grid, videos, videos_patch_nums, videos_grid, group_index=group_index) # 注意更新group index
|
||
|
||
if seqlens is not None and seqlens.ndim == 2:
|
||
cu_seqlens = []
|
||
offset, seqlen = 0, seqlens.size(1)
|
||
for lens in seqlens:
|
||
cu_seqlens.append(offset)
|
||
cu_seqlens.extend((lens[(lens > 0) & (lens < seqlen)] + offset).tolist())
|
||
offset += seqlen
|
||
cu_seqlens.append(offset)
|
||
seqlens = torch.tensor(cu_seqlens, dtype=seqlens.dtype, device=seqlens.device)
|
||
elif seqlens is None and self.training:
|
||
seqlens = torch.arange(
|
||
end=input_ids.size(0) + 1,
|
||
dtype=torch.int32,
|
||
device=input_ids.device
|
||
) * input_ids.size(1)
|
||
if seqlens is not None:
|
||
attention_mask = None # unset attention_mask to save memory
|
||
|
||
if seqlens is None and attention_mask is None:
|
||
attention_mask = torch.ones(
|
||
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
||
)
|
||
if attention_mask is not None:
|
||
attention_mask = _prepare_4d_causal_attention_mask(
|
||
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
||
)
|
||
|
||
# embed positions
|
||
hidden_states = inputs_embeds
|
||
|
||
if self.gradient_checkpointing and self.training:
|
||
if use_cache:
|
||
logger.warning_once(
|
||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||
)
|
||
use_cache = False
|
||
|
||
# decoder layers
|
||
all_hidden_states = () if output_hidden_states else None
|
||
all_self_attns = () if output_attentions else None
|
||
next_decoder_cache = () if use_cache else None
|
||
|
||
for idx, decoder_layer in enumerate(self.layers):
|
||
if output_hidden_states:
|
||
all_hidden_states += (hidden_states,)
|
||
|
||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||
|
||
if self.gradient_checkpointing and self.training:
|
||
|
||
def create_custom_forward(module):
|
||
def custom_forward(*inputs):
|
||
# None for past_key_value
|
||
return module(*inputs, output_attentions, False, group_index)
|
||
|
||
return custom_forward
|
||
|
||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||
create_custom_forward(decoder_layer),
|
||
hidden_states,
|
||
attention_mask,
|
||
position_ids,
|
||
seqlens,
|
||
None,
|
||
)
|
||
else:
|
||
layer_outputs = decoder_layer(
|
||
hidden_states,
|
||
attention_mask=attention_mask,
|
||
position_ids=position_ids,
|
||
seqlens=seqlens,
|
||
past_key_value=past_key_value,
|
||
output_attentions=output_attentions,
|
||
use_cache=use_cache,
|
||
group_index=group_index,
|
||
)
|
||
|
||
hidden_states = layer_outputs[0]
|
||
|
||
if use_cache:
|
||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||
|
||
if output_attentions:
|
||
all_self_attns += (layer_outputs[1],)
|
||
|
||
hidden_states = self.norm(hidden_states)
|
||
|
||
# add hidden states from the last decoder layer
|
||
if output_hidden_states:
|
||
all_hidden_states += (hidden_states,)
|
||
|
||
next_cache = next_decoder_cache if use_cache else None
|
||
if not return_dict:
|
||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||
return BaseModelOutputWithPast(
|
||
last_hidden_state=hidden_states,
|
||
past_key_values=next_cache,
|
||
hidden_states=all_hidden_states,
|
||
attentions=all_self_attns,
|
||
)
|
||
|
||
|
||
class NormHead(nn.Module):
|
||
def __init__(self, hidden_size, vocab_size, bias=False):
|
||
super().__init__()
|
||
self.hidden_size = hidden_size
|
||
self.vocab_size = vocab_size
|
||
self.weight = nn.Parameter(torch.empty((self.vocab_size, self.hidden_size)))
|
||
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||
|
||
def forward(self, hidden_states, mask=None):
|
||
norm_weight = nn.functional.normalize(self.weight)
|
||
if mask is not None:
|
||
mask = mask.to(norm_weight)
|
||
norm_weight = norm_weight * mask + (1 - mask) * norm_weight.detach()
|
||
return nn.functional.linear(hidden_states, norm_weight)
|
||
|
||
|
||
def extra_repr(self) -> str:
|
||
return f'in_features={self.hidden_size}, out_features={self.vocab_size}'
|
||
|
||
@dataclass
|
||
class OmniMMCausalLMOutputWithPast(ModelOutput):
|
||
loss: Optional[torch.FloatTensor] = None
|
||
logits: Optional[torch.FloatTensor] = None
|
||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||
audios_emb_for_infer: Optional[torch.FloatTensor] = None # 用于audio head 推理的 embeddings
|
||
|
||
|
||
class CasualDepthTransformerLayer(nn.Module):
|
||
def __init__(self, config, depth):
|
||
super().__init__()
|
||
self.config = config
|
||
embed_size = config.hidden_size
|
||
assert embed_size % 128 == 0
|
||
num_heads = embed_size // 128
|
||
self.self_attention = nn.MultiheadAttention(embed_dim=embed_size, num_heads=num_heads,batch_first=True)
|
||
self.layernorm1 = RMSNorm(embed_size)
|
||
self.layernorm2 = RMSNorm(embed_size)
|
||
self.linear1 = nn.Linear(embed_size * depth, 2 * embed_size)
|
||
self.linear2 = nn.Linear(2 * embed_size * depth, embed_size)
|
||
|
||
def forward(self, x):
|
||
seq_len = x.size(1)
|
||
res = x
|
||
x = self.layernorm1(x)
|
||
src_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool().to(x.device)
|
||
_x, _ = self.self_attention(x, x, x, is_causal=True, attn_mask=src_mask)
|
||
res = _x + res # (bs, sl, d)
|
||
res = self.layernorm2(res)
|
||
x = torch.einsum('bld,tld->blt', res, torch.reshape(self.linear1.weight, (2 * self.config.hidden_size, -1, self.config.hidden_size)))
|
||
x = torch.nn.functional.gelu(x)
|
||
x = torch.einsum('blt,dlt->bld', x, torch.reshape(self.linear2.weight, (self.config.hidden_size, -1, 2 * self.config.hidden_size)))
|
||
return res + x
|
||
|
||
class OmniAudioHead(nn.Module):
|
||
def __init__(self, config):
|
||
super().__init__()
|
||
self.config = config
|
||
hidden_size = config.hidden_size
|
||
self.transformer_layers = nn.ModuleList([
|
||
CasualDepthTransformerLayer(config, len(config.audio_config.vq_config.codebook_sizes))
|
||
for _ in range(config.audio_config.audio_head_transformer_layers)
|
||
])
|
||
self.headnorm = RMSNorm(hidden_size)
|
||
self.heads = nn.ModuleList([
|
||
nn.Linear(hidden_size, vq_size+1)
|
||
for vq_size in config.audio_config.vq_config.codebook_sizes
|
||
])
|
||
self.gradient_checkpointing = True
|
||
|
||
def forward(self, x, audios_tokens, audio_emb_layers):
|
||
cumsum_audio_embed = torch.stack([
|
||
audio_emb_layers[i](audios_tokens[..., i])
|
||
for i, vq_size in enumerate(self.config.audio_config.vq_config.codebook_sizes[:-1])
|
||
], dim=1)
|
||
cumsum_audio_embed = torch.cumsum(cumsum_audio_embed, dim=1) # (bs, depth-1, d)
|
||
hidden_states = torch.concat([x.reshape(-1, 1, self.config.hidden_size), cumsum_audio_embed], dim=1) # (bs, depth, d)
|
||
assert hidden_states.size(1) == len(self.config.audio_config.vq_config.codebook_sizes)
|
||
for i, tlayer in enumerate(self.transformer_layers):
|
||
hidden_states = tlayer(hidden_states,)
|
||
hidden_states = self.headnorm(hidden_states)
|
||
logits = [head(hidden_states[:,i]) for i, head in enumerate(self.heads)]
|
||
return logits
|
||
|
||
|
||
class OmniForCausalLM(OmniPreTrainedModel):
|
||
def __init__(self, config):
|
||
super().__init__(config)
|
||
self.config = config
|
||
self.model = OmniModel(config)
|
||
self.audio_tokenizer = OmniAudioTokenizer(config)
|
||
self.audio_head = OmniAudioHead(config)
|
||
if config.use_norm_head:
|
||
self.lm_head = NormHead(config.hidden_size, config.vocab_size, bias=False)
|
||
else:
|
||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||
# Initialize weights and apply final processing
|
||
self.post_init()
|
||
|
||
@property
|
||
def main_device(self):
|
||
return self.lm_head.weight.device
|
||
|
||
def bind_processor(self, tokenizer, **kwargs):
|
||
self.processor = OmniMMProcessor(
|
||
tokenizer=tokenizer,
|
||
config=self.config,
|
||
**kwargs,
|
||
)
|
||
return self.processor
|
||
|
||
def get_input_embeddings(self):
|
||
return self.model.embed_tokens
|
||
|
||
def set_input_embeddings(self, value):
|
||
self.model.embed_tokens = value
|
||
|
||
def get_output_embeddings(self):
|
||
return self.lm_head
|
||
|
||
def set_output_embeddings(self, new_embeddings):
|
||
self.lm_head = new_embeddings
|
||
|
||
def set_decoder(self, decoder):
|
||
self.model = decoder
|
||
|
||
def get_decoder(self):
|
||
return self.model
|
||
|
||
def forward(
|
||
self,
|
||
input_ids: torch.LongTensor = None,
|
||
attention_mask: Optional[torch.Tensor] = None,
|
||
position_ids: Optional[torch.LongTensor] = None,
|
||
seqlens: Optional[torch.IntTensor] = None,
|
||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||
labels: Optional[torch.LongTensor] = None,
|
||
audios: Optional[List|torch.Tensor] = None,
|
||
audios_tokens: Optional[List|torch.Tensor] = None,
|
||
encoder_length: Optional[torch.Tensor] = None,
|
||
bridge_length: Optional[torch.Tensor] = None,
|
||
images: Optional[torch.Tensor] = None,
|
||
patch_nums: Optional[torch.Tensor] = None,
|
||
images_grid: Optional[torch.Tensor] = None,
|
||
videos: Optional[torch.Tensor] = None,
|
||
videos_patch_nums: Optional[torch.Tensor] = None,
|
||
videos_grid: Optional[torch.Tensor] = None,
|
||
use_cache: Optional[bool] = None,
|
||
output_attentions: Optional[bool] = None,
|
||
output_hidden_states: Optional[bool] = None,
|
||
return_dict: Optional[bool] = None,
|
||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||
|
||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||
output_hidden_states = (
|
||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||
)
|
||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||
|
||
if audios_tokens is not None:
|
||
assert isinstance(audios_tokens, torch.Tensor)
|
||
else:
|
||
if audios is None or len(audios) == 0:
|
||
audios_tokens = None
|
||
else:
|
||
audios_tokens = self.audio_tokenizer(audios,encoder_length,bridge_length)
|
||
|
||
outputs = self.model(
|
||
input_ids=input_ids,
|
||
attention_mask=attention_mask,
|
||
position_ids=position_ids,
|
||
seqlens=seqlens,
|
||
past_key_values=past_key_values,
|
||
inputs_embeds=inputs_embeds,
|
||
audios_tokens=audios_tokens,
|
||
images=images,
|
||
patch_nums=patch_nums,
|
||
images_grid=images_grid,
|
||
videos=videos,
|
||
videos_patch_nums=videos_patch_nums,
|
||
videos_grid=videos_grid,
|
||
use_cache=use_cache,
|
||
output_attentions=output_attentions,
|
||
output_hidden_states=output_hidden_states,
|
||
return_dict=return_dict,
|
||
)
|
||
hidden_states = outputs.last_hidden_state
|
||
audios_emb_for_infer = hidden_states[:,-1,:]
|
||
logits = self.lm_head(hidden_states)
|
||
|
||
return OmniMMCausalLMOutputWithPast(
|
||
logits=logits,
|
||
past_key_values=outputs.past_key_values,
|
||
hidden_states=outputs.hidden_states,
|
||
attentions=outputs.attentions,
|
||
audios_emb_for_infer=audios_emb_for_infer
|
||
)
|
||
|
||
def prepare_inputs_for_generation(
|
||
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
||
):
|
||
if past_key_values:
|
||
input_ids = input_ids[:, past_key_values[0][0].shape[-2]:]
|
||
|
||
position_ids = kwargs.get("position_ids", None)
|
||
if attention_mask is not None and position_ids is None:
|
||
# create position_ids on the fly for batch generation
|
||
position_ids = attention_mask.long().cumsum(-1)
|
||
# position_ids.masked_fill_(attention_mask == 0, 1)
|
||
if past_key_values:
|
||
position_ids = position_ids[:, past_key_values[0][0].shape[-2]:]
|
||
|
||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||
if inputs_embeds is not None and past_key_values is None:
|
||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||
elif past_key_values is not None:
|
||
model_inputs = {"input_ids": input_ids}
|
||
else:
|
||
model_inputs = {"input_ids": input_ids,
|
||
"audios": kwargs.get("audios", None), "encoder_length": kwargs.get("encoder_length", None), "bridge_length": kwargs.get("bridge_length", None),
|
||
"audios_tokens": kwargs.get("audios_tokens", None),
|
||
"images": kwargs.get("images", None),
|
||
"videos": kwargs.get("videos", None)
|
||
}
|
||
|
||
model_inputs.update(
|
||
{
|
||
"position_ids": position_ids,
|
||
"past_key_values": past_key_values,
|
||
"use_cache": kwargs.get("use_cache"),
|
||
"attention_mask": attention_mask,
|
||
"images_grid": kwargs.get("images_grid"),
|
||
"videos_grid": kwargs.get("videos_grid"),
|
||
"patch_nums": kwargs.get("patch_nums"),
|
||
"videos_patch_nums": kwargs.get("videos_patch_nums"),
|
||
}
|
||
)
|
||
return model_inputs
|
||
|
||
@staticmethod
|
||
def _reorder_cache(past_key_values, beam_idx):
|
||
reordered_past = ()
|
||
for layer_past in past_key_values:
|
||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||
return reordered_past
|
||
|
||
def chat(self, tokenizer, messages: List[dict], stream=False,
|
||
generation_config: Optional[GenerationConfig]=None):
|
||
generation_config = generation_config or self.generation_config
|
||
input_ids = build_chat_input(self, tokenizer, messages, generation_config.max_new_tokens)
|
||
if stream:
|
||
streamer = TextIterStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||
Thread(target=self.generate, kwargs=dict(
|
||
inputs=input_ids, streamer=streamer,
|
||
generation_config=generation_config,
|
||
)).start()
|
||
return streamer
|
||
else:
|
||
outputs = self.generate(input_ids, generation_config=generation_config)
|
||
response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
|
||
return response
|
||
|
||
|
||
class OmniAudioTokenizer(OmniPreTrainedModel):
|
||
"""
|
||
Construct an audio tokenizer and decoder.
|
||
"""
|
||
def __init__(self, config: OmniConfig):
|
||
super().__init__(config)
|
||
self.padding_idx = None
|
||
self.vocab_size = config.vocab_size
|
||
self.training = False
|
||
self.eval()
|
||
self.audio_model = OmniAudioEncoder(config.audio_config)
|
||
self.audio_bridge_model = OmniAudioVQBridgeTokenizer(config)
|
||
if config.vocoder_config.enable:
|
||
self.audio_decoder = OmniAudioDecoder(config)
|
||
if config.flow_matching_config.enable:
|
||
self.audio_flow_matching_decoder = OmniAudioFlowMatchingDecoder(config)
|
||
|
||
def encode(self, x, encoder_length: Optional[torch.Tensor] = None,
|
||
bridge_length: Optional[torch.Tensor] = None):
|
||
audio_emb = self.audio_model(x, encoder_length)
|
||
audios_tokens = self.audio_bridge_model(audio_emb, bridge_length)
|
||
return audios_tokens
|
||
|
||
def decode(self, audio_code_ids, bridge_length: Optional[torch.Tensor] = None):
|
||
assert self.config.vocoder_config.enable, "Vocoder is not enabled in config."
|
||
audio_emb = self.audio_bridge_model.decode(audio_code_ids)
|
||
audio_dec = self.audio_decoder(
|
||
audio_emb.to(next(self.audio_decoder.parameters())), bridge_length
|
||
)
|
||
if self.config.flow_matching_config.enable:
|
||
if self.config.flow_matching_config.use_hidden_states_before_dconv2:
|
||
hidden_states, hidden_states_length = (
|
||
self.audio_flow_matching_decoder.unpack_hidden_states(
|
||
audio_dec.hidden_states_before_dconv2,
|
||
audio_dec.output_length_before_dconv2,
|
||
)
|
||
)
|
||
audio_flow_matching_decoder_ret = self.audio_flow_matching_decoder(
|
||
hidden_states, hidden_states_length
|
||
)
|
||
|
||
else:
|
||
audio_flow_matching_decoder_ret = self.audio_flow_matching_decoder(
|
||
audio_dec.refined_mel, audio_dec.mel_length
|
||
)
|
||
return audio_flow_matching_decoder_ret
|
||
else:
|
||
return audio_dec
|
||
|
||
@torch.no_grad()
|
||
def forward(self, audios, encoder_length: Optional[torch.Tensor] = None, bridge_length: Optional[torch.Tensor] = None):
|
||
self.eval()
|
||
audios_tokens = self.encode(audios, encoder_length, bridge_length)
|
||
return audios_tokens
|