first commit

This commit is contained in:
xxl 2025-01-26 17:53:40 +08:00
parent 0192f4532b
commit 3cb9f17068
22 changed files with 313344 additions and 2 deletions

1115
README.md

File diff suppressed because it is too large Load Diff

61
added_tokens.json Normal file
View File

@ -0,0 +1,61 @@
{
"</tool_call>": 151658,
"<B_APE>": 151671,
"<B_CODE>": 151670,
"<B_FUNC>": 151669,
"<B_SYS>": 151665,
"<B_USYS>": 151666,
"<C_A>": 151668,
"<C_Q>": 151667,
"<audio_delim_baichuan>": 151693,
"<audio_end_baichuan>": 151677,
"<audio_pad_baichuan>": 151678,
"<audio_start_baichuan>": 151676,
"<audiogen_end_baichuan>": 151701,
"<audiogen_start_baichuan>": 151700,
"<audiotext_end_baichuan>": 151698,
"<audiotext_pad_baichuan>": 151699,
"<audiotext_start_baichuan>": 151697,
"<baichuan_pad_token>": 151691,
"<box_delim_baichuan>": 151685,
"<box_end_baichuan>": 151684,
"<box_start_baichuan>": 151683,
"<calc_end>": 151674,
"<calc_start>": 151673,
"<function_calling>": 151672,
"<img_delim_baichuan>": 151688,
"<img_end_baichuan>": 151680,
"<img_newline_baichuan>": 151682,
"<img_pad_baichuan>": 151681,
"<img_start_baichuan>": 151679,
"<inner_think>": 151675,
"<polygon_end_baichuan>": 151690,
"<polygon_start_baichuan>": 151689,
"<ref_end_baichuan>": 151687,
"<ref_start_baichuan>": 151686,
"<reserved_113>": 151692,
"<tool_call>": 151657,
"<video_end_baichuan>": 151696,
"<video_palce_baichuan>": 151694,
"<video_start_baichuan>": 151695,
"<|box_end|>": 151649,
"<|box_start|>": 151648,
"<|endoftext|>": 151643,
"<|file_sep|>": 151664,
"<|fim_middle|>": 151660,
"<|fim_pad|>": 151662,
"<|fim_prefix|>": 151659,
"<|fim_suffix|>": 151661,
"<|im_end|>": 151645,
"<|im_start|>": 151644,
"<|image_pad|>": 151655,
"<|object_ref_end|>": 151647,
"<|object_ref_start|>": 151646,
"<|quad_end|>": 151651,
"<|quad_start|>": 151650,
"<|repo_name|>": 151663,
"<|video_pad|>": 151656,
"<|vision_end|>": 151653,
"<|vision_pad|>": 151654,
"<|vision_start|>": 151652
}

658
audio_modeling_omni.py Normal file
View File

@ -0,0 +1,658 @@
import torch, fire
from typing import Optional
import torch.distributed
from torch.nn import functional as F
from flash_attn import flash_attn_varlen_func
from torch import nn
import numpy as np
import deepspeed
from transformers.activations import ACT2FN
from dataclasses import dataclass
from transformers.modeling_outputs import ModelOutput
try:
from .vector_quantize import VectorQuantize
except:
from vector_quantize import VectorQuantize
from .flow_matching import (
ConditionalDecoder,
ConditionalCFM,
)
import math
import copy
def sinusoids(length, channels, max_timescale=10000):
"""Returns sinusoids for positional embedding"""
assert channels % 2 == 0
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
def get_sequence_mask(inputs, inputs_length):
if inputs.dim() == 3:
bsz, tgt_len, _ = inputs.size()
else:
bsz, tgt_len = inputs_length.shape[0], torch.max(inputs_length)
sequence_mask = torch.arange(0, tgt_len).to(inputs.device)
sequence_mask = torch.lt(sequence_mask, inputs_length.reshape(bsz, 1)).view(bsz, tgt_len, 1)
unpacking_index = torch.cumsum(sequence_mask.to(torch.int64).view(-1), dim=0) - 1 # 转成下标
return sequence_mask, unpacking_index
def unpack_hidden_states(hidden_states, lengths):
bsz = lengths.shape[0]
sequence_mask, unpacking_index = get_sequence_mask(hidden_states, lengths)
hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(
bsz, torch.max(lengths), hidden_states.shape[-1]
)
hidden_states = torch.where(
sequence_mask, hidden_states, 0
) # 3d (bsz, max_input_len, d)
return hidden_states
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states
class OmniWhisperAttention(nn.Module):
def __init__(self, embed_dim, num_heads, causal=False):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
self.causal = causal
def forward(self, hidden_states: torch.Tensor, seq_len: torch.Tensor):
bsz, _ = hidden_states.size()
query_states = self.q_proj(hidden_states).view(bsz, self.num_heads, self.head_dim)
key_states = self.k_proj(hidden_states).view(bsz, self.num_heads, self.head_dim)
value_states = self.v_proj(hidden_states).view(bsz, self.num_heads, self.head_dim)
cu_len = F.pad(torch.cumsum(seq_len, dim=0), (1, 0), "constant", 0).to(torch.int32)
max_seqlen = torch.max(seq_len).to(torch.int32).detach()
attn_output = flash_attn_varlen_func(query_states, key_states, value_states, cu_len, cu_len, max_seqlen,
max_seqlen, causal=self.causal) # (bsz * qlen, nheads, headdim)
attn_output = attn_output.reshape(bsz, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output
class OmniWhisperTransformerLayer(nn.Module):
def __init__(
self,
act,
d_model,
encoder_attention_heads,
encoder_ffn_dim,
causal,
ln_type="LayerNorm",
):
super().__init__()
self.embed_dim = d_model
self.self_attn = OmniWhisperAttention(
self.embed_dim, encoder_attention_heads, causal
)
if ln_type == "LayerNorm":
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
elif ln_type == "RMSNorm":
self.self_attn_layer_norm = RMSNorm(self.embed_dim)
else:
raise ValueError(f"Unknown ln_type: {ln_type}")
self.activation_fn = act
self.fc1 = nn.Linear(self.embed_dim, encoder_ffn_dim)
self.fc2 = nn.Linear(encoder_ffn_dim, self.embed_dim)
if ln_type == "LayerNorm":
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
elif ln_type == "RMSNorm":
self.final_layer_norm = RMSNorm(self.embed_dim)
else:
raise ValueError(f"Unknown ln_type: {ln_type}")
def forward(
self, hidden_states: torch.Tensor, seq_len: torch.Tensor
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn(hidden_states, seq_len)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = self.fc2(hidden_states)
hidden_states = residual + hidden_states
if (
hidden_states.dtype == torch.float16
or hidden_states.dtype == torch.bfloat16
) and (torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()):
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(
hidden_states, min=-clamp_value, max=clamp_value
)
return hidden_states
class OmniAudioEncoder(nn.Module):
def __init__(self, config):
super().__init__()
config._attn_implementation = 'flash_attention_2' #
self.config = config
self.max_source_positions = (config.max_audio_seconds * config.sampling_rate // config.hop_length) // config.stride_size
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.conv1 = nn.Conv1d(config.num_mel_bins, config.d_model, kernel_size=config.kernel_size, padding=1)
self.conv2 = nn.Conv1d(config.d_model, config.d_model, kernel_size=config.kernel_size,
stride=config.stride_size, padding=1)
self.register_buffer("positional_embedding", sinusoids(self.max_source_positions, config.d_model)) # 1500 * d
self.layers = nn.ModuleList([OmniWhisperTransformerLayer(
ACT2FN[config.activation_function],
config.d_model,
config.encoder_attention_heads,
config.encoder_ffn_dim,
False) for _ in range(config.encoder_layers)])
self.layer_norm = nn.LayerNorm(config.d_model)
@torch.no_grad()
def fake_input(self, device):
input_features = torch.rand([2, self.config.num_mel_bins, 10], dtype=torch.float32, device=device)
encoder_length = torch.ones([2], dtype=torch.int32, device=device) * 3
bridge_length = torch.ones([2], dtype=torch.int32, device=device)
return input_features, encoder_length, bridge_length
def forward(
self,
input_features,
output_length,
):
input_features = input_features.to(self.conv1.weight.dtype)
inputs_embeds = nn.functional.gelu(self.conv1(input_features)) # (bs, channels, frames)
inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) # (bs, channels, frames // 2)
inputs_embeds = inputs_embeds.permute(0, 2, 1) # (bs, frams, channels)
bsz, tgt_len, _ = inputs_embeds.size()
if tgt_len < self.positional_embedding.shape[0]:
current_positional_embedding = self.positional_embedding[:tgt_len]
else:
current_positional_embedding = self.positional_embedding
hidden_states = (inputs_embeds.to(torch.float32) + current_positional_embedding).to(inputs_embeds.dtype)
# packing hidden states
attention_mask, unpacking_index = get_sequence_mask(hidden_states, output_length)
hidden_states = torch.masked_select(hidden_states, attention_mask).view(torch.sum(output_length),
self.config.d_model)
for idx, encoder_layer in enumerate(self.layers):
hidden_states = encoder_layer(hidden_states, output_length)
hidden_states = self.layer_norm(hidden_states)
# unpacking
hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(bsz, tgt_len, self.config.d_model)
hidden_states = torch.where(attention_mask, hidden_states, 0)
return hidden_states
class CasualConvTranspose1d(nn.Module): # 反卷积
def __init__(self, in_channels, out_channels, kernel_size, stride):
super().__init__()
self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride)
self.norm = nn.GroupNorm(1, out_channels)
self.in_channels = in_channels
self.out_channels = out_channels
def forward(self, hidden_states, input_length, output_dim=None):
kernel_size = self.conv.kernel_size[0]
stride = self.conv.stride[0]
bsz = input_length.shape[0]
if output_dim is None:
output_dim = hidden_states.dim()
if hidden_states.dim() <= 2: # unpack sequence to 3d
sequence_mask, unpacking_index = get_sequence_mask(hidden_states, input_length)
hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(bsz, torch.max(input_length),
self.in_channels)
hidden_states = torch.where(sequence_mask, hidden_states, 0) # 3d (bsz, max_input_len, d)
hidden_states = hidden_states.transpose(2, 1) # (N, L, C) -> (N, C, L)
hidden_states = self.conv(hidden_states)
hidden_states = self.norm(hidden_states)
hidden_states = hidden_states.transpose(2, 1) # (N, C, L) -> (N, L, C)
casual_padding_right = max(0, kernel_size - stride)
hidden_states = hidden_states[:, :hidden_states.shape[1] - casual_padding_right,
:]
output_length = (input_length - 1) * stride + kernel_size - casual_padding_right
sequence_mask, _ = get_sequence_mask(hidden_states, output_length)
if output_dim <= 2:
hidden_states = torch.masked_select(hidden_states, sequence_mask).view(-1, self.out_channels)
else:
hidden_states = torch.where(sequence_mask, hidden_states, 0)
hidden_states = hidden_states[:, :torch.max(output_length), :] # 截断到最大有效长度
return hidden_states, output_length
class MelSpecRefineNet(nn.Module):
"""
# post net, coarse to refined mel-spectrogram frames
# ref1: Autoregressive Speech Synthesis without Vector Quantization
# ref2: CosyVoice length_regulator.py
# ref3: Neural Speech Synthesis with Transformer Network https://github.com/soobinseo/Transformer-TTS/blob/master/network.py
"""
def __init__(self, encoder_config, vocoder_config):
super().__init__()
self.encoder_config = encoder_config
self.vocoder_config = vocoder_config
layers = nn.ModuleList([])
in_channels = self.vocoder_config.num_mel_bins
for i, out_channels in enumerate(self.vocoder_config.channels[:-1]):
module = nn.Conv1d(in_channels, out_channels, 5, 1, 2) # cosyvoice kernel=3, stride=1, pad=1
in_channels = out_channels
norm = nn.GroupNorm(1, out_channels)
act = nn.Mish()
layers.extend([module, norm, act])
layers.append(nn.Conv1d(in_channels, self.vocoder_config.num_mel_bins, 1, 1)) # projector
self.layers = nn.Sequential(*layers)
def compute_output_length(self, input_length):
output_length = input_length.to(
torch.float32) * self.encoder_config.hop_length / self.encoder_config.sampling_rate
output_length = output_length * self.vocoder_config.sampling_rate / self.vocoder_config.hop_length
return output_length.to(torch.int64)
def forward(self, coarse_mel, input_length, output_length=None):
bsz, _, d = coarse_mel.shape
assert (d == self.vocoder_config.num_mel_bins)
if output_length is None or not self.training:
output_length = self.compute_output_length(input_length)
coarse_mel, default_dtype = coarse_mel[:, :torch.max(input_length), :], coarse_mel.dtype
coarse_mel = F.interpolate(coarse_mel.to(torch.float32).transpose(1, 2).contiguous(), size=output_length.max(),
mode='nearest').to(default_dtype)
refined_mel = self.layers(coarse_mel).transpose(1, 2).contiguous() # (bs, t, d)
coarse_mel = coarse_mel.transpose(1, 2) # (bs, max(output_length), d)
refined_mel += coarse_mel # residual conntection
sequence_mask, _ = get_sequence_mask(refined_mel, output_length)
coarse_mel = torch.where(sequence_mask, coarse_mel, 0)
refined_mel = torch.where(sequence_mask, refined_mel, 0)
return refined_mel, coarse_mel, output_length
@dataclass
class OmniAudioDecoderOutput(ModelOutput):
refined_mel: Optional[torch.FloatTensor] = None
coarse_mel: Optional[torch.FloatTensor] = None
mel_length: Optional[torch.Tensor] = None
hidden_states_before_dconv2: Optional[torch.FloatTensor] = None
output_length_before_dconv2: Optional[torch.Tensor] = None
class OmniAudioDecoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config.audio_config
self.vocoder_config = config.vocoder_config
self.max_source_positions = self.config.max_audio_seconds * self.config.sampling_rate // self.config.hop_length
self.dconv1 = CasualConvTranspose1d(
self.config.d_model,
self.config.d_model,
self.config.decoder_kernel_size,
self.config.avg_pooler,
)
self.register_buffer("positional_embedding", sinusoids(self.max_source_positions, self.config.d_model))
# causal transformer layers
self.layers = nn.ModuleList(
[OmniWhisperTransformerLayer(
ACT2FN[self.config.activation_function],
self.config.d_model,
self.config.decoder_attention_heads,
self.config.decoder_ffn_dim,
True # causal
) for _ in range(self.config.decoder_layers)
])
self.layer_norm = nn.LayerNorm(self.config.d_model)
self.dconv2 = CasualConvTranspose1d(
self.config.d_model,
self.vocoder_config.num_mel_bins,
self.config.decoder_kernel_size,
self.config.decoder_stride_size
)
self.post_net = MelSpecRefineNet(config.audio_config, config.vocoder_config)
self.gradient_checkpointing = True
@torch.no_grad()
def fake_input(self, device):
audio_embed = torch.rand([1, 10, self.config.d_model], dtype=torch.float32, device=device)
input_length = torch.ones([1], dtype=torch.int32, device=device) * 10
mel_labels_length = self.post_net.compute_output_length(input_length)
return audio_embed, input_length, None, mel_labels_length
def forward(self,
audio_embed,
input_length,
mel_labels=None,
mel_labels_length=None,
fake_input=False,
):
if fake_input:
audio_embed, input_length, mel_labels, mel_labels_length = self.fake_input(self.layer_norm.weight.device)
assert (audio_embed.shape[-1] == self.config.d_model)
audio_embed = audio_embed.to(self.layer_norm.weight) # device and type
audio_embed, output_length = self.dconv1(audio_embed, input_length, output_dim=3) # (b, l*2, d_model)
_, tgt_len, _ = audio_embed.size()
if tgt_len < self.positional_embedding.shape[0]:
current_positional_embedding = self.positional_embedding[:tgt_len]
else:
current_positional_embedding = self.positional_embedding
hidden_states = (audio_embed.to(torch.float32) + current_positional_embedding).to(audio_embed.dtype)
# packing hidden states
attention_mask, _ = get_sequence_mask(hidden_states, output_length)
hidden_states = torch.masked_select(hidden_states, attention_mask).view(torch.sum(output_length), self.config.d_model)
for idx, encoder_layer in enumerate(self.layers):
hidden_states = encoder_layer(hidden_states, output_length)
hidden_states = self.layer_norm(hidden_states)
hidden_states_before_dconv2 = hidden_states
output_length_before_dconv2 = output_length
coarse_mel, output_length = self.dconv2(hidden_states, output_length, output_dim=3)
refined_mel, coarse_mel, mel_labels_length = self.post_net(coarse_mel, output_length, mel_labels_length)
return OmniAudioDecoderOutput(
refined_mel=refined_mel,
coarse_mel=coarse_mel,
mel_length=mel_labels_length,
hidden_states_before_dconv2=hidden_states_before_dconv2,
output_length_before_dconv2=output_length_before_dconv2,
)
class OmniAudioVQBridgeTokenizer(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config.audio_config
self.gradient_checkpointing = False
self.intermediate_dim = self.config.d_model * self.config.avg_pooler
self.gate_proj = nn.Conv1d(self.config.d_model, self.intermediate_dim, self.config.avg_pooler, self.config.avg_pooler, bias=False)
self.up_proj = nn.Conv1d(self.config.d_model, self.intermediate_dim, self.config.avg_pooler, self.config.avg_pooler, bias=False)
self.down_proj = nn.Linear(self.intermediate_dim, self.intermediate_dim, bias=False)
self.act_fn = ACT2FN['silu']
self.layer_norm = nn.LayerNorm(self.intermediate_dim)
self.proj_decoder = nn.Linear(self.intermediate_dim, self.config.d_model)
self.vq_list = nn.ModuleList([])
for idx, codebook_size in enumerate(self.config.vq_config.codebook_sizes):
vq_config = copy.deepcopy(self.config.vq_config)
vq_config.dim = self.intermediate_dim
vq_config.codebook_size = codebook_size
self.vq_list.append(VectorQuantize(vq_config))
for vq_layer in self.vq_list:
deepspeed.zero.register_external_parameter(self, vq_layer.codebook.embed)
def rvq_op(self, inputs, output_length):
def rvq_layer_op(vq_layer, residual_encoding, output_length):
q_v_i, code_ids_i = vq_layer(residual_encoding, output_length)
residual_encoding = residual_encoding.float() - q_v_i.float()
residual_encoding = residual_encoding.to(inputs.dtype)
return residual_encoding, code_ids_i
cmt_loss, residual_encoding = 0, inputs
code_ids_list = []
for i, vq_layer in enumerate(self.vq_list):
residual_encoding, code_ids_i = rvq_layer_op(vq_layer, residual_encoding, output_length)
code_ids_list.append(code_ids_i)
return torch.stack(code_ids_list, -1)
def forward(self, x, output_length):
batch_size, _, _ = x.shape
output_length = output_length.to(x.device)
if x.shape[1] % self.config.avg_pooler != 0:
x = F.pad(x, (0, 0, 0, self.config.avg_pooler - x.shape[1] % self.config.avg_pooler), "constant", 0)
xt = x.permute(0, 2, 1)
g = self.gate_proj(xt).permute(0, 2, 1) # (bs, sl//poolersizre+1, d*2)
u = self.up_proj(xt).permute(0, 2, 1)
x = x.reshape(batch_size, -1, self.intermediate_dim) # (bs, sl//poolersizre+1, d*2)
c = self.down_proj(self.act_fn(g) * u)
res = self.layer_norm(c + x)
valid_mask, _ = get_sequence_mask(res, output_length)
code_ids = self.rvq_op(res, output_length)
code_ids = torch.masked_select(code_ids, valid_mask).reshape(-1, len(self.vq_list)) # (sum(valid_sequence_length), vq_num)
return code_ids
@torch.no_grad()
def decode(self, code_ids):
vq_num = code_ids.shape[-1]
res = sum(self.vq_list[i].get_output_from_indices(code_ids[:, i]).float() for i in range(vq_num-1,-1,-1)).to(self.proj_decoder.weight)
decoder_emb = self.proj_decoder(res.to(self.proj_decoder.weight))
return decoder_emb
@torch.no_grad()
def recover(self, code_ids):
vq_num = code_ids.shape[-1]
res = sum(self.vq_list[i].get_output_from_indices(code_ids[:, i]).float() for i in range(vq_num-1,-1,-1)).to(self.proj_decoder.weight)
return res
class FlowmatchingPrenet(nn.Module):
def __init__(
self,
input_feat_dim,
out_feat_dim,
d_model,
attention_heads,
ffn_dim,
nlayers,
activation_function,
max_source_positions,
target_mel_length_scale_ratio,
):
super().__init__()
self.d_model = d_model
self.target_mel_length_scale_ratio = target_mel_length_scale_ratio
self.gradient_checkpointing = False
self.register_buffer(
"positional_embedding", sinusoids(max_source_positions, d_model)
)
self.in_mlp = nn.Sequential(
nn.Linear(input_feat_dim, d_model * 4),
nn.SiLU(),
nn.Linear(d_model * 4, d_model),
)
self.transformer_layers = nn.ModuleList(
[
OmniWhisperTransformerLayer(
act=ACT2FN[activation_function],
d_model=d_model,
encoder_attention_heads=attention_heads,
encoder_ffn_dim=ffn_dim,
causal=True, # causal
ln_type="RMSNorm",
)
for _ in range(nlayers)
]
)
self.final_norm = RMSNorm(self.d_model)
self.out_proj = nn.Linear(d_model, out_feat_dim, bias=False)
def compute_output_length(self, input_length):
output_length = input_length.float() * self.target_mel_length_scale_ratio
return output_length.to(torch.int64)
def forward(self, input_feat, input_length, output_length=None):
"""
Args:
input_feat: [B, T, input_feat_dim]
input_length: [B]
output_length: [B]
"""
if output_length is None or not self.training:
output_length = self.compute_output_length(input_length)
input_feat = input_feat[:, : input_length.max(), :] # [B, T, D]
orig_dtype = input_feat.dtype
input_feat = F.interpolate(
input=input_feat.to(torch.float32).transpose(1, 2).contiguous(),
size=output_length.max(),
mode="nearest",
).to(orig_dtype)
input_feat = input_feat.transpose(1, 2).contiguous() # [B, T, D]
hidden_states = self.in_mlp(input_feat)
# packing hidden states
bsz, tgt_len, d_model = hidden_states.shape
attention_mask, unpacking_index = get_sequence_mask(
hidden_states, output_length
)
hidden_states = torch.masked_select(hidden_states, attention_mask).view(
torch.sum(output_length), self.d_model
)
for idx, encoder_layer in enumerate(self.transformer_layers):
hidden_states = encoder_layer(hidden_states, output_length)
# unpacking
hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(
bsz, tgt_len, d_model
)
hidden_states = torch.where(attention_mask, hidden_states, 0)
hidden_states = self.final_norm(hidden_states)
output = self.out_proj(hidden_states)
return output, output_length
@dataclass
class OmniAudioFlowMatchingDecoderOutput(ModelOutput):
flow_matching_mel: Optional[torch.FloatTensor] = None
flow_matching_mel_lengths: Optional[torch.FloatTensor] = None
class OmniAudioFlowMatchingDecoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config.flow_matching_config
self.in_channels = self.config.in_channels
self.spk_emb_dim = self.config.spk_emb_dim
self.diffusion_steps = self.config.diffusion_steps
self.cal_mel_mae = self.config.cal_mel_mae
self.forward_step = -1
self.prenet = FlowmatchingPrenet(
input_feat_dim=self.config.prenet_in_dim,
out_feat_dim=self.config.prenet_out_dim,
d_model=self.config.prenet_d_model,
attention_heads=self.config.prenet_attention_heads,
ffn_dim=self.config.prenet_ffn_dim,
nlayers=self.config.prenet_nlayers,
activation_function=self.config.prenet_activation_function,
max_source_positions=self.config.prenet_max_source_positions,
target_mel_length_scale_ratio=self.config.prenet_target_mel_length_scale_ratio,
)
self.conditional_decoder = ConditionalDecoder(
in_channels=self.in_channels * 2 + self.spk_emb_dim,
out_channels=self.in_channels,
causal=True,
channels=self.config.channels,
dropout=self.config.dropout,
attention_head_dim=self.config.attention_head_dim,
n_blocks=self.config.n_blocks,
num_mid_blocks=self.config.num_mid_blocks,
num_heads=self.config.num_heads,
act_fn=self.config.act_fn,
)
self.cfm = ConditionalCFM(
in_channels=self.in_channels,
cfm_params=self.config.cfm_params,
n_spks=0,
spk_emb_dim=self.spk_emb_dim,
)
def unpack_hidden_states(self, hidden_states, output_length):
unpacked = unpack_hidden_states(hidden_states, output_length)
return unpacked, output_length
def forward(
self, refined_mel, input_length, mel_labels=None, mel_labels_length=None
):
"""
:param refined_mel: [bs, max_input_len, mel_bin]
:param input_length: [batch_size]
:param refined_mel: [bs, mel_bin, max_input_len]
:return:
"""
self.forward_step += 1
orig_dtype = refined_mel.dtype
prenet_mae_metric = torch.tensor(0.0).to(refined_mel.device)
prenet_regression_loss = torch.tensor(0.0).to(refined_mel.device)
if self.prenet is not None:
refined_mel = refined_mel[:, : torch.max(input_length), :]
if mel_labels_length is None:
mel_labels_length = self.prenet.compute_output_length(input_length)
refined_mel, input_length = self.prenet(
refined_mel, input_length, mel_labels_length
)
float_dtype = refined_mel.dtype
refined_mel = refined_mel.float()
input_length = input_length.long()
refined_mel = refined_mel[:, : torch.max(input_length), :]
sequence_mask, unpacking_index = get_sequence_mask(refined_mel, input_length)
refined_mel = refined_mel.transpose(1, 2) # (bs, mel_bin, max_input_len)
sequence_mask = sequence_mask.transpose(2, 1) # (bs, 1, sl)
fm_mel = self.cfm.forward(
estimator=self.conditional_decoder,
mu=refined_mel.to(float_dtype),
mask=sequence_mask.float(),
n_timesteps=self.diffusion_steps,
)
return OmniAudioFlowMatchingDecoderOutput(
flow_matching_mel=fm_mel.transpose(1, 2),
flow_matching_mel_lengths=mel_labels_length,
)

255
config.json Normal file
View File

@ -0,0 +1,255 @@
{
"_name_or_path": "_",
"architectures": [
"OmniForCausalLM"
],
"attention_qkv_bias": true,
"attention_qkv_pack": true,
"audio_config": {
"audio_head_transformer_layers": 3,
"audio_delim_token_id": 151693,
"audio_end_token_id": 151677,
"audio_pad_token_id": 151678,
"audio_start_token_id": 151676,
"audiogen_end_token_id": 151701,
"audiogen_start_token_id": 151700,
"audiotext_end_token_id": 151698,
"audiotext_pad_token_id": 151699,
"audiotext_start_token_id": 151697,
"avg_pooler": 4,
"d_model": 1280,
"decoder_attention_heads": 20,
"decoder_ffn_dim": 5120,
"decoder_kernel_size": 3,
"decoder_layers": 8,
"decoder_stride_size": 2,
"enable": true,
"encoder_attention_heads": 20,
"encoder_ffn_dim": 5120,
"encoder_layers": 32,
"hop_length": 160,
"kernel_size": 3,
"max_audio_seconds": 30,
"n_fft": 400,
"num_mel_bins": 128,
"sampling_rate": 16000,
"stride_size": 2,
"split_overlap": 0.0,
"vq_config":{
"enable": true,
"codebook_sizes": [8192, 4096, 2048, 1024, 1024, 1024, 1024, 1024]
}
},
"auto_map": {
"AutoConfig": "configuration_omni.OmniConfig",
"AutoModelForCausalLM": "modeling_omni.OmniForCausalLM"
},
"omni_tokenizer_type": "auto",
"bos_token_id": 1,
"eos_token_id": 2,
"flow_matching_config": {
"enable": true,
"use_hires_mel": true,
"sampling_rate": 24000,
"hop_length": 480,
"max_audio_seconds": 30,
"split_overlap": 0.1,
"use_hidden_states_before_dconv2": true,
"prenet_in_dim": 1280,
"prenet_out_dim": 80,
"prenet_d_model": 512,
"prenet_attention_heads": 8,
"prenet_ffn_dim": 2048,
"prenet_nlayers": 12,
"prenet_activation_function": "gelu",
"prenet_max_source_positions": 5000,
"prenet_target_mel_length_scale_ratio": 1.0,
"prenet_loss_weight": 1.0,
"unet_use_omni_attn": false,
"loss_weight": 1.0,
"in_channels": 80,
"spk_emb_dim": 0,
"diffusion_steps": 10,
"channels": [256],
"dropout": 0.0,
"attention_head_dim": 64,
"n_blocks": 4,
"num_mid_blocks": 12,
"num_heads": 8,
"act_fn": "gelu",
"cal_mel_mae": true,
"cfm_params": {
"sigma_min": 1e-6,
"solver": "euler",
"t_scheduler": "cosine",
"training_cfg_rate": 0.2,
"inference_cfg_rate": 0.7,
"reg_loss_type": "l1"
}
},
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 3584,
"initializer_range": 0.02,
"intermediate_size": 18944,
"max_position_embeddings": 65536,
"max_window_layers": 28,
"model_type": "omni",
"multimodal": [
"audio",
"image",
"video",
"audiogen"
],
"multimodal_special_token_list": [
151676,
151677,
151678,
151679,
151680,
151681,
151682,
151683,
151684,
151685,
151686,
151687,
151688,
151693,
151694,
151695,
151696,
151697,
151698,
151699,
151700,
151701
],
"num_attention_heads": 28,
"num_hidden_layers": 28,
"num_key_value_heads": 4,
"pad_token_id": 0,
"position_embedding_type": "rope",
"rms_norm_eps": 1e-06,
"rope_theta": 1000000.0,
"sliding_window": 131072,
"sparse_attention_heads": null,
"sparse_attention_layers": [],
"tie_word_embeddings": false,
"torch_dtype": "bfloat16",
"train_multimodal_special_tokens_only": false,
"transformers_version": "4.45.0.dev0",
"use_cache": false,
"use_norm_head": false,
"use_sliding_window": false,
"video_config": {
"_name_or_path": "",
"_attn_implementation": "flash_attention_2",
"decode_way": "1fps",
"depth": 32,
"embed_dim": 1280,
"enable": true,
"hidden_act": "quick_gelu",
"hidden_size": 3584,
"image_delimiter_token_id": 151688,
"image_end_token_id": 151680,
"image_line_token_id": 151682,
"image_mean": [
0.48145466,
0.4578275,
0.40821073
],
"image_pad_token_id": 151681,
"image_size": 224,
"image_start_token_id": 151679,
"image_std": [
0.26862954,
0.26130258,
0.27577711
],
"in_channels": 3,
"in_chans": 3,
"intermediate_size": 3072,
"layer_norm_eps": 1e-05,
"max_frame_num": 32,
"max_length": 20,
"max_pixels": 602112,
"merge_size": 2,
"min_length": 0,
"min_pixels": 3136,
"mlp_ratio": 4,
"model_type": "clip_vision_model",
"num_attention_heads": 12,
"num_channels": 3,
"num_heads": 16,
"num_hidden_layers": 12,
"patch_size": 14,
"spatial_merge_size": 2,
"spatial_patch_size": 14,
"temporal_patch_size": 2,
"video_end_token_id": 151696,
"video_place_token_id": 151694,
"video_start_token_id": 151695
},
"visual_config": {
"_name_or_path": "",
"_attn_implementation": "flash_attention_2",
"depth": 32,
"diversity_penalty": 0.0,
"do_sample": false,
"early_stopping": false,
"embed_dim": 1280,
"enable": true,
"hidden_act": "quick_gelu",
"hidden_size": 3584,
"image_delimiter_token_id": 151688,
"image_end_token_id": 151680,
"image_line_token_id": 151682,
"image_mean": [
0.48145466,
0.4578275,
0.40821073
],
"image_pad_token_id": 151681,
"image_size": 224,
"image_start_token_id": 151679,
"image_std": [
0.26862954,
0.26130258,
0.27577711
],
"in_channels": 3,
"in_chans": 3,
"intermediate_size": 3072,
"layer_norm_eps": 1e-05,
"length_penalty": 1.0,
"max_length": 20,
"max_pixels": 3211264,
"merge_size": 2,
"min_length": 0,
"min_pixels": 3136,
"mlp_ratio": 4,
"model_type": "clip_vision_model",
"num_attention_heads": 12,
"num_channels": 3,
"num_heads": 16,
"num_hidden_layers": 12,
"patch_size": 14,
"projection_dim": 512,
"spatial_merge_size": 2,
"spatial_patch_size": 14,
"temporal_patch_size": 2
},
"vocab_size": 152064,
"vocoder_config":{
"enable": true,
"enable_multi_scale": true,
"max_audio_seconds": 30,
"sampling_rate": 16000,
"hop_length": 256,
"split_overlap": 0.0,
"n_fft": 1024,
"num_mel_bins": 80,
"channels": [256, 256, 256, 256, 256]
}
}

120
configuration_omni.py Normal file
View File

@ -0,0 +1,120 @@
# Copyright 2023 Baichuan Inc. All Rights Reserved.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from transformers import WhisperConfig
from transformers import CLIPVisionConfig
logger = logging.get_logger(__name__)
class OmniConfig(PretrainedConfig):
model_type = "omni"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=125696,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
sparse_attention_heads=None,
sparse_attention_layers=[],
head_dim=None,
attention_qkv_pack=True,
attention_qkv_bias=False,
use_norm_head=True,
hidden_act="silu",
max_position_embeddings=4096,
position_embedding_type="rope",
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
audio_config=None,
visual_config=None,
video_config=None,
vocoder_config=None,
flow_matching_config=None,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads or self.num_attention_heads
self.sparse_attention_heads = sparse_attention_heads
self.sparse_attention_layers = sparse_attention_layers
self.head_dim = head_dim or self.hidden_size // self.num_attention_heads
self.attention_qkv_pack = attention_qkv_pack
self.attention_qkv_bias = attention_qkv_bias
self.use_norm_head = use_norm_head
self.hidden_act = hidden_act
self.position_embedding_type = position_embedding_type
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
assert self.position_embedding_type.lower() in ("rope", "alibi")
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
if audio_config is not None:
self.audio_config = WhisperConfig(**audio_config)
if self.audio_config.vq_config is not None:
self.audio_config.vq_config = PretrainedConfig(**self.audio_config.vq_config)
if vocoder_config is not None:
self.vocoder_config = WhisperConfig(**vocoder_config)
if flow_matching_config is not None:
self.flow_matching_config = PretrainedConfig(**flow_matching_config)
self.flow_matching_config.cfm_params = PretrainedConfig(**self.flow_matching_config.cfm_params)
if visual_config is not None:
self.visual_config = CLIPVisionConfig(**visual_config)
if video_config is not None:
self.video_config = CLIPVisionConfig(**video_config)
def to_diff_dict(self):
data = super().to_diff_dict()
data["model_type"] = self.model_type
return data
def get_rotary_base(self):
if hasattr(self, "rotary_emb_base"):
return self.rotary_emb_base
else:
return self.rope_theta
if __name__ == '__main__':
from transformers import AutoConfig
config = AutoConfig.from_pretrained("./", trust_remote_code=True)
print(config)

791
flow_matching.py Normal file
View File

@ -0,0 +1,791 @@
# Modified from CosyVoice https://github.com/FunAudioLLM/CosyVoice/tree/main
"""
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from abc import ABC
import torch
import torch.nn.functional as F
from typing import Dict, Optional
import torch.nn as nn
from einops import pack, rearrange, repeat
from .matcha_components import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
from .matcha_transformer import BasicTransformerBlock
from omegaconf import DictConfig
def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
assert mask.dtype == torch.bool
assert dtype in [torch.float32, torch.bfloat16, torch.float16]
mask = mask.to(dtype)
# attention mask bias
# NOTE(Mddct): torch.finfo jit issues
# chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
mask = (1.0 - mask) * torch.finfo(dtype).min
return mask
def subsequent_chunk_mask(
size: int,
chunk_size: int,
num_left_chunks: int = -1,
device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
"""Create mask for subsequent steps (size, size) with chunk size,
this is for streaming encoder
Args:
size (int): size of mask
chunk_size (int): size of chunk
num_left_chunks (int): number of left chunks
<0: use full chunk
>=0: use num_left_chunks
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
Returns:
torch.Tensor: mask
Examples:
>>> subsequent_chunk_mask(4, 2)
[[1, 1, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 1],
[1, 1, 1, 1]]
"""
# NOTE this modified implementation meets onnx export requirements, but it doesn't support num_left_chunks
# actually this is not needed after we have inference cache implemented, will remove it later
pos_idx = torch.arange(size, device=device)
block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size
ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1)
return ret
def subsequent_mask(
size: int,
device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
"""Create mask for subsequent steps (size, size).
This mask is used only in decoder which works in an auto-regressive mode.
This means the current step could only do attention with its left steps.
In encoder, fully attention is used when streaming is not necessary and
the sequence is not long. In this case, no attention mask is needed.
When streaming is need, chunk-based attention is used in encoder. See
subsequent_chunk_mask for the chunk-based attention mask.
Args:
size (int): size of mask
str device (str): "cpu" or "cuda" or torch.Tensor.device
dtype (torch.device): result dtype
Returns:
torch.Tensor: mask
Examples:
>>> subsequent_mask(3)
[[1, 0, 0],
[1, 1, 0],
[1, 1, 1]]
"""
arange = torch.arange(size, device=device)
mask = arange.expand(size, size)
arange = arange.unsqueeze(-1)
mask = mask <= arange
return mask
def add_optional_chunk_mask(xs: torch.Tensor,
masks: torch.Tensor,
use_dynamic_chunk: bool,
use_dynamic_left_chunk: bool,
decoding_chunk_size: int,
static_chunk_size: int,
num_decoding_left_chunks: int,
enable_full_context: bool = True):
""" Apply optional mask for encoder.
Args:
xs (torch.Tensor): padded input, (B, L, D), L for max length
mask (torch.Tensor): mask for xs, (B, 1, L)
use_dynamic_chunk (bool): whether to use dynamic chunk or not
use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
training.
decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
0: default for training, use random dynamic chunk.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
static_chunk_size (int): chunk size for static chunk training/decoding
if it's greater than 0, if use_dynamic_chunk is true,
this parameter will be ignored
num_decoding_left_chunks: number of left chunks, this is for decoding,
the chunk size is decoding_chunk_size.
>=0: use num_decoding_left_chunks
<0: use all left chunks
enable_full_context (bool):
True: chunk size is either [1, 25] or full context(max_len)
False: chunk size ~ U[1, 25]
Returns:
torch.Tensor: chunk mask of the input xs.
"""
# Whether to use chunk mask or not
if use_dynamic_chunk:
max_len = xs.size(1)
if decoding_chunk_size < 0:
chunk_size = max_len
num_left_chunks = -1
elif decoding_chunk_size > 0:
chunk_size = decoding_chunk_size
num_left_chunks = num_decoding_left_chunks
else:
# chunk size is either [1, 25] or full context(max_len).
# Since we use 4 times subsampling and allow up to 1s(100 frames)
# delay, the maximum frame is 100 / 4 = 25.
chunk_size = torch.randint(1, max_len, (1, )).item()
num_left_chunks = -1
if chunk_size > max_len // 2 and enable_full_context:
chunk_size = max_len
else:
chunk_size = chunk_size % 25 + 1
if use_dynamic_left_chunk:
max_left_chunks = (max_len - 1) // chunk_size
num_left_chunks = torch.randint(0, max_left_chunks,
(1, )).item()
chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
num_left_chunks,
xs.device) # (L, L)
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
chunk_masks = masks & chunk_masks # (B, L, L)
elif static_chunk_size > 0:
num_left_chunks = num_decoding_left_chunks
chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
num_left_chunks,
xs.device) # (L, L)
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
chunk_masks = masks & chunk_masks # (B, L, L)
else:
chunk_masks = masks
return chunk_masks
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
"""Make mask tensor containing indices of padded part.
See description of make_non_pad_mask.
Args:
lengths (torch.Tensor): Batch of lengths (B,).
Returns:
torch.Tensor: Mask tensor containing indices of padded part.
Examples:
>>> lengths = [5, 3, 2]
>>> make_pad_mask(lengths)
masks = [[0, 0, 0, 0 ,0],
[0, 0, 0, 1, 1],
[0, 0, 1, 1, 1]]
"""
batch_size = lengths.size(0)
max_len = max_len if max_len > 0 else lengths.max().item()
seq_range = torch.arange(0,
max_len,
dtype=torch.int64,
device=lengths.device)
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
seq_length_expand = lengths.unsqueeze(-1)
mask = seq_range_expand >= seq_length_expand
return mask
# Causal
class Transpose(torch.nn.Module):
def __init__(self, dim0: int, dim1: int):
super().__init__()
self.dim0 = dim0
self.dim1 = dim1
def forward(self, x: torch.Tensor):
x = torch.transpose(x, self.dim0, self.dim1)
return x
class CausalBlock1D(Block1D):
def __init__(self, dim: int, dim_out: int):
super(CausalBlock1D, self).__init__(dim, dim_out)
self.block = torch.nn.Sequential(
CausalConv1d(dim, dim_out, 3),
Transpose(1, 2),
nn.LayerNorm(dim_out),
Transpose(1, 2),
nn.Mish(),
)
def forward(self, x: torch.Tensor, mask: torch.Tensor):
output = self.block(x * mask)
return output * mask
class CausalResnetBlock1D(ResnetBlock1D):
def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
self.block1 = CausalBlock1D(dim, dim_out)
self.block2 = CausalBlock1D(dim_out, dim_out)
class CausalConv1d(torch.nn.Conv1d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',
device=None,
dtype=None
) -> None:
super(CausalConv1d, self).__init__(in_channels, out_channels,
kernel_size, stride,
padding=0, dilation=dilation,
groups=groups, bias=bias,
padding_mode=padding_mode,
device=device, dtype=dtype)
assert stride == 1
self.causal_padding = (kernel_size - 1, 0)
def forward(self, x: torch.Tensor):
x = F.pad(x, self.causal_padding)
x = super(CausalConv1d, self).forward(x)
return x
class BASECFM(torch.nn.Module, ABC):
def __init__(
self,
n_feats,
cfm_params,
n_spks=1,
spk_emb_dim=128,
):
super().__init__()
self.n_feats = n_feats
self.n_spks = n_spks
self.spk_emb_dim = spk_emb_dim
self.solver = cfm_params.solver
if hasattr(cfm_params, "sigma_min"):
self.sigma_min = cfm_params.sigma_min
else:
self.sigma_min = 1e-4
self.estimator = None
@torch.inference_mode()
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
"""Forward diffusion
Args:
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): output_mask
shape: (batch_size, 1, mel_timesteps)
n_timesteps (int): number of diffusion steps
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
spks (torch.Tensor, optional): speaker ids. Defaults to None.
shape: (batch_size, spk_emb_dim)
cond: Not used but kept for future purposes
Returns:
sample: generated mel-spectrogram
shape: (batch_size, n_feats, mel_timesteps)
"""
z = torch.randn_like(mu) * temperature
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
def solve_euler(self, x, t_span, mu, mask, spks, cond):
"""
Fixed euler solver for ODEs.
Args:
x (torch.Tensor): random noise
t_span (torch.Tensor): n_timesteps interpolated
shape: (n_timesteps + 1,)
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): output_mask
shape: (batch_size, 1, mel_timesteps)
spks (torch.Tensor, optional): speaker ids. Defaults to None.
shape: (batch_size, spk_emb_dim)
cond: Not used but kept for future purposes
"""
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
# Or in future might add like a return_all_steps flag
sol = []
for step in range(1, len(t_span)):
dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
x = x + dt * dphi_dt
t = t + dt
sol.append(x)
if step < len(t_span) - 1:
dt = t_span[step + 1] - t
return sol[-1]
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
"""Computes diffusion loss
Args:
x1 (torch.Tensor): Target
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): target mask
shape: (batch_size, 1, mel_timesteps)
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
shape: (batch_size, spk_emb_dim)
Returns:
loss: conditional flow matching loss
y: conditional flow
shape: (batch_size, n_feats, mel_timesteps)
"""
b, _, t = mu.shape
# random timestep
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
# sample noise p(x_0)
z = torch.randn_like(x1)
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
u = x1 - (1 - self.sigma_min) * z
loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / (
torch.sum(mask) * u.shape[1]
)
return loss, y
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
"""Make mask tensor containing indices of padded part.
See description of make_non_pad_mask.
Args:
lengths (torch.Tensor): Batch of lengths (B,).
Returns:
torch.Tensor: Mask tensor containing indices of padded part.
Examples:
>>> lengths = [5, 3, 2]
>>> make_pad_mask(lengths)
masks = [[0, 0, 0, 0 ,0],
[0, 0, 0, 1, 1],
[0, 0, 1, 1, 1]]
"""
batch_size = lengths.size(0)
max_len = max_len if max_len > 0 else lengths.max().item()
seq_range = torch.arange(0,
max_len,
dtype=torch.int64,
device=lengths.device)
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
seq_length_expand = lengths.unsqueeze(-1)
mask = seq_range_expand >= seq_length_expand
return mask
class ConditionalDecoder(nn.Module):
def __init__(
self,
in_channels,
out_channels,
causal=False,
channels=(256, 256),
dropout=0.05,
attention_head_dim=64,
n_blocks=1,
num_mid_blocks=2,
num_heads=4,
act_fn="snake",
gradient_checkpointing=True,
):
"""
This decoder requires an input with the same shape of the target. So, if your text content
is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
"""
super().__init__()
channels = tuple(channels)
self.in_channels = in_channels
self.out_channels = out_channels
self.causal = causal
self.static_chunk_size = 2 * 25 * 2 # 2*input_frame_rate*token_mel_ratio
self.gradient_checkpointing = gradient_checkpointing
self.time_embeddings = SinusoidalPosEmb(in_channels)
time_embed_dim = channels[0] * 4
self.time_mlp = TimestepEmbedding(
in_channels=in_channels,
time_embed_dim=time_embed_dim,
act_fn="silu",
)
self.down_blocks = nn.ModuleList([])
self.mid_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([])
output_channel = in_channels
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
input_channel = output_channel
output_channel = channels[i]
is_last = i == len(channels) - 1
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
dim=output_channel,
num_attention_heads=num_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
activation_fn=act_fn,
)
for _ in range(n_blocks)
]
)
downsample = (
Downsample1D(output_channel) if not is_last else
CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
)
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
for _ in range(num_mid_blocks):
input_channel = channels[-1]
out_channels = channels[-1]
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
dim=output_channel,
num_attention_heads=num_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
activation_fn=act_fn,
)
for _ in range(n_blocks)
]
)
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
channels = channels[::-1] + (channels[0],)
for i in range(len(channels) - 1):
input_channel = channels[i] * 2
output_channel = channels[i + 1]
is_last = i == len(channels) - 2
resnet = CausalResnetBlock1D(
dim=input_channel,
dim_out=output_channel,
time_emb_dim=time_embed_dim,
) if self.causal else ResnetBlock1D(
dim=input_channel,
dim_out=output_channel,
time_emb_dim=time_embed_dim,
)
transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
dim=output_channel,
num_attention_heads=num_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
activation_fn=act_fn,
)
for _ in range(n_blocks)
]
)
upsample = (
Upsample1D(output_channel, use_conv_transpose=True)
if not is_last
else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
)
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1])
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
self.initialize_weights()
def initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.GroupNorm):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x, mask, mu, t, spks=None, cond=None):
"""Forward pass of the UNet1DConditional model.
Args:
x (torch.Tensor): shape (batch_size, in_channels, time)
mask (_type_): shape (batch_size, 1, time)
t (_type_): shape (batch_size)
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
cond (_type_, optional): placeholder for future use. Defaults to None.
Raises:
ValueError: _description_
ValueError: _description_
Returns:
_type_: _description_
"""
t = self.time_embeddings(t)
t = t.to(x.dtype)
t = self.time_mlp(t)
x = pack([x, mu], "b * t")[0]
mask = mask.to(x.dtype)
if spks is not None:
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
x = pack([x, spks], "b * t")[0]
if cond is not None:
x = pack([x, cond], "b * t")[0]
hiddens = []
masks = [mask]
for resnet, transformer_blocks, downsample in self.down_blocks:
mask_down = masks[-1]
x = resnet(x, mask_down, t)
x = rearrange(x, "b c t -> b t c").contiguous()
# attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
for transformer_block in transformer_blocks:
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(transformer_block),
x,
attn_mask,
t,
)
else:
x = transformer_block(
hidden_states=x,
attention_mask=attn_mask,
timestep=t,
)
x = rearrange(x, "b t c -> b c t").contiguous()
hiddens.append(x) # Save hidden states for skip connections
x = downsample(x * mask_down)
masks.append(mask_down[:, :, ::2])
masks = masks[:-1]
mask_mid = masks[-1]
for resnet, transformer_blocks in self.mid_blocks:
x = resnet(x, mask_mid, t)
x = rearrange(x, "b c t -> b t c").contiguous()
# attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
for transformer_block in transformer_blocks:
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(transformer_block),
x,
attn_mask,
t,
)
else:
x = transformer_block(
hidden_states=x,
attention_mask=attn_mask,
timestep=t,
)
x = rearrange(x, "b t c -> b c t").contiguous()
for resnet, transformer_blocks, upsample in self.up_blocks:
mask_up = masks.pop()
skip = hiddens.pop()
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
x = resnet(x, mask_up, t)
x = rearrange(x, "b c t -> b t c").contiguous()
# attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
for transformer_block in transformer_blocks:
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(transformer_block),
x,
attn_mask,
t,
)
else:
x = transformer_block(
hidden_states=x,
attention_mask=attn_mask,
timestep=t,
)
x = rearrange(x, "b t c -> b c t").contiguous()
x = upsample(x * mask_up)
x = self.final_block(x, mask_up)
output = self.final_proj(x * mask_up)
return output * mask
class ConditionalCFM(BASECFM):
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64):
super().__init__(
n_feats=in_channels,
cfm_params=cfm_params,
n_spks=n_spks,
spk_emb_dim=spk_emb_dim,
)
self.t_scheduler = cfm_params.t_scheduler
self.training_cfg_rate = cfm_params.training_cfg_rate
self.inference_cfg_rate = cfm_params.inference_cfg_rate
@torch.inference_mode()
def forward(self, estimator, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
"""Forward diffusion
Args:
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): output_mask
shape: (batch_size, 1, mel_timesteps)
n_timesteps (int): number of diffusion steps
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
spks (torch.Tensor, optional): speaker ids. Defaults to None.
shape: (batch_size, spk_emb_dim)
cond: Not used but kept for future purposes
Returns:
sample: generated mel-spectrogram
shape: (batch_size, n_feats, mel_timesteps)
"""
z = torch.randn_like(mu) * temperature
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
if self.t_scheduler == 'cosine':
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
return self.solve_euler(estimator, z, t_span=t_span.to(mu.dtype), mu=mu, mask=mask, spks=spks, cond=cond)
def solve_euler(self, estimator, x, t_span, mu, mask, spks, cond):
"""
Fixed euler solver for ODEs.
Args:
x (torch.Tensor): random noise
t_span (torch.Tensor): n_timesteps interpolated
shape: (n_timesteps + 1,)
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): output_mask
shape: (batch_size, 1, mel_timesteps)
spks (torch.Tensor, optional): speaker ids. Defaults to None.
shape: (batch_size, spk_emb_dim)
cond: Not used but kept for future purposes
"""
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
# Or in future might add like a return_all_steps flag
sol = []
for step in range(1, len(t_span)):
dphi_dt = estimator(x, mask, mu, t, spks, cond)
# Classifier-Free Guidance inference introduced in VoiceBox
if self.inference_cfg_rate > 0:
cfg_dphi_dt = estimator(
x, mask,
torch.zeros_like(mu), t,
torch.zeros_like(spks) if spks is not None else None,
cond=cond
)
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt -
self.inference_cfg_rate * cfg_dphi_dt)
x = x + dt * dphi_dt
t = t + dt
sol.append(x)
if step < len(t_span) - 1:
dt = t_span[step + 1] - t
return sol[-1]
def compute_loss(self, estimator, x1, mask, mu, spks=None, cond=None):
"""Computes diffusion loss
Args:
x1 (torch.Tensor): Target
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): target mask
shape: (batch_size, 1, mel_timesteps)
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
shape: (batch_size, spk_emb_dim)
Returns:
loss: conditional flow matching loss
y: conditional flow
shape: (batch_size, n_feats, mel_timesteps)
"""
org_dtype = x1.dtype
b, _, t = mu.shape
# random timestep
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
if self.t_scheduler == 'cosine':
t = 1 - torch.cos(t * 0.5 * torch.pi)
# sample noise p(x_0)
z = torch.randn_like(x1)
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
u = x1 - (1 - self.sigma_min) * z
# during training, we randomly drop condition to trade off mode coverage and sample fidelity
if self.training_cfg_rate > 0:
cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
mu = mu * cfg_mask.view(-1, 1, 1)
if spks is not None:
spks = spks * cfg_mask.view(-1, 1)
if cond is not None:
cond = cond * cfg_mask.view(-1, 1, 1)
pred = estimator(y, mask, mu, t.squeeze(), spks, cond)
pred = pred.float()
u = u.float()
loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
loss = loss.to(org_dtype)
return loss, y

6
generation_config.json Normal file
View File

@ -0,0 +1,6 @@
{
"bos_token_id": 151643,
"eos_token_id": 151643,
"max_new_tokens": 2048,
"transformers_version": "4.45.0.dev0"
}

83
generation_utils.py Normal file
View File

@ -0,0 +1,83 @@
from typing import List
from queue import Queue
import torch
def build_chat_input(model, tokenizer, messages: List[dict], max_new_tokens: int=0):
def _parse_messages(messages, split_role="user"):
system, rounds = "", []
round = []
for i, message in enumerate(messages):
if message["role"] == "system":
assert i == 0
system = message["content"]
continue
if message["role"] == split_role and round:
rounds.append(round)
round = []
round.append(message)
if round:
rounds.append(round)
return system, rounds
max_new_tokens = max_new_tokens or model.generation_config.max_new_tokens
max_input_tokens = model.config.model_max_length - max_new_tokens
system, rounds = _parse_messages(messages, split_role="user")
system_tokens = tokenizer.encode(system)
max_history_tokens = max_input_tokens - len(system_tokens)
history_tokens = []
for round in rounds[::-1]:
round_tokens = []
for message in round:
if message["role"] == "user":
round_tokens.append(model.generation_config.user_token_id)
else:
round_tokens.append(model.generation_config.assistant_token_id)
round_tokens.extend(tokenizer.encode(message["content"]))
if len(history_tokens) == 0 or len(history_tokens) + len(round_tokens) <= max_history_tokens:
history_tokens = round_tokens + history_tokens # concat left
if len(history_tokens) < max_history_tokens:
continue
break
input_tokens = system_tokens + history_tokens
if messages[-1]["role"] != "assistant":
input_tokens.append(model.generation_config.assistant_token_id)
input_tokens = input_tokens[-max_input_tokens:] # truncate left
return torch.LongTensor([input_tokens]).to(model.device)
class TextIterStreamer:
def __init__(self, tokenizer, skip_prompt=False, skip_special_tokens=False):
self.tokenizer = tokenizer
self.skip_prompt = skip_prompt
self.skip_special_tokens = skip_special_tokens
self.tokens = []
self.text_queue = Queue()
self.next_tokens_are_prompt = True
def put(self, value):
if self.skip_prompt and self.next_tokens_are_prompt:
self.next_tokens_are_prompt = False
else:
if len(value.shape) > 1:
value = value[0]
self.tokens.extend(value.tolist())
self.text_queue.put(
self.tokenizer.decode(self.tokens, skip_special_tokens=self.skip_special_tokens))
def end(self):
self.text_queue.put(None)
def __iter__(self):
return self
def __next__(self):
value = self.text_queue.get()
if value is None:
raise StopIteration()
else:
return value

189
matcha_components.py Normal file
View File

@ -0,0 +1,189 @@
# Modified from Matcha-TTS https://github.com/shivammehta25/Matcha-TTS
"""
MIT License
Copyright (c) 2023 Shivam Mehta
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
import math
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.activations import get_activation
class SinusoidalPosEmb(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
def forward(self, x, scale=1000):
if x.ndim < 1:
x = x.unsqueeze(0)
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class Block1D(torch.nn.Module):
def __init__(self, dim, dim_out, groups=8):
super().__init__()
self.block = torch.nn.Sequential(
torch.nn.Conv1d(dim, dim_out, 3, padding=1),
torch.nn.GroupNorm(groups, dim_out),
nn.Mish(),
)
def forward(self, x, mask):
output = self.block(x * mask)
return output * mask
class ResnetBlock1D(torch.nn.Module):
def __init__(self, dim, dim_out, time_emb_dim, groups=8):
super().__init__()
self.mlp = torch.nn.Sequential(
nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out)
)
self.block1 = Block1D(dim, dim_out, groups=groups)
self.block2 = Block1D(dim_out, dim_out, groups=groups)
self.res_conv = torch.nn.Conv1d(dim, dim_out, 1)
def forward(self, x, mask, time_emb):
h = self.block1(x, mask)
h += self.mlp(time_emb).unsqueeze(-1)
h = self.block2(h, mask)
output = h + self.res_conv(x * mask)
return output
class Downsample1D(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = torch.nn.Conv1d(dim, dim, 3, 2, 1)
def forward(self, x):
return self.conv(x)
class TimestepEmbedding(nn.Module):
def __init__(
self,
in_channels: int,
time_embed_dim: int,
act_fn: str = "silu",
out_dim: int = None,
post_act_fn: Optional[str] = None,
cond_proj_dim=None,
):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
if cond_proj_dim is not None:
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
else:
self.cond_proj = None
self.act = get_activation(act_fn)
if out_dim is not None:
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
if post_act_fn is None:
self.post_act = None
else:
self.post_act = get_activation(post_act_fn)
def forward(self, sample, condition=None):
if condition is not None:
sample = sample + self.cond_proj(condition)
sample = self.linear_1(sample)
if self.act is not None:
sample = self.act(sample)
sample = self.linear_2(sample)
if self.post_act is not None:
sample = self.post_act(sample)
return sample
class Upsample1D(nn.Module):
"""A 1D upsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
use_conv_transpose (`bool`, default `False`):
option to use a convolution transpose.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
"""
def __init__(
self,
channels,
use_conv=False,
use_conv_transpose=True,
out_channels=None,
name="conv",
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_conv_transpose = use_conv_transpose
self.name = name
self.conv = None
if use_conv_transpose:
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
elif use_conv:
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
def forward(self, inputs):
assert inputs.shape[1] == self.channels
if self.use_conv_transpose:
return self.conv(inputs)
outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
if self.use_conv:
outputs = self.conv(outputs)
return outputs

107
matcha_feat.py Normal file
View File

@ -0,0 +1,107 @@
# Modified from Matcha-TTS https://github.com/shivammehta25/Matcha-TTS
"""
MIT License
Copyright (c) 2023 Shivam Mehta
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
import numpy as np
import torch
import torch.utils.data
from librosa.filters import mel as librosa_mel_fn
from scipy.io.wavfile import read
MAX_WAV_VALUE = 32768.0
def load_wav(full_path):
sampling_rate, data = read(full_path)
return data, sampling_rate
def dynamic_range_compression(x, C=1, clip_val=1e-5):
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
def dynamic_range_decompression(x, C=1):
return np.exp(x) / C
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
return torch.log(torch.clamp(x, min=clip_val) * C)
def dynamic_range_decompression_torch(x, C=1):
return torch.exp(x) / C
def spectral_normalize_torch(magnitudes):
output = dynamic_range_compression_torch(magnitudes)
return output
def spectral_de_normalize_torch(magnitudes):
output = dynamic_range_decompression_torch(magnitudes)
return output
mel_basis = {}
hann_window = {}
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
if torch.min(y) < -1.0:
print("min value is ", torch.min(y))
if torch.max(y) > 1.0:
print("max value is ", torch.max(y))
global mel_basis, hann_window # pylint: disable=global-statement
if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
y = torch.nn.functional.pad(
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
)
y = y.squeeze(1)
spec = torch.view_as_real(
torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window[str(y.device)],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
)
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
spec = spectral_normalize_torch(spec)
return spec

480
matcha_transformer.py Normal file
View File

@ -0,0 +1,480 @@
# Modified from Matcha-TTS https://github.com/shivammehta25/Matcha-TTS
"""
MIT License
Copyright (c) 2023 Shivam Mehta
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
from typing import Any, Dict, Optional
import torch
import torch.nn as nn
from diffusers.models.attention import (
GEGLU,
GELU,
AdaLayerNorm,
AdaLayerNormZero,
ApproximateGELU,
)
from diffusers.models.attention_processor import Attention
from diffusers.models.lora import LoRACompatibleLinear
from diffusers.utils.torch_utils import maybe_allow_in_graph
import torch.nn.functional as F
from flash_attn import flash_attn_varlen_func
def get_sequence_mask(inputs, inputs_length):
if inputs.dim() == 3:
bsz, tgt_len, _ = inputs.size()
else:
bsz, tgt_len = inputs_length.shape[0], torch.max(inputs_length)
sequence_mask = torch.arange(0, tgt_len).to(inputs.device)
sequence_mask = torch.lt(sequence_mask, inputs_length.reshape(bsz, 1)).view(
bsz, tgt_len, 1
)
unpacking_index = (
torch.cumsum(sequence_mask.to(torch.int64).view(-1), dim=0) - 1
) # 转成下标
return sequence_mask, unpacking_index
class OmniWhisperAttention(nn.Module):
def __init__(self, embed_dim, num_heads, causal=False):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
self.causal = causal
def forward(self, hidden_states: torch.Tensor, seq_len: torch.Tensor):
bsz, _ = hidden_states.size()
query_states = self.q_proj(hidden_states).view(
bsz, self.num_heads, self.head_dim
)
key_states = self.k_proj(hidden_states).view(bsz, self.num_heads, self.head_dim)
value_states = self.v_proj(hidden_states).view(
bsz, self.num_heads, self.head_dim
)
cu_len = F.pad(torch.cumsum(seq_len, dim=0), (1, 0), "constant", 0).to(
torch.int32
)
max_seqlen = torch.max(seq_len).to(torch.int32).detach()
attn_output = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_len,
cu_len,
max_seqlen,
max_seqlen,
causal=self.causal,
) # (bsz * qlen, nheads, headdim)
attn_output = attn_output.reshape(bsz, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output
class SnakeBeta(nn.Module):
"""
A modified Snake function which uses separate parameters for the magnitude of the periodic components
Shape:
- Input: (B, C, T)
- Output: (B, C, T), same shape as the input
Parameters:
- alpha - trainable parameter that controls frequency
- beta - trainable parameter that controls magnitude
References:
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
https://arxiv.org/abs/2006.08195
Examples:
>>> a1 = snakebeta(256)
>>> x = torch.randn(256)
>>> x = a1(x)
"""
def __init__(
self,
in_features,
out_features,
alpha=1.0,
alpha_trainable=True,
alpha_logscale=True,
):
"""
Initialization.
INPUT:
- in_features: shape of the input
- alpha - trainable parameter that controls frequency
- beta - trainable parameter that controls magnitude
alpha is initialized to 1 by default, higher values = higher-frequency.
beta is initialized to 1 by default, higher values = higher-magnitude.
alpha will be trained along with the rest of your model.
"""
super().__init__()
self.in_features = (
out_features if isinstance(out_features, list) else [out_features]
)
self.proj = LoRACompatibleLinear(in_features, out_features)
# initialize alpha
self.alpha_logscale = alpha_logscale
if self.alpha_logscale: # log scale alphas initialized to zeros
self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha)
self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha)
else: # linear scale alphas initialized to ones
self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha)
self.beta = nn.Parameter(torch.ones(self.in_features) * alpha)
self.alpha.requires_grad = alpha_trainable
self.beta.requires_grad = alpha_trainable
self.no_div_by_zero = 0.000000001
def forward(self, x):
"""
Forward pass of the function.
Applies the function to the input elementwise.
SnakeBeta = x + 1/b * sin^2 (xa)
"""
x = self.proj(x)
if self.alpha_logscale:
alpha = torch.exp(self.alpha)
beta = torch.exp(self.beta)
else:
alpha = self.alpha
beta = self.beta
x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(
torch.sin(x * alpha), 2
)
return x
class FeedForward(nn.Module):
r"""
A feed-forward layer.
Parameters:
dim (`int`): The number of channels in the input.
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
"""
def __init__(
self,
dim: int,
dim_out: Optional[int] = None,
mult: int = 4,
dropout: float = 0.0,
activation_fn: str = "geglu",
final_dropout: bool = False,
):
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
if activation_fn == "gelu":
act_fn = GELU(dim, inner_dim)
if activation_fn == "gelu-approximate":
act_fn = GELU(dim, inner_dim, approximate="tanh")
elif activation_fn == "geglu":
act_fn = GEGLU(dim, inner_dim)
elif activation_fn == "geglu-approximate":
act_fn = ApproximateGELU(dim, inner_dim)
elif activation_fn == "snakebeta":
act_fn = SnakeBeta(dim, inner_dim)
self.net = nn.ModuleList([])
# project in
self.net.append(act_fn)
# project dropout
self.net.append(nn.Dropout(dropout))
# project out
self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
if final_dropout:
self.net.append(nn.Dropout(dropout))
def forward(self, hidden_states):
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
@maybe_allow_in_graph
class BasicTransformerBlock(nn.Module):
r"""
A basic Transformer block.
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
only_cross_attention (`bool`, *optional*):
Whether to use only cross-attention layers. In this case two cross attention layers are used.
double_self_attention (`bool`, *optional*):
Whether to use two self-attention layers. In this case no cross attention layers are used.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm (:
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
attention_bias (:
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_elementwise_affine: bool = True,
norm_type: str = "layer_norm",
final_dropout: bool = False,
use_omni_attn: bool = False,
):
super().__init__()
self.use_omni_attn = use_omni_attn
self.dim = dim
self.only_cross_attention = only_cross_attention
self.use_ada_layer_norm_zero = (
num_embeds_ada_norm is not None
) and norm_type == "ada_norm_zero"
self.use_ada_layer_norm = (
num_embeds_ada_norm is not None
) and norm_type == "ada_norm"
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
raise ValueError(
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
)
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
if self.use_ada_layer_norm:
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
elif self.use_ada_layer_norm_zero:
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
else:
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
if self.use_omni_attn:
if only_cross_attention:
raise NotImplementedError
print(
"Use OmniWhisperAttention with flash attention. Dropout is ignored."
)
self.attn1 = OmniWhisperAttention(
embed_dim=dim, num_heads=num_attention_heads, causal=False
)
else:
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=(
cross_attention_dim if only_cross_attention else None
),
upcast_attention=upcast_attention,
)
# 2. Cross-Attn
if cross_attention_dim is not None or double_self_attention:
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
# the second cross attention block.
self.norm2 = (
AdaLayerNorm(dim, num_embeds_ada_norm)
if self.use_ada_layer_norm
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=(
cross_attention_dim if not double_self_attention else None
),
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
# scale_qk=False, # uncomment this to not to use flash attention
) # is self-attn if encoder_hidden_states is none
else:
self.norm2 = None
self.attn2 = None
# 3. Feed-forward
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
)
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
# Sets chunk feed-forward
self._chunk_size = chunk_size
self._chunk_dim = dim
def forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
timestep: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
class_labels: Optional[torch.LongTensor] = None,
):
bsz, tgt_len, d_model = hidden_states.shape
# Notice that normalization is always applied before the real computation in the following blocks.
# 1. Self-Attention
if self.use_ada_layer_norm:
norm_hidden_states = self.norm1(hidden_states, timestep)
elif self.use_ada_layer_norm_zero:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
)
else:
norm_hidden_states = self.norm1(hidden_states)
cross_attention_kwargs = (
cross_attention_kwargs if cross_attention_kwargs is not None else {}
)
if self.use_omni_attn:
seq_len = attention_mask[:, 0, :].float().long().sum(dim=1)
var_len_attention_mask, unpacking_index = get_sequence_mask(
norm_hidden_states, seq_len
)
norm_hidden_states = torch.masked_select(
norm_hidden_states, var_len_attention_mask
)
norm_hidden_states = norm_hidden_states.view(torch.sum(seq_len), self.dim)
attn_output = self.attn1(norm_hidden_states, seq_len)
# unpacking
attn_output = torch.index_select(attn_output, 0, unpacking_index).view(
bsz, tgt_len, d_model
)
attn_output = torch.where(var_len_attention_mask, attn_output, 0)
else:
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=(
encoder_hidden_states if self.only_cross_attention else None
),
attention_mask=(
encoder_attention_mask
if self.only_cross_attention
else attention_mask
),
**cross_attention_kwargs,
)
if self.use_ada_layer_norm_zero:
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = attn_output + hidden_states
# 2. Cross-Attention
if self.attn2 is not None:
norm_hidden_states = (
self.norm2(hidden_states, timestep)
if self.use_ada_layer_norm
else self.norm2(hidden_states)
)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
# 3. Feed-forward
norm_hidden_states = self.norm3(hidden_states)
if self.use_ada_layer_norm_zero:
norm_hidden_states = (
norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
)
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
raise ValueError(
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
)
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
ff_output = torch.cat(
[
self.ff(hid_slice)
for hid_slice in norm_hidden_states.chunk(
num_chunks, dim=self._chunk_dim
)
],
dim=self._chunk_dim,
)
else:
ff_output = self.ff(norm_hidden_states)
if self.use_ada_layer_norm_zero:
ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = ff_output + hidden_states
return hidden_states

2420
model.safetensors.index.json Normal file

File diff suppressed because it is too large Load Diff

1011
modeling_omni.py Normal file

File diff suppressed because it is too large Load Diff

865
processor_omni.py Normal file
View File

@ -0,0 +1,865 @@
import requests
import re, ujson, os, sys, fire, glob, random, time, json
import numpy as np
import io
import torch
from torch.utils.data import default_collate
import torchaudio
from typing import *
from dataclasses import dataclass, field
import transformers
from transformers.modeling_outputs import ModelOutput
from transformers.audio_utils import mel_filter_bank, spectrogram, window_function
from functools import lru_cache
from io import BytesIO
from PIL import Image
import concurrent.futures as cf
from transformers.image_transforms import resize, center_crop, get_resize_output_image_size
from transformers.image_utils import PILImageResampling
from PIL import Image, ImageOps
from PIL import ImageFile
torch.set_num_threads(1) # 限制torch的线程数 否则可能会卡住
ImageFile.LOAD_TRUNCATED_IMAGES = True
import base64
from decord import VideoReader, cpu
import cv2
import av
import imagesize
import tempfile
import math
from multiprocessing import Pool
from cairosvg import svg2png
import hashlib
IMAGE_FACTOR = 28
MIN_PIXELS = 4 * 28 * 28
MAX_PIXELS = 16384 * 28 * 28
MAX_RATIO = 200
VIDEO_MIN_PIXELS = 128 * 28 * 28
VIDEO_MAX_PIXELS = 768 * 28 * 28
VIDEO_TOTAL_PIXELS = 24576 * 28 * 28
FRAME_FACTOR = 2
FPS = 2.0
FPS_MIN_FRAMES = 4
FPS_MAX_FRAMES = 768
def round_by_factor(number: int, factor: int) -> int:
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
return round(number / factor) * factor
def ceil_by_factor(number: int, factor: int) -> int:
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
return math.ceil(number / factor) * factor
def floor_by_factor(number: int, factor: int) -> int:
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return math.floor(number / factor) * factor
def smart_resize(
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
) -> tuple[int, int]:
"""
Rescales the image so that the following conditions are met:
1. Both dimensions (height and width) are divisible by 'factor'.
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.
"""
if max(height, width) / min(height, width) > MAX_RATIO:
raise ValueError(
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
)
h_bar = max(factor, round_by_factor(height, factor))
w_bar = max(factor, round_by_factor(width, factor))
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = floor_by_factor(height / beta, factor)
w_bar = floor_by_factor(width / beta, factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = ceil_by_factor(height * beta, factor)
w_bar = ceil_by_factor(width * beta, factor)
return h_bar, w_bar
def split_text(text, match_regex):
matches = list(re.finditer(match_regex, text))
# 初始化结果列表
result = []
match_flag_list = []
# 上一个匹配的结束位置
last_end = 0
# 遍历所有匹配项
for match in matches:
# 添加匹配项之前的部分
if text[last_end:match.start()]:
result.append(text[last_end:match.start()])
match_flag_list.append(False)
# 添加匹配项
result.append(match.group(0))
match_flag_list.append(True)
# 更新上一个匹配的结束位置
last_end = match.end()
# 添加最后一个匹配项之后的部分
if text[last_end:]:
result.append(text[last_end:])
match_flag_list.append(False)
return result, match_flag_list
def read_video(image_path, max_frame_number, decode_way):
if decode_way=='1fps':
try:
# print(image_path)
vr = VideoReader(image_path, ctx=cpu(0))
total_frame_num = len(vr)
fps = round(vr.get_avg_fps())
frame_idx = [i for i in range(0, len(vr), fps)]
frames = vr.get_batch(frame_idx).asnumpy()
cnt = len(frames)
frame_times = range(cnt)
except Exception as e:
print(image_path)
print('error is', e)
return None
elif decode_way=='key':
try:
with av.open(image_path) as container:
stream = container.streams.video[0]
stream.codec_context.skip_frame = 'NONKEY'
frames = []
frame_times = []
fps = int(stream.average_rate)
cnt = 0
for frame in container.decode(stream): # 关键帧存成image patch
image = np.array(frame.to_image())
frames.append(image)
frame_time = int(frame.time)
frame_times.append(frame_time)
cnt += 1
except Exception as e:
print('error is', e)
return None
if frames is None or len(frames)==0:
return None
if len(frames)>max_frame_number and max_frame_number>0:
# 生成14个均匀间隔的索引
indices = np.linspace(0, len(frames) - 1, max_frame_number, dtype=int)
# 根据索引获取对应元素
frames = frames[indices]
frame_times = frame_times[indices]
return frames, frame_times
class OmniImageProcessor:
def __init__(self, config, **kwargs):
self.config = config # visual_config
self.min_pixels = self.config.min_pixels if hasattr(self.config, 'min_pixels') else 56 * 56
self.max_pixels = self.config.max_pixels if hasattr(self.config, 'max_pixels') else 28 * 28 * 1280
self.patch_size = self.config.patch_size if hasattr(self.config, 'patch_size') else 14
self.temporal_patch_size = self.config.temporal_patch_size if hasattr(self.config, 'temporal_patch_size') else 2
self.merge_size = self.config.merge_size if hasattr(self.config, 'merge_size') else 2
self.spatial_merge_size = self.config.spatial_merge_size if hasattr(self.config, 'spatial_merge_size') else 2
def image_transform(self, strseq, return_mm_data = True):
image = None
if isinstance(strseq, str):
if return_mm_data:
image = Image.open(strseq).convert("RGB")
else:
try:
image = Image.open(BytesIO(strseq)).convert("RGB")
except:
image = Image.open(BytesIO(svg2png(bytestring=strseq))).convert("RGB") # interleaved有的是矢量图需要转换
image = np.array(image.convert("RGB")) # 这一步首先将图像转换为 RGB 格式确保图像有三个通道R、G、B。然后使用 np.array() 将其转换为 NumPy 数组,方便后续处理。
image_org_size = image.shape[:2] # 这里保存了图像的原始大小高度和宽度image.shape 返回图像的形状 (高度, 宽度, 通道数),而 image.shape[:2] 提取了前两个值,即原始的高度和宽度。这个信息可以用于后续的对比或其他处理。
# resize, crop, scale, normalize
# 输出一个新的尺寸,这个尺寸通常是 (宽度, 高度) 格式,用于后续的图像调整操作,如缩放或裁剪。
resized_height, resized_width = smart_resize(
image_org_size[0], image_org_size[1],
factor=self.patch_size * self.spatial_merge_size,
min_pixels=self.min_pixels,
max_pixels=self.max_pixels,
)
output_size = (resized_height, resized_width)
# 使用 resize 函数将图像调整到 output_size 大小。PILImageResampling.BICUBIC 指定使用双三次插值法来进行图像缩放,这种方法通常能够提供较好的图像质量。
# image: 输入的图像数据,可以是 NumPy 数组或 PIL 图像对象output_size: 目标大小,通常是一个二元组 (宽度, 高度)。这个尺寸可以是图像的绝对大小,也可以是相对于原始图像的比例;
# resample: 可选的重采样方法通常用于确定如何插值像素。例如PILImageResampling.BICUBIC 表示使用双三次插值法,这是一种平滑的插值方法,常用于图像缩放。
image = resize(image, output_size, PILImageResampling.BICUBIC)
img = image.transpose(2, 0, 1)
# 对图像进行归一化和标准化处理
image = (img / 255.0 - np.array(self.config.image_mean)[:, np.newaxis, np.newaxis]) / np.array(self.config.image_std)[:,np.newaxis,np.newaxis]
# 处理成patch
patches = image[np.newaxis, :]
if patches.shape[0] == 1:
patches = np.tile(patches, (self.temporal_patch_size, 1, 1, 1))
channel = patches.shape[1]
grid_t = patches.shape[0] // self.temporal_patch_size
grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
patches = patches.reshape(
grid_t,
self.temporal_patch_size,
channel,
grid_h // self.spatial_merge_size,
self.spatial_merge_size,
self.patch_size,
grid_w // self.spatial_merge_size,
self.spatial_merge_size,
self.patch_size,
)
patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8)
flatten_patches = patches.reshape(
grid_t * grid_h * grid_w, channel * self.temporal_patch_size * self.patch_size * self.patch_size
)
return flatten_patches, image_org_size, (grid_t, grid_h, grid_w)
class OmniAudioProcessor:
# 包含基本的音频特征抽取模块 + 输入数据解析模块
def __init__(
self,
config, # audio processor config
**kwargs
):
# make sure you have install 'conda install -c conda-forge 'ffmpeg<7'' for torchaudio
assert(len(torchaudio.list_audio_backends()) > 0)
self.config = config
self.mel_filters = mel_filter_bank(
num_frequency_bins=1 + self.config.n_fft // 2,
num_mel_filters=self.config.num_mel_bins,
min_frequency=0.0,
max_frequency=self.config.sampling_rate / 2.0,
sampling_rate=self.config.sampling_rate,
norm="slaney",
mel_scale="slaney",
)
self.window = torch.hann_window(self.config.n_fft)
@staticmethod
def dynamic_range_compression(x, C=1, clip_val=1e-6):
return torch.log(torch.clamp(x, min=clip_val) * C)
@staticmethod
def zero_mean_unit_var_norm(x):
return (x - x.mean()) / torch.sqrt(x.var() + 1e-8)
def load_audio_waveform(self, uri, return_tensors=True, do_normalize=False):
metadata = torchaudio.info(uri) # sample_rate, num_frames, num_channels, bits_per_sample, encoding=PCM_S
assert(metadata.num_channels <= 2), "acoustic file with {} channels.".format(metadata.num_channels) # whisper only accept mono channel audio
waveform_tensor, _ = torchaudio.load(uri, normalize=True)
if self.config.sampling_rate != metadata.sample_rate:
waveform_tensor = torchaudio.functional.resample(waveform_tensor, metadata.sample_rate, self.config.sampling_rate, lowpass_filter_width=128)
# downmix to mono channel https://trac.ffmpeg.org/wiki/AudioChannelManipulation
if metadata.num_channels > 1:
waveform_tensor = torch.mean(waveform_tensor, dim=0, keepdim=True)
# normalized to zero mean
if do_normalize:
waveform_tensor = self.zero_mean_unit_var_norm(waveform_tensor)
if return_tensors: # (channels, samples)
return waveform_tensor
else:
return waveform_tensor.numpy()
def split_with_overlap(self, waveform): # 如果长度超过最大长度限制 分割为带overlap的多段
channels, wave_samples = waveform.shape
max_audio_samples = self.config.max_audio_seconds * self.config.sampling_rate
if wave_samples <= max_audio_samples or self.config.split_overlap < 0:
return [waveform] # 没有超出最大长度or截断逻辑 统一返回list
split_waveform, start = [], 0
while start < wave_samples: # 统一按秒数对齐overlap
if start > int(self.config.sampling_rate * self.config.split_overlap):
start -= int(self.config.sampling_rate * self.config.split_overlap) # 0表示没有overlap>0 overlap对应秒数
end = min(start + max_audio_samples, wave_samples)
if end - start>= self.config.n_fft: # 保证至少有一帧数据
split_waveform.append(waveform[:, start:end]) # 注意这里可能会切割出特别短的片段 需要在预处理判断并丢弃
start = end
return split_waveform
@classmethod
def inference_output_length(cls, config, input_length):
# for whisper + bridge
kernel_size = config.kernel_size
stride_size = config.stride_size
avg_pooler = config.avg_pooler
encoder_length = (input_length + 2 * (kernel_size // 2) - kernel_size) // 1 + 1 # conv layer1 with pad=1
encoder_length = (encoder_length + 2 * (kernel_size // 2) - kernel_size) // stride_size + 1 # conv layer2 with pad=1
if avg_pooler > 1:
bridge_length = encoder_length // avg_pooler
return encoder_length, bridge_length
def extract_fbank_features(self, waveform):
# ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/feature_extraction_whisper.py
channels, wave_samples = waveform.shape
assert(wave_samples >= self.config.n_fft)
valid_frame_nums = min(self.config.max_audio_seconds * self.config.sampling_rate // self.config.hop_length, wave_samples // self.config.hop_length + 1)
if wave_samples < self.config.max_audio_seconds * self.config.sampling_rate:
waveform = torch.nn.functional.pad(waveform, (0, self.config.max_audio_seconds * self.config.sampling_rate - wave_samples), "constant", 0)
else:
waveform = waveform[:, :self.config.max_audio_seconds * self.config.sampling_rate]
# window = torch.hann_window(self.config.n_fft)
stft = torch.stft(waveform, self.config.n_fft, self.config.hop_length, window=self.window, return_complex=True) # fft, len(wave) // n_fft // 2 + 1
magnitudes = stft[..., :-1].abs() ** 2
mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32)
mel_spec = mel_filters.T @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
if waveform.dim() == 2:
max_val = log_spec.max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0]
log_spec = torch.maximum(log_spec, max_val - 8.0)
else:
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
log_spec = log_spec[0].numpy() # (channel, filters, samples) -> (filters, samples)
log_spec[:, valid_frame_nums:] = 0.0 # pad0
return log_spec, valid_frame_nums
def data_augment(self, feature: np.array, input_length, training=True):
# reference https://arxiv.org/pdf/1904.08779
def mask_start_indices(input_length, mask_length, min_masks, mask_prob):
num_masked_span = int(mask_prob * input_length / mask_length + random.random())
num_masked_span = max(num_masked_span, min_masks)
start_indices = list(range(input_length - mask_length))
random.shuffle(start_indices)
start_indices = start_indices[:num_masked_span]
return start_indices
if not training or (self.config.mask_time_prob <= 0 and self.config.mask_feature_prob <= 0):
return feature
if input_length < self.config.mask_time_length * self.config.mask_time_min_masks + 1:
return feature
if self.config.num_mel_bins < self.config.mask_feature_length * self.config.mask_feature_min_masks + 1:
return feature
if self.config.mask_time_prob > 0:
start_indices = mask_start_indices(input_length, self.config.mask_time_length, self.config.mask_time_min_masks, self.config.mask_time_prob)
for start_idx in start_indices:
feature[:, start_idx: start_idx + self.config.mask_time_length] = 0.0
if self.config.mask_feature_prob > 0:
start_indices = mask_start_indices(self.config.num_mel_bins, self.config.mask_feature_length, self.config.mask_feature_min_masks, self.config.mask_feature_prob)
for start_idx in start_indices:
feature[start_idx: start_idx + self.config.mask_feature_length, :] = 0.0
return feature
@dataclass
class OmniProcessorOutput(ModelOutput):
input_ids: Optional["List|torch.Tensor"] = None
labels: Optional["List|torch.Tensor"] = None
attention_mask: Optional["List|torch.Tensor"] = None
position_ids: Optional["List|torch.Tensor"] = None
seqlens: Optional["List|torch.Tensor"] = None # 需要配合Omni Modeling使用
# audio fields
audios: Optional["List|torch.Tensor"] = None
encoder_length: Optional["List|torch.Tensor"] = None
bridge_length: Optional["List|torch.Tensor"] = None
# image fields
images: Optional["List|torch.Tensor"] = None
patch_nums: Optional["List|torch.Tensor"] = None
images_size: Optional["List|torch.Tensor"] = None
crop_size: Optional["List|torch.Tensor"] = None
images_grid: Optional["List|torch.Tensor"] = None
# video fields
videos: Optional["List|torch.Tensor"] = None
videos_patch_nums: Optional["List|torch.Tensor"] = None
videos_size: Optional["List|torch.Tensor"] = None
videos_crop_size: Optional["List|torch.Tensor"] = None
videos_grid: Optional["List|torch.Tensor"] = None
# processor fields
raw_text: Optional[str] = None
index: Optional[int] = None
def concatenate(self, other): # 仅限list使用
def concat_one(a, b):
if a is None and b is None:
return None
elif a is None and b is not None:
return b
elif a is not None and b is None:
return a
else:
return a + b
return OmniProcessorOutput(
input_ids=concat_one(self.input_ids, other.input_ids),
labels=concat_one(self.labels, other.labels),
audios=concat_one(self.audios, other.audios),
encoder_length=concat_one(self.encoder_length, other.encoder_length),
bridge_length=concat_one(self.bridge_length, other.bridge_length),
images=concat_one(self.images, other.images),
images_grid=concat_one(self.images_grid, other.images_grid),
patch_nums=concat_one(self.patch_nums, other.patch_nums),
videos=concat_one(self.videos, other.videos),
videos_grid=concat_one(self.videos_grid, other.videos_grid),
videos_patch_nums=concat_one(self.videos_patch_nums, other.videos_patch_nums),
position_ids=concat_one(self.position_ids, other.position_ids),
seqlens=concat_one(self.seqlens, other.seqlens),
images_size=concat_one(self.images_size, other.images_size),
videos_size=concat_one(self.videos_size, other.videos_size),
index = self.index # concat保持index不变
)
class OmniMMProcessor(object):
def __init__(self,
tokenizer: transformers.PreTrainedTokenizer,
config,
training,
relative_path=None,
parallel=None,
**kwargs,
):
self.tokenizer = tokenizer
self.config = config
self.audio_processor = OmniAudioProcessor(config.audio_config)
self.visual_processor = None
if hasattr(config, "visual_config"):
self.visual_processor = OmniImageProcessor(config.visual_config)
self.video_processor = None
if hasattr(config, "video_config"):
self.video_processor = OmniImageProcessor(config.video_config)
self.training = training
self.relative_path = relative_path
self.parallel = parallel
# audio tag
self.audio_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_start_token_id)
self.audio_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_end_token_id)
self.audio_pad_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_pad_token_id)
self.audio_delim_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_delim_token_id)
self.audiogen_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audiogen_start_token_id)
self.audiogen_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audiogen_end_token_id)
# image tag
self.image_start_tag = None
self.image_end_tag = None
self.image_pad_tag = None
self.video_start_tag = None
self.video_end_tag = None
# videoframe tag只是为了兼容图片帧作为输入的情况没有token id在抽取视频帧的时候会将这个替换成image tag的start、end
self.videoframe_start_tag = '<videoframe_start_omni>'
self.videoframe_end_tag = '<videoframe_end_omni>'
if hasattr(self.config, "visual_config"):
# special token for start_tag
self.image_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_start_token_id)
# special token for end_tag
self.image_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_end_token_id)
# special token for pad_tag
self.image_pad_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_pad_token_id)
self.image_line_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_line_token_id)
self.image_delimiter_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_delimiter_token_id)
if hasattr(self.config, "video_config"):
self.video_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.video_start_token_id)
self.video_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.video_end_token_id)
self.image_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.image_start_token_id)
self.image_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.image_end_token_id)
self.image_pad_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.image_pad_token_id)
self.video_place_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.video_place_token_id)
self.frame_pattern = getattr(self.config.video_config, 'frame_pattern', '<frame>')
# @lru_cache(maxsize=1024)
def _get_audio(self, audio_info):
try:
audio_info = ujson.loads(audio_info)
if 'path' in audio_info.keys():
audio_uri = None
if os.path.exists(audio_info['path']):
audio_uri = audio_info['path']
elif self.relative_path is not None:
audio_uri = os.path.join(self.relative_path, audio_info['path'].lstrip('/'))
if not os.path.exists(audio_uri):
audio_uri = None
if audio_uri is not None:
waveform = self.audio_processor.load_audio_waveform(audio_uri, True)
waveforms = self.audio_processor.split_with_overlap(waveform)
ret = OmniProcessorOutput() # 默认初始化 audios字段为None
for i, waveform in enumerate(waveforms): #(zip(waveforms,vocoder_waveforms)):
audio, input_length = self.audio_processor.extract_fbank_features(waveform)
audio = self.audio_processor.data_augment(audio, input_length, self.training)
encoder_length, bridge_length = self.audio_processor.inference_output_length(self.config.audio_config, input_length)
if bridge_length <= 0:
continue
current_ret = OmniProcessorOutput(
audios=[audio[:,:input_length]],
encoder_length=[encoder_length],
bridge_length=[bridge_length],
)
if ret.audios is None:
ret = current_ret
else:
ret = ret.concatenate(current_ret) # 拼接多个切片
return ret
else:
raise ValueError("can not find path in audio_info")
except Exception as e:
print("**** get audio error: {}, info: {} *****".format(str(e), str(audio_info)))
return OmniProcessorOutput()
# @lru_cache(maxsize=1024)
def _get_image(self, image_info):
try:
try:
image_info = ujson.loads(image_info)
except:
image_info = re.sub(r"(?<!\\)'", '"', image_info)
image_info = ujson.loads(image_info)
if 'base64' in image_info.keys():
image_data = base64.b64decode(image_info['base64'])
image_feat, org_size, image_list = self.visual_processor.image_transform(image_data)
elif 'local' in image_info.keys():
image_feat, org_size, image_list = self.visual_processor.image_transform(image_info['local'])
elif 'path' in image_info.keys() and os.path.exists(image_info['path']):
image_feat, org_size, image_list = self.visual_processor.image_transform(image_info['path'])
elif 'url' in image_info.keys():
image_bytes = self._get_vision_obj_byte('url', image_info['url'])
image_feat, org_size, image_list = self.visual_processor.image_transform(image_bytes)
else:
raise ValueError("can not find any path in image_info")
merge_length = self.visual_processor.merge_size**2
patch_nums = np.array(image_list).prod() // merge_length
if org_size[0] * org_size[1] > 16**2: # 极端小的图过滤
return OmniProcessorOutput(
images=[image_feat],
patch_nums=[patch_nums],
crop_size=[image_list],
images_size= [org_size],
images_grid=[image_list]
)
else:
print("**** image too small: {}, info: {} *****".format(str(org_size), str(image_info)))
return OmniProcessorOutput()
except Exception as e:
print("**** get image error: {}, info: {} *****".format(str(e), str(image_info)))
return OmniProcessorOutput()
# @lru_cache(maxsize=1024)
def _get_video_frame(self, video_frame_infos):
try:
pattern = r'\{.*?\}'
matches = re.findall(pattern, video_frame_infos)
ret = OmniProcessorOutput()
# 逐个解析
for match in matches:
video_frame_info = ujson.loads(match)
# video_frame_info = ujson.loads(video_frame_info)
if 'local' in video_frame_info.keys():
image_feat, org_size, image_list = self.video_processor.image_transform(video_frame_info['local'])
elif 'path' in video_frame_info.keys() and os.path.exists(video_frame_info['path']):
image_feat, org_size, image_list = self.video_processor.image_transform(video_frame_info['path'])
else:
raise ValueError("can not find any path in video_info")
merge_length = self.video_processor.merge_size**2
patch_nums = np.array(image_list).prod() // merge_length
if org_size[0] * org_size[1] > 16**2: # 极端小的图过滤
ret = ret.concatenate(
OmniProcessorOutput(
videos=[image_feat],
videos_patch_nums=[patch_nums],
videos_crop_size=[image_list],
videos_size= [org_size],
videos_grid=[image_list]
)
)
else:
print("**** video too small: {}, info: {} *****".format(str(org_size), str(video_frame_info)))
return ret
except Exception as e:
print("**** get video error: {}, info: {} *****".format(str(e), str(video_frame_info)))
return OmniProcessorOutput()
# 读取视频
def _get_vision_obj_byte(self, source, path):
vision_obj_byte = None
if source == "local":
if os.path.exists(path):
vision_obj_byte = open(path, "rb").read()
else:
vision_obj_byte = None
if source == "base64":
vision_obj_byte = base64.b64decode(path)
if source == "url":
vision_obj_byte = requests.get(url=path).content
return vision_obj_byte
# 将视频切分为帧,保存至子目录中
def _split_video_to_frames(self, video_info, max_frame_number=-1, decode_way="1fps"):
if decode_way=='1fps':
frame_suffix = f'_frames'
elif decode_way=='key':
frame_suffix = f'_keyframes'
else:
raise ValueError('unvalid decode way!!!')
server = "local"
if 'local' in video_info.keys():
# 本地路径
video_path = video_info['local']
# 帧保存本地路径
frame_path = video_path.split('.')[0] + frame_suffix
mm_obj_byte = self._get_vision_obj_byte('local', video_path)
elif 'base64' in video_info.keys():
md5 = hashlib.md5(video_info['base64'].encode('utf-8')).hexdigest()
if self.relative_path is not None:
video_path = os.path.join(self.relative_path, md5)
else:
video_path = os.path.join(os.getcwd(), md5)
frame_path = md5 + frame_suffix
mm_obj_byte = self._get_vision_obj_byte('base64', video_info['base64'])
elif 'url' in video_info.keys():
md5 = hashlib.md5(video_info['url'].encode('utf-8')).hexdigest()
if self.relative_path is not None:
video_path = os.path.join(self.relative_path, md5)
else:
video_path = os.path.join(os.getcwd(), md5)
frame_path = md5 + frame_suffix
mm_obj_byte = self._get_vision_obj_byte('url', video_info['url'])
else:
raise ValueError('unvalid video server !!!')
return ""
if mm_obj_byte is None: # 未读取到视频文件
return ""
if not os.path.exists(frame_path) or len(os.listdir(frame_path))==0:
# 保存帧
os.makedirs(frame_path, exist_ok=True)
frames, frame_times = read_video(io.BytesIO(mm_obj_byte), max_frame_number=-1, decode_way=decode_way) #读取全部帧
for frame_idx, frame in enumerate(frames):
output_filename = os.path.join(frame_path, f"{frame_times[frame_idx]}.jpg")
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
cv2.imwrite(output_filename, frame)
frame_paths = os.listdir(frame_path)
# 选取帧
frame_times = [int(filename.split('/')[-1].replace('.jpg', '')) for filename in frame_paths if filename.endswith('.jpg')] # 文件名对应秒数
frame_times.sort() #从小到大排序
frame_number = len(frame_times)
if frame_number > max_frame_number:
indices = np.linspace(0, frame_number - 1, max_frame_number, dtype=int)
else:
indices = np.linspace(0, frame_number - 1, frame_number, dtype=int)
# 拼接模式
replace_str = ""
for frame_idx, idx in enumerate(indices):
frame_time = frame_times[idx] # frame_time表示帧对应的时间 单位为s 同时也是存储的文件名
frame_dict = {"local": os.path.join(frame_path, f'{frame_time}.jpg')}
frame_str = self.frame_pattern.format(frame_idx) if '{}' in self.frame_pattern else self.frame_pattern # {}对应的是第几张图片
frame_str = frame_str.replace('<TIMEIDX>', str(frame_time)) # TIMEIDX对应的是第几秒
frame_str = frame_str.replace('<TIMESTAMP>', time.strftime("%H:%M:%S", time.gmtime(frame_time))) # TIMESTAMP对应的是时间戳
frame_str = frame_str.replace('<frame>', f'{self.image_start_tag}{json.dumps(frame_dict)}{self.image_end_tag}')
replace_str += frame_str
return replace_str
def sample_frame(self,frames_str,max_frame = 32):
def uniform_sample(lst, num_samples):
if num_samples > len(lst):
return lst
interval = len(lst) / num_samples
samples = [lst[int(i * interval)] for i in range(num_samples)]
return samples
p = rf'({self.image_start_tag}.*?{self.image_end_tag})'
frames_str_split = re.split(p,frames_str)
frame_idxs = [idx for idx in range(len(frames_str_split)) if self.image_start_tag in frames_str_split[idx]]
sample_frame_idxs = set(uniform_sample(frame_idxs, max_frame))
return ''.join([item for idx,item in enumerate(frames_str_split) if idx in sample_frame_idxs or self.image_start_tag not in frames_str_split[idx]])
def _get_video_frame_str(self, video_info):
try:
if self.videoframe_start_tag in video_info:#如果是以视频帧的形式表示一个视频则替换成image tag
frames_str = video_info
frames_str = frames_str.replace(self.videoframe_start_tag,self.image_start_tag).replace(self.videoframe_end_tag,self.image_end_tag)
return self.sample_frame(frames_str, max_frame = self.config.video_config.max_frame_num)
video_info = ujson.loads(video_info)
# 获取包含多帧图像路径的字符串最大帧数量max_frame_number
frames_str = self._split_video_to_frames(video_info, max_frame_number=self.config.video_config.max_frame_num, decode_way=self.config.video_config.decode_way)
return frames_str
except Exception as e:
print("**** get video error: {}, info: {} *****".format(str(e), str(video_info)))
return ""
def _replace_image(self, image_text):
image_info = re.sub(re.compile(self.image_start_tag + "|" + self.image_end_tag), '', image_text)
ret = self._get_image(image_info) # 重复取结果 cached result
if ret.patch_nums is None:
return ''
return ret, self.image_start_tag + self.image_pad_tag * ret.patch_nums[0] + self.image_end_tag
def _replace_video_frame(self, video_frame_text):
video_frame_info = re.sub(re.compile(self.image_start_tag + "|" + self.image_end_tag), '', video_frame_text)
ret = self._get_video_frame(video_frame_info) # 重复取结果 cached result
if ret.videos_patch_nums is None:
return ''
video_frame_str = [self.image_start_tag + self.video_place_tag * ret.videos_patch_nums[i] + self.image_end_tag for i in range(len(ret.videos_patch_nums))]
return ret, ''.join(video_frame_str)
def split_multimodal_chunk(self, text_list, mm_label_list, trainable_list, mtype='audio'):
# 抽取text中的json格式音频/图像信息读取并转化为特征同时估计encoder token数填入对应数量的pad token
if (self.audio_start_tag != None) and (mtype == 'audio'):
match_regex = re.compile(self.audio_start_tag + '.*?' + self.audio_end_tag,re.S)
drop_regex = re.compile(self.audio_start_tag + "|" + self.audio_end_tag,re.S)
elif (self.image_start_tag != None) and (mtype == 'image'):
match_regex = re.compile(self.image_start_tag + '.*?' + self.image_end_tag,re.S)
drop_regex = re.compile(self.image_start_tag + "|" + self.image_end_tag,re.S)
elif (self.audiogen_start_tag != None) and (mtype == 'audiogen'):
match_regex = re.compile(self.audiogen_start_tag + '.*?' + self.audiogen_end_tag,re.S)
drop_regex = re.compile(self.audiogen_start_tag + "|" + self.audiogen_end_tag,re.S)
elif (self.video_start_tag != None) and (mtype == 'video'):
match_regex = re.compile(self.video_start_tag + '.*?' + self.video_end_tag,re.S)
drop_regex = re.compile(self.video_start_tag + "|" + self.video_end_tag,re.S)
else:
raise ValueError("mtype not supportted!")
new_text_list = []
new_mm_label_list = []
new_trainable_flag_list = []
for text,mm_label,trainable in zip(text_list,mm_label_list,trainable_list):
for t,m in zip(*split_text(text, match_regex)):
new_trainable_flag_list.append(trainable)
if m:
new_text_list.append(re.sub(drop_regex, '', t))
new_mm_label_list.append(mtype)
else:
new_text_list.append(t)
new_mm_label_list.append(mm_label)
return new_text_list, new_mm_label_list, new_trainable_flag_list
def process_multimodal_chunk(self, text, mm_label, trainable):
ret = OmniProcessorOutput()
if mm_label == 'audio':
ret = self._get_audio(text)
if ret.bridge_length is not None:
ret.input_ids = self.tokenizer.encode(self.audio_start_tag,add_special_tokens=False) + self.tokenizer.encode(self.audio_pad_tag,add_special_tokens=False) * sum(ret.bridge_length) + self.tokenizer.encode(self.audio_end_tag,add_special_tokens=False)
else:
raise ValueError(f"Get audio data Failed at Process audio chunk {text}")
elif mm_label == 'audiogen':
ret = self._get_audio(text)
if ret.bridge_length is not None:
ret.input_ids = self.tokenizer.encode(self.audiogen_start_tag,add_special_tokens=False) + self.tokenizer.encode(self.audio_pad_tag,add_special_tokens=False) * sum(ret.bridge_length) + self.tokenizer.encode(self.audiogen_end_tag,add_special_tokens=False)
else:
raise ValueError(f"Get audio data Failed at Process audio chunk {text}")
elif mm_label == 'image':
ret, input_str = self._replace_image(text)
if input_str:
ret.input_ids = self.tokenizer.encode(input_str, add_special_tokens=False)
else:
raise ValueError("Get image data Failed at Process image chunk")
elif mm_label == 'video':
frame_str = self.video_start_tag+self._get_video_frame_str(text)+self.video_end_tag
ret, input_str = self._replace_video_frame(frame_str)
if input_str:
ret.input_ids = self.tokenizer.encode(input_str, add_special_tokens=False)
else:
raise ValueError("Get video data Failed at Process video chunk")
elif mm_label == 'text':
ret.input_ids = self.tokenizer.encode(text, add_special_tokens=False)
if len(ret.input_ids) > self.tokenizer.model_max_length-1: # 过滤长文本
raise ValueError(f"Text too long, please check text length{text[:5]+'...'*6+text[-5:]}")
else:
raise ValueError(f"mm_label not supportted! must in ['audio', 'image', 'text'] but get {mm_label}")
return ret
def process_one(self, text, index=0, raw_only=False):
ret = OmniProcessorOutput(index=index)
all_text_list = []
all_mm_label_list = []
all_trainable_flag_list = []
text_list, match_flag = split_text(text, re.compile("<trainable_start>.*?<trainable_end>",re.S))
if len(text_list) == 1:
text = re.sub(re.compile("<trainable_start>|<trainable_end>",re.S), '', text_list[0])
all_text_list.append(text)
all_mm_label_list.append('text')
all_trainable_flag_list.append(True)
else:
for text, match in zip(text_list, match_flag):
text = re.sub(re.compile("<trainable_start>|<trainable_end>",re.S), '', text)
if text.strip() == '':
continue # 把多余的空格干掉
all_text_list.append(text)
all_mm_label_list.append('text')
all_trainable_flag_list.append(match)
# 处理多模态信息
for mtype in self.config.multimodal: # 循环获取音频 图像结果
all_text_list, all_mm_label_list, all_trainable_flag_list = self.split_multimodal_chunk(all_text_list, all_mm_label_list, all_trainable_flag_list, mtype)
if len(all_text_list) == 0:
print(f"Process {text} chunk error: No valid Text data!!!!!")
return OmniProcessorOutput(index=index)
for text, mm_label, trainable in zip(all_text_list, all_mm_label_list, all_trainable_flag_list):
try:
mret = self.process_multimodal_chunk(text, mm_label, trainable)
ret = ret.concatenate(mret)
except ValueError as e:
tt = text[:24].replace('\n','<LF>')
print(f"Process {tt if mm_label == 'text' else text} {mm_label} chunk error: {str(e)}")
return OmniProcessorOutput(index=index)
if raw_only:
ret.raw_text = self.tokenizer.decode(ret.input_ids, skip_special_tokens=False)
return ret
return ret
@torch.no_grad()
def __call__(self, example, parallel=128):
if isinstance(example, Dict):
pass
elif isinstance(example, str):
return self.process_one(example)
elif isinstance(example, List): # batch推理 异步多线程处理
with cf.ThreadPoolExecutor(min(parallel, len(example))) as executor:
future_list = [executor.submit(self.process_one, di, idx) for idx, di in enumerate(example)]
batch_data = [key.result() for key in cf.as_completed(future_list)]
valid_num = sum([1 if x.input_ids is not None else 0 for x in batch_data])
assert(valid_num == len(batch_data)) # 推理数据严格要求数量对齐
batch_data = sorted(batch_data, key=lambda x: x.index) # 保证顺序不变
ret = OmniProcessorOutput()
for i in range(len(batch_data)):
ret = ret.concatenate(batch_data[i])
self.tokenizer.padding_side = "left"
max_len = min(max([len(x.input_ids) for x in batch_data]),self.tokenizer.model_max_length)
padding_result = self.tokenizer.pad({"input_ids": [r.input_ids for r in batch_data]}, return_tensors='pt')
ret.input_ids = padding_result["input_ids"]
ret.attention_mask = padding_result["attention_mask"] # batch推理不pack 不需要seqlens
if ret.audios is not None:
max_audios_len = max([x.shape[-1] for x in ret.audios])
ret.audios = default_collate([np.pad(x, ((0,0),(0,max_audios_len - x.shape[-1])), 'constant', constant_values=0) for x in ret.audios])
ret.encoder_length = default_collate(ret.encoder_length)
ret.bridge_length = default_collate(ret.bridge_length)
if ret.images is not None:
ret.images = [torch.from_numpy(np.asarray(image, dtype=np.float32)) for image in ret.images]
ret.patch_nums = default_collate(ret.patch_nums)
if ret.videos is not None:
ret.videos = [torch.from_numpy(np.asarray(image, dtype=np.float32)) for image in ret.videos]
ret.videos_patch_nums = default_collate(ret.videos_patch_nums)
return ret
else:
raise ValueError("example format supported yet")

186
sequence_parallel_utils.py Normal file
View File

@ -0,0 +1,186 @@
from typing import Any, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from flash_attn import flash_attn_varlen_func
try:
import deepspeed.comm as dist
except:
dist = None
try:
from utils import (
get_sequence_parallel_group,
get_sequence_parallel_size,
get_sequence_parallel_rank
)
except (ModuleNotFoundError, ImportError):
# 从 utils 获取seq parallel设置import不成功默认为不开启
get_sequence_parallel_group = lambda : None
get_sequence_parallel_size = lambda : 1
get_sequence_parallel_rank = lambda : 0
def single_all_to_all(input, scatter_idx, gather_idx, group):
seq_world_size = dist.get_world_size(group)
inp_shape = list(input.shape)
inp_shape[scatter_idx] = inp_shape[scatter_idx] // seq_world_size
if scatter_idx < 2:
input_t = input.reshape(
[seq_world_size, inp_shape[scatter_idx]] + \
inp_shape[scatter_idx + 1:]
).contiguous()
else:
# transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!
input_t = input.reshape(
[-1, seq_world_size, inp_shape[scatter_idx]] + \
inp_shape[scatter_idx + 1:]
).transpose(0, 1).contiguous()
output = torch.empty_like(input_t)
dist.all_to_all_single(output, input_t, group=group)
# if scattering the seq-dim, transpose the heads back to the original dimension
# [sp_size, seq_len//sp_size, batch_size, head_num // sp_size, head_dim] -->
# [seq_len//sp_size,batch_size, sp_size, head_num // sp_size, head_dim]
if scatter_idx < 2:
output = output.transpose(0, 1).transpose(1, 2).contiguous()
return output.reshape(
inp_shape[: gather_idx] + \
[inp_shape[gather_idx] * seq_world_size,] + \
inp_shape[gather_idx + 1:]).contiguous()
class _SeqAllToAll(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, group: 'dist.ProcessGroup', input: Tensor, scatter_idx: int, gather_idx: int) -> Tensor:
ctx.group = group
ctx.scatter_idx = scatter_idx
ctx.gather_idx = gather_idx
return single_all_to_all(input, scatter_idx, gather_idx, group)
@staticmethod
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:
return (None, _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx), None, None)
# import from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/sequence/layer.py
# but fix some bugs for 符合训练的维度设置
class DistributedAttention(nn.Module):
"""Initialization.
Arguments:
local_attention (Module): local attention with q,k,v
sequence_process_group (ProcessGroup): sequence parallel process group
scatter_idx (int): scatter_idx for all2all comm
gather_idx (int): gather_idx for all2all comm
"""
def __init__(
self,
local_attention: nn.Module,
sequence_process_group: 'dist.ProcessGroup',
scatter_idx: int = 2,
gather_idx: int = 0,
) -> None:
super(DistributedAttention, self).__init__()
self.local_attn = local_attention
self.spg = sequence_process_group
self.scatter_idx = scatter_idx
self.gather_idx = gather_idx
def pad_attention_head(self, query: Tensor, key: Tensor, value: Tensor):
# 将输入的head 维度pad到sp_size的倍数
sp_size = torch.distributed.get_world_size(self.spg)
pad_size = (sp_size - query.size(1) % sp_size) % sp_size
if pad_size > 0:
# [bs, num_head, seq_len, head_dim] -> [bs, num_head+pad_size, seq_len, head_dim]
query = torch.nn.functional.pad(query, (0,0,0,0,0,pad_size), value = 0.01)
key = torch.nn.functional.pad(key, (0,0,0,0,0,pad_size), value = 0.01)
value = torch.nn.functional.pad(value, (0,0,0,0,0,pad_size),value=0.0)
return query, key, value
def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any, **kwargs) -> Tensor:
""" forward
Arguments:
query (Tensor): query input to the layer [batch_size, num_head, seq_len, head_dim]
key (Tensor): key input to the layer
value (Tensor): value input to the layer
args: other args
Returns:
* output (Tensor): context output
"""
# TODO Merge three alltoall calls into one
# TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together!
# [batch_size,num_head,seq_len, head_dim ]trans to [seq_len,batch_size,num_head,head_dim]
origin_num_head = query.size(1)
query, key, value = self.pad_attention_head(query,key,value)
query = query.transpose(1,2).transpose(0,1)
key = key.transpose(1,2).transpose(0,1)
value = value.transpose(1,2).transpose(0,1)
#in shape : e.g., [s/p,bs,h,head_dim]
query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx).transpose(0,1).transpose(1,2).contiguous()
key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx).transpose(0,1).transpose(1,2).contiguous()
value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx).transpose(0,1).transpose(1,2).contiguous()
context_layer = self.local_attn(query_layer, key_layer, value_layer, *args, **kwargs)
context_layer = context_layer.transpose(0,1).contiguous()
# [seq_len, batch_size, num_head, head_dim]
output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx)
return output.transpose(0,1)[:,:,:origin_num_head,:]
class LocalAttention(nn.Module):
def __init__(self, hidden_size, num_heads, head_dim):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = head_dim
def forward(self, q, k, v, *args, use_flash=True, **kwargs):
# input q,k,v [batch_size, num_head, seq_len, head_dim]
# output [batch_size, seq_len, num_head, head_dim]
if use_flash:
q_len, num_heads = q.shape[2], q.shape[1]
q = q.transpose(1,2).reshape(-1, num_heads, self.head_dim)
k = k.transpose(1,2).reshape(-1, num_heads, self.head_dim)
v = v.transpose(1,2).reshape(-1, num_heads, self.head_dim)
return flash_attn_varlen_func(q,k,v,*args, **kwargs).reshape(-1,q_len, num_heads, self.head_dim)
else:
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
attn_output = F.scaled_dot_product_attention(
q,k,v, *args, **kwargs)
attn_output = attn_output.transpose(1, 2)
return attn_output
def create_attention_layer(hidden_size, num_heads, head_dim):
if get_sequence_parallel_group() is None:
return LocalAttention(hidden_size, num_heads, head_dim)
else:
return DistributedAttention(
local_attention=LocalAttention(hidden_size, num_heads, head_dim),
sequence_process_group=get_sequence_parallel_group()
)
def get_sequence_parallel_chunk(tensor, dim=1, shift=0):
assert tensor.size(dim) % get_sequence_parallel_size() == 0
original_size = tensor.size(dim)
if shift:
tensor = tensor.split([shift, tensor.size(dim) - shift], dim=dim)[1]
if get_sequence_parallel_group() is None:
return tensor
else:
chunk_size = original_size // get_sequence_parallel_size()
return tensor.split(chunk_size, dim=dim)[get_sequence_parallel_rank()]

68
special_tokens_map.json Normal file
View File

@ -0,0 +1,68 @@
{
"additional_special_tokens": [
"<|im_start|>",
"<|im_end|>",
"<|object_ref_start|>",
"<|object_ref_end|>",
"<|box_start|>",
"<|box_end|>",
"<|quad_start|>",
"<|quad_end|>",
"<|vision_start|>",
"<|vision_end|>",
"<|vision_pad|>",
"<|image_pad|>",
"<|video_pad|>",
"<B_SYS>",
"<B_USYS>",
"<C_Q>",
"<C_A>",
"<B_FUNC>",
"<B_CODE>",
"<B_APE>",
"<function_calling>",
"<calc_start>",
"<calc_end>",
"<inner_think>",
"<audio_start_baichuan>",
"<audio_end_baichuan>",
"<audio_pad_baichuan>",
"<img_start_baichuan>",
"<img_end_baichuan>",
"<img_pad_baichuan>",
"<img_newline_baichuan>",
"<box_start_baichuan>",
"<box_end_baichuan>",
"<box_delim_baichuan>",
"<ref_start_baichuan>",
"<ref_end_baichuan>",
"<img_delim_baichuan>",
"<polygon_start_baichuan>",
"<polygon_end_baichuan>",
"<baichuan_pad_token>",
"<reserved_113>",
"<audio_delim_baichuan>",
"<video_start_baichuan>",
"<video_end_baichuan>",
"<video_palce_baichuan>",
"<audiotext_start_baichuan>",
"<audiotext_end_baichuan>",
"<audiotext_pad_baichuan>",
"<audiogen_start_baichuan>",
"<audiogen_end_baichuan>"
],
"eos_token": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"pad_token": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
}
}

303621
tokenizer.json Normal file

File diff suppressed because it is too large Load Diff

540
tokenizer_config.json Normal file
View File

@ -0,0 +1,540 @@
{
"add_bos_token": false,
"add_prefix_space": false,
"added_tokens_decoder": {
"151643": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151644": {
"content": "<|im_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151645": {
"content": "<|im_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151646": {
"content": "<|object_ref_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151647": {
"content": "<|object_ref_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151648": {
"content": "<|box_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151649": {
"content": "<|box_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151650": {
"content": "<|quad_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151651": {
"content": "<|quad_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151652": {
"content": "<|vision_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151653": {
"content": "<|vision_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151654": {
"content": "<|vision_pad|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151655": {
"content": "<|image_pad|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151656": {
"content": "<|video_pad|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151657": {
"content": "<tool_call>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151658": {
"content": "</tool_call>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151659": {
"content": "<|fim_prefix|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151660": {
"content": "<|fim_middle|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151661": {
"content": "<|fim_suffix|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151662": {
"content": "<|fim_pad|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151663": {
"content": "<|repo_name|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151664": {
"content": "<|file_sep|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151665": {
"content": "<B_SYS>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151666": {
"content": "<B_USYS>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151667": {
"content": "<C_Q>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151668": {
"content": "<C_A>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151669": {
"content": "<B_FUNC>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151670": {
"content": "<B_CODE>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151671": {
"content": "<B_APE>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": true,
"special": true
},
"151672": {
"content": "<function_calling>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": true,
"special": true
},
"151673": {
"content": "<calc_start>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": true,
"special": true
},
"151674": {
"content": "<calc_end>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": true,
"special": true
},
"151675": {
"content": "<inner_think>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": true,
"special": true
},
"151676": {
"content": "<audio_start_baichuan>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151677": {
"content": "<audio_end_baichuan>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151678": {
"content": "<audio_pad_baichuan>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151679": {
"content": "<img_start_baichuan>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151680": {
"content": "<img_end_baichuan>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151681": {
"content": "<img_pad_baichuan>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151682": {
"content": "<img_newline_baichuan>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151683": {
"content": "<box_start_baichuan>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151684": {
"content": "<box_end_baichuan>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151685": {
"content": "<box_delim_baichuan>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151686": {
"content": "<ref_start_baichuan>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151687": {
"content": "<ref_end_baichuan>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151688": {
"content": "<img_delim_baichuan>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151689": {
"content": "<polygon_start_baichuan>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151690": {
"content": "<polygon_end_baichuan>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151691": {
"content": "<baichuan_pad_token>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151692": {
"content": "<reserved_113>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151693": {
"content": "<audio_delim_baichuan>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151694": {
"content": "<video_palce_baichuan>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151695": {
"content": "<video_start_baichuan>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151696": {
"content": "<video_end_baichuan>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151697": {
"content": "<audiotext_start_baichuan>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151698": {
"content": "<audiotext_end_baichuan>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151699": {
"content": "<audiotext_pad_baichuan>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151700": {
"content": "<audiogen_start_baichuan>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151701": {
"content": "<audiogen_end_baichuan>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
}
},
"additional_special_tokens": [
"<|im_start|>",
"<|im_end|>",
"<|object_ref_start|>",
"<|object_ref_end|>",
"<|box_start|>",
"<|box_end|>",
"<|quad_start|>",
"<|quad_end|>",
"<|vision_start|>",
"<|vision_end|>",
"<|vision_pad|>",
"<|image_pad|>",
"<|video_pad|>",
"<B_SYS>",
"<B_USYS>",
"<C_Q>",
"<C_A>",
"<B_FUNC>",
"<B_CODE>",
"<B_APE>",
"<function_calling>",
"<calc_start>",
"<calc_end>",
"<inner_think>",
"<audio_start_baichuan>",
"<audio_end_baichuan>",
"<audio_pad_baichuan>",
"<img_start_baichuan>",
"<img_end_baichuan>",
"<img_pad_baichuan>",
"<img_newline_baichuan>",
"<box_start_baichuan>",
"<box_end_baichuan>",
"<box_delim_baichuan>",
"<ref_start_baichuan>",
"<ref_end_baichuan>",
"<img_delim_baichuan>",
"<polygon_start_baichuan>",
"<polygon_end_baichuan>",
"<baichuan_pad_token>",
"<reserved_113>",
"<audio_delim_baichuan>",
"<video_start_baichuan>",
"<video_end_baichuan>",
"<video_palce_baichuan>",
"<audiotext_start_baichuan>",
"<audiotext_end_baichuan>",
"<audiotext_pad_baichuan>",
"<audiogen_start_baichuan>",
"<audiogen_end_baichuan>"
],
"bos_token": null,
"chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
"clean_up_tokenization_spaces": false,
"eos_token": "<|endoftext|>",
"errors": "replace",
"model_max_length": 65536,
"pad_token": "<|endoftext|>",
"split_special_tokens": false,
"tokenizer_class": "Qwen2Tokenizer",
"unk_token": null
}

78
vector_quantize.py Normal file
View File

@ -0,0 +1,78 @@
import torch, random
from torch.nn import functional as F
from torch import nn
import numpy as np
from torch.cuda.amp import autocast
def uniform_init(*shape):
t = torch.zeros(shape)
nn.init.kaiming_uniform_(t)
return t
def cdist(x, y):
x2 = torch.sum(x ** 2, dim=-1, keepdims=True) # (b, 1)
y2 = torch.sum(y ** 2, dim=-1).reshape(1, -1) # (1, c)
xy = torch.einsum('bd,cd->bc', x, y) * -2
return (x2 + y2 + xy).clamp(min=0).sqrt() # (b, c)
def get_sequence_mask(inputs, inputs_length):
if inputs.dim() == 3:
bsz, tgt_len, _ = inputs.size()
else:
bsz, tgt_len = inputs_length.shape[0], torch.max(inputs_length)
sequence_mask = torch.arange(0, tgt_len).to(inputs.device)
sequence_mask = torch.lt(sequence_mask, inputs_length.reshape(bsz, 1)).view(bsz, tgt_len, 1)
unpacking_index = torch.cumsum(sequence_mask.to(torch.int64).view(-1), dim=0) - 1 # 转成下标
return sequence_mask, unpacking_index
class EuclideanCodebook(nn.Module):
def __init__(
self,
dim,
codebook_size,
init_std=0.02,
):
super().__init__()
self.init_std = init_std
self.dim = dim
self.codebook_size = codebook_size
embed = uniform_init(codebook_size, dim).to(torch.float32)
self.cluster_size = nn.Parameter(torch.ones(codebook_size))
self.embed_avg = nn.Parameter(embed.clone())
self.embed = nn.Parameter(embed)
del embed
@autocast(enabled=True, dtype=torch.float32)
@torch.no_grad()
def forward(self, x):
assert(len(x.shape) == 2)
assert(x.dtype == torch.float32)
embed = self.embed.detach().to(x.device)
dist = -cdist(x, embed) # dist((bs*sl, d), (c, d)) --> (bs*sl, c)
embed_ind = dist.argmax(dim=-1)
quantize = embed[embed_ind] # (bs*sl, d)
return quantize, embed_ind, dist
class VectorQuantize(nn.Module):
def __init__(self, config, *args, **kwargs):
super().__init__(*args, **kwargs)
self.config = config
self.codebook = EuclideanCodebook(dim=config.dim, codebook_size=config.codebook_size)
def forward(self, x, input_length):
batch_size, seq_len, _ = x.shape
mask, unpacking_index = get_sequence_mask(x, input_length)
if x.dtype != torch.float32:
x = x.to(torch.float32)
x = torch.masked_select(x, mask).reshape(-1, self.config.dim) # (bs*sl?, d)
quantize, embed_ind, _ = self.codebook(x)
quantize = torch.index_select(quantize, 0, unpacking_index).view(batch_size, seq_len, self.config.dim)
quantize = torch.where(mask, quantize, 0)
embed_ind = torch.index_select(embed_ind.reshape(-1, 1), 0, unpacking_index).view(batch_size, seq_len, 1)
embed_ind = torch.where(mask, embed_ind, -1).squeeze()
return quantize, embed_ind
def get_output_from_indices(self, indices):
return self.codebook.embed[indices]

87
visual_modeling_omni.py Normal file
View File

@ -0,0 +1,87 @@
from typing import List, Optional, Tuple, Union
import torch, math
import torch.utils.checkpoint
from torch import nn
import transformers
from flash_attn import flash_attn_varlen_func
from transformers.activations import ACT2FN
from PIL import Image
import io, fire
from torch.nn import functional as F
class OmniVisualEncoder(transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VisionTransformerPretrainedModel):
def __init__(self, config):
super().__init__(config)
self.config_attn_implementation = 'flash_attention_2'
self.gradient_checkpointing = True # 强制开启
self._gradient_checkpointing_func = torch.utils.checkpoint.checkpoint
self.merge_size = config.merge_size if hasattr(config, 'merge_size') else 2
del self.merger
def forward(
self,
pixel_values: torch.Tensor,
grid_thw: torch.Tensor,
):
hidden_states = pixel_values.to(self.get_dtype())
grid_thw = grid_thw.to(pixel_values.device)
hidden_states = self.patch_embed(hidden_states)
rotary_pos_emb = self.rot_pos_emb(grid_thw)
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
dim=0, dtype=torch.int32
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
for blk in self.blocks:
if self.gradient_checkpointing and self.training:
hidden_states = self._gradient_checkpointing_func(blk.__call__, hidden_states, cu_seqlens, rotary_pos_emb)
else:
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
return hidden_states
@torch.no_grad()
def fake_input(self, device):
merge_size = max(self.merge_size, self.config.spatial_merge_size)
fake_image = torch.zeros([
1,
self.config.temporal_patch_size,
3,
merge_size // self.config.spatial_merge_size,
self.config.spatial_merge_size,
self.config.patch_size,
merge_size // self.config.spatial_merge_size,
self.config.spatial_merge_size,
self.config.patch_size,
], dtype=torch.float32, device=device)
patches = fake_image.permute(0, 3, 6, 4, 7, 2, 1, 5, 8)
flatten_patches = patches.reshape(
merge_size * merge_size, 3 * self.config.temporal_patch_size * self.config.patch_size * self.config.patch_size
)
return [flatten_patches], [(1, merge_size, merge_size)], [1]
class OmniVisualBridge(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.merge_size = self.config.merge_size if hasattr(self.config, 'merge_size') else 2
self.hidden_size = config.embed_dim * (self.merge_size**2)
self.ln_q = nn.LayerNorm(config.embed_dim, eps=1e-6)
self.mlp = nn.Sequential(
nn.Linear(self.hidden_size, self.hidden_size),
nn.GELU(),
nn.Linear(self.hidden_size, config.hidden_size),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
return x
if __name__ == '__main__':
fire.Fire()

1
vocab.json Normal file

File diff suppressed because one or more lines are too long

604
zero_to_fp32.py Normal file
View File

@ -0,0 +1,604 @@
#!/usr/bin/env python
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
# This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets
# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
# the future. Once extracted, the weights don't require DeepSpeed and can be used in any
# application.
#
# example: python zero_to_fp32.py . pytorch_model.bin
import argparse
import torch
import glob
import math
import os
import re
from collections import OrderedDict
from dataclasses import dataclass
# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
# DeepSpeed data structures it has to be available in the current python environment.
from deepspeed.utils import logger
from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
@dataclass
class zero_model_state:
buffers: dict()
param_shapes: dict()
shared_params: list
ds_version: int
frozen_param_shapes: dict()
frozen_param_fragments: dict()
debug = 0
# load to cpu
device = torch.device('cpu')
def atoi(text):
return int(text) if text.isdigit() else text
def natural_keys(text):
'''
alist.sort(key=natural_keys) sorts in human order
http://nedbatchelder.com/blog/200712/human_sorting.html
(See Toothy's implementation in the comments)
'''
return [atoi(c) for c in re.split(r'(\d+)', text)]
def get_model_state_file(checkpoint_dir, zero_stage):
if not os.path.isdir(checkpoint_dir):
raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
# there should be only one file
if zero_stage <= 2:
file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
elif zero_stage == 3:
file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
if not os.path.exists(file):
raise FileNotFoundError(f"can't find model states file at '{file}'")
return file
def get_checkpoint_files(checkpoint_dir, glob_pattern):
# XXX: need to test that this simple glob rule works for multi-node setup too
ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
if len(ckpt_files) == 0:
raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
return ckpt_files
def get_optim_files(checkpoint_dir):
return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
def get_model_state_files(checkpoint_dir):
return get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
def parse_model_states(files):
zero_model_states = []
for file in files:
state_dict = torch.load(file, map_location=device)
if BUFFER_NAMES not in state_dict:
raise ValueError(f"{file} is not a model state checkpoint")
buffer_names = state_dict[BUFFER_NAMES]
if debug:
print("Found buffers:", buffer_names)
# recover just the buffers while restoring them to fp32 if they were saved in fp16
buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}
param_shapes = state_dict[PARAM_SHAPES]
# collect parameters that are included in param_shapes
param_names = []
for s in param_shapes:
for name in s.keys():
param_names.append(name)
# update with frozen parameters
frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
if frozen_param_shapes is not None:
if debug:
print(f"Found frozen_param_shapes: {frozen_param_shapes}")
param_names += list(frozen_param_shapes.keys())
# handle shared params
shared_params = [[k, v] for k, v in state_dict["shared_params"].items()]
ds_version = state_dict.get(DS_VERSION, None)
frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
z_model_state = zero_model_state(buffers=buffers,
param_shapes=param_shapes,
shared_params=shared_params,
ds_version=ds_version,
frozen_param_shapes=frozen_param_shapes,
frozen_param_fragments=frozen_param_fragments)
zero_model_states.append(z_model_state)
return zero_model_states
def parse_optim_states(files, ds_checkpoint_dir):
total_files = len(files)
state_dicts = []
for f in files:
state_dict = torch.load(f, map_location=device)
# immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
# and also handle the case where it was already removed by another helper script
state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
state_dicts.append(state_dict)
if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]:
raise ValueError(f"{files[0]} is not a zero checkpoint")
zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
# For ZeRO-2 each param group can have different partition_count as data parallelism for expert
# parameters can be different from data parallelism for non-expert parameters. So we can just
# use the max of the partition_count to get the dp world_size.
if type(world_size) is list:
world_size = max(world_size)
if world_size != total_files:
raise ValueError(
f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
"Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
)
# the groups are named differently in each stage
if zero_stage <= 2:
fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
elif zero_stage == 3:
fp32_groups_key = FP32_FLAT_GROUPS
else:
raise ValueError(f"unknown zero stage {zero_stage}")
if zero_stage <= 2:
fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
elif zero_stage == 3:
# if there is more than one param group, there will be multiple flattened tensors - one
# flattened tensor per group - for simplicity merge them into a single tensor
#
# XXX: could make the script more memory efficient for when there are multiple groups - it
# will require matching the sub-lists of param_shapes for each param group flattened tensor
fp32_flat_groups = [
torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts))
]
return zero_stage, world_size, fp32_flat_groups
def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters):
"""
Returns fp32 state_dict reconstructed from ds checkpoint
Args:
- ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
"""
print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
optim_files = get_optim_files(ds_checkpoint_dir)
zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
model_files = get_model_state_files(ds_checkpoint_dir)
zero_model_states = parse_model_states(model_files)
print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
if zero_stage <= 2:
return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
exclude_frozen_parameters)
elif zero_stage == 3:
return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
exclude_frozen_parameters)
def _zero2_merge_frozen_params(state_dict, zero_model_states):
if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
return
frozen_param_shapes = zero_model_states[0].frozen_param_shapes
frozen_param_fragments = zero_model_states[0].frozen_param_fragments
if debug:
num_elem = sum(s.numel() for s in frozen_param_shapes.values())
print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
wanted_params = len(frozen_param_shapes)
wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
print(f'Frozen params: Have {avail_numel} numels to process.')
print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
total_params = 0
total_numel = 0
for name, shape in frozen_param_shapes.items():
total_params += 1
unpartitioned_numel = shape.numel()
total_numel += unpartitioned_numel
state_dict[name] = frozen_param_fragments[name]
if debug:
print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
def _has_callable(obj, fn):
attr = getattr(obj, fn, None)
return callable(attr)
def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
param_shapes = zero_model_states[0].param_shapes
# Reconstruction protocol:
#
# XXX: document this
if debug:
for i in range(world_size):
for j in range(len(fp32_flat_groups[0])):
print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
# XXX: memory usage doubles here (zero2)
num_param_groups = len(fp32_flat_groups[0])
merged_single_partition_of_fp32_groups = []
for i in range(num_param_groups):
merged_partitions = [sd[i] for sd in fp32_flat_groups]
full_single_fp32_vector = torch.cat(merged_partitions, 0)
merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
avail_numel = sum(
[full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
if debug:
wanted_params = sum([len(shapes) for shapes in param_shapes])
wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
# not asserting if there is a mismatch due to possible padding
print(f"Have {avail_numel} numels to process.")
print(f"Need {wanted_numel} numels in {wanted_params} params.")
# params
# XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
# out-of-core computing solution
total_numel = 0
total_params = 0
for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
offset = 0
avail_numel = full_single_fp32_vector.numel()
for name, shape in shapes.items():
unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape)
total_numel += unpartitioned_numel
total_params += 1
if debug:
print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
offset += unpartitioned_numel
# Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
# avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
# paddings performed in the code it's almost impossible to predict the exact numbers w/o the
# live optimizer object, so we are checking that the numbers are within the right range
align_to = 2 * world_size
def zero2_align(x):
return align_to * math.ceil(x / align_to)
if debug:
print(f"original offset={offset}, avail_numel={avail_numel}")
offset = zero2_align(offset)
avail_numel = zero2_align(avail_numel)
if debug:
print(f"aligned offset={offset}, avail_numel={avail_numel}")
# Sanity check
if offset != avail_numel:
raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
exclude_frozen_parameters):
state_dict = OrderedDict()
# buffers
buffers = zero_model_states[0].buffers
state_dict.update(buffers)
if debug:
print(f"added {len(buffers)} buffers")
if not exclude_frozen_parameters:
_zero2_merge_frozen_params(state_dict, zero_model_states)
_zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
# recover shared parameters
for pair in zero_model_states[0].shared_params:
if pair[1] in state_dict:
state_dict[pair[0]] = state_dict[pair[1]]
return state_dict
def zero3_partitioned_param_info(unpartitioned_numel, world_size):
remainder = unpartitioned_numel % world_size
padding_numel = (world_size - remainder) if remainder else 0
partitioned_numel = math.ceil(unpartitioned_numel / world_size)
return partitioned_numel, padding_numel
def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
return
if debug:
for i in range(world_size):
num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
frozen_param_shapes = zero_model_states[0].frozen_param_shapes
wanted_params = len(frozen_param_shapes)
wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
print(f'Frozen params: Have {avail_numel} numels to process.')
print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
total_params = 0
total_numel = 0
for name, shape in zero_model_states[0].frozen_param_shapes.items():
total_params += 1
unpartitioned_numel = shape.numel()
total_numel += unpartitioned_numel
param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
if debug:
print(
f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
)
print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
param_shapes = zero_model_states[0].param_shapes
avail_numel = fp32_flat_groups[0].numel() * world_size
# Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
# param, re-consolidating each param, while dealing with padding if any
# merge list of dicts, preserving order
param_shapes = {k: v for d in param_shapes for k, v in d.items()}
if debug:
for i in range(world_size):
print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
wanted_params = len(param_shapes)
wanted_numel = sum(shape.numel() for shape in param_shapes.values())
# not asserting if there is a mismatch due to possible padding
avail_numel = fp32_flat_groups[0].numel() * world_size
print(f"Trainable params: Have {avail_numel} numels to process.")
print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.")
# params
# XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
# out-of-core computing solution
offset = 0
total_numel = 0
total_params = 0
for name, shape in param_shapes.items():
unpartitioned_numel = shape.numel()
total_numel += unpartitioned_numel
total_params += 1
partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
if debug:
print(
f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
)
# XXX: memory usage doubles here
state_dict[name] = torch.cat(
tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)),
0).narrow(0, 0, unpartitioned_numel).view(shape)
offset += partitioned_numel
offset *= world_size
# Sanity check
if offset != avail_numel:
raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
exclude_frozen_parameters):
state_dict = OrderedDict()
# buffers
buffers = zero_model_states[0].buffers
state_dict.update(buffers)
if debug:
print(f"added {len(buffers)} buffers")
if not exclude_frozen_parameters:
_zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
_zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
# recover shared parameters
for pair in zero_model_states[0].shared_params:
if pair[1] in state_dict:
state_dict[pair[0]] = state_dict[pair[1]]
return state_dict
def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_frozen_parameters=False):
"""
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
via a model hub.
Args:
- ``checkpoint_dir``: path to the desired checkpoint folder
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
- ``exclude_frozen_parameters``: exclude frozen parameters
Returns:
- pytorch ``state_dict``
Note: this approach may not work if your application doesn't have sufficient free CPU memory and
you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
the checkpoint.
A typical usage might be ::
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
# do the training and checkpoint saving
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
model = model.cpu() # move to cpu
model.load_state_dict(state_dict)
# submit to model hub or save the model to share with others
In this example the ``model`` will no longer be usable in the deepspeed context of the same
application. i.e. you will need to re-initialize the deepspeed engine, since
``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
"""
if tag is None:
latest_path = os.path.join(checkpoint_dir, 'latest')
if os.path.isfile(latest_path):
with open(latest_path, 'r') as fd:
tag = fd.read().strip()
else:
raise ValueError(f"Unable to find 'latest' file at {latest_path}")
ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
if not os.path.isdir(ds_checkpoint_dir):
raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters)
def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None, exclude_frozen_parameters=False):
"""
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
Args:
- ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
- ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin)
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
- ``exclude_frozen_parameters``: exclude frozen parameters
"""
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag, exclude_frozen_parameters)
print(f"Saving fp32 state dict to {output_file}")
torch.save(state_dict, output_file)
def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
"""
1. Put the provided model to cpu
2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
3. Load it into the provided model
Args:
- ``model``: the model object to update
- ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
Returns:
- ``model`: modified model
Make sure you have plenty of CPU memory available before you call this function. If you don't
have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
conveniently placed for you in the checkpoint folder.
A typical usage might be ::
from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
# submit to model hub or save the model to share with others
Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
of the same application. i.e. you will need to re-initialize the deepspeed engine, since
``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
"""
logger.info(f"Extracting fp32 weights")
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
logger.info(f"Overwriting model with fp32 weights")
model = model.cpu()
model.load_state_dict(state_dict, strict=False)
return model
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("checkpoint_dir",
type=str,
help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
parser.add_argument(
"output_file",
type=str,
help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)")
parser.add_argument("-t",
"--tag",
type=str,
default=None,
help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters")
parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
args = parser.parse_args()
debug = args.debug
convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir,
args.output_file,
tag=args.tag,
exclude_frozen_parameters=args.exclude_frozen_parameters)