866 lines
44 KiB
Python
866 lines
44 KiB
Python
import requests
|
||
import re, ujson, os, sys, fire, glob, random, time, json
|
||
import numpy as np
|
||
import io
|
||
import torch
|
||
from torch.utils.data import default_collate
|
||
import torchaudio
|
||
from typing import *
|
||
from dataclasses import dataclass, field
|
||
import transformers
|
||
from transformers.modeling_outputs import ModelOutput
|
||
from transformers.audio_utils import mel_filter_bank, spectrogram, window_function
|
||
from functools import lru_cache
|
||
from io import BytesIO
|
||
from PIL import Image
|
||
import concurrent.futures as cf
|
||
from transformers.image_transforms import resize, center_crop, get_resize_output_image_size
|
||
from transformers.image_utils import PILImageResampling
|
||
from PIL import Image, ImageOps
|
||
from PIL import ImageFile
|
||
torch.set_num_threads(1) # 限制torch的线程数 否则可能会卡住
|
||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||
import base64
|
||
from decord import VideoReader, cpu
|
||
import cv2
|
||
import av
|
||
import imagesize
|
||
import tempfile
|
||
import math
|
||
from multiprocessing import Pool
|
||
from cairosvg import svg2png
|
||
import hashlib
|
||
|
||
IMAGE_FACTOR = 28
|
||
MIN_PIXELS = 4 * 28 * 28
|
||
MAX_PIXELS = 16384 * 28 * 28
|
||
MAX_RATIO = 200
|
||
|
||
VIDEO_MIN_PIXELS = 128 * 28 * 28
|
||
VIDEO_MAX_PIXELS = 768 * 28 * 28
|
||
VIDEO_TOTAL_PIXELS = 24576 * 28 * 28
|
||
FRAME_FACTOR = 2
|
||
FPS = 2.0
|
||
FPS_MIN_FRAMES = 4
|
||
FPS_MAX_FRAMES = 768
|
||
|
||
def round_by_factor(number: int, factor: int) -> int:
|
||
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
|
||
return round(number / factor) * factor
|
||
|
||
|
||
def ceil_by_factor(number: int, factor: int) -> int:
|
||
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
|
||
return math.ceil(number / factor) * factor
|
||
|
||
|
||
def floor_by_factor(number: int, factor: int) -> int:
|
||
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
|
||
return math.floor(number / factor) * factor
|
||
|
||
|
||
def smart_resize(
|
||
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
|
||
) -> tuple[int, int]:
|
||
"""
|
||
Rescales the image so that the following conditions are met:
|
||
|
||
1. Both dimensions (height and width) are divisible by 'factor'.
|
||
|
||
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
||
|
||
3. The aspect ratio of the image is maintained as closely as possible.
|
||
"""
|
||
if max(height, width) / min(height, width) > MAX_RATIO:
|
||
raise ValueError(
|
||
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
|
||
)
|
||
h_bar = max(factor, round_by_factor(height, factor))
|
||
w_bar = max(factor, round_by_factor(width, factor))
|
||
if h_bar * w_bar > max_pixels:
|
||
beta = math.sqrt((height * width) / max_pixels)
|
||
h_bar = floor_by_factor(height / beta, factor)
|
||
w_bar = floor_by_factor(width / beta, factor)
|
||
elif h_bar * w_bar < min_pixels:
|
||
beta = math.sqrt(min_pixels / (height * width))
|
||
h_bar = ceil_by_factor(height * beta, factor)
|
||
w_bar = ceil_by_factor(width * beta, factor)
|
||
return h_bar, w_bar
|
||
|
||
|
||
def split_text(text, match_regex):
|
||
matches = list(re.finditer(match_regex, text))
|
||
# 初始化结果列表
|
||
result = []
|
||
match_flag_list = []
|
||
# 上一个匹配的结束位置
|
||
last_end = 0
|
||
# 遍历所有匹配项
|
||
for match in matches:
|
||
# 添加匹配项之前的部分
|
||
if text[last_end:match.start()]:
|
||
result.append(text[last_end:match.start()])
|
||
match_flag_list.append(False)
|
||
# 添加匹配项
|
||
result.append(match.group(0))
|
||
match_flag_list.append(True)
|
||
# 更新上一个匹配的结束位置
|
||
last_end = match.end()
|
||
# 添加最后一个匹配项之后的部分
|
||
if text[last_end:]:
|
||
result.append(text[last_end:])
|
||
match_flag_list.append(False)
|
||
return result, match_flag_list
|
||
|
||
|
||
def read_video(image_path, max_frame_number, decode_way):
|
||
if decode_way=='1fps':
|
||
try:
|
||
# print(image_path)
|
||
vr = VideoReader(image_path, ctx=cpu(0))
|
||
total_frame_num = len(vr)
|
||
fps = round(vr.get_avg_fps())
|
||
frame_idx = [i for i in range(0, len(vr), fps)]
|
||
frames = vr.get_batch(frame_idx).asnumpy()
|
||
cnt = len(frames)
|
||
frame_times = range(cnt)
|
||
except Exception as e:
|
||
print(image_path)
|
||
print('error is', e)
|
||
return None
|
||
elif decode_way=='key':
|
||
try:
|
||
with av.open(image_path) as container:
|
||
stream = container.streams.video[0]
|
||
stream.codec_context.skip_frame = 'NONKEY'
|
||
frames = []
|
||
frame_times = []
|
||
fps = int(stream.average_rate)
|
||
cnt = 0
|
||
for frame in container.decode(stream): # 关键帧存成image patch
|
||
image = np.array(frame.to_image())
|
||
frames.append(image)
|
||
frame_time = int(frame.time)
|
||
frame_times.append(frame_time)
|
||
cnt += 1
|
||
except Exception as e:
|
||
print('error is', e)
|
||
return None
|
||
if frames is None or len(frames)==0:
|
||
return None
|
||
if len(frames)>max_frame_number and max_frame_number>0:
|
||
# 生成14个均匀间隔的索引
|
||
indices = np.linspace(0, len(frames) - 1, max_frame_number, dtype=int)
|
||
# 根据索引获取对应元素
|
||
frames = frames[indices]
|
||
frame_times = frame_times[indices]
|
||
return frames, frame_times
|
||
|
||
|
||
class OmniImageProcessor:
|
||
def __init__(self, config, **kwargs):
|
||
self.config = config # visual_config
|
||
self.min_pixels = self.config.min_pixels if hasattr(self.config, 'min_pixels') else 56 * 56
|
||
self.max_pixels = self.config.max_pixels if hasattr(self.config, 'max_pixels') else 28 * 28 * 1280
|
||
self.patch_size = self.config.patch_size if hasattr(self.config, 'patch_size') else 14
|
||
self.temporal_patch_size = self.config.temporal_patch_size if hasattr(self.config, 'temporal_patch_size') else 2
|
||
self.merge_size = self.config.merge_size if hasattr(self.config, 'merge_size') else 2
|
||
self.spatial_merge_size = self.config.spatial_merge_size if hasattr(self.config, 'spatial_merge_size') else 2
|
||
|
||
def image_transform(self, strseq, return_mm_data = True):
|
||
image = None
|
||
if isinstance(strseq, str):
|
||
if return_mm_data:
|
||
image = Image.open(strseq).convert("RGB")
|
||
else:
|
||
try:
|
||
image = Image.open(BytesIO(strseq)).convert("RGB")
|
||
except:
|
||
image = Image.open(BytesIO(svg2png(bytestring=strseq))).convert("RGB") # interleaved有的是矢量图,需要转换
|
||
|
||
image = np.array(image.convert("RGB")) # 这一步首先将图像转换为 RGB 格式,确保图像有三个通道(R、G、B)。然后使用 np.array() 将其转换为 NumPy 数组,方便后续处理。
|
||
image_org_size = image.shape[:2] # 这里保存了图像的原始大小(高度和宽度),image.shape 返回图像的形状 (高度, 宽度, 通道数),而 image.shape[:2] 提取了前两个值,即原始的高度和宽度。这个信息可以用于后续的对比或其他处理。
|
||
|
||
# resize, crop, scale, normalize
|
||
# 输出一个新的尺寸,这个尺寸通常是 (宽度, 高度) 格式,用于后续的图像调整操作,如缩放或裁剪。
|
||
resized_height, resized_width = smart_resize(
|
||
image_org_size[0], image_org_size[1],
|
||
factor=self.patch_size * self.spatial_merge_size,
|
||
min_pixels=self.min_pixels,
|
||
max_pixels=self.max_pixels,
|
||
)
|
||
output_size = (resized_height, resized_width)
|
||
|
||
# 使用 resize 函数将图像调整到 output_size 大小。PILImageResampling.BICUBIC 指定使用双三次插值法来进行图像缩放,这种方法通常能够提供较好的图像质量。
|
||
# image: 输入的图像数据,可以是 NumPy 数组或 PIL 图像对象;output_size: 目标大小,通常是一个二元组 (宽度, 高度)。这个尺寸可以是图像的绝对大小,也可以是相对于原始图像的比例;
|
||
# resample: 可选的重采样方法,通常用于确定如何插值像素。例如,PILImageResampling.BICUBIC 表示使用双三次插值法,这是一种平滑的插值方法,常用于图像缩放。
|
||
image = resize(image, output_size, PILImageResampling.BICUBIC)
|
||
img = image.transpose(2, 0, 1)
|
||
# 对图像进行归一化和标准化处理
|
||
image = (img / 255.0 - np.array(self.config.image_mean)[:, np.newaxis, np.newaxis]) / np.array(self.config.image_std)[:,np.newaxis,np.newaxis]
|
||
# 处理成patch
|
||
patches = image[np.newaxis, :]
|
||
if patches.shape[0] == 1:
|
||
patches = np.tile(patches, (self.temporal_patch_size, 1, 1, 1))
|
||
channel = patches.shape[1]
|
||
grid_t = patches.shape[0] // self.temporal_patch_size
|
||
grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
|
||
patches = patches.reshape(
|
||
grid_t,
|
||
self.temporal_patch_size,
|
||
channel,
|
||
grid_h // self.spatial_merge_size,
|
||
self.spatial_merge_size,
|
||
self.patch_size,
|
||
grid_w // self.spatial_merge_size,
|
||
self.spatial_merge_size,
|
||
self.patch_size,
|
||
)
|
||
patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8)
|
||
flatten_patches = patches.reshape(
|
||
grid_t * grid_h * grid_w, channel * self.temporal_patch_size * self.patch_size * self.patch_size
|
||
)
|
||
|
||
return flatten_patches, image_org_size, (grid_t, grid_h, grid_w)
|
||
|
||
|
||
class OmniAudioProcessor:
|
||
# 包含基本的音频特征抽取模块 + 输入数据解析模块
|
||
def __init__(
|
||
self,
|
||
config, # audio processor config
|
||
**kwargs
|
||
):
|
||
# make sure you have install 'conda install -c conda-forge 'ffmpeg<7'' for torchaudio
|
||
assert(len(torchaudio.list_audio_backends()) > 0)
|
||
self.config = config
|
||
self.mel_filters = mel_filter_bank(
|
||
num_frequency_bins=1 + self.config.n_fft // 2,
|
||
num_mel_filters=self.config.num_mel_bins,
|
||
min_frequency=0.0,
|
||
max_frequency=self.config.sampling_rate / 2.0,
|
||
sampling_rate=self.config.sampling_rate,
|
||
norm="slaney",
|
||
mel_scale="slaney",
|
||
)
|
||
self.window = torch.hann_window(self.config.n_fft)
|
||
|
||
@staticmethod
|
||
def dynamic_range_compression(x, C=1, clip_val=1e-6):
|
||
return torch.log(torch.clamp(x, min=clip_val) * C)
|
||
|
||
@staticmethod
|
||
def zero_mean_unit_var_norm(x):
|
||
return (x - x.mean()) / torch.sqrt(x.var() + 1e-8)
|
||
|
||
def load_audio_waveform(self, uri, return_tensors=True, do_normalize=False):
|
||
metadata = torchaudio.info(uri) # sample_rate, num_frames, num_channels, bits_per_sample, encoding=PCM_S
|
||
assert(metadata.num_channels <= 2), "acoustic file with {} channels.".format(metadata.num_channels) # whisper only accept mono channel audio
|
||
waveform_tensor, _ = torchaudio.load(uri, normalize=True)
|
||
if self.config.sampling_rate != metadata.sample_rate:
|
||
waveform_tensor = torchaudio.functional.resample(waveform_tensor, metadata.sample_rate, self.config.sampling_rate, lowpass_filter_width=128)
|
||
|
||
# downmix to mono channel https://trac.ffmpeg.org/wiki/AudioChannelManipulation
|
||
if metadata.num_channels > 1:
|
||
waveform_tensor = torch.mean(waveform_tensor, dim=0, keepdim=True)
|
||
|
||
# normalized to zero mean
|
||
if do_normalize:
|
||
waveform_tensor = self.zero_mean_unit_var_norm(waveform_tensor)
|
||
|
||
if return_tensors: # (channels, samples)
|
||
return waveform_tensor
|
||
else:
|
||
return waveform_tensor.numpy()
|
||
|
||
def split_with_overlap(self, waveform): # 如果长度超过最大长度限制 分割为带overlap的多段
|
||
channels, wave_samples = waveform.shape
|
||
max_audio_samples = self.config.max_audio_seconds * self.config.sampling_rate
|
||
if wave_samples <= max_audio_samples or self.config.split_overlap < 0:
|
||
return [waveform] # 没有超出最大长度or截断逻辑 统一返回list
|
||
|
||
split_waveform, start = [], 0
|
||
while start < wave_samples: # 统一按秒数对齐overlap
|
||
if start > int(self.config.sampling_rate * self.config.split_overlap):
|
||
start -= int(self.config.sampling_rate * self.config.split_overlap) # 0表示没有overlap,>0 overlap对应秒数
|
||
end = min(start + max_audio_samples, wave_samples)
|
||
if end - start>= self.config.n_fft: # 保证至少有一帧数据
|
||
split_waveform.append(waveform[:, start:end]) # 注意这里可能会切割出特别短的片段 需要在预处理判断并丢弃
|
||
start = end
|
||
return split_waveform
|
||
|
||
@classmethod
|
||
def inference_output_length(cls, config, input_length):
|
||
# for whisper + bridge
|
||
kernel_size = config.kernel_size
|
||
stride_size = config.stride_size
|
||
avg_pooler = config.avg_pooler
|
||
encoder_length = (input_length + 2 * (kernel_size // 2) - kernel_size) // 1 + 1 # conv layer1 with pad=1
|
||
encoder_length = (encoder_length + 2 * (kernel_size // 2) - kernel_size) // stride_size + 1 # conv layer2 with pad=1
|
||
if avg_pooler > 1:
|
||
bridge_length = encoder_length // avg_pooler
|
||
return encoder_length, bridge_length
|
||
|
||
def extract_fbank_features(self, waveform):
|
||
# ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/feature_extraction_whisper.py
|
||
channels, wave_samples = waveform.shape
|
||
assert(wave_samples >= self.config.n_fft)
|
||
valid_frame_nums = min(self.config.max_audio_seconds * self.config.sampling_rate // self.config.hop_length, wave_samples // self.config.hop_length + 1)
|
||
if wave_samples < self.config.max_audio_seconds * self.config.sampling_rate:
|
||
waveform = torch.nn.functional.pad(waveform, (0, self.config.max_audio_seconds * self.config.sampling_rate - wave_samples), "constant", 0)
|
||
else:
|
||
waveform = waveform[:, :self.config.max_audio_seconds * self.config.sampling_rate]
|
||
|
||
# window = torch.hann_window(self.config.n_fft)
|
||
stft = torch.stft(waveform, self.config.n_fft, self.config.hop_length, window=self.window, return_complex=True) # fft, len(wave) // n_fft // 2 + 1
|
||
magnitudes = stft[..., :-1].abs() ** 2
|
||
|
||
mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32)
|
||
mel_spec = mel_filters.T @ magnitudes
|
||
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
||
if waveform.dim() == 2:
|
||
max_val = log_spec.max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0]
|
||
log_spec = torch.maximum(log_spec, max_val - 8.0)
|
||
else:
|
||
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
||
log_spec = (log_spec + 4.0) / 4.0
|
||
|
||
log_spec = log_spec[0].numpy() # (channel, filters, samples) -> (filters, samples)
|
||
log_spec[:, valid_frame_nums:] = 0.0 # pad0
|
||
|
||
return log_spec, valid_frame_nums
|
||
|
||
def data_augment(self, feature: np.array, input_length, training=True):
|
||
# reference https://arxiv.org/pdf/1904.08779
|
||
def mask_start_indices(input_length, mask_length, min_masks, mask_prob):
|
||
num_masked_span = int(mask_prob * input_length / mask_length + random.random())
|
||
num_masked_span = max(num_masked_span, min_masks)
|
||
start_indices = list(range(input_length - mask_length))
|
||
random.shuffle(start_indices)
|
||
start_indices = start_indices[:num_masked_span]
|
||
return start_indices
|
||
|
||
if not training or (self.config.mask_time_prob <= 0 and self.config.mask_feature_prob <= 0):
|
||
return feature
|
||
if input_length < self.config.mask_time_length * self.config.mask_time_min_masks + 1:
|
||
return feature
|
||
if self.config.num_mel_bins < self.config.mask_feature_length * self.config.mask_feature_min_masks + 1:
|
||
return feature
|
||
|
||
if self.config.mask_time_prob > 0:
|
||
start_indices = mask_start_indices(input_length, self.config.mask_time_length, self.config.mask_time_min_masks, self.config.mask_time_prob)
|
||
for start_idx in start_indices:
|
||
feature[:, start_idx: start_idx + self.config.mask_time_length] = 0.0
|
||
if self.config.mask_feature_prob > 0:
|
||
start_indices = mask_start_indices(self.config.num_mel_bins, self.config.mask_feature_length, self.config.mask_feature_min_masks, self.config.mask_feature_prob)
|
||
for start_idx in start_indices:
|
||
feature[start_idx: start_idx + self.config.mask_feature_length, :] = 0.0
|
||
|
||
return feature
|
||
|
||
@dataclass
|
||
class OmniProcessorOutput(ModelOutput):
|
||
input_ids: Optional["List|torch.Tensor"] = None
|
||
labels: Optional["List|torch.Tensor"] = None
|
||
attention_mask: Optional["List|torch.Tensor"] = None
|
||
position_ids: Optional["List|torch.Tensor"] = None
|
||
seqlens: Optional["List|torch.Tensor"] = None # 需要配合Omni Modeling使用
|
||
# audio fields
|
||
audios: Optional["List|torch.Tensor"] = None
|
||
encoder_length: Optional["List|torch.Tensor"] = None
|
||
bridge_length: Optional["List|torch.Tensor"] = None
|
||
# image fields
|
||
images: Optional["List|torch.Tensor"] = None
|
||
patch_nums: Optional["List|torch.Tensor"] = None
|
||
images_size: Optional["List|torch.Tensor"] = None
|
||
crop_size: Optional["List|torch.Tensor"] = None
|
||
images_grid: Optional["List|torch.Tensor"] = None
|
||
# video fields
|
||
videos: Optional["List|torch.Tensor"] = None
|
||
videos_patch_nums: Optional["List|torch.Tensor"] = None
|
||
videos_size: Optional["List|torch.Tensor"] = None
|
||
videos_crop_size: Optional["List|torch.Tensor"] = None
|
||
videos_grid: Optional["List|torch.Tensor"] = None
|
||
# processor fields
|
||
raw_text: Optional[str] = None
|
||
index: Optional[int] = None
|
||
|
||
def concatenate(self, other): # 仅限list使用
|
||
def concat_one(a, b):
|
||
if a is None and b is None:
|
||
return None
|
||
elif a is None and b is not None:
|
||
return b
|
||
elif a is not None and b is None:
|
||
return a
|
||
else:
|
||
return a + b
|
||
return OmniProcessorOutput(
|
||
input_ids=concat_one(self.input_ids, other.input_ids),
|
||
labels=concat_one(self.labels, other.labels),
|
||
audios=concat_one(self.audios, other.audios),
|
||
encoder_length=concat_one(self.encoder_length, other.encoder_length),
|
||
bridge_length=concat_one(self.bridge_length, other.bridge_length),
|
||
images=concat_one(self.images, other.images),
|
||
images_grid=concat_one(self.images_grid, other.images_grid),
|
||
patch_nums=concat_one(self.patch_nums, other.patch_nums),
|
||
|
||
videos=concat_one(self.videos, other.videos),
|
||
videos_grid=concat_one(self.videos_grid, other.videos_grid),
|
||
videos_patch_nums=concat_one(self.videos_patch_nums, other.videos_patch_nums),
|
||
|
||
position_ids=concat_one(self.position_ids, other.position_ids),
|
||
seqlens=concat_one(self.seqlens, other.seqlens),
|
||
images_size=concat_one(self.images_size, other.images_size),
|
||
videos_size=concat_one(self.videos_size, other.videos_size),
|
||
index = self.index # concat保持index不变
|
||
)
|
||
|
||
class OmniMMProcessor(object):
|
||
def __init__(self,
|
||
tokenizer: transformers.PreTrainedTokenizer,
|
||
config,
|
||
training,
|
||
relative_path=None,
|
||
parallel=None,
|
||
**kwargs,
|
||
):
|
||
self.tokenizer = tokenizer
|
||
self.config = config
|
||
self.audio_processor = OmniAudioProcessor(config.audio_config)
|
||
self.visual_processor = None
|
||
if hasattr(config, "visual_config"):
|
||
self.visual_processor = OmniImageProcessor(config.visual_config)
|
||
self.video_processor = None
|
||
if hasattr(config, "video_config"):
|
||
self.video_processor = OmniImageProcessor(config.video_config)
|
||
self.training = training
|
||
self.relative_path = relative_path
|
||
self.parallel = parallel
|
||
# audio tag
|
||
self.audio_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_start_token_id)
|
||
self.audio_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_end_token_id)
|
||
self.audio_pad_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_pad_token_id)
|
||
self.audio_delim_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_delim_token_id)
|
||
self.audiogen_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audiogen_start_token_id)
|
||
self.audiogen_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audiogen_end_token_id)
|
||
# image tag
|
||
self.image_start_tag = None
|
||
self.image_end_tag = None
|
||
self.image_pad_tag = None
|
||
self.video_start_tag = None
|
||
self.video_end_tag = None
|
||
# videoframe tag只是为了兼容图片帧作为输入的情况,没有token id,在抽取视频帧的时候,会将这个替换成image tag的start、end
|
||
self.videoframe_start_tag = '<videoframe_start_omni>'
|
||
self.videoframe_end_tag = '<videoframe_end_omni>'
|
||
if hasattr(self.config, "visual_config"):
|
||
# special token for start_tag
|
||
self.image_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_start_token_id)
|
||
# special token for end_tag
|
||
self.image_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_end_token_id)
|
||
# special token for pad_tag
|
||
self.image_pad_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_pad_token_id)
|
||
self.image_line_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_line_token_id)
|
||
self.image_delimiter_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_delimiter_token_id)
|
||
if hasattr(self.config, "video_config"):
|
||
self.video_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.video_start_token_id)
|
||
self.video_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.video_end_token_id)
|
||
self.image_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.image_start_token_id)
|
||
self.image_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.image_end_token_id)
|
||
self.image_pad_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.image_pad_token_id)
|
||
self.video_place_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.video_place_token_id)
|
||
|
||
self.frame_pattern = getattr(self.config.video_config, 'frame_pattern', '<frame>')
|
||
|
||
|
||
# @lru_cache(maxsize=1024)
|
||
def _get_audio(self, audio_info):
|
||
try:
|
||
audio_info = ujson.loads(audio_info)
|
||
if 'path' in audio_info.keys():
|
||
audio_uri = None
|
||
if os.path.exists(audio_info['path']):
|
||
audio_uri = audio_info['path']
|
||
elif self.relative_path is not None:
|
||
audio_uri = os.path.join(self.relative_path, audio_info['path'].lstrip('/'))
|
||
if not os.path.exists(audio_uri):
|
||
audio_uri = None
|
||
if audio_uri is not None:
|
||
waveform = self.audio_processor.load_audio_waveform(audio_uri, True)
|
||
waveforms = self.audio_processor.split_with_overlap(waveform)
|
||
|
||
ret = OmniProcessorOutput() # 默认初始化 audios字段为None
|
||
for i, waveform in enumerate(waveforms): #(zip(waveforms,vocoder_waveforms)):
|
||
audio, input_length = self.audio_processor.extract_fbank_features(waveform)
|
||
audio = self.audio_processor.data_augment(audio, input_length, self.training)
|
||
encoder_length, bridge_length = self.audio_processor.inference_output_length(self.config.audio_config, input_length)
|
||
if bridge_length <= 0:
|
||
continue
|
||
current_ret = OmniProcessorOutput(
|
||
audios=[audio[:,:input_length]],
|
||
encoder_length=[encoder_length],
|
||
bridge_length=[bridge_length],
|
||
)
|
||
if ret.audios is None:
|
||
ret = current_ret
|
||
else:
|
||
ret = ret.concatenate(current_ret) # 拼接多个切片
|
||
return ret
|
||
else:
|
||
raise ValueError("can not find path in audio_info")
|
||
except Exception as e:
|
||
print("**** get audio error: {}, info: {} *****".format(str(e), str(audio_info)))
|
||
return OmniProcessorOutput()
|
||
|
||
# @lru_cache(maxsize=1024)
|
||
def _get_image(self, image_info):
|
||
try:
|
||
try:
|
||
image_info = ujson.loads(image_info)
|
||
except:
|
||
image_info = re.sub(r"(?<!\\)'", '"', image_info)
|
||
image_info = ujson.loads(image_info)
|
||
if 'base64' in image_info.keys():
|
||
image_data = base64.b64decode(image_info['base64'])
|
||
image_feat, org_size, image_list = self.visual_processor.image_transform(image_data)
|
||
elif 'local' in image_info.keys():
|
||
image_feat, org_size, image_list = self.visual_processor.image_transform(image_info['local'])
|
||
elif 'path' in image_info.keys() and os.path.exists(image_info['path']):
|
||
image_feat, org_size, image_list = self.visual_processor.image_transform(image_info['path'])
|
||
elif 'url' in image_info.keys():
|
||
image_bytes = self._get_vision_obj_byte('url', image_info['url'])
|
||
image_feat, org_size, image_list = self.visual_processor.image_transform(image_bytes)
|
||
else:
|
||
raise ValueError("can not find any path in image_info")
|
||
|
||
merge_length = self.visual_processor.merge_size**2
|
||
patch_nums = np.array(image_list).prod() // merge_length
|
||
|
||
if org_size[0] * org_size[1] > 16**2: # 极端小的图过滤
|
||
return OmniProcessorOutput(
|
||
images=[image_feat],
|
||
patch_nums=[patch_nums],
|
||
crop_size=[image_list],
|
||
images_size= [org_size],
|
||
images_grid=[image_list]
|
||
)
|
||
else:
|
||
print("**** image too small: {}, info: {} *****".format(str(org_size), str(image_info)))
|
||
return OmniProcessorOutput()
|
||
|
||
except Exception as e:
|
||
print("**** get image error: {}, info: {} *****".format(str(e), str(image_info)))
|
||
return OmniProcessorOutput()
|
||
|
||
# @lru_cache(maxsize=1024)
|
||
def _get_video_frame(self, video_frame_infos):
|
||
try:
|
||
pattern = r'\{.*?\}'
|
||
matches = re.findall(pattern, video_frame_infos)
|
||
ret = OmniProcessorOutput()
|
||
# 逐个解析
|
||
for match in matches:
|
||
video_frame_info = ujson.loads(match)
|
||
# video_frame_info = ujson.loads(video_frame_info)
|
||
if 'local' in video_frame_info.keys():
|
||
image_feat, org_size, image_list = self.video_processor.image_transform(video_frame_info['local'])
|
||
elif 'path' in video_frame_info.keys() and os.path.exists(video_frame_info['path']):
|
||
image_feat, org_size, image_list = self.video_processor.image_transform(video_frame_info['path'])
|
||
else:
|
||
raise ValueError("can not find any path in video_info")
|
||
|
||
merge_length = self.video_processor.merge_size**2
|
||
patch_nums = np.array(image_list).prod() // merge_length
|
||
|
||
if org_size[0] * org_size[1] > 16**2: # 极端小的图过滤
|
||
ret = ret.concatenate(
|
||
OmniProcessorOutput(
|
||
videos=[image_feat],
|
||
videos_patch_nums=[patch_nums],
|
||
videos_crop_size=[image_list],
|
||
videos_size= [org_size],
|
||
videos_grid=[image_list]
|
||
)
|
||
)
|
||
else:
|
||
print("**** video too small: {}, info: {} *****".format(str(org_size), str(video_frame_info)))
|
||
return ret
|
||
|
||
except Exception as e:
|
||
print("**** get video error: {}, info: {} *****".format(str(e), str(video_frame_info)))
|
||
return OmniProcessorOutput()
|
||
|
||
# 读取视频
|
||
def _get_vision_obj_byte(self, source, path):
|
||
vision_obj_byte = None
|
||
if source == "local":
|
||
if os.path.exists(path):
|
||
vision_obj_byte = open(path, "rb").read()
|
||
else:
|
||
vision_obj_byte = None
|
||
if source == "base64":
|
||
vision_obj_byte = base64.b64decode(path)
|
||
if source == "url":
|
||
vision_obj_byte = requests.get(url=path).content
|
||
return vision_obj_byte
|
||
|
||
# 将视频切分为帧,保存至子目录中
|
||
def _split_video_to_frames(self, video_info, max_frame_number=-1, decode_way="1fps"):
|
||
if decode_way=='1fps':
|
||
frame_suffix = f'_frames'
|
||
elif decode_way=='key':
|
||
frame_suffix = f'_keyframes'
|
||
else:
|
||
raise ValueError('unvalid decode way!!!')
|
||
|
||
server = "local"
|
||
if 'local' in video_info.keys():
|
||
# 本地路径
|
||
video_path = video_info['local']
|
||
# 帧保存本地路径
|
||
frame_path = video_path.split('.')[0] + frame_suffix
|
||
mm_obj_byte = self._get_vision_obj_byte('local', video_path)
|
||
elif 'base64' in video_info.keys():
|
||
md5 = hashlib.md5(video_info['base64'].encode('utf-8')).hexdigest()
|
||
if self.relative_path is not None:
|
||
video_path = os.path.join(self.relative_path, md5)
|
||
else:
|
||
video_path = os.path.join(os.getcwd(), md5)
|
||
frame_path = md5 + frame_suffix
|
||
mm_obj_byte = self._get_vision_obj_byte('base64', video_info['base64'])
|
||
elif 'url' in video_info.keys():
|
||
md5 = hashlib.md5(video_info['url'].encode('utf-8')).hexdigest()
|
||
if self.relative_path is not None:
|
||
video_path = os.path.join(self.relative_path, md5)
|
||
else:
|
||
video_path = os.path.join(os.getcwd(), md5)
|
||
frame_path = md5 + frame_suffix
|
||
mm_obj_byte = self._get_vision_obj_byte('url', video_info['url'])
|
||
else:
|
||
raise ValueError('unvalid video server !!!')
|
||
return ""
|
||
|
||
if mm_obj_byte is None: # 未读取到视频文件
|
||
return ""
|
||
if not os.path.exists(frame_path) or len(os.listdir(frame_path))==0:
|
||
# 保存帧
|
||
os.makedirs(frame_path, exist_ok=True)
|
||
frames, frame_times = read_video(io.BytesIO(mm_obj_byte), max_frame_number=-1, decode_way=decode_way) #读取全部帧
|
||
for frame_idx, frame in enumerate(frames):
|
||
output_filename = os.path.join(frame_path, f"{frame_times[frame_idx]}.jpg")
|
||
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
||
cv2.imwrite(output_filename, frame)
|
||
frame_paths = os.listdir(frame_path)
|
||
|
||
# 选取帧
|
||
frame_times = [int(filename.split('/')[-1].replace('.jpg', '')) for filename in frame_paths if filename.endswith('.jpg')] # 文件名对应秒数
|
||
frame_times.sort() #从小到大排序
|
||
frame_number = len(frame_times)
|
||
if frame_number > max_frame_number:
|
||
indices = np.linspace(0, frame_number - 1, max_frame_number, dtype=int)
|
||
else:
|
||
indices = np.linspace(0, frame_number - 1, frame_number, dtype=int)
|
||
# 拼接模式
|
||
replace_str = ""
|
||
for frame_idx, idx in enumerate(indices):
|
||
frame_time = frame_times[idx] # frame_time表示帧对应的时间 单位为s 同时也是存储的文件名
|
||
frame_dict = {"local": os.path.join(frame_path, f'{frame_time}.jpg')}
|
||
frame_str = self.frame_pattern.format(frame_idx) if '{}' in self.frame_pattern else self.frame_pattern # {}对应的是第几张图片
|
||
frame_str = frame_str.replace('<TIMEIDX>', str(frame_time)) # TIMEIDX对应的是第几秒
|
||
frame_str = frame_str.replace('<TIMESTAMP>', time.strftime("%H:%M:%S", time.gmtime(frame_time))) # TIMESTAMP对应的是时间戳
|
||
frame_str = frame_str.replace('<frame>', f'{self.image_start_tag}{json.dumps(frame_dict)}{self.image_end_tag}')
|
||
replace_str += frame_str
|
||
|
||
return replace_str
|
||
|
||
def sample_frame(self,frames_str,max_frame = 32):
|
||
def uniform_sample(lst, num_samples):
|
||
if num_samples > len(lst):
|
||
return lst
|
||
interval = len(lst) / num_samples
|
||
samples = [lst[int(i * interval)] for i in range(num_samples)]
|
||
return samples
|
||
p = rf'({self.image_start_tag}.*?{self.image_end_tag})'
|
||
frames_str_split = re.split(p,frames_str)
|
||
frame_idxs = [idx for idx in range(len(frames_str_split)) if self.image_start_tag in frames_str_split[idx]]
|
||
sample_frame_idxs = set(uniform_sample(frame_idxs, max_frame))
|
||
return ''.join([item for idx,item in enumerate(frames_str_split) if idx in sample_frame_idxs or self.image_start_tag not in frames_str_split[idx]])
|
||
|
||
def _get_video_frame_str(self, video_info):
|
||
try:
|
||
if self.videoframe_start_tag in video_info:#如果是以视频帧的形式表示一个视频,则替换成image tag
|
||
frames_str = video_info
|
||
frames_str = frames_str.replace(self.videoframe_start_tag,self.image_start_tag).replace(self.videoframe_end_tag,self.image_end_tag)
|
||
return self.sample_frame(frames_str, max_frame = self.config.video_config.max_frame_num)
|
||
video_info = ujson.loads(video_info)
|
||
# 获取包含多帧图像路径的字符串,最大帧数量max_frame_number
|
||
frames_str = self._split_video_to_frames(video_info, max_frame_number=self.config.video_config.max_frame_num, decode_way=self.config.video_config.decode_way)
|
||
return frames_str
|
||
except Exception as e:
|
||
print("**** get video error: {}, info: {} *****".format(str(e), str(video_info)))
|
||
return ""
|
||
|
||
def _replace_image(self, image_text):
|
||
image_info = re.sub(re.compile(self.image_start_tag + "|" + self.image_end_tag), '', image_text)
|
||
ret = self._get_image(image_info) # 重复取结果 cached result
|
||
if ret.patch_nums is None:
|
||
return ''
|
||
return ret, self.image_start_tag + self.image_pad_tag * ret.patch_nums[0] + self.image_end_tag
|
||
|
||
def _replace_video_frame(self, video_frame_text):
|
||
video_frame_info = re.sub(re.compile(self.image_start_tag + "|" + self.image_end_tag), '', video_frame_text)
|
||
ret = self._get_video_frame(video_frame_info) # 重复取结果 cached result
|
||
if ret.videos_patch_nums is None:
|
||
return ''
|
||
video_frame_str = [self.image_start_tag + self.video_place_tag * ret.videos_patch_nums[i] + self.image_end_tag for i in range(len(ret.videos_patch_nums))]
|
||
return ret, ''.join(video_frame_str)
|
||
|
||
|
||
def split_multimodal_chunk(self, text_list, mm_label_list, trainable_list, mtype='audio'):
|
||
# 抽取text中的json格式音频/图像信息,读取并转化为特征,同时估计encoder token数,填入对应数量的pad token
|
||
if (self.audio_start_tag != None) and (mtype == 'audio'):
|
||
match_regex = re.compile(self.audio_start_tag + '.*?' + self.audio_end_tag,re.S)
|
||
drop_regex = re.compile(self.audio_start_tag + "|" + self.audio_end_tag,re.S)
|
||
elif (self.image_start_tag != None) and (mtype == 'image'):
|
||
match_regex = re.compile(self.image_start_tag + '.*?' + self.image_end_tag,re.S)
|
||
drop_regex = re.compile(self.image_start_tag + "|" + self.image_end_tag,re.S)
|
||
elif (self.audiogen_start_tag != None) and (mtype == 'audiogen'):
|
||
match_regex = re.compile(self.audiogen_start_tag + '.*?' + self.audiogen_end_tag,re.S)
|
||
drop_regex = re.compile(self.audiogen_start_tag + "|" + self.audiogen_end_tag,re.S)
|
||
elif (self.video_start_tag != None) and (mtype == 'video'):
|
||
match_regex = re.compile(self.video_start_tag + '.*?' + self.video_end_tag,re.S)
|
||
drop_regex = re.compile(self.video_start_tag + "|" + self.video_end_tag,re.S)
|
||
else:
|
||
raise ValueError("mtype not supportted!")
|
||
new_text_list = []
|
||
new_mm_label_list = []
|
||
new_trainable_flag_list = []
|
||
for text,mm_label,trainable in zip(text_list,mm_label_list,trainable_list):
|
||
for t,m in zip(*split_text(text, match_regex)):
|
||
new_trainable_flag_list.append(trainable)
|
||
if m:
|
||
new_text_list.append(re.sub(drop_regex, '', t))
|
||
new_mm_label_list.append(mtype)
|
||
else:
|
||
new_text_list.append(t)
|
||
new_mm_label_list.append(mm_label)
|
||
return new_text_list, new_mm_label_list, new_trainable_flag_list
|
||
|
||
def process_multimodal_chunk(self, text, mm_label, trainable):
|
||
ret = OmniProcessorOutput()
|
||
if mm_label == 'audio':
|
||
ret = self._get_audio(text)
|
||
if ret.bridge_length is not None:
|
||
ret.input_ids = self.tokenizer.encode(self.audio_start_tag,add_special_tokens=False) + self.tokenizer.encode(self.audio_pad_tag,add_special_tokens=False) * sum(ret.bridge_length) + self.tokenizer.encode(self.audio_end_tag,add_special_tokens=False)
|
||
else:
|
||
raise ValueError(f"Get audio data Failed at Process audio chunk {text}")
|
||
elif mm_label == 'audiogen':
|
||
ret = self._get_audio(text)
|
||
if ret.bridge_length is not None:
|
||
ret.input_ids = self.tokenizer.encode(self.audiogen_start_tag,add_special_tokens=False) + self.tokenizer.encode(self.audio_pad_tag,add_special_tokens=False) * sum(ret.bridge_length) + self.tokenizer.encode(self.audiogen_end_tag,add_special_tokens=False)
|
||
else:
|
||
raise ValueError(f"Get audio data Failed at Process audio chunk {text}")
|
||
elif mm_label == 'image':
|
||
ret, input_str = self._replace_image(text)
|
||
if input_str:
|
||
ret.input_ids = self.tokenizer.encode(input_str, add_special_tokens=False)
|
||
else:
|
||
raise ValueError("Get image data Failed at Process image chunk")
|
||
elif mm_label == 'video':
|
||
frame_str = self.video_start_tag+self._get_video_frame_str(text)+self.video_end_tag
|
||
ret, input_str = self._replace_video_frame(frame_str)
|
||
if input_str:
|
||
ret.input_ids = self.tokenizer.encode(input_str, add_special_tokens=False)
|
||
else:
|
||
raise ValueError("Get video data Failed at Process video chunk")
|
||
elif mm_label == 'text':
|
||
ret.input_ids = self.tokenizer.encode(text, add_special_tokens=False)
|
||
if len(ret.input_ids) > self.tokenizer.model_max_length-1: # 过滤长文本
|
||
raise ValueError(f"Text too long, please check text length! 【{text[:5]+'...'*6+text[-5:]}】")
|
||
else:
|
||
raise ValueError(f"mm_label not supportted! must in ['audio', 'image', 'text'] but get {mm_label}")
|
||
return ret
|
||
|
||
def process_one(self, text, index=0, raw_only=False):
|
||
ret = OmniProcessorOutput(index=index)
|
||
all_text_list = []
|
||
all_mm_label_list = []
|
||
all_trainable_flag_list = []
|
||
text_list, match_flag = split_text(text, re.compile("<trainable_start>.*?<trainable_end>",re.S))
|
||
if len(text_list) == 1:
|
||
text = re.sub(re.compile("<trainable_start>|<trainable_end>",re.S), '', text_list[0])
|
||
all_text_list.append(text)
|
||
all_mm_label_list.append('text')
|
||
all_trainable_flag_list.append(True)
|
||
else:
|
||
for text, match in zip(text_list, match_flag):
|
||
text = re.sub(re.compile("<trainable_start>|<trainable_end>",re.S), '', text)
|
||
if text.strip() == '':
|
||
continue # 把多余的空格干掉
|
||
all_text_list.append(text)
|
||
all_mm_label_list.append('text')
|
||
all_trainable_flag_list.append(match)
|
||
# 处理多模态信息
|
||
for mtype in self.config.multimodal: # 循环获取音频 图像结果
|
||
all_text_list, all_mm_label_list, all_trainable_flag_list = self.split_multimodal_chunk(all_text_list, all_mm_label_list, all_trainable_flag_list, mtype)
|
||
if len(all_text_list) == 0:
|
||
print(f"Process {text} chunk error: No valid Text data!!!!!")
|
||
return OmniProcessorOutput(index=index)
|
||
|
||
for text, mm_label, trainable in zip(all_text_list, all_mm_label_list, all_trainable_flag_list):
|
||
try:
|
||
mret = self.process_multimodal_chunk(text, mm_label, trainable)
|
||
ret = ret.concatenate(mret)
|
||
except ValueError as e:
|
||
tt = text[:24].replace('\n','<LF>')
|
||
print(f"Process {tt if mm_label == 'text' else text} {mm_label} chunk error: {str(e)}")
|
||
return OmniProcessorOutput(index=index)
|
||
|
||
if raw_only:
|
||
ret.raw_text = self.tokenizer.decode(ret.input_ids, skip_special_tokens=False)
|
||
return ret
|
||
return ret
|
||
|
||
@torch.no_grad()
|
||
def __call__(self, example, parallel=128):
|
||
if isinstance(example, Dict):
|
||
pass
|
||
elif isinstance(example, str):
|
||
return self.process_one(example)
|
||
elif isinstance(example, List): # batch推理 异步多线程处理
|
||
with cf.ThreadPoolExecutor(min(parallel, len(example))) as executor:
|
||
future_list = [executor.submit(self.process_one, di, idx) for idx, di in enumerate(example)]
|
||
batch_data = [key.result() for key in cf.as_completed(future_list)]
|
||
valid_num = sum([1 if x.input_ids is not None else 0 for x in batch_data])
|
||
assert(valid_num == len(batch_data)) # 推理数据严格要求数量对齐
|
||
batch_data = sorted(batch_data, key=lambda x: x.index) # 保证顺序不变
|
||
|
||
ret = OmniProcessorOutput()
|
||
for i in range(len(batch_data)):
|
||
ret = ret.concatenate(batch_data[i])
|
||
self.tokenizer.padding_side = "left"
|
||
max_len = min(max([len(x.input_ids) for x in batch_data]),self.tokenizer.model_max_length)
|
||
padding_result = self.tokenizer.pad({"input_ids": [r.input_ids for r in batch_data]}, return_tensors='pt')
|
||
ret.input_ids = padding_result["input_ids"]
|
||
ret.attention_mask = padding_result["attention_mask"] # batch推理不pack 不需要seqlens
|
||
|
||
if ret.audios is not None:
|
||
max_audios_len = max([x.shape[-1] for x in ret.audios])
|
||
ret.audios = default_collate([np.pad(x, ((0,0),(0,max_audios_len - x.shape[-1])), 'constant', constant_values=0) for x in ret.audios])
|
||
|
||
ret.encoder_length = default_collate(ret.encoder_length)
|
||
ret.bridge_length = default_collate(ret.bridge_length)
|
||
|
||
if ret.images is not None:
|
||
ret.images = [torch.from_numpy(np.asarray(image, dtype=np.float32)) for image in ret.images]
|
||
ret.patch_nums = default_collate(ret.patch_nums)
|
||
|
||
if ret.videos is not None:
|
||
ret.videos = [torch.from_numpy(np.asarray(image, dtype=np.float32)) for image in ret.videos]
|
||
ret.videos_patch_nums = default_collate(ret.videos_patch_nums)
|
||
|
||
return ret
|
||
|
||
else:
|
||
raise ValueError("example format supported yet")
|