588 lines
21 KiB
Python
588 lines
21 KiB
Python
# -*- encoding: utf-8 -*-
|
|
# File: processing_megrezo.py
|
|
# Description: None
|
|
|
|
import io
|
|
import re
|
|
import subprocess
|
|
from collections import UserDict
|
|
from typing import List, Literal, Optional, Tuple, Union
|
|
|
|
import numpy as np
|
|
import PIL
|
|
import PIL.Image
|
|
import torch
|
|
from torch.nn.utils.rnn import pad_sequence
|
|
from transformers import TensorType
|
|
from transformers.feature_extraction_utils import BatchFeature
|
|
from transformers.image_utils import ImageInput
|
|
from transformers.processing_utils import ProcessorMixin
|
|
|
|
from .image_processing_megrezo import MegrezOImageProcessor # noqa: F401
|
|
|
|
AudioInput = Union[str, bytes, "np.ndarray", List[str], List[bytes], List["np.ndarray"]]
|
|
ReturnTensorType = Union[str, TensorType]
|
|
|
|
|
|
class ImageBatchFeature(BatchFeature):
|
|
r"""
|
|
Holds the image features of a batch of images.
|
|
"""
|
|
|
|
pixel_values: Union[np.ndarray, torch.Tensor]
|
|
image_sizes: Union[np.ndarray, torch.Tensor]
|
|
tgt_sizes: Union[np.ndarray, torch.Tensor]
|
|
patch_attention_mask: Union[np.ndarray, torch.Tensor]
|
|
image_bounds: Union[np.ndarray, torch.Tensor]
|
|
|
|
|
|
class AudioBatchFeature(BatchFeature):
|
|
r"""
|
|
Holds the audio features of a batch of audio.
|
|
"""
|
|
|
|
input_audios: List[Union[np.ndarray, torch.Tensor]]
|
|
input_audio_lengths: List[Union[np.ndarray, torch.Tensor]]
|
|
audio_span_tokens: List[Union[np.ndarray, torch.Tensor]]
|
|
audio_bounds: Union[np.ndarray, torch.Tensor]
|
|
|
|
|
|
class ConvContent(UserDict):
|
|
text: Optional[str] = None
|
|
image: Optional[ImageInput] = None
|
|
audio: Optional[Union[str, bytes, List[Union[str, bytes]]]] = None
|
|
|
|
|
|
class Conversation(UserDict):
|
|
role: Literal["user", "assistant"]
|
|
content: Union[str, dict, ConvContent]
|
|
|
|
|
|
def load_audio(
|
|
audio: Union[str, bytes],
|
|
sample_rate: int = 16000,
|
|
) -> "np.ndarray":
|
|
"""Load audio from a file path or bytes and return as a numpy array.
|
|
|
|
Args:
|
|
audio (Union[str, bytes]): path to a audio file or audio bytes.
|
|
sample_rate (int, optional): sample rate. Defaults to 16000.
|
|
|
|
Raises:
|
|
ValueError: if the input audio is neither a path nor bytes.
|
|
|
|
Returns:
|
|
np.ndarray: the audio as a numpy array.
|
|
"""
|
|
if isinstance(audio, str):
|
|
inp = audio
|
|
out = "-"
|
|
cmd_inp = None
|
|
elif isinstance(audio, bytes):
|
|
inp = "pipe:"
|
|
out = "pipe:"
|
|
cmd_inp = audio
|
|
else:
|
|
raise ValueError("input audio must be either a path or bytes")
|
|
|
|
cmd = [
|
|
"ffmpeg",
|
|
"-nostdin",
|
|
"-threads",
|
|
"0",
|
|
"-i",
|
|
inp,
|
|
"-f",
|
|
"s16le",
|
|
"-ac",
|
|
"1",
|
|
"-acodec",
|
|
"pcm_s16le",
|
|
"-ar",
|
|
str(sample_rate),
|
|
out,
|
|
]
|
|
|
|
out = subprocess.check_output(cmd, input=cmd_inp, stderr=subprocess.PIPE)
|
|
arr = np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
|
|
return arr
|
|
|
|
|
|
def load_image(
|
|
image: Union[str, bytes, PIL.Image.Image],
|
|
) -> PIL.Image.Image:
|
|
"""Load image from a file path or bytes and return as a PIL image.
|
|
|
|
Args:
|
|
image (Union[str, bytes, PIL.Image.Image]): path to an image file, image bytes or a PIL image.
|
|
|
|
Raises:
|
|
ValueError: if the input image is neither a path nor bytes.
|
|
|
|
Returns:
|
|
PIL.Image.Image: the image as a PIL image.
|
|
"""
|
|
if isinstance(image, PIL.Image.Image):
|
|
return image
|
|
|
|
if isinstance(image, str):
|
|
img = PIL.Image.open(image)
|
|
elif isinstance(image, bytes):
|
|
img = PIL.Image.open(io.BytesIO(image))
|
|
else:
|
|
raise ValueError("input image must be either a path or bytes")
|
|
|
|
return img
|
|
|
|
|
|
class MegrezOProcessor(ProcessorMixin):
|
|
attributes = ["image_processor", "audio_feature_extractor", "tokenizer"]
|
|
image_processor_class = "AutoImageProcessor"
|
|
audio_feature_extractor_class = "WhisperFeatureExtractor"
|
|
tokenizer_class = "AutoTokenizer"
|
|
|
|
_image_placeholder = r"(<image>./</image>)"
|
|
_audio_placeholder = r"(<audio>./</audio>)"
|
|
|
|
def __init__(self, image_processor=None, audio_feature_extractor=None, tokenizer=None):
|
|
super().__init__(image_processor, audio_feature_extractor, tokenizer)
|
|
self.chat_template = self.tokenizer.chat_template
|
|
|
|
def _parse_and_check_inputs(self, inputs) -> List[Conversation]:
|
|
if not isinstance(inputs, list):
|
|
raise ValueError("inputs must be a list of conversations")
|
|
|
|
conversations = []
|
|
images = []
|
|
audios = []
|
|
|
|
for input in inputs:
|
|
if not isinstance(input, dict) and not isinstance(input, Conversation):
|
|
raise ValueError("each element of inputs must be a dictionary or a Conversation object")
|
|
|
|
role = input.get("role")
|
|
content = input.get("content")
|
|
if role is None or content is None:
|
|
raise ValueError("role and content must be provided in each conversation")
|
|
|
|
if isinstance(content, str):
|
|
content = content
|
|
elif isinstance(content, dict):
|
|
content = ConvContent({**content})
|
|
elif not isinstance(content, ConvContent):
|
|
raise ValueError("content must be a dictionary or a ConvContent object")
|
|
|
|
if not isinstance(content, str):
|
|
if content.get("image") is not None:
|
|
images.extend(content["image"] if isinstance(content["image"], list) else [content["image"]])
|
|
|
|
if content.get("audio") is not None:
|
|
audios.extend(content["audio"] if isinstance(content["audio"], list) else [content["audio"]])
|
|
|
|
conv = Conversation({"role": role, "content": content})
|
|
conversations.append(conv)
|
|
|
|
return conversations, images, audios
|
|
|
|
def __call__(
|
|
self,
|
|
conversations: List[Conversation],
|
|
apply_chat_template: bool = True,
|
|
max_length: Optional[int] = None,
|
|
return_tensors: ReturnTensorType = TensorType.PYTORCH,
|
|
apply_data_collator: bool = True,
|
|
**kwargs,
|
|
):
|
|
assert return_tensors is TensorType.PYTORCH, "Only PyTorch tensors are supported for now."
|
|
convs, images, audios = self._parse_and_check_inputs(conversations)
|
|
add_generation_prompt = kwargs.pop("add_generation_prompt", True)
|
|
if apply_chat_template:
|
|
prompt = self.tokenizer.apply_chat_template(
|
|
convs,
|
|
tokenize=False,
|
|
add_generation_prompt=add_generation_prompt,
|
|
)
|
|
else: # (TODO) For clarification temporarily. Check whether this needs to be removed.
|
|
prompt = "\n".join([conv["content"] for conv in convs])
|
|
|
|
prompt, multimodal_inputs = self.process_multimodal_inputs(
|
|
prompt,
|
|
images=images,
|
|
audios=audios,
|
|
return_tensors=return_tensors,
|
|
**kwargs,
|
|
)
|
|
text_encodings = self.tokenizer(
|
|
prompt,
|
|
return_tensors=return_tensors,
|
|
max_length=max_length,
|
|
padding=True,
|
|
padding_side="left",
|
|
truncation=True,
|
|
**kwargs,
|
|
)
|
|
|
|
merged = self.merge_encodings(text_encodings, multimodal_inputs)
|
|
|
|
if apply_data_collator:
|
|
return self.data_collator([merged])
|
|
|
|
return merged
|
|
|
|
def merge_encodings(self, text_encodings, multimodal_inputs):
|
|
|
|
result = {
|
|
"image_encoding": None,
|
|
"audio_encoding": None,
|
|
}
|
|
|
|
result["input_ids"] = text_encodings["input_ids"].reshape(-1).to(torch.int32)
|
|
result["attention_mask"] = result["input_ids"].ne(0)
|
|
result["position_ids"] = torch.arange(result["input_ids"].size(0)).long()
|
|
|
|
if "image_encoding" in multimodal_inputs and multimodal_inputs["image_encoding"]:
|
|
result["image_encoding"] = multimodal_inputs["image_encoding"]
|
|
result["image_encoding"]["image_bounds"] = self.compute_bounds_image(result["input_ids"])
|
|
|
|
if "audio_encoding" in multimodal_inputs and multimodal_inputs["audio_encoding"]:
|
|
result["audio_encoding"] = multimodal_inputs["audio_encoding"]
|
|
result["audio_encoding"]["audio_bounds"] = self.compute_bounds_audio(result["input_ids"])
|
|
|
|
return result
|
|
|
|
def compute_bounds_image(self, input_ids: torch.Tensor) -> List[torch.Tensor]:
|
|
image_start_ids = (
|
|
torch.where((input_ids == self.tokenizer.im_start_id) | (input_ids == self.tokenizer.slice_start_id))[0] + 1
|
|
)
|
|
image_end_ids = torch.where(
|
|
(input_ids == self.tokenizer.im_end_id) | (input_ids == self.tokenizer.slice_end_id)
|
|
)[0]
|
|
|
|
valid_image_nums = max(len(image_start_ids), len(image_end_ids))
|
|
bounds_image = torch.hstack(
|
|
[
|
|
image_start_ids[:valid_image_nums].unsqueeze(-1),
|
|
image_end_ids[:valid_image_nums].unsqueeze(-1),
|
|
]
|
|
)
|
|
return bounds_image
|
|
|
|
def compute_bounds_audio(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
audio_bos_ids = torch.where(input_ids == self.tokenizer.audio_start_id)[0]
|
|
audio_eos_ids = torch.where(input_ids == self.tokenizer.audio_end_id)[0]
|
|
bounds_audio = torch.stack([audio_bos_ids, audio_eos_ids], 1)
|
|
return bounds_audio
|
|
|
|
def process_multimodal_inputs(
|
|
self,
|
|
text: str,
|
|
images: Optional[ImageInput] = None,
|
|
audios: Optional[Union[str, bytes, List[Union[str, bytes]]]] = None,
|
|
return_tensors: ReturnTensorType = TensorType.PYTORCH,
|
|
**kwargs,
|
|
):
|
|
# (NOTE) Only single pair of multimodal input is allowed currently.
|
|
# (TODO) Check whether single multimodal input is allowed.
|
|
if text is None and images is None and audios is None:
|
|
raise ValueError("At least one of text, images or audio must be provided")
|
|
|
|
image_processor_kwargs, audio_feature_extractor_kwargs = {}, {}
|
|
if kwargs:
|
|
image_processor_kwargs = {
|
|
k: v for k, v in kwargs.items() if k in self.image_processor._valid_processor_keys
|
|
}
|
|
audio_feature_extractor_kwargs = {
|
|
k: v for k, v in kwargs.items() if k in self.audio_feature_extractor._valid_processor_keys
|
|
}
|
|
|
|
multimodal_encodings = {
|
|
"image_encoding": None,
|
|
"audio_encoding": None,
|
|
}
|
|
|
|
if images:
|
|
image_encoding = self.process_image(
|
|
images,
|
|
return_tensors=return_tensors,
|
|
**image_processor_kwargs,
|
|
)
|
|
text = self.insert_image_feature_placeholders(text, image_encoding)
|
|
multimodal_encodings["image_encoding"] = image_encoding
|
|
|
|
if audios:
|
|
audio_encoding = self.process_audio(
|
|
audios,
|
|
**audio_feature_extractor_kwargs,
|
|
)
|
|
text = self.insert_audio_feature_placeholders(text, audio_encoding)
|
|
multimodal_encodings["audio_encoding"] = audio_encoding
|
|
|
|
return text, multimodal_encodings
|
|
|
|
def insert_image_feature_placeholders(
|
|
self,
|
|
prompt: str,
|
|
image_features: ImageBatchFeature,
|
|
max_slice_nums: Optional[int] = None,
|
|
use_image_id: Optional[bool] = None,
|
|
) -> List[str]:
|
|
# Check the number of image tags and the number of images.
|
|
img_tags = re.findall(self._image_placeholder, prompt)
|
|
assert len(img_tags) == len(
|
|
image_features.image_sizes
|
|
), f"the number of image tags must match the number of images, got {len(img_tags)} and {len(image_features.image_sizes)}"
|
|
|
|
# Replace image tags with image placeholders.
|
|
text_chunks = prompt.split(self._image_placeholder)
|
|
final_text = ""
|
|
for i in range(len(img_tags)):
|
|
final_text += text_chunks[i] + self.image_processor.get_slice_image_placeholder(
|
|
image_features.image_sizes[i],
|
|
i,
|
|
max_slice_nums,
|
|
use_image_id,
|
|
)
|
|
final_text += text_chunks[-1]
|
|
|
|
return final_text
|
|
|
|
def insert_audio_feature_placeholders(
|
|
self,
|
|
prompt: str,
|
|
audio_features: AudioBatchFeature,
|
|
) -> List[str]:
|
|
# Check the number of audio tags and the number of audios.
|
|
audio_tags = re.findall(self._audio_placeholder, prompt)
|
|
assert len(audio_tags) == len(
|
|
audio_features.input_audios
|
|
), "the number of audio tags must match the number of audios"
|
|
|
|
# Replace audio tags with audio placeholders.
|
|
text_chunks = prompt.split(self._audio_placeholder)
|
|
final_text = ""
|
|
for idx in range(len(audio_features.input_audios)):
|
|
final_text += text_chunks[idx] + (
|
|
self.tokenizer.audio_start
|
|
+ self.tokenizer.unk_token * audio_features.audio_span_tokens[idx]
|
|
+ self.tokenizer.audio_end
|
|
)
|
|
final_text += text_chunks[-1]
|
|
|
|
return final_text
|
|
|
|
def process_audio(
|
|
self,
|
|
audio_input: AudioInput,
|
|
return_tensors: ReturnTensorType = TensorType.PYTORCH,
|
|
**kwargs,
|
|
) -> AudioBatchFeature:
|
|
if isinstance(audio_input, list):
|
|
inputs = [load_audio(x) for x in audio_input]
|
|
elif isinstance(audio_input, (str, bytes, "np.ndarray")):
|
|
inputs = [load_audio(audio_input)]
|
|
else:
|
|
raise ValueError("audio_input must be a path or bytes or a list of paths/bytes")
|
|
|
|
features = self.audio_feature_extractor(
|
|
inputs,
|
|
sampling_rate=self.audio_feature_extractor.sampling_rate,
|
|
return_attention_mask=True,
|
|
return_token_timestamps=True,
|
|
padding="max_length",
|
|
return_tensors=return_tensors,
|
|
**kwargs,
|
|
)
|
|
|
|
input_lengths = features["num_frames"]
|
|
input_lengths = (input_lengths - 1) // 2 + 1
|
|
output_lengths = (input_lengths - 2) // 2 + 1
|
|
input_audio_lengths = torch.stack([input_lengths, output_lengths], dim=1)
|
|
audio_span_tokens = (output_lengths + 2).tolist() # add bos and eos tokens
|
|
|
|
data = {
|
|
"input_audios": features["input_features"],
|
|
"input_audio_lengths": input_audio_lengths,
|
|
"audio_span_tokens": audio_span_tokens,
|
|
}
|
|
|
|
# tensor types are already converted in `self.audio_feature_extractor`.
|
|
return AudioBatchFeature(data={**data})
|
|
|
|
def pad_images(
|
|
self,
|
|
pixel_values_list: List[torch.Tensor],
|
|
tgt_sizes: torch.Tensor,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""Pad images to the same size and return the padded pixel values and patch attention mask.
|
|
|
|
Sliced pataches may have different sizes. We pad them to the same size and return the padded pixel values and corresponding patch attention mask.
|
|
"""
|
|
|
|
all_pixel_values = []
|
|
for pixel_value in pixel_values_list:
|
|
all_pixel_values.append(pixel_value.flatten(end_dim=1).permute(1, 0))
|
|
|
|
max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1])
|
|
all_pixel_values = torch.nn.utils.rnn.pad_sequence(all_pixel_values, batch_first=True, padding_value=0.0)
|
|
B, L, _ = all_pixel_values.shape
|
|
all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
|
|
|
|
patch_attention_mask = torch.zeros((B, 1, max_patches), dtype=torch.bool)
|
|
for i in range(B):
|
|
patch_attention_mask[i, 0, : tgt_sizes[i][0] * tgt_sizes[i][1]] = True
|
|
|
|
return all_pixel_values, patch_attention_mask
|
|
|
|
def process_image(
|
|
self,
|
|
image_input: ImageInput,
|
|
do_pad: bool = True,
|
|
max_slice_nums: Optional[int] = None,
|
|
return_tensors: ReturnTensorType = TensorType.PYTORCH,
|
|
**kwargs,
|
|
) -> ImageBatchFeature:
|
|
if isinstance(image_input, list):
|
|
image_input = [load_image(x) for x in image_input]
|
|
elif isinstance(image_input, (str, bytes, PIL.Image.Image)):
|
|
image_input = [load_image(image_input)]
|
|
else:
|
|
raise ValueError(f"image_input must be a path or bytes or a list of paths/bytes, not: {type(image_input)}")
|
|
|
|
image_features = self.image_processor(
|
|
image_input,
|
|
do_pad=do_pad,
|
|
max_slice_nums=max_slice_nums,
|
|
return_tensors=return_tensors,
|
|
**kwargs,
|
|
)
|
|
|
|
# Multiple images are packed into first element of the list. We unpack them here.
|
|
assert len(image_features.pixel_values) == 1, "images should be packed into one list."
|
|
pixel_values = image_features.pixel_values[0]
|
|
tgt_sizes = image_features.tgt_sizes[0]
|
|
image_sizes = image_features.image_sizes[0]
|
|
|
|
pixel_values, patch_attention_mask = self.pad_images(pixel_values, tgt_sizes)
|
|
|
|
data = {
|
|
"pixel_values": pixel_values,
|
|
"image_sizes": image_sizes,
|
|
"tgt_sizes": tgt_sizes,
|
|
"patch_attention_mask": patch_attention_mask,
|
|
}
|
|
|
|
# tensor types are already converted in `self.image_processor`.
|
|
return ImageBatchFeature(data=data)
|
|
|
|
def data_collator(self, examples, padding_value=0, max_length=4096, collate_labels=False):
|
|
"""Collate data for MegrezO model.
|
|
|
|
Batch data for MegrezO model. This function trims and pads the input_ids, position_ids, and attention_mask tensors. For bounds tensors, it adds batch index to the bounds.
|
|
"""
|
|
# (TODO) Remove this function?
|
|
|
|
def trim_and_pad(seq, batch_first, padding_value):
|
|
return pad_sequence(
|
|
[s[:max_length] for s in seq],
|
|
batch_first=True,
|
|
padding_value=padding_value,
|
|
)
|
|
|
|
input_ids = trim_and_pad(
|
|
[example["input_ids"] for example in examples],
|
|
batch_first=True,
|
|
padding_value=padding_value,
|
|
)
|
|
position_ids = trim_and_pad(
|
|
[example["position_ids"] for example in examples],
|
|
batch_first=True,
|
|
padding_value=padding_value,
|
|
)
|
|
|
|
attention_mask = trim_and_pad(
|
|
[example["attention_mask"] for example in examples],
|
|
batch_first=True,
|
|
padding_value=padding_value,
|
|
)
|
|
|
|
image_encoding_list = {
|
|
"pixel_values": [],
|
|
"image_bounds": [],
|
|
"tgt_sizes": [],
|
|
"patch_attention_mask": [],
|
|
}
|
|
for bid, example in enumerate(examples):
|
|
image_encoding = example.get("image_encoding")
|
|
if not image_encoding:
|
|
continue
|
|
|
|
image_encoding_list["pixel_values"].append(image_encoding["pixel_values"])
|
|
image_encoding_list["tgt_sizes"].append(image_encoding["tgt_sizes"])
|
|
image_encoding_list["patch_attention_mask"].append(image_encoding["patch_attention_mask"])
|
|
|
|
# (TODO) Remove?
|
|
# add batch index to bounds (bid, start, end)
|
|
bounds_with_bid = image_encoding["image_bounds"].clone()
|
|
bounds_with_bid = torch.hstack(
|
|
[
|
|
torch.full((bounds_with_bid.size(0), 1), bid, dtype=bounds_with_bid.dtype),
|
|
bounds_with_bid,
|
|
]
|
|
)
|
|
image_encoding_list["image_bounds"].append(bounds_with_bid)
|
|
|
|
audio_encoding_list = {
|
|
"input_audios": [],
|
|
"input_audio_lengths": [],
|
|
"audio_span_tokens": [],
|
|
"audio_bounds": [],
|
|
}
|
|
for bid, example in enumerate(examples):
|
|
audio_encoding = example.get("audio_encoding")
|
|
if not audio_encoding:
|
|
continue
|
|
|
|
audio_encoding_list["input_audios"].append(audio_encoding["input_audios"])
|
|
audio_encoding_list["input_audio_lengths"].append(audio_encoding["input_audio_lengths"])
|
|
audio_encoding_list["audio_span_tokens"].extend(audio_encoding["audio_span_tokens"])
|
|
bounds_with_bid = audio_encoding["audio_bounds"].clone()
|
|
bounds_with_bid = torch.hstack(
|
|
[
|
|
torch.full((bounds_with_bid.size(0), 1), bid, dtype=bounds_with_bid.dtype),
|
|
bounds_with_bid,
|
|
]
|
|
)
|
|
audio_encoding_list["audio_bounds"].append(bounds_with_bid)
|
|
|
|
result = {
|
|
"input_ids": input_ids,
|
|
"position_ids": position_ids,
|
|
"attention_mask": attention_mask,
|
|
"image_encoding": None,
|
|
"audio_encoding": None,
|
|
}
|
|
|
|
if collate_labels:
|
|
labels = trim_and_pad(
|
|
[example["labels"] for example in examples],
|
|
batch_first=True,
|
|
padding_value=-100,
|
|
)
|
|
result["labels"] = labels
|
|
|
|
if any(image_encoding_list.values()):
|
|
result["image_encoding"] = {
|
|
"pixel_values": torch.vstack(image_encoding_list["pixel_values"]),
|
|
"tgt_sizes": torch.vstack(image_encoding_list["tgt_sizes"]),
|
|
"patch_attention_mask": torch.vstack(image_encoding_list["patch_attention_mask"]),
|
|
"image_bounds": torch.vstack(image_encoding_list["image_bounds"]),
|
|
}
|
|
if any(audio_encoding_list.values()):
|
|
result["audio_encoding"] = {
|
|
"input_audios": torch.vstack(audio_encoding_list["input_audios"]),
|
|
"input_audio_lengths": torch.vstack(audio_encoding_list["input_audio_lengths"]),
|
|
"audio_span_tokens": audio_encoding_list["audio_span_tokens"],
|
|
"audio_bounds": torch.vstack(audio_encoding_list["audio_bounds"]),
|
|
}
|
|
return result
|