3260 lines
136 KiB
Python
3260 lines
136 KiB
Python
# coding=utf-8
|
||
# Copyright 2025 The OpenBMB Team. All rights reserved.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
|
||
import json
|
||
import logging
|
||
import math
|
||
import os
|
||
import types
|
||
from collections.abc import Iterator
|
||
from copy import deepcopy
|
||
from dataclasses import dataclass
|
||
from threading import Thread
|
||
from typing import List
|
||
from typing import Literal
|
||
from typing import Optional
|
||
from typing import Tuple
|
||
from typing import Union
|
||
|
||
import numpy as np
|
||
import soundfile as sf
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
import torch.nn.utils.parametrize as P
|
||
from huggingface_hub import hf_hub_download
|
||
from PIL import Image
|
||
from torch.nn.utils.parametrizations import weight_norm
|
||
from tqdm import tqdm
|
||
from transformers import AutoProcessor
|
||
from transformers import BertTokenizerFast
|
||
from transformers import LlamaConfig
|
||
from transformers import LlamaModel
|
||
from transformers import LogitsWarper
|
||
from transformers import PreTrainedModel
|
||
from transformers import Qwen2ForCausalLM
|
||
from transformers import Qwen2PreTrainedModel
|
||
from transformers import TextIteratorStreamer
|
||
from transformers import TopKLogitsWarper
|
||
from transformers import TopPLogitsWarper
|
||
from transformers.cache_utils import Cache
|
||
from transformers.cache_utils import DynamicCache
|
||
from transformers.cache_utils import EncoderDecoderCache
|
||
from transformers.cache_utils import StaticCache
|
||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||
from transformers.modeling_outputs import ModelOutput
|
||
from transformers.models.whisper.modeling_whisper import ACT2FN
|
||
from transformers.models.whisper.modeling_whisper import WHISPER_ATTENTION_CLASSES
|
||
from transformers.models.whisper.modeling_whisper import WhisperConfig
|
||
from transformers.models.whisper.modeling_whisper import WhisperEncoder
|
||
|
||
try:
|
||
from vector_quantize_pytorch import GroupedResidualFSQ
|
||
from vocos import Vocos
|
||
from vocos.pretrained import instantiate_class
|
||
|
||
_tts_deps = True
|
||
except:
|
||
_tts_deps = False
|
||
|
||
from .configuration_minicpm import ConditionalChatTTSConfig
|
||
from .configuration_minicpm import MiniCPMOConfig
|
||
from .modeling_navit_siglip import SiglipVisionTransformer
|
||
from .resampler import Resampler
|
||
from .utils import NumberToTextConverter
|
||
from .utils import sentence_end
|
||
from .utils import VoiceChecker
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
@dataclass
|
||
class OmniOutput(ModelOutput):
|
||
text: Optional[Union[str, List[str], Iterator]] = None
|
||
spk_embeds: Optional[torch.FloatTensor] = None
|
||
audio_wav: Optional[np.ndarray] = None
|
||
sampling_rate: Optional[int] = None
|
||
|
||
|
||
class MiniCPMOPreTrainedModel(Qwen2PreTrainedModel):
|
||
config_class = MiniCPMOConfig
|
||
|
||
|
||
class MiniCPMO(MiniCPMOPreTrainedModel):
|
||
def __init__(self, config):
|
||
super().__init__(config)
|
||
self.llm = Qwen2ForCausalLM(config)
|
||
self.llm.prepare_inputs_for_generation = types.MethodType(prepare_inputs_for_generation, self.llm) # patch llm
|
||
|
||
self.embed_dim = self.llm.config.hidden_size
|
||
|
||
# init vision module
|
||
if self.config.init_vision:
|
||
self.vpm = self.init_vision_module()
|
||
self.vision_dim = self.vpm.embed_dim
|
||
self.resampler = self.init_resampler(self.embed_dim, self.vision_dim)
|
||
|
||
# init audio module
|
||
if self.config.init_audio:
|
||
self.apm = self.init_audio_module()
|
||
audio_output_dim = int(self.apm.config.encoder_ffn_dim // 4)
|
||
self.audio_avg_pooler = nn.AvgPool1d(self.config.audio_pool_step, stride=self.config.audio_pool_step)
|
||
self.audio_projection_layer = MultiModalProjector(in_dim=audio_output_dim, out_dim=self.embed_dim)
|
||
self.audio_encoder_layer = -1
|
||
|
||
# init tts module
|
||
if self.config.init_tts:
|
||
assert _tts_deps, "please make sure vector_quantize_pytorch and vocos are installed."
|
||
self.tts = self.init_tts_module()
|
||
|
||
self.processor = AutoProcessor.from_pretrained(self.config._name_or_path, trust_remote_code=True)
|
||
|
||
self.terminators = ["<|im_end|>", "<|endoftext|>"]
|
||
|
||
self.default_tts_chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>' }}{% endif %}"
|
||
self.force_no_stop = False
|
||
|
||
# for stream api
|
||
self.reset_session()
|
||
|
||
def reset_session(self):
|
||
self.session_id = None
|
||
self.new_user_msg = True
|
||
self.llm_generated = False
|
||
self.llm_generate_completed = False
|
||
self.llm_past_key_values = None
|
||
self.audio_past_key_values = None # apm kv cache
|
||
|
||
def init_tts(
|
||
self,
|
||
tts_text_tokenizer_path=None,
|
||
vocos_ckpt_path=None,
|
||
):
|
||
"""
|
||
load tts tokenizer and vocos
|
||
1. try load form local 2. try load from huggingface
|
||
"""
|
||
from .processing_minicpmo import ChatTTSProcessor
|
||
|
||
if tts_text_tokenizer_path is None:
|
||
tts_text_tokenizer_path = os.path.join(self.config._name_or_path, "assets/chattts_tokenizer")
|
||
if not os.path.exists(tts_text_tokenizer_path):
|
||
# try from hf model_id
|
||
tts_text_tokenizer_path = "openbmb/chattts_tokenizer"
|
||
|
||
tts_text_tokenizer = BertTokenizerFast.from_pretrained(tts_text_tokenizer_path)
|
||
self.tts_processor = ChatTTSProcessor(text_tokenizer=tts_text_tokenizer)
|
||
|
||
if vocos_ckpt_path is None:
|
||
vocos_ckpt_path = os.path.join(self.config._name_or_path, "assets/Vocos.pt")
|
||
if not os.path.exists(vocos_ckpt_path):
|
||
vocos_ckpt_path = hf_hub_download(repo_id="openbmb/MiniCPM-o-2_6", subfolder="assets", filename="Vocos.pt")
|
||
|
||
assert os.path.exists(vocos_ckpt_path)
|
||
self.vocos = self.initialize_vocos(vocos_ckpt_path)
|
||
|
||
def initialize_vocos(self, ckpt_path):
|
||
feature_extractor = instantiate_class(
|
||
args=(),
|
||
init={
|
||
"class_path": "vocos.feature_extractors.MelSpectrogramFeatures",
|
||
"init_args": {"sample_rate": 24000, "n_fft": 1024, "hop_length": 256, "n_mels": 100},
|
||
},
|
||
)
|
||
backbone = instantiate_class(
|
||
args=(),
|
||
init={
|
||
"class_path": "vocos.models.VocosBackbone",
|
||
"init_args": {"input_channels": 100, "dim": 512, "intermediate_dim": 1536, "num_layers": 8},
|
||
},
|
||
)
|
||
head = instantiate_class(
|
||
args=(),
|
||
init={"class_path": "vocos.heads.ISTFTHead", "init_args": {"dim": 512, "n_fft": 1024, "hop_length": 256}},
|
||
)
|
||
vocos = Vocos(feature_extractor, backbone, head).to("cuda").eval().to(torch.float32)
|
||
vocos.load_state_dict(torch.load(ckpt_path, weights_only=True, mmap=True))
|
||
return vocos
|
||
|
||
def init_vision_module(self):
|
||
if self.config._attn_implementation == "flash_attention_2":
|
||
self.config.vision_config._attn_implementation = "flash_attention_2"
|
||
else:
|
||
self.config.vision_config._attn_implementation = "eager"
|
||
model = SiglipVisionTransformer(self.config.vision_config)
|
||
if self.config.drop_vision_last_layer:
|
||
model.encoder.layers = model.encoder.layers[:-1]
|
||
|
||
setattr(model, "embed_dim", model.embeddings.embed_dim)
|
||
setattr(model, "patch_size", model.embeddings.patch_size)
|
||
|
||
return model
|
||
|
||
def init_resampler(self, embed_dim, vision_dim):
|
||
return Resampler(
|
||
num_queries=self.config.query_num,
|
||
embed_dim=embed_dim,
|
||
num_heads=embed_dim // 128,
|
||
kv_dim=vision_dim,
|
||
adaptive=True,
|
||
)
|
||
|
||
def init_audio_module(self):
|
||
model = MiniCPMWhisperEncoder(self.config.audio_config)
|
||
return model
|
||
|
||
def init_tts_module(self):
|
||
model = ConditionalChatTTS(self.config.tts_config)
|
||
return model
|
||
|
||
def get_input_embeddings(self):
|
||
return self.llm.get_input_embeddings()
|
||
|
||
def set_input_embeddings(self, value):
|
||
self.llm.embed_tokens = value
|
||
|
||
def get_output_embeddings(self):
|
||
return self.llm.lm_head
|
||
|
||
def set_output_embeddings(self, new_embeddings):
|
||
self.llm.lm_head = new_embeddings
|
||
|
||
def set_decoder(self, decoder):
|
||
self.llm = decoder
|
||
|
||
def get_decoder(self):
|
||
return self.llm
|
||
|
||
def subsequent_chunk_mask(
|
||
self,
|
||
size: int,
|
||
chunk_size: int,
|
||
num_left_chunks: int = -1,
|
||
device: torch.device = torch.device("cpu"),
|
||
num_lookhead: int = 0,
|
||
) -> torch.Tensor:
|
||
"""Create mask for subsequent steps (size, size) with chunk size,
|
||
this is for streaming encoder
|
||
|
||
Args:
|
||
size (int): size of mask
|
||
chunk_size (int): size of chunk
|
||
num_left_chunks (int): number of left chunks
|
||
<0: use full chunk
|
||
>=0: use num_left_chunks
|
||
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
|
||
|
||
Returns:
|
||
torch.Tensor: mask
|
||
|
||
Examples:
|
||
>>> subsequent_chunk_mask(4, 2)
|
||
[[1, 1, 0, 0],
|
||
[1, 1, 0, 0],
|
||
[1, 1, 1, 1],
|
||
[1, 1, 1, 1]]
|
||
"""
|
||
ret = torch.zeros(size, size, device=device, dtype=torch.bool)
|
||
for i in range(size):
|
||
if num_left_chunks < 0:
|
||
start = 0
|
||
else:
|
||
start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
|
||
ending = min((i // chunk_size + 1) * chunk_size + num_lookhead, size)
|
||
ret[i, start:ending] = True
|
||
return ret
|
||
|
||
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
|
||
"""
|
||
Computes the output length of the convolutional layers and the output length of the audio encoder
|
||
"""
|
||
input_lengths_after_cnn = (input_lengths - 1) // 2 + 1
|
||
input_lengths_after_pooling = (
|
||
input_lengths_after_cnn - self.config.audio_pool_step
|
||
) // self.config.audio_pool_step + 1
|
||
input_lengths_after_pooling = input_lengths_after_pooling.to(dtype=torch.int32)
|
||
|
||
return input_lengths_after_cnn, input_lengths_after_pooling
|
||
|
||
def get_vllm_embedding(self, data):
|
||
"""
|
||
Compute all visual embeddings, and set into llm embeddings.
|
||
Args:
|
||
data: Dict
|
||
tgt_sizes: image size after patch embedding
|
||
pixel_values: image features
|
||
image_bound: position of each picture corresponding to input_ids
|
||
input_ids: full input_ids, include placeholder
|
||
Returns:
|
||
embedding with vision, vision_hidden_states
|
||
"""
|
||
if "vision_hidden_states" not in data:
|
||
dtype = self.llm.model.embed_tokens.weight.dtype
|
||
device = self.llm.model.embed_tokens.weight.device
|
||
tgt_sizes = data["tgt_sizes"]
|
||
pixel_values_list = data["pixel_values"]
|
||
vision_hidden_states = []
|
||
all_pixel_values = []
|
||
img_cnt = []
|
||
for pixel_values in pixel_values_list:
|
||
img_cnt.append(len(pixel_values))
|
||
all_pixel_values.extend([i.flatten(end_dim=1).permute(1, 0) for i in pixel_values])
|
||
|
||
# exist image
|
||
if all_pixel_values:
|
||
tgt_sizes = [tgt_size for tgt_size in tgt_sizes if isinstance(tgt_size, torch.Tensor)]
|
||
tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32)
|
||
|
||
max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1])
|
||
|
||
all_pixel_values = torch.nn.utils.rnn.pad_sequence(
|
||
all_pixel_values, batch_first=True, padding_value=0.0
|
||
)
|
||
B, L, _ = all_pixel_values.shape
|
||
all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
|
||
|
||
patch_attn_mask = torch.zeros((B, 1, max_patches), dtype=torch.bool, device=device)
|
||
for i in range(B):
|
||
patch_attn_mask[i, 0, : tgt_sizes[i][0] * tgt_sizes[i][1]] = True
|
||
|
||
vision_batch_size = self.config.vision_batch_size
|
||
all_pixel_values = all_pixel_values.type(dtype)
|
||
if B > vision_batch_size:
|
||
hs = []
|
||
for i in range(0, B, vision_batch_size):
|
||
start_idx = i
|
||
end_idx = i + vision_batch_size
|
||
tmp_hs = self.vpm(
|
||
all_pixel_values[start_idx:end_idx],
|
||
patch_attention_mask=patch_attn_mask[start_idx:end_idx],
|
||
tgt_sizes=tgt_sizes[start_idx:end_idx],
|
||
).last_hidden_state
|
||
hs.append(tmp_hs)
|
||
vision_embedding = torch.cat(hs, dim=0)
|
||
else:
|
||
vision_embedding = self.vpm(
|
||
all_pixel_values, patch_attention_mask=patch_attn_mask, tgt_sizes=tgt_sizes
|
||
).last_hidden_state
|
||
vision_embedding = self.resampler(vision_embedding, tgt_sizes)
|
||
|
||
start = 0
|
||
for pixel_values in pixel_values_list:
|
||
img_cnt = len(pixel_values)
|
||
if img_cnt > 0:
|
||
vision_hidden_states.append(vision_embedding[start : start + img_cnt])
|
||
start += img_cnt
|
||
else:
|
||
vision_hidden_states.append([])
|
||
else: # no image
|
||
if self.training:
|
||
dummy_image = torch.zeros((1, 3, 224, 224), device=device, dtype=dtype)
|
||
tgt_sizes = torch.Tensor(
|
||
[[(224 // self.config.patch_size), math.ceil(224 / self.config.patch_size)]]
|
||
).type(torch.int32)
|
||
dummy_feature = self.resampler(self.vpm(dummy_image).last_hidden_state, tgt_sizes)
|
||
else:
|
||
dummy_feature = []
|
||
for _ in range(len(pixel_values_list)):
|
||
vision_hidden_states.append(dummy_feature)
|
||
|
||
else:
|
||
vision_hidden_states = data["vision_hidden_states"]
|
||
|
||
if hasattr(self.llm.config, "scale_emb"):
|
||
vllm_embedding = self.llm.model.embed_tokens(data["input_ids"]) * self.llm.config.scale_emb
|
||
else:
|
||
vllm_embedding = self.llm.model.embed_tokens(data["input_ids"])
|
||
|
||
vision_hidden_states = [
|
||
i.type(vllm_embedding.dtype) if isinstance(i, torch.Tensor) else i for i in vision_hidden_states
|
||
]
|
||
|
||
bs = len(data["input_ids"])
|
||
for i in range(bs):
|
||
cur_vs_hs = vision_hidden_states[i]
|
||
if len(cur_vs_hs) > 0:
|
||
cur_vllm_emb = vllm_embedding[i]
|
||
cur_image_bound = data["image_bound"][i]
|
||
if len(cur_image_bound) > 0:
|
||
image_indices = torch.stack(
|
||
[torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]
|
||
).to(vllm_embedding.device)
|
||
|
||
cur_vllm_emb.scatter_(
|
||
0,
|
||
image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
|
||
cur_vs_hs.view(-1, cur_vs_hs.shape[-1]),
|
||
)
|
||
elif self.training:
|
||
cur_vllm_emb += cur_vs_hs[0].mean() * 0
|
||
|
||
return vllm_embedding, vision_hidden_states
|
||
|
||
def get_audio_embedding_streaming(self, data):
|
||
r"""
|
||
Extract audio embeddings in a streaming manner using cached key-value pairs.
|
||
|
||
This method processes incoming audio features incrementally and stores/updates `past_key_values`
|
||
for faster inference on subsequent audio frames. It only supports batch_size=1 and is intended
|
||
for streaming scenarios.
|
||
|
||
Args:
|
||
data (dict):
|
||
- **"audio_features"** (`torch.FloatTensor`): Input mel-spectrograms of shape `(batch_size, 80, frames)`.
|
||
- **"audio_feature_lens"** (List[List[int]]): Lengths of each audio segment for each item in the batch.
|
||
|
||
Returns:
|
||
List[List[torch.Tensor]]: audio embeddings
|
||
"""
|
||
wavforms = data.get("audio_features", []) # (bs, 80, frames) or [], multi audios need filled in advance
|
||
audio_feature_lens_raw = data.get("audio_feature_lens", []) # list, [[x1, x2], [y1], [z1]]
|
||
|
||
# exist audio
|
||
if len(wavforms) > 0:
|
||
audio_feature_lens = torch.hstack(audio_feature_lens_raw)
|
||
batch_size, _, max_mel_seq_len = wavforms.shape
|
||
assert batch_size == 1
|
||
max_seq_len = (max_mel_seq_len - 1) // 2 + 1
|
||
|
||
if self.audio_past_key_values is not None:
|
||
cache_length = self.audio_past_key_values[0][0].shape[2]
|
||
apm_max_len = self.apm.embed_positions.weight.shape[0]
|
||
if cache_length + max_seq_len >= apm_max_len:
|
||
logger.warning(
|
||
f"audio_past_key_values length {cache_length + max_seq_len} exceed {apm_max_len}, reset."
|
||
)
|
||
self.audio_past_key_values = None
|
||
|
||
audio_outputs = self.apm(wavforms, past_key_values=self.audio_past_key_values, use_cache=True)
|
||
audio_states = audio_outputs.last_hidden_state # [:, :audio_feat_lengths, :]
|
||
self.audio_past_key_values = audio_outputs.past_key_values
|
||
|
||
audio_embeds = self.audio_projection_layer(audio_states)
|
||
|
||
audio_embeds = audio_embeds.transpose(1, 2)
|
||
audio_embeds = self.audio_avg_pooler(audio_embeds)
|
||
audio_embeds = audio_embeds.transpose(1, 2)
|
||
|
||
_, feature_lens_after_pooling = self._get_feat_extract_output_lengths(audio_feature_lens)
|
||
|
||
num_audio_tokens = feature_lens_after_pooling
|
||
|
||
final_audio_embeds = []
|
||
idx = 0
|
||
for i in range(len(audio_feature_lens_raw)):
|
||
target_audio_embeds = []
|
||
for _ in range(len(audio_feature_lens_raw[i])):
|
||
target_audio_embeds.append(audio_embeds[idx, : num_audio_tokens[idx], :])
|
||
idx += 1
|
||
final_audio_embeds.append(target_audio_embeds)
|
||
return final_audio_embeds
|
||
else:
|
||
return []
|
||
|
||
def get_audio_embedding(self, data, chunk_length=-1):
|
||
r"""
|
||
Extract full audio embeddings with optional chunk-based attention.
|
||
|
||
This method computes embeddings for all audio frames at once, either using full attention (when
|
||
`chunk_length` is -1) or chunk-based attention (when `chunk_length` is a positive number). It does
|
||
not use key-value caching and is suitable for non-streaming inference.
|
||
|
||
Args:
|
||
data (dict):
|
||
- **"audio_features"** (`torch.FloatTensor`): Input mel-spectrograms of shape `(batch_size, 80, frames)`.
|
||
- **"audio_feature_lens"** (List[List[int]]): Lengths of each audio segment for each item in the batch.
|
||
chunk_length (int, optional): Determines whether to use full attention (-1) or chunk-based
|
||
attention (>0) during embedding computation.
|
||
|
||
Returns:
|
||
List[List[torch.Tensor]]: audio embeddings
|
||
"""
|
||
|
||
wavforms = data.get("audio_features", []) # (bs, 80, frames) or [], multi audios need filled in advance
|
||
audio_feature_lens_raw = data.get("audio_feature_lens", []) # list, [[x1, x2], [y1], [z1]]
|
||
|
||
# exist audio
|
||
if len(wavforms) > 0:
|
||
audio_feature_lens = torch.hstack(audio_feature_lens_raw)
|
||
batch_size, _, max_mel_seq_len = wavforms.shape
|
||
max_seq_len = (max_mel_seq_len - 1) // 2 + 1
|
||
|
||
# Create a sequence tensor of shape (batch_size, max_seq_len)
|
||
seq_range = (
|
||
torch.arange(0, max_seq_len, dtype=audio_feature_lens.dtype, device=audio_feature_lens.device)
|
||
.unsqueeze(0)
|
||
.expand(batch_size, max_seq_len)
|
||
)
|
||
lengths_expand = audio_feature_lens.unsqueeze(1).expand(batch_size, max_seq_len)
|
||
# Create mask
|
||
padding_mask = seq_range >= lengths_expand # 1 for padded values
|
||
|
||
audio_attention_mask_ = padding_mask.view(batch_size, 1, 1, max_seq_len).expand(
|
||
batch_size, 1, max_seq_len, max_seq_len
|
||
)
|
||
audio_attention_mask = audio_attention_mask_.to(
|
||
dtype=self.apm.conv1.weight.dtype, device=self.apm.conv1.weight.device
|
||
)
|
||
|
||
if chunk_length > 0:
|
||
chunk_num_frame = int(chunk_length * 50)
|
||
chunk_mask = self.subsequent_chunk_mask(
|
||
size=max_seq_len,
|
||
chunk_size=chunk_num_frame,
|
||
num_left_chunks=-1,
|
||
device=audio_attention_mask_.device,
|
||
)
|
||
audio_attention_mask_ = torch.logical_or(audio_attention_mask_, torch.logical_not(chunk_mask))
|
||
|
||
audio_attention_mask[audio_attention_mask_] = float("-inf")
|
||
audio_states = self.apm(
|
||
wavforms, output_hidden_states=True, attention_mask=audio_attention_mask
|
||
).hidden_states[self.audio_encoder_layer]
|
||
audio_embeds = self.audio_projection_layer(audio_states)
|
||
|
||
audio_embeds = audio_embeds.transpose(1, 2)
|
||
audio_embeds = self.audio_avg_pooler(audio_embeds)
|
||
audio_embeds = audio_embeds.transpose(1, 2)
|
||
|
||
_, feature_lens_after_pooling = self._get_feat_extract_output_lengths(audio_feature_lens)
|
||
|
||
num_audio_tokens = feature_lens_after_pooling
|
||
|
||
final_audio_embeds = []
|
||
idx = 0
|
||
for i in range(len(audio_feature_lens_raw)):
|
||
target_audio_embeds = []
|
||
for _ in range(len(audio_feature_lens_raw[i])):
|
||
target_audio_embeds.append(audio_embeds[idx, : num_audio_tokens[idx], :])
|
||
idx += 1
|
||
final_audio_embeds.append(target_audio_embeds)
|
||
return final_audio_embeds
|
||
else:
|
||
return []
|
||
|
||
def get_omni_embedding(self, data, input_embeddings, chunk_length=-1, stream_input=False):
|
||
"""
|
||
Args:
|
||
data:
|
||
input_embeddings:
|
||
chunk_length: whisper use full attention or chunk attention
|
||
stream_input: use streaming audio embedding
|
||
Returns:
|
||
final embeddings with audio feature
|
||
"""
|
||
if stream_input:
|
||
audio_embeddings = self.get_audio_embedding_streaming(data)
|
||
else:
|
||
audio_embeddings = self.get_audio_embedding(data, chunk_length)
|
||
|
||
bs = len(input_embeddings)
|
||
if len(data.get("audio_features", [])) > 0:
|
||
assert len(audio_embeddings) == len(input_embeddings)
|
||
if len(audio_embeddings) > 0:
|
||
audio_bounds = data["audio_bounds"]
|
||
|
||
if self.config.chunk_input:
|
||
for i in range(bs):
|
||
audio_embs = torch.cat(audio_embeddings[i], dim=0).to(
|
||
device=input_embeddings.device, dtype=input_embeddings.dtype
|
||
)
|
||
audio_start_pos = 0
|
||
for bound in audio_bounds[i]:
|
||
audio_len = bound[1] - bound[0]
|
||
input_embeddings[0, bound[0] : bound[1]] = audio_embs[
|
||
audio_start_pos : audio_start_pos + audio_len, :
|
||
]
|
||
audio_start_pos += audio_len
|
||
else:
|
||
for i in range(bs):
|
||
audio_embs = audio_embeddings[i]
|
||
bounds = audio_bounds[i]
|
||
for embs, bound in zip(audio_embs, bounds):
|
||
audio_indices = torch.arange(bound[0], bound[1], dtype=torch.long).to(
|
||
input_embeddings.device
|
||
)
|
||
|
||
if embs.shape[0] != len(audio_indices):
|
||
raise ValueError(
|
||
f"Shape mismatch: Trying to assign embeddings of shape {embs.shape} "
|
||
f"to input indices of length {len(audio_indices)}"
|
||
)
|
||
input_embeddings[i, audio_indices] = embs.to(input_embeddings.dtype)
|
||
elif self.training:
|
||
for i in range(bs):
|
||
# dummy audio_embeddings
|
||
input_embeddings += audio_embeddings[0].mean() * 0
|
||
|
||
return input_embeddings
|
||
|
||
def forward(self, data, **kwargs):
|
||
vllm_embedding, vision_hidden_states = self.get_vllm_embedding(data)
|
||
|
||
if self.config.init_audio:
|
||
vllm_embedding = self.get_omni_embedding(
|
||
data, input_embeddings=vllm_embedding, chunk_length=self.config.audio_chunk_length
|
||
)
|
||
|
||
position_ids = data["position_ids"]
|
||
if position_ids.dtype != torch.int64:
|
||
position_ids = position_ids.long()
|
||
|
||
# compatible with llama factory
|
||
for key in ["input_ids", "inputs_embeds", "position_ids"]:
|
||
if key in kwargs:
|
||
del kwargs[key]
|
||
|
||
return self.llm(input_ids=None, position_ids=position_ids, inputs_embeds=vllm_embedding, **kwargs)
|
||
|
||
def _decode(self, inputs_embeds, tokenizer, attention_mask, **kwargs):
|
||
terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
|
||
outputs = self.llm.generate(
|
||
inputs_embeds=inputs_embeds,
|
||
pad_token_id=0,
|
||
eos_token_id=terminators,
|
||
attention_mask=attention_mask,
|
||
output_hidden_states=True,
|
||
return_dict_in_generate=True,
|
||
**kwargs,
|
||
)
|
||
return outputs
|
||
|
||
def _decode_stream(self, inputs_embeds, tokenizer, **kwargs):
|
||
terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
|
||
streamer = TextIteratorStreamer(tokenizer=tokenizer)
|
||
generation_kwargs = {
|
||
"inputs_embeds": inputs_embeds,
|
||
"pad_token_id": 0,
|
||
"eos_token_id": terminators,
|
||
"streamer": streamer,
|
||
}
|
||
generation_kwargs.update(kwargs)
|
||
|
||
thread = Thread(target=self.llm.generate, kwargs=generation_kwargs)
|
||
thread.start()
|
||
|
||
return streamer
|
||
|
||
def _decode_text(self, result_ids, tokenizer):
|
||
terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
|
||
result_text = []
|
||
for result in result_ids:
|
||
result = result[result != 0]
|
||
if result[0] == tokenizer.bos_id:
|
||
result = result[1:]
|
||
if result[-1] in terminators:
|
||
result = result[:-1]
|
||
result_text.append(tokenizer.decode(result))
|
||
return result_text
|
||
|
||
def get_sys_prompt(self, ref_audio=None, mode="default", language="zh"):
|
||
"""
|
||
Choose different system prompts according to different tasks
|
||
Args:
|
||
ref_audio: if ref_audio is not None, will use the voice cloning prompts, and the voice
|
||
generated by the model will refer to the timbre of ref audio
|
||
mode:
|
||
"default": default system prompt and not refer to any task
|
||
"omni": input video and audio simultaneously
|
||
"audio_assistant": Default voice-only mode, the model will use the ref_audio's voice to reply user as a helpful assistant.
|
||
"audio_roleplay": Roleplay voice-only model, the model will use the ref_audio's voice to reply, and also role-play the character based on the audio prompt.
|
||
"voice_cloning": TTS mode, the model will clone the voice of ref_audio
|
||
language: prompts language, the model has the ability to automatically select the response language
|
||
based on the question language
|
||
Returns:
|
||
|
||
"""
|
||
if ref_audio is not None:
|
||
assert isinstance(ref_audio, np.ndarray), "ref_audio error"
|
||
if mode == "omni":
|
||
if language == "zh":
|
||
sys_prompt = "你是一个AI助手。你能接受视频,音频和文本输入并输出语音和文本。"
|
||
vc_prompt_prefix = sys_prompt + "模仿输入音频中的声音特征。"
|
||
vc_prompt_suffix = "作为助手,你将使用这种声音风格说话。"
|
||
else:
|
||
sys_prompt = "You are a helpful assistant. You can accept video, audio and text input and output voice and text. "
|
||
vc_prompt_prefix = sys_prompt + "Clone the voice in the provided audio prompt."
|
||
vc_prompt_suffix = "As an assistant, you will speak using this voice style."
|
||
|
||
if ref_audio is not None:
|
||
sys_msgs = {"role": "user", "content": [vc_prompt_prefix, ref_audio, vc_prompt_suffix]}
|
||
|
||
else:
|
||
sys_msgs = {"role": "user", "content": [sys_prompt]}
|
||
|
||
return sys_msgs
|
||
elif mode == "audio_assistant":
|
||
if language == "zh":
|
||
vc_prompt_prefix = "模仿输入音频中的声音特征。"
|
||
vc_prompt_suffix = "作为助手,你将使用这种声音风格说话。"
|
||
else:
|
||
vc_prompt_prefix = "Clone the voice in the provided audio prompt."
|
||
vc_prompt_suffix = "As an assistant, you will speak using this voice style."
|
||
|
||
if ref_audio is not None:
|
||
sys_msgs = {"role": "user", "content": [vc_prompt_prefix, ref_audio, vc_prompt_suffix]}
|
||
|
||
else:
|
||
logger.warning(
|
||
"Warning: ref_audio is None, speech generation will be performed based on the default voice."
|
||
)
|
||
sys_msgs = {"role": "user", "content": ["Use the <reserved_53> voice.", vc_prompt_suffix]}
|
||
|
||
return sys_msgs
|
||
elif mode == "audio_roleplay":
|
||
if language == "zh":
|
||
vc_prompt_prefix = "模仿输入音频中的声音特征。"
|
||
vc_prompt_suffix = "假装你是上述音频中的人物,与我进行对话。"
|
||
else:
|
||
vc_prompt_prefix = "Clone the voice in the provided audio prompt."
|
||
vc_prompt_suffix = "Try to role-play the character based on the audio prompt above."
|
||
|
||
if ref_audio is not None:
|
||
sys_msgs = {"role": "user", "content": [vc_prompt_prefix, ref_audio, vc_prompt_suffix]}
|
||
else:
|
||
print("Warning: ref_audio is None, speech generation will be performed based on the default voice.")
|
||
sys_msgs = {"role": "user", "content": ["Use the <reserved_53> voice.", vc_prompt_suffix]}
|
||
|
||
return sys_msgs
|
||
elif mode == "voice_cloning":
|
||
if language == "zh":
|
||
vc_prompt_prefix = "模仿输入音频中的声音特征。"
|
||
else:
|
||
vc_prompt_prefix = "Clone the voice in the provided audio prompt."
|
||
|
||
if ref_audio is not None:
|
||
sys_msgs = {"role": "user", "content": [vc_prompt_prefix, ref_audio]}
|
||
else:
|
||
raise ValueError("ref_audio con't be None in voice_cloning mode.")
|
||
|
||
return sys_msgs
|
||
else:
|
||
sys_prompt = "You are a helpful assistant. You can accept audio and text input and output voice and text."
|
||
sys_msgs = {"role": "user", "content": [sys_prompt]}
|
||
|
||
return sys_msgs
|
||
|
||
def generate(
|
||
self,
|
||
input_ids=None,
|
||
pixel_values=None,
|
||
tgt_sizes=None,
|
||
audio_features=None,
|
||
audio_feature_lens=None,
|
||
image_bound=None,
|
||
audio_bounds=None,
|
||
spk_bounds=None,
|
||
attention_mask=None,
|
||
tokenizer=None,
|
||
vision_hidden_states=None,
|
||
stream=False,
|
||
**kwargs,
|
||
):
|
||
assert input_ids is not None
|
||
assert len(input_ids) == len(pixel_values)
|
||
|
||
model_inputs = {
|
||
"input_ids": input_ids,
|
||
"audio_features": audio_features,
|
||
"audio_feature_lens": audio_feature_lens,
|
||
"image_bound": image_bound,
|
||
"audio_bounds": audio_bounds,
|
||
"spk_bounds": spk_bounds,
|
||
}
|
||
|
||
if vision_hidden_states is None:
|
||
model_inputs["pixel_values"] = pixel_values
|
||
model_inputs["tgt_sizes"] = tgt_sizes
|
||
else:
|
||
model_inputs["vision_hidden_states"] = vision_hidden_states
|
||
|
||
model_output = {}
|
||
with torch.inference_mode():
|
||
model_inputs["inputs_embeds"], vision_hidden_states = self.get_vllm_embedding(model_inputs)
|
||
model_inputs["inputs_embeds"] = self.get_omni_embedding(
|
||
model_inputs,
|
||
input_embeddings=model_inputs["inputs_embeds"],
|
||
chunk_length=self.config.audio_chunk_length,
|
||
)
|
||
|
||
if stream:
|
||
result = self._decode_stream(model_inputs["inputs_embeds"], tokenizer, **kwargs)
|
||
# if stream return TextIteratorStreamer and output is empty
|
||
outputs = {}
|
||
else:
|
||
outputs = self._decode(model_inputs["inputs_embeds"], tokenizer, attention_mask, **kwargs)
|
||
|
||
result = self._decode_text(outputs.sequences, tokenizer)
|
||
|
||
return result, outputs
|
||
|
||
def chat(
|
||
self,
|
||
image=None,
|
||
msgs=None,
|
||
tokenizer=None,
|
||
processor=None,
|
||
vision_hidden_states=None,
|
||
max_new_tokens=2048,
|
||
min_new_tokens=0,
|
||
sampling=True,
|
||
max_inp_length=32768,
|
||
stream=False,
|
||
chunk_input=True,
|
||
omni_input=False,
|
||
max_slice_nums=None,
|
||
use_image_id=None,
|
||
use_tts_template=False,
|
||
generate_audio=False,
|
||
return_spk_embed=False,
|
||
return_dict=False,
|
||
output_audio_path=None,
|
||
**kwargs,
|
||
):
|
||
"""
|
||
Unified chat function
|
||
|
||
Args:
|
||
image: use for batch_size=1 vqa, It is not recommended to continue to use this parameter
|
||
msgs: the input chat msgs, support text: (string) / image: (PIL.Image) / audio (numpy.ndarray)
|
||
tokenizer: tokenizer for llm
|
||
processor: if None, use the default processor
|
||
max_new_tokens: the maximum length of the generation
|
||
min_new_tokens: the minimum length of the generation
|
||
sampling: whether to use sampling decoding or beam search decoding
|
||
max_inp_length: the maximum length of input
|
||
stream: whether to return generator, only used when tts is not required
|
||
chunk_input: whether to split audio into 1s chunks
|
||
omni_input: determine whether it is omni mode
|
||
max_slice_nums: control the maximum number of image slices
|
||
use_image_id: for video understanding or omni understanding, use_image_id should be False
|
||
use_tts_template: if the msgs contain audio, use_tts_template should be True
|
||
generate_audio: whether to generate audio output, only used when return_dict=True
|
||
return_spk_embed: whether to return spk embedding, only used when return_dict=True
|
||
return_dict: whether to return dict
|
||
output_audio_path: audio save path when generate_audio
|
||
**kwargs:
|
||
"""
|
||
if isinstance(msgs[0], list):
|
||
batched = True
|
||
else:
|
||
batched = False
|
||
|
||
if generate_audio or return_spk_embed:
|
||
return_dict = True
|
||
|
||
msgs_list = msgs
|
||
images_list = image
|
||
|
||
if batched is False:
|
||
images_list, msgs_list = [images_list], [msgs_list]
|
||
else:
|
||
assert images_list is None, "Please integrate image to msgs when using batch inference."
|
||
images_list = [None] * len(msgs_list)
|
||
assert len(images_list) == len(msgs_list), "The batch dim of images_list and msgs_list should be the same."
|
||
|
||
if processor is None:
|
||
if self.processor is None:
|
||
self.processor = AutoProcessor.from_pretrained(self.config._name_or_path, trust_remote_code=True)
|
||
processor = self.processor
|
||
|
||
assert (
|
||
self.config.query_num == processor.image_processor.image_feature_size
|
||
), "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
|
||
assert (
|
||
self.config.patch_size == processor.image_processor.patch_size
|
||
), "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
|
||
assert (
|
||
self.config.use_image_id == processor.image_processor.use_image_id
|
||
), "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
|
||
assert (
|
||
self.config.slice_config.max_slice_nums == processor.image_processor.max_slice_nums
|
||
), "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
|
||
assert (
|
||
self.config.slice_mode == processor.image_processor.slice_mode
|
||
), "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
|
||
|
||
prompts_lists = []
|
||
input_images_list = []
|
||
input_audios_list = []
|
||
audio_parts_list = []
|
||
|
||
for image, msgs in zip(images_list, msgs_list):
|
||
if isinstance(msgs, str):
|
||
msgs = json.loads(msgs)
|
||
copy_msgs = deepcopy(msgs)
|
||
|
||
assert len(msgs) > 0, "msgs is empty"
|
||
assert sampling or not stream, "if use stream mode, make sure sampling=True"
|
||
|
||
if image is not None and isinstance(copy_msgs[0]["content"], str):
|
||
copy_msgs[0]["content"] = [image, copy_msgs[0]["content"]]
|
||
|
||
images = []
|
||
audios = []
|
||
audio_parts = []
|
||
for i, msg in enumerate(copy_msgs):
|
||
role = msg["role"]
|
||
content = msg["content"]
|
||
assert role in ["system", "user", "assistant"]
|
||
if i == 0:
|
||
assert role in ["user", "system"], "The role of first msg should be user"
|
||
if isinstance(content, str):
|
||
content = [content]
|
||
cur_msgs = []
|
||
for c in content:
|
||
if isinstance(c, Image.Image):
|
||
images.append(c)
|
||
cur_msgs.append("(<image>./</image>)")
|
||
elif isinstance(c, np.ndarray): # audio
|
||
audios.append(c)
|
||
audio_parts.append(i)
|
||
cur_msgs.append("(<audio>./</audio>)")
|
||
use_tts_template = True
|
||
elif isinstance(c, str):
|
||
cur_msgs.append(c)
|
||
if omni_input:
|
||
msg["content"] = "".join(cur_msgs)
|
||
else:
|
||
msg["content"] = "\n".join(cur_msgs)
|
||
|
||
prompts_lists.append(
|
||
processor.tokenizer.apply_chat_template(
|
||
copy_msgs,
|
||
tokenize=False,
|
||
add_generation_prompt=True,
|
||
chat_template=self.default_tts_chat_template if use_tts_template else None,
|
||
)
|
||
)
|
||
input_images_list.append(images)
|
||
input_audios_list.append(audios)
|
||
audio_parts_list.append(audio_parts)
|
||
|
||
inputs = processor(
|
||
prompts_lists,
|
||
input_images_list,
|
||
input_audios_list,
|
||
audio_parts_list,
|
||
max_slice_nums=max_slice_nums,
|
||
use_image_id=use_image_id,
|
||
chunk_input=chunk_input,
|
||
return_tensors="pt",
|
||
max_length=max_inp_length,
|
||
).to(self.device)
|
||
|
||
if sampling:
|
||
generation_config = {
|
||
"top_p": 0.8,
|
||
"top_k": 100,
|
||
"temperature": 0.7,
|
||
"do_sample": True,
|
||
"repetition_penalty": 1.01,
|
||
}
|
||
else:
|
||
generation_config = {
|
||
"num_beams": 3,
|
||
"repetition_penalty": 1.2,
|
||
}
|
||
|
||
if min_new_tokens > 0:
|
||
generation_config["min_new_tokens"] = min_new_tokens
|
||
|
||
generation_config.update((k, kwargs[k]) for k in generation_config.keys() & kwargs.keys())
|
||
|
||
inputs.pop("image_sizes")
|
||
with torch.inference_mode():
|
||
res, outputs = self.generate(
|
||
**inputs,
|
||
tokenizer=tokenizer,
|
||
max_new_tokens=max_new_tokens,
|
||
vision_hidden_states=vision_hidden_states,
|
||
stream=stream,
|
||
**generation_config,
|
||
)
|
||
|
||
if stream:
|
||
|
||
def stream_gen():
|
||
for text in res:
|
||
for term in self.terminators:
|
||
text = text.replace(term, "")
|
||
yield text
|
||
|
||
if return_dict:
|
||
return OmniOutput(text=stream_gen())
|
||
else:
|
||
return stream_gen()
|
||
|
||
else:
|
||
spk_embeds = wav_numpy = sr = None
|
||
|
||
if batched:
|
||
answer = res
|
||
else:
|
||
answer = res[0]
|
||
|
||
if use_tts_template and generate_audio:
|
||
mel_spec = self._generate_mel_spec(inputs, outputs, answer)
|
||
wav_numpy, sr = self.decode_mel_to_audio(mel_spec, output_audio_path)
|
||
|
||
if return_spk_embed:
|
||
spk_embeds = self._get_last_spk_embeds(inputs, outputs)
|
||
|
||
if isinstance(answer, list):
|
||
answer = [i.replace(tokenizer.tts_end, "") for i in answer]
|
||
else:
|
||
answer = answer.replace(tokenizer.tts_end, "")
|
||
|
||
if return_dict:
|
||
return OmniOutput(text=answer, spk_embeds=spk_embeds, audio_wav=wav_numpy, sampling_rate=sr)
|
||
else:
|
||
return answer
|
||
|
||
@torch.inference_mode()
|
||
def streaming_prefill(
|
||
self,
|
||
session_id,
|
||
msgs,
|
||
tokenizer,
|
||
omni_input=True,
|
||
max_slice_nums=None,
|
||
ls_temperature=1.0,
|
||
**kwargs,
|
||
):
|
||
"""
|
||
Streaming video/audio input and output audio stream, Only support batch_size=1
|
||
Args:
|
||
session_id: Note: new connection should use a new session_id
|
||
"""
|
||
assert session_id is not None
|
||
if self.session_id is None or session_id != self.session_id: # new session
|
||
self.is_first = True
|
||
else:
|
||
self.is_first = False
|
||
|
||
images = []
|
||
audios = []
|
||
|
||
assert len(msgs) == 1
|
||
copy_msgs = deepcopy(msgs)
|
||
msg = copy_msgs[0]
|
||
|
||
assert msg["role"] in ["system", "user", "assistant"]
|
||
|
||
content = msg["content"]
|
||
cur_msgs = []
|
||
for j, c in enumerate(content):
|
||
if isinstance(c, Image.Image):
|
||
images.append(c)
|
||
cur_msgs.append("(<image>./</image>)")
|
||
elif isinstance(c, np.ndarray): # audio
|
||
audios.append(c)
|
||
cur_msgs.append("(<audio>./</audio>)")
|
||
elif isinstance(c, str):
|
||
cur_msgs.append(c)
|
||
else:
|
||
logger.error("Invalid content type:", c)
|
||
|
||
cur_contents = "".join(cur_msgs) if omni_input else "\n".join(omni_input)
|
||
if not self.is_first and self.new_user_msg and msg["role"] == "user": # new user add im_start
|
||
if self.llm_generated:
|
||
if self.llm_generate_completed:
|
||
msg["content"] = "<|im_end|>\n<|im_start|>user\n" + cur_contents
|
||
else: # break llm gen, add tts_eos
|
||
msg["content"] = "<|tts_eos|><|im_end|>\n<|im_start|>user\n" + cur_contents
|
||
else:
|
||
msg["content"] = "<|im_start|>user\n" + cur_contents
|
||
self.new_user_msg = False
|
||
else:
|
||
msg["content"] = cur_contents
|
||
|
||
if msg["role"] in ["system", "assistant"]:
|
||
self.new_user_msg = True
|
||
self.audio_past_key_values = None # apm kv cache
|
||
|
||
if self.is_first:
|
||
# init pask_key_values
|
||
logger.info(f"new session_id: {session_id}, reset kv cache")
|
||
self.reset_session()
|
||
self.session_id = session_id
|
||
|
||
prompt = tokenizer.apply_chat_template(
|
||
copy_msgs, tokenize=False, add_generation_prompt=False, chat_template=self.default_tts_chat_template
|
||
)
|
||
add_special_tokens = True # add bos
|
||
else:
|
||
prompt = copy_msgs[0]["content"]
|
||
add_special_tokens = False
|
||
|
||
model_inputs = self.processor(
|
||
[prompt],
|
||
[images],
|
||
[audios],
|
||
max_slice_nums=1 if max_slice_nums is None else max_slice_nums,
|
||
use_image_id=False,
|
||
chunk_input=True,
|
||
return_tensors="pt",
|
||
max_length=None,
|
||
sampling_rate=16000,
|
||
add_special_tokens=add_special_tokens,
|
||
).to(self.device)
|
||
|
||
# 1. prepare input embeddings
|
||
model_inputs["inputs_embeds"], _ = self.get_vllm_embedding(model_inputs)
|
||
# get audio embedding with audio_past_key_values
|
||
inputs_embeds = self.get_omni_embedding(
|
||
model_inputs, input_embeddings=model_inputs["inputs_embeds"], stream_input=True
|
||
)
|
||
|
||
if self.is_first:
|
||
# clean audio_past_key_values after first prefill
|
||
self.audio_past_key_values = None
|
||
|
||
if self.llm_past_key_values is not None:
|
||
cache_length = self.llm_past_key_values[0][0].shape[2]
|
||
else:
|
||
cache_length = 0
|
||
|
||
attention_mask = torch.ones((1, cache_length + inputs_embeds.shape[1]), dtype=torch.bool, device=self.device)
|
||
|
||
# 2. do prefill and predict listen/speak label
|
||
outputs = self.llm(
|
||
past_key_values=self.llm_past_key_values,
|
||
inputs_embeds=inputs_embeds,
|
||
attention_mask=attention_mask,
|
||
position_ids=None, # position_ids,
|
||
use_cache=True,
|
||
return_dict=True,
|
||
)
|
||
self.llm_past_key_values = outputs["past_key_values"]
|
||
return
|
||
|
||
@torch.inference_mode()
|
||
def streaming_generate(
|
||
self,
|
||
session_id,
|
||
tokenizer,
|
||
max_new_tokens=512,
|
||
min_new_tokens=0,
|
||
sampling=True,
|
||
generate_audio=True,
|
||
enable_regenerate=False,
|
||
**kwargs,
|
||
):
|
||
"""
|
||
Streaming video/audio input and output audio stream
|
||
Args:
|
||
"""
|
||
if sampling:
|
||
generation_config = {
|
||
"top_p": 0.8,
|
||
"top_k": 100,
|
||
"temperature": 0.7,
|
||
"do_sample": True,
|
||
"repetition_penalty": 1.01,
|
||
}
|
||
else:
|
||
generation_config = {
|
||
"num_beams": 3,
|
||
"repetition_penalty": 1.2,
|
||
}
|
||
generation_config["min_new_tokens"] = min_new_tokens
|
||
generation_config.update((k, kwargs[k]) for k in generation_config.keys() & kwargs.keys())
|
||
|
||
# do generate
|
||
# reset buffer
|
||
self.new_user_msg = True
|
||
self.llm_generated = True
|
||
self.llm_generate_completed = False
|
||
self.audio_past_key_values = None # apm kv cache
|
||
|
||
terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
|
||
generate_prompt = "<|im_end|>\n<|im_start|>assistant\n<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>"
|
||
input_ids = tokenizer(generate_prompt, return_tensors="pt", add_special_tokens=False)["input_ids"].cuda()
|
||
|
||
spk_start_idx = torch.where(input_ids[0] == tokenizer.spk_start_id)[0]
|
||
spk_end_idx = torch.where(input_ids[0] == tokenizer.spk_end_id)[0]
|
||
spk_bounds = [
|
||
torch.hstack([(spk_start_idx + 1).unsqueeze(-1), spk_end_idx.unsqueeze(-1)])
|
||
] # List[Tensor], (1,2)
|
||
|
||
cache_length = past_length = self.llm_past_key_values[0][0].shape[2]
|
||
attention_mask = torch.ones((1, cache_length + input_ids.shape[1]), dtype=torch.bool, device=self.device)
|
||
|
||
generation_config["max_new_tokens"] = max_new_tokens
|
||
streamer = self.llm_generate_chunk(input_ids, attention_mask, tokenizer, terminators, generation_config)
|
||
|
||
if generate_audio:
|
||
result = self._generate_mel_spec_audio_streaming(
|
||
spk_bounds, streamer, output_chunk_size=25, enable_regenerate=enable_regenerate
|
||
)
|
||
return result
|
||
else:
|
||
return streamer
|
||
|
||
def llm_generate_chunk(self, input_ids, attention_mask, tokenizer, terminators, generation_config):
|
||
def check_uncompleted_token(ids):
|
||
cur_text = tokenizer.decode(ids)
|
||
end = len(ids)
|
||
while cur_text[-1] == "<EFBFBD>":
|
||
end -= 1
|
||
if end == 0:
|
||
break
|
||
cur_text = tokenizer.decode(ids[:end])
|
||
return end
|
||
|
||
max_new_tokens = int(generation_config.pop("max_new_tokens", 2048))
|
||
new_len = 0
|
||
first_chunk = True
|
||
eos = False
|
||
left_ids = None
|
||
|
||
while True:
|
||
outputs = self.llm.generate(
|
||
input_ids=input_ids,
|
||
past_key_values=self.llm_past_key_values,
|
||
attention_mask=attention_mask,
|
||
use_cache=True,
|
||
max_new_tokens=3, # reduce first token delay
|
||
pad_token_id=0,
|
||
output_hidden_states=True if first_chunk else False,
|
||
return_dict_in_generate=True,
|
||
eos_token_id=terminators,
|
||
**generation_config,
|
||
)
|
||
if outputs.sequences[0, -1] in terminators:
|
||
eos = True
|
||
input_len = input_ids.shape[1]
|
||
cur_ids = outputs.sequences[:, input_len:]
|
||
new_len += cur_ids.shape[1]
|
||
|
||
if left_ids is not None and left_ids.shape[1] > 0:
|
||
cur_ids = torch.cat([left_ids, cur_ids], dim=1)
|
||
end = check_uncompleted_token(cur_ids[0])
|
||
left_ids = cur_ids[:, end:]
|
||
cur_ids = cur_ids[:, :end]
|
||
text = self._decode_text(cur_ids, tokenizer)[0] if end > 0 else ""
|
||
|
||
self.llm_past_key_values = outputs.past_key_values
|
||
input_ids = outputs.sequences[:, -1:]
|
||
cache_length = past_length = self.llm_past_key_values[0][0].shape[2]
|
||
attention_mask = torch.ones((1, cache_length + input_ids.shape[1]), dtype=torch.bool, device=self.device)
|
||
|
||
res = {"text": text}
|
||
if first_chunk:
|
||
res["hidden_states"] = outputs.hidden_states
|
||
first_chunk = False
|
||
yield res
|
||
|
||
if eos:
|
||
self.llm_generate_completed = True
|
||
break
|
||
if new_len >= max_new_tokens:
|
||
logger.debug(f"LLM generation {new_len} exceeds max_new_tokens({max_new_tokens}), break.")
|
||
break
|
||
|
||
def prepare_tts_text(self, text):
|
||
tts_tokens = self.tts_processor.text_tokenizer.encode(text, add_special_tokens=False)
|
||
tts_tokens_len = len(tts_tokens)
|
||
if tts_tokens_len < self.tts.streaming_text_reserved_len:
|
||
num_pad_tokens = self.tts.streaming_text_reserved_len - tts_tokens_len
|
||
|
||
pad_str = "[Etts]" + "[PAD]" * (num_pad_tokens - 1)
|
||
else:
|
||
tts_tokens = tts_tokens[0 : self.tts.streaming_text_reserved_len]
|
||
tts_tokens_len = len(tts_tokens)
|
||
text = self.tts_processor.text_tokenizer.decode(tts_tokens, add_special_tokens=False)
|
||
pad_str = ""
|
||
spk_emb_placeholder_tts = "[spk_emb]" * self.tts.num_spk_embs
|
||
|
||
new_text_tts = f"[Stts]{spk_emb_placeholder_tts}{text}{pad_str}[Ptts]"
|
||
return new_text_tts, tts_tokens_len
|
||
|
||
def get_tts_text_start_token_ids(self):
|
||
text = "[Stts]" + "[spk_emb]" * self.tts.num_spk_embs
|
||
tts_input_ids = self.tts_processor.text_tokenizer(text, return_tensors="pt", add_special_tokens=False)[
|
||
"input_ids"
|
||
].cuda()
|
||
return tts_input_ids
|
||
|
||
def _build_streaming_mask(self, tts_tokens_len):
|
||
tts_sequence_full_length = (
|
||
1 + self.tts.num_spk_embs * self.tts.use_speaker_embedding + self.tts.streaming_text_reserved_len + 1
|
||
)
|
||
streaming_attention_mask = torch.zeros(tts_sequence_full_length, dtype=torch.int8)
|
||
streaming_attention_mask[0 : 1 + 1 + tts_tokens_len + 1] = 1
|
||
streaming_attention_mask[-1] = 1
|
||
return streaming_attention_mask
|
||
|
||
def _get_last_spk_embeds(self, inputs, outputs):
|
||
last_hidden_states = [hs[-1] for hs in outputs.hidden_states]
|
||
|
||
# batch = 1
|
||
last_hidden_states = torch.vstack([i[0] for i in last_hidden_states])
|
||
|
||
# last spk
|
||
spk_bound = inputs["spk_bounds"][0][-1]
|
||
|
||
spk_embeds = last_hidden_states[spk_bound[0] : spk_bound[1]]
|
||
return spk_embeds
|
||
|
||
def _generate_mel_spec(self, inputs, outputs, text, output_chunk_size=25, tts_max_new_tokens=2048):
|
||
spk_embeds = self._get_last_spk_embeds(inputs, outputs)
|
||
|
||
text = text.split("<|tts_bos|>")[-1]
|
||
gen_text = text.split("<|tts_eos|>")[0]
|
||
tts_text, tts_token_lens = self.prepare_tts_text(gen_text)
|
||
tts_inputs = self.tts_processor.text_tokenizer.encode(tts_text, add_special_tokens=False)
|
||
tts_input_ids = torch.Tensor(tts_inputs).unsqueeze(0).to("cuda", dtype=torch.long)
|
||
streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device)
|
||
|
||
logits_warpers, logits_processors = gen_logits(
|
||
num_code=626, top_P=self.tts.top_p, top_K=self.tts.top_k, repetition_penalty=self.tts.repetition_penalty
|
||
)
|
||
|
||
condition_length = (
|
||
1 + self.tts.use_speaker_embedding * self.tts.num_spk_embs + self.tts.streaming_text_reserved_len + 1
|
||
)
|
||
|
||
dtype = self.tts.emb_text.weight.dtype
|
||
emb = torch.zeros(1, condition_length, self.tts.num_vq, dtype=dtype, device=self.tts.device)
|
||
past_key_values = [
|
||
(
|
||
torch.zeros(
|
||
1,
|
||
self.tts.config.num_attention_heads,
|
||
condition_length - 1,
|
||
self.tts.config.hidden_size // self.tts.config.num_attention_heads,
|
||
dtype=emb.dtype,
|
||
device=self.tts.device,
|
||
),
|
||
torch.zeros(
|
||
1,
|
||
self.tts.config.num_attention_heads,
|
||
condition_length - 1,
|
||
self.tts.config.hidden_size // self.tts.config.num_attention_heads,
|
||
dtype=emb.dtype,
|
||
device=self.tts.device,
|
||
),
|
||
)
|
||
for _ in range(self.tts.config.num_hidden_layers)
|
||
]
|
||
|
||
audio_input_ids = torch.zeros(1, condition_length, self.tts.num_vq, dtype=torch.long, device=self.tts.device)
|
||
|
||
eos_lab = False
|
||
for chunk_idx in range(math.ceil(emb.shape[1] / self.tts.streaming_text_chunk_size)):
|
||
if chunk_idx == 0:
|
||
begin = chunk_idx * self.tts.streaming_text_chunk_size + 0
|
||
end = (
|
||
(chunk_idx + 1) * self.tts.streaming_text_chunk_size
|
||
+ 1
|
||
+ self.tts.use_speaker_embedding * self.tts.num_spk_embs
|
||
)
|
||
else:
|
||
begin = (
|
||
chunk_idx * self.tts.streaming_text_chunk_size
|
||
+ 1
|
||
+ self.tts.use_speaker_embedding * self.tts.num_spk_embs
|
||
)
|
||
end = min(
|
||
(chunk_idx + 1) * self.tts.streaming_text_chunk_size
|
||
+ 1
|
||
+ self.tts.use_speaker_embedding * self.tts.num_spk_embs,
|
||
condition_length - 1,
|
||
)
|
||
|
||
if end - begin > 0:
|
||
text_input_ids = tts_input_ids[:, begin:end]
|
||
position_ids = torch.arange(begin, end, dtype=torch.long, device=self.tts.device).unsqueeze(0)
|
||
|
||
if begin == 0:
|
||
past_key_values = self.tts.prefill_text(
|
||
input_ids=text_input_ids,
|
||
position_ids=position_ids,
|
||
past_key_values=past_key_values,
|
||
lm_spk_emb_last_hidden_states=spk_embeds,
|
||
)
|
||
else:
|
||
past_key_values = self.tts.prefill_text(
|
||
input_ids=text_input_ids, position_ids=position_ids, past_key_values=past_key_values
|
||
)
|
||
|
||
outputs = self.tts.generate(
|
||
input_ids=audio_input_ids,
|
||
past_key_values=past_key_values,
|
||
streaming_tts_text_mask=streaming_tts_text_mask,
|
||
max_new_token=output_chunk_size,
|
||
force_no_stop=self.force_no_stop,
|
||
temperature=torch.tensor([0.1, 0.3, 0.1, 0.3], dtype=torch.float, device=self.tts.device),
|
||
eos_token=torch.tensor([625], dtype=torch.long, device=self.tts.device),
|
||
logits_warpers=logits_warpers,
|
||
logits_processors=logits_processors,
|
||
)
|
||
audio_input_ids = outputs.audio_input_ids
|
||
past_key_values = outputs.past_key_values
|
||
|
||
if outputs.finished:
|
||
logger.debug("Generation finished.")
|
||
eos_lab = True
|
||
break
|
||
|
||
if not eos_lab:
|
||
logger.debug("eos_lab False, Generation continue.")
|
||
while True:
|
||
outputs = self.tts.generate(
|
||
input_ids=audio_input_ids,
|
||
past_key_values=past_key_values,
|
||
streaming_tts_text_mask=streaming_tts_text_mask,
|
||
max_new_token=output_chunk_size,
|
||
force_no_stop=self.force_no_stop,
|
||
temperature=torch.tensor([0.1, 0.3, 0.1, 0.3], dtype=torch.float, device=self.tts.device),
|
||
eos_token=torch.tensor([625], dtype=torch.long, device=self.tts.device),
|
||
logits_warpers=logits_warpers,
|
||
logits_processors=logits_processors,
|
||
)
|
||
|
||
audio_input_ids = outputs.audio_input_ids
|
||
past_key_values = outputs.past_key_values
|
||
|
||
if outputs.finished:
|
||
logger.debug("Generation finished.")
|
||
break
|
||
if outputs.new_ids.shape[1] > tts_max_new_tokens:
|
||
logger.debug(f"Generation length > {tts_max_new_tokens}, stopped.")
|
||
break
|
||
|
||
mel_spec = self.tts.decode_to_mel_specs(outputs.new_ids)
|
||
return mel_spec
|
||
|
||
def _linear_overlap_add2_wav(self, frames: List[torch.Tensor], overlap: int):
|
||
"""
|
||
Merge two audio waveforms with smooth in streaming audio generation.
|
||
Borrowed some codes from `https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py`
|
||
"""
|
||
assert len(frames) == 2
|
||
device = frames[0].device
|
||
dtype = frames[0].dtype
|
||
# shape = frames[0].shape[:-1]
|
||
|
||
frame0_length = frames[0].shape[-1]
|
||
frame1_length = frames[1].shape[-1]
|
||
total_size = frame0_length + frame1_length - overlap
|
||
weight_len = max(frame0_length, frame1_length) + overlap
|
||
t = torch.linspace(0, 1, weight_len + 2, device=device, dtype=dtype)[1:-1]
|
||
weight = 0.5 - (t - 0.5).abs()
|
||
|
||
sum_weight = torch.zeros(total_size, device=device, dtype=dtype)
|
||
out = torch.zeros(total_size, device=device, dtype=dtype)
|
||
offset: int = 0
|
||
|
||
out[offset : offset + frame0_length] += weight[-frame0_length:] * frames[0]
|
||
sum_weight[offset : offset + frame0_length] += weight[-frame0_length:]
|
||
offset += frame0_length - overlap
|
||
out[offset : offset + frame1_length] += weight[:frame1_length] * frames[1]
|
||
sum_weight[offset : offset + frame1_length] += weight[:frame1_length]
|
||
|
||
assert sum_weight.min() > 0
|
||
out = out / sum_weight
|
||
return out[:frame0_length], out[frame0_length:]
|
||
|
||
def _generate_mel_spec_audio_streaming(
|
||
self,
|
||
spk_bounds,
|
||
streamer,
|
||
output_chunk_size=25,
|
||
spk_embeds=None,
|
||
prev_seg_text_ids=None,
|
||
prev_seg_text_left="",
|
||
prev_seg_audio_ids=None,
|
||
enable_regenerate=False,
|
||
):
|
||
# get spk_embedding
|
||
gen_text = ""
|
||
tts_text = ""
|
||
new_segment_gen = False
|
||
if spk_embeds is None:
|
||
spk_bound = spk_bounds[0][-1]
|
||
r = next(streamer)
|
||
txt = r["text"]
|
||
gen_text += txt.split("<|tts_eos|>")[0]
|
||
tts_text, tts_token_lens = self.prepare_tts_text(gen_text)
|
||
last_hidden_states = r["hidden_states"][0][-1][0] # output: (input_seq_len, dim)
|
||
spk_embeds = last_hidden_states[spk_bound[0] : spk_bound[1]]
|
||
|
||
# init past_key_values
|
||
logits_warpers, logits_processors = gen_logits(
|
||
num_code=626, top_P=self.tts.top_p, top_K=self.tts.top_k, repetition_penalty=self.tts.repetition_penalty
|
||
)
|
||
condition_length = (
|
||
1 + self.tts.use_speaker_embedding * self.tts.num_spk_embs + self.tts.streaming_text_reserved_len + 1
|
||
)
|
||
tts_start_token_len = 1 + self.tts.use_speaker_embedding * self.tts.num_spk_embs
|
||
dtype = self.tts.emb_text.weight.dtype
|
||
past_key_values = [
|
||
(
|
||
torch.zeros(
|
||
1,
|
||
self.tts.config.num_attention_heads,
|
||
condition_length - 1,
|
||
self.tts.config.hidden_size // self.tts.config.num_attention_heads,
|
||
dtype=dtype,
|
||
device=self.tts.device,
|
||
),
|
||
torch.zeros(
|
||
1,
|
||
self.tts.config.num_attention_heads,
|
||
condition_length - 1,
|
||
self.tts.config.hidden_size // self.tts.config.num_attention_heads,
|
||
dtype=dtype,
|
||
device=self.tts.device,
|
||
),
|
||
)
|
||
for _ in range(self.tts.config.num_hidden_layers)
|
||
]
|
||
audio_input_ids = torch.zeros(1, condition_length, self.tts.num_vq, dtype=torch.long, device=self.tts.device)
|
||
|
||
# prefill prev segment for smooth
|
||
chunk_idx = 0
|
||
new_ids_len = 0
|
||
prev_text_len = 0
|
||
if prev_seg_text_ids is not None and prev_seg_audio_ids is not None:
|
||
tts_token_lens = prev_seg_text_ids.shape[1]
|
||
# assert tts_token_lens % self.tts.streaming_text_chunk_size == 0
|
||
streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device)
|
||
position_ids = torch.arange(
|
||
0, tts_token_lens + tts_start_token_len, dtype=torch.long, device=self.tts.device
|
||
).unsqueeze(0)
|
||
|
||
text_input_ids = self.get_tts_text_start_token_ids()
|
||
text_input_ids = torch.cat([text_input_ids, prev_seg_text_ids], dim=1)
|
||
past_key_values = self.tts.prefill_text(
|
||
input_ids=text_input_ids,
|
||
position_ids=position_ids,
|
||
past_key_values=past_key_values,
|
||
lm_spk_emb_last_hidden_states=spk_embeds,
|
||
)
|
||
past_key_values = self.tts.prefill_audio_ids(
|
||
input_ids=prev_seg_audio_ids[:, :-1, :],
|
||
# not prefill last id, which will be input_id of next generation
|
||
past_key_values=past_key_values,
|
||
streaming_tts_text_mask=streaming_tts_text_mask,
|
||
)
|
||
|
||
# update init
|
||
chunk_idx += int(tts_token_lens / self.tts.streaming_text_chunk_size)
|
||
audio_input_ids = torch.cat([audio_input_ids, prev_seg_audio_ids], dim=1)
|
||
text = self.tts_processor.text_tokenizer.decode(prev_seg_text_ids[0].tolist(), add_special_tokens=False)
|
||
|
||
gen_text += text
|
||
gen_text += prev_seg_text_left
|
||
prev_text_len = len(gen_text) # takecare the position
|
||
new_ids_len += prev_seg_audio_ids.shape[1]
|
||
|
||
prev_wav = None
|
||
eos_lab = False
|
||
stop = False
|
||
shift_len = 180
|
||
voice_checker = VoiceChecker()
|
||
number_converter = NumberToTextConverter()
|
||
lang = None
|
||
gen_text_raw = gen_text
|
||
for t, r in enumerate(streamer):
|
||
t += 1
|
||
txt = r["text"]
|
||
txt = txt.split("<|tts_eos|>")[0]
|
||
gen_text_raw += txt
|
||
if t == 1 and txt == "" and prev_seg_text_ids is not None:
|
||
logger.warning("New segment is empty, generation finished.")
|
||
return
|
||
if t <= 2: # do just one time, more token greater certainty
|
||
lang = number_converter.detect_language(gen_text_raw)
|
||
gen_text += number_converter.replace_numbers_with_text(txt, lang).replace("*", "") # markdown **
|
||
|
||
# TODO speed up
|
||
tts_text, tts_token_lens = self.prepare_tts_text(gen_text)
|
||
|
||
if tts_token_lens >= self.tts.streaming_text_reserved_len - shift_len:
|
||
end_c = sentence_end(txt)
|
||
if end_c:
|
||
end_c_idx = gen_text.rfind(end_c)
|
||
assert end_c_idx != -1
|
||
text_left = gen_text[end_c_idx + 1 :]
|
||
gen_text = gen_text[: end_c_idx + 1]
|
||
tts_text, tts_token_lens = self.prepare_tts_text(gen_text)
|
||
new_segment_gen = True
|
||
logger.debug(
|
||
f"tts_text tokens {tts_token_lens} exceed {self.tts.streaming_text_reserved_len - shift_len}, starting a new segment generation"
|
||
)
|
||
break
|
||
|
||
if tts_token_lens >= (chunk_idx + 1) * self.tts.streaming_text_chunk_size:
|
||
|
||
# do prefill and generate
|
||
if chunk_idx == 0:
|
||
begin = 0
|
||
end = (chunk_idx + 1) * self.tts.streaming_text_chunk_size + tts_start_token_len
|
||
else:
|
||
begin = chunk_idx * self.tts.streaming_text_chunk_size + tts_start_token_len
|
||
end = min(
|
||
(chunk_idx + 1) * self.tts.streaming_text_chunk_size + tts_start_token_len, condition_length - 1
|
||
)
|
||
|
||
tts_input_ids = self.tts_processor.text_tokenizer(
|
||
tts_text, return_tensors="pt", add_special_tokens=False
|
||
)["input_ids"].cuda()
|
||
text_input_ids = tts_input_ids[:, begin:end]
|
||
streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device)
|
||
position_ids = torch.arange(begin, end, dtype=torch.long, device=self.tts.device).unsqueeze(0)
|
||
|
||
past_key_values = self.tts.prefill_text(
|
||
input_ids=text_input_ids,
|
||
position_ids=position_ids,
|
||
past_key_values=past_key_values,
|
||
lm_spk_emb_last_hidden_states=spk_embeds if chunk_idx == 0 else None,
|
||
)
|
||
outputs = self.tts.generate(
|
||
input_ids=audio_input_ids,
|
||
past_key_values=past_key_values,
|
||
streaming_tts_text_mask=streaming_tts_text_mask,
|
||
max_new_token=output_chunk_size,
|
||
force_no_stop=self.force_no_stop,
|
||
temperature=torch.tensor([0.1, 0.3, 0.1, 0.3], dtype=torch.float, device=self.tts.device),
|
||
eos_token=torch.tensor([625], dtype=torch.long, device=self.tts.device),
|
||
logits_warpers=logits_warpers,
|
||
logits_processors=logits_processors,
|
||
)
|
||
audio_input_ids = (
|
||
outputs.audio_input_ids
|
||
) # [1,seq_len,4] seq_len=tts.streaming_text_reserved_len + 3 + len(new_ids)
|
||
past_key_values = outputs.past_key_values
|
||
chunk_idx += 1
|
||
|
||
mel_spec = self.tts.decode_to_mel_specs(outputs.new_ids[:, max(new_ids_len - 4, 0) :, :])
|
||
new_ids_len = outputs.new_ids.shape[1] # [1, seq_len, 4]
|
||
|
||
wav_np, sr = self.decode_mel_to_audio(mel_spec) # [1,100,50] -> [50*256]
|
||
|
||
if enable_regenerate:
|
||
if prev_wav is not None:
|
||
check_wav_np = wav_np[2048:].cpu().numpy() # 2*4*256(hop)
|
||
check_mel = mel_spec[0, :, 8:].cpu().numpy() # 2*4
|
||
else:
|
||
check_wav_np = wav_np.cpu().numpy()
|
||
check_mel = mel_spec[0].cpu().numpy()
|
||
if enable_regenerate and voice_checker.is_bad(check_wav_np, check_mel, chunk_size=2560):
|
||
voice_checker.reset()
|
||
# regenerate
|
||
N = output_chunk_size if prev_wav is None else output_chunk_size * 2
|
||
past_kv = []
|
||
for i in range(len(past_key_values)):
|
||
past_kv.append(
|
||
(
|
||
past_key_values[i][0][:, :, :-N, :], # .clone(),
|
||
past_key_values[i][1][:, :, :-N, :], # .clone(),
|
||
)
|
||
)
|
||
outputs = self.tts.generate(
|
||
input_ids=audio_input_ids[:, :-N, :],
|
||
past_key_values=past_kv,
|
||
streaming_tts_text_mask=streaming_tts_text_mask,
|
||
max_new_token=N,
|
||
force_no_stop=self.force_no_stop,
|
||
temperature=torch.tensor([0.1, 0.3, 0.1, 0.3], dtype=torch.float, device=self.tts.device),
|
||
eos_token=torch.tensor([625], dtype=torch.long, device=self.tts.device),
|
||
logits_warpers=logits_warpers,
|
||
logits_processors=logits_processors,
|
||
)
|
||
audio_input_ids = outputs.audio_input_ids
|
||
past_key_values = outputs.past_key_values
|
||
|
||
new_ids_len -= N
|
||
mel_spec = self.tts.decode_to_mel_specs(outputs.new_ids[:, new_ids_len:, :])
|
||
new_ids_len = outputs.new_ids.shape[1] # [1, seq_len, 4]
|
||
wav_np, sr = self.decode_mel_to_audio(mel_spec)
|
||
|
||
if prev_wav is not None:
|
||
wav_y = wav_np[: len(prev_wav)]
|
||
prev_wav = wav_np[len(prev_wav) :]
|
||
cur_text = gen_text_raw[prev_text_len:]
|
||
prev_text_len = len(gen_text_raw)
|
||
yield OmniOutput(text=cur_text, audio_wav=wav_y, sampling_rate=sr)
|
||
|
||
else:
|
||
prev_wav = wav_np
|
||
else:
|
||
# smooth wav
|
||
if prev_wav is not None:
|
||
wav_np, prev_wav = self._linear_overlap_add2_wav(
|
||
[prev_wav, wav_np], overlap=512 * 4
|
||
) # tts_hop256*2
|
||
cur_text = gen_text_raw[prev_text_len:]
|
||
prev_text_len = len(gen_text_raw)
|
||
yield OmniOutput(text=cur_text, audio_wav=wav_np, sampling_rate=sr)
|
||
|
||
else:
|
||
prev_wav = wav_np
|
||
|
||
if outputs.finished:
|
||
logger.debug("Generation finished.")
|
||
eos_lab = True
|
||
break
|
||
|
||
if not eos_lab and tts_text:
|
||
logger.debug("eos_lab False, Generation continue.")
|
||
|
||
if chunk_idx == 0:
|
||
begin = 0
|
||
else:
|
||
begin = chunk_idx * self.tts.streaming_text_chunk_size + tts_start_token_len
|
||
end = tts_token_lens + tts_start_token_len + 1 # 1 for [Etts]
|
||
if end > begin:
|
||
tts_input_ids = self.tts_processor.text_tokenizer(
|
||
tts_text, return_tensors="pt", add_special_tokens=False
|
||
)["input_ids"].cuda()
|
||
text_input_ids = tts_input_ids[:, begin:end]
|
||
streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device)
|
||
position_ids = torch.arange(begin, end, dtype=torch.long, device=self.tts.device).unsqueeze(0)
|
||
|
||
past_key_values = self.tts.prefill_text(
|
||
input_ids=text_input_ids,
|
||
position_ids=position_ids,
|
||
past_key_values=past_key_values,
|
||
lm_spk_emb_last_hidden_states=spk_embeds if chunk_idx == 0 else None,
|
||
)
|
||
|
||
while True:
|
||
# temp = [0.1, 0.3, 0.1, 0.3] if chunk_idx < 21 else [0.1] * self.tts.num_vq
|
||
outputs = self.tts.generate(
|
||
input_ids=audio_input_ids,
|
||
past_key_values=past_key_values,
|
||
streaming_tts_text_mask=streaming_tts_text_mask,
|
||
max_new_token=output_chunk_size,
|
||
force_no_stop=self.force_no_stop,
|
||
# temperature=torch.tensor([0.1] * self.tts.num_vq, dtype=torch.float, device=self.tts.device),
|
||
temperature=torch.tensor([0.1, 0.3, 0.1, 0.3], dtype=torch.float, device=self.tts.device),
|
||
eos_token=torch.tensor([625], dtype=torch.long, device=self.tts.device),
|
||
logits_warpers=logits_warpers,
|
||
logits_processors=logits_processors,
|
||
)
|
||
audio_input_ids = outputs.audio_input_ids
|
||
past_key_values = outputs.past_key_values
|
||
chunk_idx += 1
|
||
|
||
mel_spec = self.tts.decode_to_mel_specs(outputs.new_ids[:, max(new_ids_len - 4, 0) :, :])
|
||
new_ids_len = outputs.new_ids.shape[1] # [1, seq_len, 4]
|
||
|
||
wav_np, sr = self.decode_mel_to_audio(mel_spec)
|
||
|
||
if enable_regenerate:
|
||
if prev_wav is not None:
|
||
check_wav_np = wav_np[2048:].cpu().numpy() # 2*4*256(hop)
|
||
check_mel = mel_spec[0, :, 8:].cpu().numpy() # 2*4
|
||
else:
|
||
check_wav_np = wav_np.cpu().numpy()
|
||
check_mel = mel_spec[0].cpu().numpy()
|
||
if enable_regenerate and voice_checker.is_bad(check_wav_np, check_mel, chunk_size=2560):
|
||
voice_checker.reset()
|
||
# regenerate
|
||
N = output_chunk_size if prev_wav is None else output_chunk_size * 2
|
||
past_kv = []
|
||
for i in range(len(past_key_values)):
|
||
past_kv.append(
|
||
(
|
||
past_key_values[i][0][:, :, :-N, :], # .clone(),
|
||
past_key_values[i][1][:, :, :-N, :], # .clone(),
|
||
)
|
||
)
|
||
outputs = self.tts.generate(
|
||
input_ids=audio_input_ids[:, :-N, :],
|
||
past_key_values=past_kv,
|
||
streaming_tts_text_mask=streaming_tts_text_mask,
|
||
max_new_token=N,
|
||
force_no_stop=self.force_no_stop,
|
||
temperature=torch.tensor([0.1, 0.3, 0.1, 0.3], dtype=torch.float, device=self.tts.device),
|
||
eos_token=torch.tensor([625], dtype=torch.long, device=self.tts.device),
|
||
logits_warpers=logits_warpers,
|
||
logits_processors=logits_processors,
|
||
)
|
||
audio_input_ids = outputs.audio_input_ids
|
||
past_key_values = outputs.past_key_values
|
||
|
||
new_ids_len -= N
|
||
mel_spec = self.tts.decode_to_mel_specs(outputs.new_ids[:, new_ids_len:, :])
|
||
new_ids_len = outputs.new_ids.shape[1] # [1, seq_len, 4]
|
||
wav_np, sr = self.decode_mel_to_audio(mel_spec)
|
||
|
||
if prev_wav is not None:
|
||
wav_y = wav_np[: len(prev_wav)]
|
||
prev_wav = wav_np[len(prev_wav) :]
|
||
cur_text = gen_text_raw[prev_text_len:]
|
||
prev_text_len = len(gen_text_raw)
|
||
yield OmniOutput(text=cur_text, audio_wav=wav_y, sampling_rate=sr)
|
||
else:
|
||
prev_wav = wav_np
|
||
else:
|
||
# smooth wav
|
||
if prev_wav is not None:
|
||
wav_np, prev_wav = self._linear_overlap_add2_wav(
|
||
[prev_wav, wav_np], overlap=512 * 4
|
||
) # tts_hop256*2
|
||
cur_text = gen_text_raw[prev_text_len:]
|
||
prev_text_len = len(gen_text_raw)
|
||
yield OmniOutput(text=cur_text, audio_wav=wav_np, sampling_rate=sr)
|
||
else:
|
||
prev_wav = wav_np
|
||
|
||
if outputs.finished:
|
||
logger.debug("Generation finished.")
|
||
break
|
||
if outputs.new_ids.shape[1] > 2048:
|
||
stop = True
|
||
logger.debug("Generation length > 2048, stopped.")
|
||
break
|
||
|
||
if prev_wav is not None:
|
||
cur_text = gen_text_raw[prev_text_len:]
|
||
yield OmniOutput(text=cur_text, audio_wav=prev_wav, sampling_rate=sr) # yield last chunk wav without smooth
|
||
|
||
if new_segment_gen and not stop:
|
||
logger.debug(
|
||
f"tts_text tokens {tts_token_lens} exceed {self.tts.streaming_text_reserved_len - shift_len}, start a new segment generation"
|
||
)
|
||
tid_len = 5 # self.tts.streaming_text_chunk_size
|
||
prev_seg_text_ids = tts_input_ids[:, end - 1 - tid_len : end - 1] # exclude last Etts
|
||
aid_len = 50 # int(tid_len * new_ids_len / tts_token_lens)
|
||
prev_seg_audio_ids = outputs.new_ids[:, -aid_len:, :]
|
||
|
||
result = self._generate_mel_spec_audio_streaming(
|
||
spk_bounds,
|
||
streamer,
|
||
output_chunk_size,
|
||
spk_embeds,
|
||
prev_seg_text_ids,
|
||
text_left,
|
||
prev_seg_audio_ids,
|
||
enable_regenerate=enable_regenerate,
|
||
)
|
||
for res in result:
|
||
yield res
|
||
|
||
def decode_mel_to_audio(self, mel_spec, output_path=""):
|
||
with torch.inference_mode():
|
||
wav_numpy = self.vocos.decode(mel_spec.float()).cpu().squeeze()
|
||
sr = 24000
|
||
if output_path:
|
||
sf.write(output_path, wav_numpy.numpy(), samplerate=sr)
|
||
logger.info(f"Audio saved to {output_path}")
|
||
return wav_numpy, sr
|
||
|
||
|
||
# Copied from transformers.models.whisper.modeling_whisper.WhisperEncoderLayer and add use_cache for streaming inference
|
||
class MiniCPMWhisperEncoderLayer(nn.Module):
|
||
def __init__(self, config: WhisperConfig, layer_idx: int = None):
|
||
super().__init__()
|
||
self.embed_dim = config.d_model
|
||
self.self_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation](
|
||
embed_dim=self.embed_dim,
|
||
num_heads=config.encoder_attention_heads,
|
||
dropout=config.attention_dropout,
|
||
config=config,
|
||
layer_idx=layer_idx,
|
||
)
|
||
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||
self.dropout = config.dropout
|
||
self.activation_fn = ACT2FN[config.activation_function]
|
||
self.activation_dropout = config.activation_dropout
|
||
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
|
||
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
||
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
||
|
||
def forward(
|
||
self,
|
||
hidden_states: torch.Tensor,
|
||
attention_mask: torch.Tensor,
|
||
layer_head_mask: torch.Tensor,
|
||
output_attentions: bool = False,
|
||
past_key_values: Optional[EncoderDecoderCache] = None,
|
||
use_cache: Optional[bool] = False,
|
||
) -> torch.Tensor:
|
||
r"""
|
||
Args:
|
||
hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, embed_dim)`):
|
||
Hidden states to be fed into the encoder layer.
|
||
attention_mask (`torch.FloatTensor` of shape `(batch_size, 1, tgt_len, src_len)`):
|
||
Attention mask where padding elements are indicated by large negative values.
|
||
layer_head_mask (`torch.FloatTensor` of shape `(encoder_attention_heads,)`):
|
||
Mask to nullify selected heads of the attention modules.
|
||
output_attentions (`bool`, *optional*):
|
||
Whether or not to return the attention weights.
|
||
past_key_values (`EncoderDecoderCache`, *optional*):
|
||
Past key-value pairs used for incremental decoding.
|
||
use_cache (`bool`, *optional*):
|
||
Whether or not to return updated `past_key_values` for caching.
|
||
|
||
Returns:
|
||
A tuple of shape `(hidden_states, optional(attn_weights), optional(past_key_values))`.
|
||
"""
|
||
residual = hidden_states
|
||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||
hidden_states, attn_weights, past_key_values = self.self_attn(
|
||
hidden_states=hidden_states,
|
||
attention_mask=attention_mask,
|
||
layer_head_mask=layer_head_mask,
|
||
output_attentions=output_attentions,
|
||
past_key_value=past_key_values,
|
||
)
|
||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||
hidden_states = residual + hidden_states
|
||
|
||
residual = hidden_states
|
||
hidden_states = self.final_layer_norm(hidden_states)
|
||
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
||
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
||
hidden_states = self.fc2(hidden_states)
|
||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||
hidden_states = residual + hidden_states
|
||
|
||
if hidden_states.dtype == torch.float16 and (
|
||
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
|
||
):
|
||
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
||
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
||
|
||
outputs = (hidden_states,)
|
||
|
||
if output_attentions:
|
||
outputs += (attn_weights,)
|
||
|
||
if use_cache:
|
||
outputs += (past_key_values,)
|
||
|
||
return outputs
|
||
|
||
|
||
# Copied from from transformers.models.whisper.modeling_whisper.WhisperEncoder and add use_cache for streaming inference
|
||
class MiniCPMWhisperEncoder(WhisperEncoder):
|
||
|
||
def __init__(self, config: WhisperConfig):
|
||
super().__init__(config)
|
||
self.layers = nn.ModuleList(
|
||
[MiniCPMWhisperEncoderLayer(config, layer_idx=i) for i in range(config.encoder_layers)]
|
||
)
|
||
|
||
def forward(
|
||
self,
|
||
input_features,
|
||
attention_mask=None,
|
||
head_mask=None,
|
||
output_attentions=None,
|
||
output_hidden_states=None,
|
||
return_dict=None,
|
||
past_key_values: Optional[EncoderDecoderCache] = None,
|
||
use_cache: Optional[bool] = None,
|
||
):
|
||
r"""
|
||
Forward pass of the Whisper encoder.
|
||
|
||
Args:
|
||
input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, sequence_length)`):
|
||
Float values of log-mel features extracted from the raw audio waveform. Typically generated
|
||
by a feature extractor (e.g., `WhisperFeatureExtractor`) that processes `.flac` or `.wav`
|
||
files into padded 2D mel spectrogram frames. These features are projected via convolution layers
|
||
(`conv1` and `conv2`) and then transformed into embeddings for the encoder.
|
||
|
||
attention_mask (`torch.Tensor`, *optional*):
|
||
Not used by Whisper for masking `input_features`, but included for API compatibility with
|
||
other models. If provided, it is simply ignored within the model. By default, Whisper
|
||
effectively ignores silence in the input log-mel spectrogram.
|
||
|
||
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
|
||
Mask to nullify selected attention heads. The elements should be either 1 or 0, where:
|
||
- 1 indicates the head is **not masked**,
|
||
- 0 indicates the head is **masked** (i.e., the attention head is dropped).
|
||
|
||
output_attentions (`bool`, *optional*):
|
||
Whether or not to return the attention tensors of all encoder layers. If set to `True`, the
|
||
returned tuple (or `BaseModelOutputWithPast`) will contain an additional element with
|
||
attention weights for each encoder layer.
|
||
|
||
output_hidden_states (`bool`, *optional*):
|
||
Whether or not to return the hidden states of all layers. If set to `True`, the returned
|
||
tuple (or `BaseModelOutputWithPast`) will contain a tuple of hidden states, including the
|
||
initial embedding output as well as the outputs of each layer.
|
||
|
||
return_dict (`bool`, *optional*):
|
||
Whether or not to return a `BaseModelOutputWithPast` (a subclass of `ModelOutput`) instead
|
||
of a plain tuple. If set to `True`, the output will be a `BaseModelOutputWithPast` object,
|
||
otherwise it will be a tuple.
|
||
|
||
past_key_values (`EncoderDecoderCache`, *optional*):
|
||
When using caching for faster inference, this is an object that stores the key-value pairs
|
||
for attention states. If provided, the model will append new states to the existing cache
|
||
and return the updated cache. This speeds up sequential decoding or chunked inference.
|
||
|
||
- If `past_key_values` is `None`, no past states are used or returned.
|
||
- If `past_key_values` is not `None` and `use_cache=True`, the model will use the provided
|
||
cache and return the updated cache (as `next_encoder_cache`).
|
||
|
||
use_cache (`bool`, *optional*):
|
||
Whether or not the model should use caching (`past_key_values`) to speed up processing
|
||
during inference. When set to `True`, the model will:
|
||
- Inspect and use `past_key_values` if provided.
|
||
- Return updated `past_key_values` (under the name `next_encoder_cache` in
|
||
`BaseModelOutputWithPast`).
|
||
|
||
Returns:
|
||
`BaseModelOutputWithPast` or `tuple` (depending on `return_dict`):
|
||
If `return_dict=True`, a `BaseModelOutputWithPast` is returned, which contains:
|
||
- **last_hidden_state** (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||
The output of the final encoder layer.
|
||
- **hidden_states** (`tuple(torch.FloatTensor)`, *optional*, returned if `output_hidden_states=True`):
|
||
Hidden states of the model at each layer (including the initial projection).
|
||
- **attentions** (`tuple(torch.FloatTensor)`, *optional*, returned if `output_attentions=True`):
|
||
Attention weights from each encoder layer.
|
||
- **past_key_values** (an object of type `EncoderDecoderCache` or `None`, *optional*):
|
||
Updated cache of key-value pairs if `use_cache=True`.
|
||
|
||
If `return_dict=False`, a tuple is returned, where the format is:
|
||
`(last_hidden_state, hidden_states, attentions)`, with `hidden_states` and `attentions`
|
||
only present if their respective `output_*` arguments are set to `True`.
|
||
|
||
Example:
|
||
>>> from transformers import AutoFeatureExtractor, WhisperConfig, WhisperForConditionalGeneration
|
||
>>> import torch
|
||
|
||
>>> # Load a feature extractor and a Whisper model
|
||
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-tiny.en")
|
||
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
|
||
|
||
>>> # Assume you have audio (list of floats or numpy array) loaded from a file
|
||
>>> # Then extract the mel features:
|
||
>>> input_features = feature_extractor(audio, sampling_rate=16000, return_tensors="pt").input_features
|
||
|
||
>>> # Forward pass
|
||
>>> outputs = model.encoder(
|
||
... input_features=input_features,
|
||
... output_hidden_states=True,
|
||
... output_attentions=True,
|
||
... use_cache=True
|
||
... )
|
||
|
||
>>> # Retrieve the last hidden state
|
||
>>> last_hidden_state = outputs.last_hidden_state
|
||
>>> print(last_hidden_state.shape)
|
||
torch.Size([batch_size, seq_length, hidden_size])
|
||
|
||
>>> # Retrieve the intermediate hidden states if output_hidden_states=True
|
||
>>> all_encoder_hidden_states = outputs.hidden_states
|
||
|
||
>>> # Retrieve attention weights if output_attentions=True
|
||
>>> all_encoder_attentions = outputs.attentions
|
||
|
||
>>> # Retrieve updated past key values if use_cache=True
|
||
>>> encoder_cache = outputs.past_key_values
|
||
"""
|
||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||
output_hidden_states = (
|
||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||
)
|
||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||
|
||
# Ignore copy
|
||
input_features = input_features.to(dtype=self.conv1.weight.dtype, device=self.conv1.weight.device)
|
||
|
||
inputs_embeds = nn.functional.gelu(self.conv1(input_features))
|
||
inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
|
||
|
||
inputs_embeds = inputs_embeds.permute(0, 2, 1)
|
||
|
||
embed_pos = self.embed_positions.weight
|
||
past_key_values_length = 0
|
||
if use_cache:
|
||
if past_key_values is None:
|
||
past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
|
||
elif isinstance(past_key_values, list):
|
||
past_key_values = EncoderDecoderCache(DynamicCache.from_legacy_cache(past_key_values), DynamicCache())
|
||
elif isinstance(past_key_values, DynamicCache):
|
||
past_key_values = EncoderDecoderCache(past_key_values, DynamicCache())
|
||
else:
|
||
pass
|
||
past_key_values_length = past_key_values.self_attention_cache.get_usable_length(inputs_embeds.shape[1])
|
||
if inputs_embeds.shape[1] + past_key_values_length > embed_pos.shape[0]:
|
||
logger.warning("seems the audio is longer than 30s. repeating the last part of the audio")
|
||
embed_pos_front = embed_pos[past_key_values_length:, :]
|
||
embed_pos = torch.cat(
|
||
(
|
||
embed_pos_front,
|
||
torch.repeat_interleave(
|
||
embed_pos[-1, :].unsqueeze(0),
|
||
inputs_embeds.shape[1] - embed_pos.shape[0] + past_key_values_length,
|
||
dim=0,
|
||
),
|
||
)
|
||
)
|
||
else:
|
||
embed_pos = embed_pos[past_key_values_length : inputs_embeds.shape[1] + past_key_values_length, :]
|
||
else:
|
||
embed_pos = embed_pos[: inputs_embeds.shape[1], :]
|
||
|
||
hidden_states = inputs_embeds + embed_pos
|
||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||
|
||
encoder_states = () if output_hidden_states else None
|
||
all_attentions = () if output_attentions else None
|
||
|
||
# check if head_mask has a correct number of layers specified if desired
|
||
if head_mask is not None:
|
||
assert head_mask.size()[0] == (
|
||
len(self.layers)
|
||
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||
|
||
for idx, encoder_layer in enumerate(self.layers):
|
||
if output_hidden_states:
|
||
encoder_states = encoder_states + (hidden_states,)
|
||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||
to_drop = False
|
||
if self.training:
|
||
dropout_probability = torch.rand([])
|
||
if dropout_probability < self.layerdrop: # skip the layer
|
||
to_drop = True
|
||
|
||
# Ignore copy
|
||
if to_drop:
|
||
layer_outputs = (None, None)
|
||
else:
|
||
if self.gradient_checkpointing and self.training:
|
||
layer_outputs = self._gradient_checkpointing_func(
|
||
encoder_layer.__call__,
|
||
hidden_states,
|
||
attention_mask,
|
||
(head_mask[idx] if head_mask is not None else None),
|
||
output_attentions,
|
||
past_key_values,
|
||
use_cache,
|
||
)
|
||
else:
|
||
layer_outputs = encoder_layer(
|
||
hidden_states,
|
||
attention_mask,
|
||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||
output_attentions=output_attentions,
|
||
past_key_values=past_key_values,
|
||
use_cache=use_cache,
|
||
)
|
||
|
||
hidden_states = layer_outputs[0]
|
||
|
||
if use_cache:
|
||
next_encoder_cache = layer_outputs[2 if output_attentions else 1]
|
||
else:
|
||
next_encoder_cache = None
|
||
|
||
if output_attentions:
|
||
all_attentions = all_attentions + (layer_outputs[1],)
|
||
|
||
hidden_states = self.layer_norm(hidden_states)
|
||
if output_hidden_states:
|
||
encoder_states = encoder_states + (hidden_states,)
|
||
|
||
if not return_dict:
|
||
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
||
return BaseModelOutputWithPast(
|
||
last_hidden_state=hidden_states,
|
||
hidden_states=encoder_states,
|
||
attentions=all_attentions,
|
||
past_key_values=next_encoder_cache,
|
||
)
|
||
|
||
|
||
# Borrowed from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/model/dvae.py`
|
||
class ConvNeXtBlock(nn.Module):
|
||
def __init__(
|
||
self,
|
||
dim: int,
|
||
intermediate_dim: int,
|
||
kernel: int,
|
||
dilation: int,
|
||
layer_scale_init_value: float = 1e-6,
|
||
):
|
||
# ConvNeXt Block copied from Vocos.
|
||
super().__init__()
|
||
self.dwconv = nn.Conv1d(
|
||
dim,
|
||
dim,
|
||
kernel_size=kernel,
|
||
padding=dilation * (kernel // 2),
|
||
dilation=dilation,
|
||
groups=dim,
|
||
)
|
||
|
||
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
||
self.pwconv1 = nn.Linear(dim, intermediate_dim)
|
||
self.act = nn.GELU()
|
||
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
||
self.coef = (
|
||
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
|
||
if layer_scale_init_value > 0
|
||
else None
|
||
)
|
||
|
||
def forward(self, x: torch.Tensor, cond=None) -> torch.Tensor:
|
||
residual = x
|
||
|
||
y = self.dwconv(x)
|
||
y.transpose_(1, 2) # (B, C, T) -> (B, T, C)
|
||
x = self.norm(y)
|
||
del y
|
||
y = self.pwconv1(x)
|
||
del x
|
||
x = self.act(y)
|
||
del y
|
||
y = self.pwconv2(x)
|
||
del x
|
||
if self.coef is not None:
|
||
y *= self.coef
|
||
y.transpose_(1, 2) # (B, T, C) -> (B, C, T)
|
||
|
||
x = y + residual
|
||
del y
|
||
|
||
return x
|
||
|
||
|
||
# Borrowed from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/model/dvae.py`
|
||
class GFSQ(nn.Module):
|
||
def __init__(
|
||
self,
|
||
dim: int,
|
||
levels: List[int],
|
||
G: int,
|
||
R: int,
|
||
eps=1e-5,
|
||
transpose=True,
|
||
):
|
||
super(GFSQ, self).__init__()
|
||
self.quantizer = GroupedResidualFSQ(
|
||
dim=dim,
|
||
levels=list(levels),
|
||
num_quantizers=R,
|
||
groups=G,
|
||
)
|
||
self.n_ind = math.prod(levels)
|
||
self.eps = eps
|
||
self.transpose = transpose
|
||
self.G = G
|
||
self.R = R
|
||
|
||
def _embed(self, x: torch.Tensor):
|
||
if self.transpose:
|
||
x = x.transpose(1, 2)
|
||
x = x.view(x.size(0), x.size(1), self.G, self.R).permute(2, 0, 1, 3)
|
||
feat = self.quantizer.get_output_from_indices(x)
|
||
return feat.transpose_(1, 2) if self.transpose else feat
|
||
|
||
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
||
return super().__call__(x)
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
if self.transpose:
|
||
x.transpose_(1, 2)
|
||
_, ind = self.quantizer(x)
|
||
ind = ind.permute(1, 2, 0, 3).contiguous()
|
||
ind = ind.view(ind.size(0), ind.size(1), -1)
|
||
return ind.transpose_(1, 2) if self.transpose else ind
|
||
|
||
|
||
# Borrowed from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/model/dvae.py`
|
||
class DVAEDecoder(nn.Module):
|
||
def __init__(
|
||
self,
|
||
idim: int,
|
||
odim: int,
|
||
n_layer=12,
|
||
bn_dim=64,
|
||
hidden=256,
|
||
kernel=7,
|
||
dilation=2,
|
||
up=False,
|
||
):
|
||
super().__init__()
|
||
self.up = up
|
||
self.conv_in = nn.Sequential(
|
||
nn.Conv1d(idim, bn_dim, 3, 1, 1),
|
||
nn.GELU(),
|
||
nn.Conv1d(bn_dim, hidden, 3, 1, 1),
|
||
)
|
||
self.decoder_block = nn.ModuleList(
|
||
[
|
||
ConvNeXtBlock(
|
||
hidden,
|
||
hidden * 4,
|
||
kernel,
|
||
dilation,
|
||
)
|
||
for _ in range(n_layer)
|
||
]
|
||
)
|
||
self.conv_out = nn.Conv1d(hidden, odim, kernel_size=1, bias=False)
|
||
|
||
def forward(self, x: torch.Tensor, conditioning=None) -> torch.Tensor:
|
||
# B, C, T
|
||
y = self.conv_in(x)
|
||
del x
|
||
for f in self.decoder_block:
|
||
y = f(y, conditioning)
|
||
|
||
x = self.conv_out(y)
|
||
del y
|
||
return x
|
||
|
||
|
||
# Borrowed from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/model/dvae.py`
|
||
class DVAE(nn.Module):
|
||
def __init__(
|
||
self,
|
||
):
|
||
super().__init__()
|
||
|
||
coef = torch.rand(100)
|
||
self.coef = nn.Parameter(coef.unsqueeze(0).unsqueeze_(2))
|
||
|
||
self.downsample_conv = nn.Sequential(
|
||
nn.Conv1d(100, 512, 3, 1, 1),
|
||
nn.GELU(),
|
||
nn.Conv1d(512, 512, 4, 2, 1),
|
||
nn.GELU(),
|
||
)
|
||
|
||
self.encoder = DVAEDecoder(
|
||
idim=512,
|
||
odim=1024,
|
||
hidden=256,
|
||
n_layer=12,
|
||
bn_dim=128,
|
||
)
|
||
|
||
self.decoder = DVAEDecoder(
|
||
idim=512,
|
||
odim=512,
|
||
hidden=256,
|
||
n_layer=12,
|
||
bn_dim=128,
|
||
)
|
||
|
||
self.out_conv = nn.Conv1d(512, 100, 3, 1, 1, bias=False)
|
||
|
||
self.vq_layer = GFSQ(
|
||
dim=1024,
|
||
levels=(5, 5, 5, 5),
|
||
G=2,
|
||
R=2,
|
||
)
|
||
|
||
@torch.inference_mode()
|
||
def forward(self, inp: torch.Tensor, mode: Literal["encode", "decode"] = "decode") -> torch.Tensor:
|
||
if mode == "encode" and hasattr(self, "encoder") and self.vq_layer is not None:
|
||
mel = inp.clone()
|
||
x: torch.Tensor = self.downsample_conv(
|
||
torch.div(mel, self.coef.view(100, 1).expand(mel.shape), out=mel),
|
||
).unsqueeze_(0)
|
||
del mel
|
||
x = self.encoder(x)
|
||
ind = self.vq_layer(x)
|
||
del x
|
||
return ind
|
||
|
||
if self.vq_layer is not None:
|
||
vq_feats = self.vq_layer._embed(inp)
|
||
else:
|
||
vq_feats = inp
|
||
|
||
vq_feats = (
|
||
vq_feats.view(
|
||
(vq_feats.size(0), 2, vq_feats.size(1) // 2, vq_feats.size(2)),
|
||
)
|
||
.permute(0, 2, 3, 1)
|
||
.flatten(2)
|
||
)
|
||
|
||
dec_out = self.out_conv(
|
||
self.decoder(
|
||
x=vq_feats,
|
||
),
|
||
)
|
||
|
||
del vq_feats
|
||
|
||
return torch.mul(dec_out, self.coef, out=dec_out)
|
||
|
||
|
||
def apply_spk_emb(
|
||
input_ids: torch.Tensor = None,
|
||
spk_emb: torch.Tensor = None,
|
||
input_embeds: torch.Tensor = None,
|
||
spk_emb_token_id: int = 0,
|
||
num_spk_embs: int = 1,
|
||
):
|
||
"""
|
||
Replace consecutive `num_spk_embs` speaker embedding placeholders in input_embeds with pre-prepared speaker embeddings. This is an in-place replacement, no new tensor is created, so no value is returned.
|
||
|
||
Args:
|
||
input_ids (torch.Tensor): Input ID tensor, shape [batch_size, seq_len_max]
|
||
spk_emb (torch.Tensor): Speaker embedding tensor, shape [batch_size, num_spk_emb, hidden_dim]
|
||
input_embeds (torch.Tensor): Input embedding tensor, shape [batch_size, seq_len_max, hidden_dim]
|
||
spk_emb_token_id (int): ID of the speaker embedding token
|
||
num_spk_embs (int): Number of speaker embeddings
|
||
|
||
Returns:
|
||
None
|
||
"""
|
||
|
||
batch_size = input_ids.shape[0]
|
||
|
||
for idx in range(batch_size):
|
||
input_ids_ = input_ids[idx] # [seq_len_max]
|
||
spk_emb_ = spk_emb[idx] # [num_spk_emb]
|
||
mask_ = input_ids_ == spk_emb_token_id # [batch_size, seq_len_max]
|
||
nonzero_position_idx = mask_.nonzero(as_tuple=False) # [num_spk_emb, 1]
|
||
assert nonzero_position_idx.shape[0] == num_spk_embs
|
||
begin_idx = nonzero_position_idx.min()
|
||
end_idx = nonzero_position_idx.max()
|
||
input_embeds[idx, begin_idx : end_idx + 1, :] = spk_emb_
|
||
|
||
return
|
||
|
||
|
||
def make_streaming_chunk_mask_generation(
|
||
inputs_embeds: torch.Tensor,
|
||
past_seen_tokens: int,
|
||
streaming_tts_text_mask: torch.Tensor,
|
||
streaming_reserved_length: int = 300,
|
||
streaming_audio_chunk_size: int = 50,
|
||
streaming_text_chunk_size: int = 10,
|
||
num_spk_emb: int = 1,
|
||
use_spk_emb: bool = True,
|
||
) -> torch.Tensor:
|
||
"""
|
||
In streaming audio generation, determine which `text` positions the TTS model can attend to when generating each chunk of `audio` tokens.
|
||
|
||
This function creates a mask that allows the model to attend to a specific chunk of text
|
||
tokens when generating each chunk of audio tokens, enabling streaming TTS generation.
|
||
|
||
Args:
|
||
inputs_embeds (torch.Tensor): Input embeddings tensor.
|
||
past_seen_tokens (int): Number of tokens already seen by the model.
|
||
streaming_tts_text_mask (torch.Tensor): Mask for the text tokens.
|
||
streaming_reserved_length (int, optional): Number of reserved tokens for streaming. Defaults to 300.
|
||
streaming_chunk_length (int, optional): Length of each streaming chunk. Defaults to 50.
|
||
streaming_text_chunk_size (int, optional): Size of each text chunk. Defaults to 7.
|
||
|
||
Returns:
|
||
torch.Tensor: Causal mask for streaming TTS generation, shape is [batch_size=1, 1, seq_len=1, past_seen_tokens+1]
|
||
|
||
Raises:
|
||
AssertionError: If the batch size is not 1 (only supports batch size of 1 for inference).
|
||
"""
|
||
assert inputs_embeds.shape[0] == 1
|
||
|
||
dtype = inputs_embeds.dtype
|
||
device = inputs_embeds.device
|
||
min_dtype = torch.finfo(dtype).min
|
||
|
||
# Add `1` to the past seen tokens to account for new `tokens` during `generate`
|
||
causal_mask = torch.full((1, past_seen_tokens + inputs_embeds.shape[1]), fill_value=0, dtype=dtype, device=device)
|
||
|
||
# Calculate the start of invisible text tokens
|
||
invisible_text_tokens_start = (
|
||
min(
|
||
math.ceil((past_seen_tokens - streaming_reserved_length) / streaming_audio_chunk_size)
|
||
* streaming_text_chunk_size,
|
||
streaming_reserved_length,
|
||
)
|
||
+ 1
|
||
+ num_spk_emb * use_spk_emb
|
||
) # Add 1 for [Stts] and N for [spk_emb] tokens if `use_spk_emb` is True
|
||
|
||
invisible_text_tokens_end = (
|
||
streaming_reserved_length + 1 + num_spk_emb * use_spk_emb + 1
|
||
) # Add 1 for [Ptts] (aka `audio_bos_token_id`)
|
||
|
||
# Set invisible text tokens to min_dtype (effectively -inf)
|
||
causal_mask[0, invisible_text_tokens_start:invisible_text_tokens_end] = min_dtype
|
||
|
||
# Mask padding positions in the text mask
|
||
causal_mask[0, 0 : 1 + num_spk_emb * use_spk_emb + streaming_reserved_length + 1].masked_fill_(
|
||
streaming_tts_text_mask == 0, min_dtype
|
||
)
|
||
|
||
# Add extra dimensions for batch and heads
|
||
causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
|
||
|
||
return causal_mask
|
||
|
||
|
||
# Borrowed from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/model/processors.py`
|
||
class CustomRepetitionPenaltyLogitsProcessorRepeat:
|
||
def __init__(self, penalty: float, max_input_ids: int, past_window: int):
|
||
if not isinstance(penalty, float) or not (penalty > 0):
|
||
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
|
||
|
||
self.penalty = penalty
|
||
self.max_input_ids = max_input_ids
|
||
self.past_window = past_window
|
||
|
||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||
if input_ids.size(1) > self.past_window:
|
||
input_ids = input_ids.narrow(1, -self.past_window, self.past_window)
|
||
freq = F.one_hot(input_ids, scores.size(1)).sum(1)
|
||
if freq.size(0) > self.max_input_ids:
|
||
freq.narrow(0, self.max_input_ids, freq.size(0) - self.max_input_ids).zero_()
|
||
alpha = torch.pow(self.penalty, freq)
|
||
scores = scores.contiguous()
|
||
inp = scores.multiply(alpha)
|
||
oth = scores.divide(alpha)
|
||
con = scores < 0
|
||
out = torch.where(con, inp, oth)
|
||
del inp, oth, scores, con, alpha
|
||
return out
|
||
|
||
|
||
@dataclass
|
||
class ConditionalChatTTSGenerationOutput(ModelOutput):
|
||
"""
|
||
Output class for ConditionalChatTTS generation.
|
||
|
||
Args:
|
||
new_ids (torch.LongTensor): Newly generated audio code sequence, shape (batch_size, sequence_length, num_vq).
|
||
audio_input_ids (torch.LongTensor): Updated input IDs including condition and generated audio codes, shape (batch_size, full_sequence_length, num_vq).
|
||
past_key_values (Tuple[Tuple[torch.FloatTensor]]): Tuple containing pre-computed keys and values used for attention mechanism. Each element has shape (batch_size, num_heads, sequence_length, embed_size_per_head).
|
||
finished (bool): Boolean indicating whether generation is complete.
|
||
|
||
"""
|
||
|
||
new_ids: torch.LongTensor = None
|
||
audio_input_ids: torch.LongTensor = None
|
||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||
finished: bool = None
|
||
|
||
|
||
class MultiModalProjector(nn.Module):
|
||
def __init__(self, in_dim, out_dim):
|
||
super().__init__()
|
||
self.linear1 = nn.Linear(in_features=in_dim, out_features=out_dim, bias=True)
|
||
self.relu = nn.ReLU()
|
||
self.linear2 = nn.Linear(in_features=out_dim, out_features=out_dim, bias=True)
|
||
|
||
def forward(self, audio_features):
|
||
hidden_states = self.relu(self.linear1(audio_features))
|
||
hidden_states = self.linear2(hidden_states)
|
||
return hidden_states
|
||
|
||
|
||
class ConditionalChatTTS(PreTrainedModel):
|
||
"""A conditional text-to-speech model that can generate speech from text with speaker conditioning.
|
||
|
||
This model extends PreTrainedModel to provide text-to-speech capabilities with:
|
||
- LLM hidden state conditioning
|
||
- Streaming generation
|
||
|
||
The model uses a transformer architecture with LLM hidden states and can operate in both
|
||
streaming and non-streaming modes for flexible deployment.
|
||
|
||
The model process sequence in the following format:
|
||
| text bos token | LLM embedding projected to tts embedding space | text tokens (fixed length, reserved for future tokens) | audio bos token | audio tokens (audio token length is not fixed)| audio eos token |
|
||
|
||
The format is designed to support LLM-conditioned streaming audio generation.
|
||
|
||
Usage:
|
||
To support streaming generation, two global variables should be maintained outside of the model.
|
||
1. `audio_input_ids`: stores *discrete* audio codes. It is a tensor with shape [1, sequence length+1, num_vq].
|
||
2. `past_key_values`: stores the KV cache for both text tokens and audio codes. It is a list of tuples, each tuple contains two tensors with shape [1, num_attention_heads, sequence length, hidden_size // num_attention_heads]
|
||
|
||
where `num_vq` is the number of audio codebooks, in default setting, it is `4`.
|
||
|
||
1. Create an empty `past_key_values` with
|
||
```python
|
||
initial_kv_cache_length = 1 + model.num_spk_embs + model.streaming_text_reserved_len # where `1` denotes the `bos` token
|
||
dtype = model.emb_text.weight.dtype
|
||
device = model.emb_text.weight.device
|
||
past_key_values = [
|
||
(
|
||
torch.zeros(1, model.config.num_attention_heads, initial_kv_cache_length, model.config.hidden_size // model.config.num_attention_heads, dtype=dtype, device=device),
|
||
torch.zeros(1, model.config.num_attention_heads, initial_kv_cache_length, model.config.hidden_size // model.config.num_attention_heads, dtype=dtype, device=device)
|
||
)
|
||
for _ in range(model.config.num_hidden_layers)
|
||
]
|
||
|
||
2. At the same time, create an empty `audio_input_ids` with shape [1, sequence length, num_vq], `num_vq` denotes multiple layer audio codebooks. But here we also include text tokens in the sequence, but they will be zeros, and will not be used, just a placeholder.
|
||
|
||
```python
|
||
initial_audio_input_ids_length = 1 + model.num_spk_embs + model.streaming_text_reserved_len + 1
|
||
# [bos token, speaker embeddings, text tokens, audio bos token]
|
||
audio_input_ids = torch.zeros(batch_size=1, initial_audio_input_ids_length, model.num_vq)
|
||
```
|
||
|
||
2. Prefill some text tokens to TTS model (for example, 10 tokens) using `prefill_text` method.
|
||
|
||
```python
|
||
outputs = llm.generate(**kwargs)
|
||
llm_tokens = some_function_to_extract_llm_tokens(outputs)
|
||
lm_spk_emb_last_hidden_states = some_function_to_extract_lm_spk_emb_last_hidden_states(outputs)
|
||
tts_text_input_ids = tts_tokenizer.encode(llm_tokenizer.decode(llm_tokens))
|
||
# here assume we are prefilling text token 0 to text token 9 (included), totally 10 tokens.
|
||
begin = 0
|
||
end = 9+1
|
||
position_ids = torch.arange(begin, end, dtype=torch.long, device=device)
|
||
|
||
past_key_values = model.prefill_text(
|
||
input_ids=tts_text_input_ids,
|
||
position_ids=position_ids,
|
||
past_key_values=past_key_values,
|
||
lm_spk_emb_last_hidden_states=lm_spk_emb_last_hidden_states,
|
||
)
|
||
```
|
||
|
||
3. Make a `streaming_tts_text_mask` to denote which position contains valid text tokens, similar to `attention_mask` in standard causal attention.
|
||
|
||
```python
|
||
streaming_tts_text_mask = torch.zeros(model.streaming_reserved_length)
|
||
streaming_tts_text_mask[0:end] = 1 # denotes these post
|
||
```
|
||
|
||
3. Generate audio codes using `generate` method.
|
||
|
||
```python
|
||
outputs = model.generate(
|
||
input_ids=audio_input_ids,
|
||
past_key_values=past_key_values,
|
||
streaming_tts_text_mask=streaming_tts_text_mask,
|
||
max_new_token=50,
|
||
)
|
||
|
||
# update past_key_values and input_ids
|
||
past_key_values = outputs.past_key_values
|
||
audio_input_ids = outputs.input_ids
|
||
```
|
||
|
||
The `past_key_values` is extended by `max_new_token=50`, and `audio_input_ids` is also extended by `max_new_token=50` after `generate` calling.
|
||
|
||
4. Notice that after prefilling `10` text tokens, the model can generate up to `50` audio tokens, if you want to generate more audio tokens, you need to prefill next `10` text tokens. And it is okay to only generate `25` audio tokens for faster initial response.
|
||
|
||
5. Repeat steps `2,3,4` as needed in your streaming audio generation cases, but ensure usage complies with the following guidelines discussed above.
|
||
"""
|
||
|
||
config_class = ConditionalChatTTSConfig
|
||
|
||
def __init__(self, config: ConditionalChatTTSConfig):
|
||
super().__init__(config)
|
||
|
||
self.use_speaker_embedding = config.use_speaker_embedding
|
||
self.use_llm_hidden_state = config.use_llm_hidden_state
|
||
self.num_spk_embs = config.num_spk_embs
|
||
self.spk_emb_token_id = config.spk_emb_token_id
|
||
|
||
self.use_text = config.use_text
|
||
self.streaming = config.streaming
|
||
self.streaming_text_chunk_size = config.streaming_text_chunk_size
|
||
self.streaming_audio_chunk_size = config.streaming_audio_chunk_size
|
||
self.streaming_text_reserved_len = config.streaming_text_reserved_len
|
||
self.audio_bos_token_id = config.audio_bos_token_id
|
||
self.num_mel_bins = config.num_mel_bins
|
||
self.num_vq = config.num_vq
|
||
self.num_audio_tokens = config.num_audio_tokens
|
||
|
||
self.top_p = config.top_p
|
||
self.top_k = config.top_k
|
||
self.repetition_penalty = config.repetition_penalty
|
||
|
||
if self.config.use_mlp:
|
||
self.projector = MultiModalProjector(config.llm_dim, config.hidden_size)
|
||
else:
|
||
self.projector = nn.Linear(config.llm_dim, config.hidden_size, bias=False)
|
||
self.emb_code = nn.ModuleList(
|
||
[nn.Embedding(config.num_audio_tokens, config.hidden_size) for _ in range(config.num_vq)]
|
||
)
|
||
self.emb_text = nn.Embedding(config.num_text_tokens, config.hidden_size)
|
||
self.head_code = nn.ModuleList(
|
||
[
|
||
weight_norm(
|
||
nn.Linear(config.hidden_size, config.num_audio_tokens, bias=False),
|
||
name="weight",
|
||
)
|
||
for _ in range(config.num_vq)
|
||
]
|
||
)
|
||
dvae = DVAE()
|
||
self.dvae = dvae
|
||
|
||
model_config = LlamaConfig(
|
||
hidden_size=config.hidden_size,
|
||
intermediate_size=config.intermediate_size,
|
||
num_attention_heads=config.num_attention_heads,
|
||
num_hidden_layers=config.num_hidden_layers,
|
||
max_position_embeddings=config.max_position_embeddings,
|
||
attn_implementation=config.attn_implementation,
|
||
)
|
||
|
||
model = LlamaModel(model_config)
|
||
self.model = model
|
||
|
||
@torch.inference_mode()
|
||
def merge_inputs_embeds(
|
||
self,
|
||
input_ids: torch.Tensor,
|
||
lm_spk_emb_last_hidden_states: Optional[torch.Tensor] = None,
|
||
):
|
||
"""Merge `input_ids` and `lm_spk_emb_last_hidden_states` to `inputs_embeds`.
|
||
|
||
Args:
|
||
input_ids (torch.Tensor): Input token IDs.
|
||
lm_spk_emb_last_hidden_states (Optional[torch.Tensor], optional): Last hidden states of speaker embeddings from the language model. Defaults to None.
|
||
|
||
Raises:
|
||
NotImplementedError: If speaker embedding is not used and language model hidden states are not implemented.
|
||
|
||
Returns:
|
||
torch.Tensor: Prepared input embeddings for the model.
|
||
"""
|
||
assert input_ids.shape[0] == 1
|
||
|
||
# Embed input_ids to input_embeds
|
||
inputs_embeds = self.emb_text(input_ids)
|
||
|
||
# Inject speaker embedding to input_embeds if it exists
|
||
if self.use_speaker_embedding:
|
||
spk_emb_mask = input_ids == self.spk_emb_token_id
|
||
if spk_emb_mask.any():
|
||
assert lm_spk_emb_last_hidden_states is not None
|
||
# Project spk emb to tts hidden size first, [batch_size, num_spk_emb, llm_dim] -> [batch_size, num_spk_emb, self.hidden_size]
|
||
lm_spk_emb_last_hidden_states = lm_spk_emb_last_hidden_states.to(self.projector.linear1.weight.dtype)
|
||
projected_spk_emb = self.projector(lm_spk_emb_last_hidden_states)
|
||
projected_spk_emb = F.normalize(projected_spk_emb, p=2, dim=-1)
|
||
apply_spk_emb(
|
||
input_ids=input_ids,
|
||
spk_emb=projected_spk_emb,
|
||
input_embeds=inputs_embeds,
|
||
spk_emb_token_id=self.spk_emb_token_id,
|
||
num_spk_embs=self.num_spk_embs,
|
||
)
|
||
else:
|
||
raise NotImplementedError
|
||
|
||
return inputs_embeds
|
||
|
||
@torch.inference_mode()
|
||
def prefill_text(
|
||
self,
|
||
input_ids: torch.Tensor,
|
||
position_ids: torch.LongTensor,
|
||
past_key_values: List[Tuple[torch.Tensor, torch.Tensor]],
|
||
lm_spk_emb_last_hidden_states: Optional[torch.Tensor] = None,
|
||
):
|
||
"""Prefill a chunk of new text tokens in streaming setting.
|
||
Specifically speaking, update `past_key_values` using new text tokens, then the model will read the new text tokens.
|
||
|
||
Args:
|
||
input_ids (Tensor): Tensor of shape [batch_size, seq_len]
|
||
position_ids (LongTensor): Tensor of shape [batch_size, seq_len]
|
||
past_key_values (List[Tuple[Tensor]]): KV Cache of all layers, each layer is a tuple (Tensor, Tensor) denoting keys and values. Each tensor is of seq_len = `self.streaming_text_reserved_len`. `past_key_values` will be updated.
|
||
lm_spk_emb_last_hidden_states (Tensor, optional): Tensor of shape [batch_size, num_spk_emb, llm_dim]. Defaults to None.
|
||
lm_last_hidden_states (Tensor, optional): _description_. Defaults to None.
|
||
|
||
Note that all `batch_size` should be `1`.
|
||
"""
|
||
assert input_ids.shape[0] == 1
|
||
assert past_key_values is not None
|
||
|
||
# Merge text and LLM embeddings
|
||
inputs_embeds = self.merge_inputs_embeds(
|
||
input_ids=input_ids,
|
||
lm_spk_emb_last_hidden_states=lm_spk_emb_last_hidden_states,
|
||
)
|
||
|
||
# Clone KV Cache
|
||
past_key_values_for_prefill = []
|
||
for i in range(len(past_key_values)):
|
||
past_key_values_for_prefill.append(
|
||
(
|
||
past_key_values[i][0][:, :, : position_ids[:, 0], :].clone(),
|
||
past_key_values[i][1][:, :, : position_ids[:, 0], :].clone(),
|
||
)
|
||
)
|
||
|
||
# Model forward
|
||
outputs_prefill: BaseModelOutputWithPast = self.model(
|
||
attention_mask=None, # because for text, it is standard causal attention mask, do nothing
|
||
position_ids=position_ids, # position_ids denotes the position of new text tokens in the sequence
|
||
past_key_values=past_key_values_for_prefill, # `past_key_values` will be updated by the model
|
||
inputs_embeds=inputs_embeds, # contains text and language model embedding
|
||
use_cache=True,
|
||
output_attentions=False,
|
||
cache_position=position_ids, # which new positions will use this cache, basically the same as position_ids
|
||
)
|
||
|
||
# Get model updated KV Cache
|
||
past_key_values_for_prefill_updated = outputs_prefill.past_key_values
|
||
|
||
# Update generated KV Cache to input `past_key_values`
|
||
for layer_idx in range(len(past_key_values)):
|
||
# Update keys
|
||
past_key_values[layer_idx][0][:, :, position_ids[:, 0] : position_ids[:, -1] + 1, :] = (
|
||
past_key_values_for_prefill_updated[layer_idx][0][
|
||
:, :, position_ids[:, 0] : position_ids[:, -1] + 1
|
||
].clone()
|
||
)
|
||
# Update values
|
||
past_key_values[layer_idx][1][:, :, position_ids[:, 0] : position_ids[:, -1] + 1, :] = (
|
||
past_key_values_for_prefill_updated[layer_idx][1][
|
||
:, :, position_ids[:, 0] : position_ids[:, -1] + 1
|
||
].clone()
|
||
)
|
||
|
||
# TODO: del past_key_values_for_prefill_updated recursively
|
||
# TODO: del outputs_prefill recursively
|
||
|
||
return past_key_values
|
||
|
||
@torch.inference_mode()
|
||
def prefill_audio_ids(
|
||
self,
|
||
input_ids: torch.Tensor,
|
||
past_key_values: List[Tuple[torch.Tensor, torch.Tensor]],
|
||
streaming_tts_text_mask=None,
|
||
add_audio_bos: bool = True,
|
||
):
|
||
"""Prefill a chunk of audio ids to the model. Used in sliding-window long audio generation.
|
||
Specifically, prefill many audio ids (typically from last window) to the model in the new window.
|
||
|
||
Args:
|
||
input_ids (torch.Tensor): (1, seq_len, num_vq) Audio input token ids.
|
||
past_key_values (List[Tuple[torch.Tensor, torch.Tensor]]): Past key values for attention mechanism.
|
||
"""
|
||
assert input_ids.shape[0] == 1
|
||
assert past_key_values is not None
|
||
|
||
code_emb = [self.emb_code[i](input_ids[:, :, i]) for i in range(self.num_vq)]
|
||
inputs_embeds = torch.stack(code_emb, 3).sum(3) # [1,seq_len,768]
|
||
input_len = input_ids.shape[1]
|
||
|
||
if add_audio_bos:
|
||
narrowed_input_ids = torch.tensor([[self.audio_bos_token_id]], dtype=torch.long, device=self.device)
|
||
bos_inputs_embeds = self.emb_text(narrowed_input_ids)
|
||
inputs_embeds = torch.cat([bos_inputs_embeds, inputs_embeds], dim=1)
|
||
input_len += 1
|
||
|
||
past_key_values_length = past_key_values[0][0].shape[2]
|
||
position_ids = torch.arange(
|
||
past_key_values_length, past_key_values_length + input_len, dtype=torch.long, device=self.device
|
||
).unsqueeze(0)
|
||
|
||
cache_position = position_ids.clone()
|
||
causal_mask = make_streaming_chunk_mask_generation(
|
||
inputs_embeds=inputs_embeds,
|
||
past_seen_tokens=past_key_values[0][0].shape[2],
|
||
streaming_tts_text_mask=streaming_tts_text_mask,
|
||
streaming_reserved_length=self.streaming_text_reserved_len,
|
||
streaming_text_chunk_size=self.streaming_text_chunk_size,
|
||
) # [1, 1, 1, past_key_values_length + input_len]
|
||
|
||
# Model forward
|
||
outputs: BaseModelOutputWithPast = self.model(
|
||
attention_mask=causal_mask,
|
||
position_ids=position_ids,
|
||
past_key_values=past_key_values,
|
||
inputs_embeds=inputs_embeds,
|
||
use_cache=True,
|
||
output_attentions=False,
|
||
cache_position=cache_position,
|
||
)
|
||
past_key_values = outputs.past_key_values
|
||
return past_key_values
|
||
|
||
@torch.inference_mode()
|
||
def generate(
|
||
self,
|
||
input_ids: torch.Tensor,
|
||
past_key_values: List[Tuple[torch.Tensor, torch.Tensor]],
|
||
temperature: torch.Tensor,
|
||
eos_token: Union[int, torch.Tensor],
|
||
streaming_tts_text_mask=None,
|
||
force_no_stop=False,
|
||
min_new_token=10,
|
||
max_new_token=50,
|
||
logits_warpers: List[LogitsWarper] = [],
|
||
logits_processors: List[CustomRepetitionPenaltyLogitsProcessorRepeat] = [],
|
||
show_tqdm=False,
|
||
):
|
||
"""Generate audio codes in streaming setting or non-streaming setting.
|
||
Specifically speaking, generate audio codes when not all text tokens are prefilled.
|
||
|
||
Always pass a valid `past_key_values` to the method. The method does not do `prefill` by itself. It relies on `prefill_text` method to provide valid `past_key_values`. Please refer to docstring of this class for more details.
|
||
|
||
In this method, we borrowed a lot of codes from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/model/gpt.py`.
|
||
|
||
Args:
|
||
input_ids (torch.Tensor): Input token ids.
|
||
past_key_values (List[Tuple[torch.Tensor, torch.Tensor]]): Past key values for attention mechanism.
|
||
temperature (torch.Tensor): Temperature for sampling.
|
||
eos_token (Union[int, torch.Tensor]): End of sequence token.
|
||
streaming_tts_text_mask (Optional[torch.Tensor], optional): Mask for streaming TTS text. Defaults to None.
|
||
max_new_token (int, optional): Maximum number of new tokens to generate. Defaults to 50.
|
||
logits_warpers (List[LogitsWarper], optional): List of logits warpers. Defaults to [].
|
||
logits_processors (List[CustomRepetitionPenaltyLogitsProcessorRepeat], optional): List of logits processors. Defaults to [].
|
||
show_tqdm (bool, optional): Whether to show progress bar. Defaults to True.
|
||
|
||
Returns:
|
||
GenerationOutputs: Generation outputs.
|
||
"""
|
||
|
||
# We only support batch size `1` for now
|
||
assert input_ids.shape[0] == 1
|
||
assert past_key_values is not None
|
||
|
||
# fix: this should not be `input_ids.shape[1]`
|
||
# start_idx = input_ids.shape[1]
|
||
start_idx = 1 + self.num_spk_embs * self.use_speaker_embedding + self.streaming_text_reserved_len + 1
|
||
|
||
finish = torch.zeros(input_ids.shape[0], device=input_ids.device).bool()
|
||
|
||
temperature = temperature.unsqueeze(0).expand(input_ids.shape[0], -1).contiguous().view(-1, 1)
|
||
|
||
progress = input_ids.shape[1]
|
||
|
||
# Pre-allocate input_ids, shape is [batch_size=1, max_possible_seq_len, self.num_vqs]
|
||
input_ids_buf = torch.zeros(
|
||
input_ids.shape[0], # batch_size
|
||
progress + max_new_token, # max_possible_seq_len = input_ids.shape[1] + max_new_token
|
||
input_ids.shape[2], # self.num_vqs
|
||
dtype=input_ids.dtype,
|
||
device=input_ids.device,
|
||
)
|
||
|
||
# Copy existing `input_ids` to `input_ids_buf`
|
||
input_ids_buf.narrow(1, 0, progress).copy_(input_ids)
|
||
|
||
del input_ids
|
||
input_ids = input_ids_buf.narrow(1, 0, progress)
|
||
|
||
pbar: Optional[tqdm] = None
|
||
if show_tqdm:
|
||
pbar = tqdm(
|
||
total=max_new_token,
|
||
desc="code",
|
||
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}(max) [{elapsed}, {rate_fmt}{postfix}]",
|
||
)
|
||
|
||
condition_length = 1 + self.num_spk_embs * self.use_speaker_embedding + self.streaming_text_reserved_len + 1
|
||
|
||
for i in range(max_new_token):
|
||
# Prepare generation inputs
|
||
audio_bos = False
|
||
|
||
# If this is the first audio token, the case is SPECIAL
|
||
if progress == condition_length:
|
||
audio_bos = True
|
||
|
||
assert progress == (
|
||
past_key_values[0][0].shape[2] + 1
|
||
) # If you are using according to the guidelines, this should be passed.
|
||
|
||
if audio_bos:
|
||
# Generate the first token, activate the model with `self.audio_bos_token_id`, the model will predict a new audio token. This is a special case because without the `audio bos token`, it is impossible to generate the first audio token in our streaming setting.
|
||
narrowed_input_ids = torch.tensor([[self.audio_bos_token_id]], dtype=torch.long, device=self.device)
|
||
inputs_embeds = self.emb_text(narrowed_input_ids)
|
||
del narrowed_input_ids
|
||
else:
|
||
# Generate the following audio tokens, it is applicable to all other cases, including second and the following calling of `generate`.
|
||
narrowed_input_ids = input_ids.narrow(dim=1, start=input_ids.shape[1] - 1, length=1)
|
||
code_emb = [self.emb_code[i](narrowed_input_ids[:, :, i]) for i in range(self.num_vq)]
|
||
inputs_embeds = torch.stack(code_emb, 3).sum(3)
|
||
|
||
position_ids = torch.tensor(
|
||
[past_key_values[0][0].shape[2] + 1], dtype=torch.long, device=self.device
|
||
).unsqueeze(0)
|
||
|
||
cache_position = position_ids.clone()
|
||
|
||
# Make causal mask
|
||
causal_mask = make_streaming_chunk_mask_generation(
|
||
inputs_embeds=inputs_embeds,
|
||
past_seen_tokens=past_key_values[0][0].shape[2],
|
||
streaming_tts_text_mask=streaming_tts_text_mask,
|
||
streaming_reserved_length=self.streaming_text_reserved_len,
|
||
streaming_text_chunk_size=self.streaming_text_chunk_size,
|
||
)
|
||
|
||
# Model forward
|
||
outputs: BaseModelOutputWithPast = self.model(
|
||
attention_mask=causal_mask,
|
||
position_ids=position_ids,
|
||
past_key_values=past_key_values,
|
||
inputs_embeds=inputs_embeds,
|
||
use_cache=True,
|
||
output_attentions=False,
|
||
cache_position=cache_position,
|
||
)
|
||
|
||
del position_ids
|
||
del inputs_embeds
|
||
del cache_position
|
||
del causal_mask
|
||
|
||
hidden_states = outputs.last_hidden_state
|
||
past_key_values = outputs.past_key_values
|
||
|
||
with P.cached():
|
||
logits = torch.empty(
|
||
hidden_states.size(0),
|
||
hidden_states.size(1),
|
||
self.num_audio_tokens,
|
||
self.num_vq,
|
||
dtype=torch.float,
|
||
device=self.device,
|
||
)
|
||
for num_vq_iter in range(self.num_vq):
|
||
x: torch.Tensor = self.head_code[num_vq_iter](hidden_states)
|
||
logits[..., num_vq_iter] = x
|
||
del x
|
||
|
||
del hidden_states
|
||
|
||
# logits = logits[:, -1].float()
|
||
logits = logits.narrow(1, -1, 1).squeeze_(1).float()
|
||
|
||
# logits = rearrange(logits, "b c n -> (b n) c")
|
||
logits = logits.permute(0, 2, 1)
|
||
logits = logits.reshape(-1, logits.size(2))
|
||
# logits_token = rearrange(input_ids[:, start_idx:], "b c n -> (b n) c")
|
||
input_ids_sliced = input_ids.narrow(
|
||
1,
|
||
start_idx,
|
||
input_ids.size(1) - start_idx,
|
||
).permute(0, 2, 1)
|
||
logits_token = input_ids_sliced.reshape(
|
||
input_ids_sliced.size(0) * input_ids_sliced.size(1),
|
||
-1,
|
||
).to(self.device)
|
||
del input_ids_sliced
|
||
|
||
logits /= temperature
|
||
|
||
if not audio_bos:
|
||
for logitsProcessors in logits_processors:
|
||
logits = logitsProcessors(logits_token, logits)
|
||
if not audio_bos:
|
||
for logitsWarpers in logits_warpers:
|
||
logits = logitsWarpers(logits_token, logits)
|
||
|
||
del logits_token
|
||
|
||
if i < min_new_token:
|
||
logits[:, eos_token] = -torch.inf
|
||
|
||
if force_no_stop:
|
||
logits[:, eos_token] = -torch.inf
|
||
|
||
scores = F.softmax(logits, dim=-1)
|
||
|
||
del logits
|
||
idx_next = torch.multinomial(scores, num_samples=1) # .to(finish.device)
|
||
|
||
del scores
|
||
|
||
# idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq)
|
||
idx_next = idx_next.view(-1, self.num_vq)
|
||
finish_or = idx_next.eq(eos_token).any(1)
|
||
finish.logical_or_(finish_or)
|
||
|
||
del finish_or
|
||
# Store new `token` into `input_ids_buf`
|
||
input_ids_buf.narrow(1, progress, 1).copy_(idx_next.unsqueeze_(1))
|
||
|
||
if i == 0 and finish.any():
|
||
# raise Exception
|
||
break
|
||
|
||
del idx_next
|
||
progress += 1
|
||
input_ids = input_ids_buf.narrow(1, 0, progress)
|
||
|
||
if finish.all():
|
||
break
|
||
|
||
if pbar is not None:
|
||
pbar.update(1)
|
||
|
||
if pbar is not None:
|
||
pbar.close()
|
||
|
||
if not finish.all():
|
||
if show_tqdm:
|
||
logger.info(f"incomplete result. hit max_new_token: {max_new_token}")
|
||
|
||
del input_ids_buf
|
||
|
||
if finish.all():
|
||
# the last may contains eos token
|
||
genrated_input_ids = input_ids[:, condition_length:-1, :]
|
||
else:
|
||
# there is no eos token
|
||
genrated_input_ids = input_ids[:, condition_length:, :]
|
||
|
||
return ConditionalChatTTSGenerationOutput(
|
||
new_ids=genrated_input_ids,
|
||
audio_input_ids=input_ids, # for update purpose
|
||
past_key_values=past_key_values, # for update purpose
|
||
finished=finish.all(),
|
||
)
|
||
|
||
@torch.inference_mode()
|
||
def decode_to_mel_specs(
|
||
self,
|
||
result_list: List[torch.Tensor],
|
||
):
|
||
"""Decode discrete audio codes to mel spectrograms.
|
||
|
||
Borrowed from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/core.py`
|
||
|
||
Args:
|
||
result_list (List[torch.Tensor]): Audio codes output from `generate`.
|
||
|
||
Returns:
|
||
torch.Tensor: Mel spectrograms.
|
||
"""
|
||
|
||
decoder = self.dvae
|
||
max_x_len = -1
|
||
if len(result_list) == 0:
|
||
return np.array([], dtype=np.float32)
|
||
for result in result_list:
|
||
if result.size(0) > max_x_len:
|
||
max_x_len = result.size(0)
|
||
batch_result = torch.zeros(
|
||
(len(result_list), result_list[0].size(1), max_x_len),
|
||
dtype=result_list[0].dtype,
|
||
device=result_list[0].device,
|
||
)
|
||
for i in range(len(result_list)):
|
||
src = result_list[i]
|
||
batch_result[i].narrow(1, 0, src.size(0)).copy_(src.permute(1, 0))
|
||
del src
|
||
|
||
mel_specs = decoder(batch_result)
|
||
del batch_result
|
||
return mel_specs
|
||
|
||
|
||
# Borrowed from `https://github.com/2noise/ChatTTS/blob/main/ChatTTS/model/processors.py`
|
||
def gen_logits(
|
||
num_code: int,
|
||
top_P=0.7,
|
||
top_K=20,
|
||
repetition_penalty=1.0,
|
||
):
|
||
logits_warpers = []
|
||
if top_P is not None:
|
||
logits_warpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3))
|
||
if top_K is not None:
|
||
logits_warpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3))
|
||
|
||
logits_processors = []
|
||
if repetition_penalty is not None and repetition_penalty != 1:
|
||
logits_processors.append(CustomRepetitionPenaltyLogitsProcessorRepeat(repetition_penalty, num_code, 16))
|
||
|
||
return logits_warpers, logits_processors
|
||
|
||
|
||
# Copy and modified from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
|
||
def prepare_inputs_for_generation(
|
||
self,
|
||
input_ids,
|
||
past_key_values=None,
|
||
attention_mask=None,
|
||
inputs_embeds=None,
|
||
cache_position=None,
|
||
position_ids=None,
|
||
use_cache=True,
|
||
**kwargs,
|
||
):
|
||
if past_key_values is not None:
|
||
if isinstance(past_key_values, Cache):
|
||
cache_length = past_key_values.get_seq_length()
|
||
past_length = past_key_values.seen_tokens
|
||
else:
|
||
cache_length = past_length = past_key_values[0][0].shape[2]
|
||
|
||
# Keep only the unprocessed tokens:
|
||
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
||
# some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
|
||
# input)
|
||
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
||
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
||
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
||
# input_ids based on the past_length.
|
||
elif past_length < input_ids.shape[1]:
|
||
input_ids = input_ids[:, past_length:]
|
||
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
||
|
||
if attention_mask is not None and position_ids is None:
|
||
# create position_ids on the fly for batch generation
|
||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||
if past_key_values:
|
||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||
|
||
# This clo≠clo≠clone call is needed to avoid recapturing cuda graphs with →rch.comπ≤→rch.comπ≤torch.compile's mode=reduce−overheadmode=reduce-overheadmode="reduce-overhead, as otherwise the input positionidspositionidsposition_ids would have various stride during the decoding. Here, simply using .contiguous().contiguous().contiguous() is not sufficient as in the batch size = 1 case, positionidspositionidsposition_ids is already contiguous but with varying stride which retriggers a capture.
|
||
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
|
||
|
||
# if ∈putsembeds∈putsembedsinputs_embeds are passed, we only want to use them in the 1st generation step
|
||
if inputs_embeds is not None and cache_position[0] == 0:
|
||
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
|
||
else:
|
||
# The clone here is for the same reason as for positionidspositionidsposition_ids.
|
||
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
||
|
||
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
||
if model_inputs["inputs_embeds"] is not None:
|
||
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
||
device = model_inputs["inputs_embeds"].device
|
||
else:
|
||
batch_size, sequence_length = model_inputs["input_ids"].shape
|
||
device = model_inputs["input_ids"].device
|
||
|
||
dtype = self.lm_head.weight.dtype
|
||
min_dtype = torch.finfo(dtype).min
|
||
|
||
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|
||
attention_mask,
|
||
sequence_length=sequence_length,
|
||
target_length=past_key_values.get_max_length(),
|
||
dtype=dtype,
|
||
device=device,
|
||
min_dtype=min_dtype,
|
||
cache_position=cache_position,
|
||
batch_size=batch_size,
|
||
)
|
||
|
||
model_inputs.update(
|
||
{
|
||
"position_ids": position_ids,
|
||
# "cache_position": cache_position,
|
||
"past_key_values": past_key_values,
|
||
"use_cache": use_cache,
|
||
"attention_mask": attention_mask,
|
||
}
|
||
)
|
||
return model_inputs
|