213 lines
9.1 KiB
Python
213 lines
9.1 KiB
Python
from PIL import ImageOps
|
|
from PIL.Image import Image
|
|
|
|
import torch
|
|
|
|
from typing import Union, List
|
|
from tqdm import tqdm
|
|
|
|
from transformers.image_utils import ImageInput
|
|
from transformers.tokenization_utils_base import TextInput
|
|
from transformers import CLIPImageProcessor
|
|
from transformers.processing_utils import (
|
|
ProcessorMixin,
|
|
)
|
|
from transformers import AutoTokenizer, PreTrainedTokenizer
|
|
|
|
from .image_processing_instellavl import InstellaVLImageProcessor
|
|
from .mm_utils import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, KeywordsStoppingCriteria
|
|
from .conversation import conv_templates
|
|
|
|
def tokenizer_image_token(prompt: str, tokenizer: PreTrainedTokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None)->Union[torch.Tensor, List[torch.Tensor]]:
|
|
r"""
|
|
Tokenizes a prompt containing image tokens and inserts the specified image token index at the appropriate positions.
|
|
|
|
Args:
|
|
- prompt (str): The input prompt string containing text and DEFAULT_IMAGE_TOKEN="<image>" placeholders.
|
|
- tokenizer (PreTrainedTokenizer): The tokenizer to use for tokenizing the text chunks.
|
|
- image_token_index (int): The token index to use for the image placeholders. Default is IMAGE_TOKEN_INDEX.
|
|
- return_tensors (str, optional): The type of tensor to return. If "pt", returns a PyTorch tensor. Default is None.
|
|
|
|
Returns:
|
|
list or torch.Tensor: The tokenized input IDs as a list or a PyTorch tensor if return_tensors is specified.
|
|
"""
|
|
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split(DEFAULT_IMAGE_TOKEN)]
|
|
|
|
def insert_separator(X, sep):
|
|
return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
|
|
|
|
input_ids = []
|
|
offset = 0
|
|
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
|
offset = 1
|
|
input_ids.append(prompt_chunks[0][0])
|
|
|
|
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
|
|
input_ids.extend(x[offset:])
|
|
|
|
if return_tensors is not None:
|
|
if return_tensors == "pt":
|
|
return torch.tensor(input_ids, dtype=torch.long)
|
|
raise ValueError(f"Unsupported tensor type: {return_tensors}")
|
|
return input_ids
|
|
|
|
|
|
class InstellaVLProcessor(ProcessorMixin):
|
|
attributes = ["image_processor", "tokenizer"]
|
|
image_processor_class = "AutoImageProcessor"
|
|
tokenizer_class = ("GPTNeoXTokenizerFast")
|
|
|
|
def __init__(self, image_processor: InstellaVLImageProcessor = None, tokenizer: AutoTokenizer = None, **kwargs):
|
|
super().__init__(image_processor, tokenizer, **kwargs)
|
|
|
|
def pad_sequence(self, input_ids: Union[List[torch.Tensor], List[List[torch.Tensor]]], batch_first: bool, padding_value: int, tokenizer: AutoTokenizer):
|
|
if tokenizer.padding_side == "left":
|
|
input_ids = [torch.flip(_input_ids, [0]) for _input_ids in input_ids]
|
|
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=batch_first, padding_value=padding_value)
|
|
if tokenizer.padding_side == "left":
|
|
input_ids = torch.flip(input_ids, [1])
|
|
return input_ids
|
|
|
|
def encode(self,
|
|
text: TextInput = None,
|
|
images: ImageInput = None,
|
|
image_processor: CLIPImageProcessor = None,
|
|
tokenizer: AutoTokenizer = None,
|
|
model_cfg: dict = None,
|
|
) -> dict:
|
|
|
|
if images is not None:
|
|
if isinstance(images, Image):
|
|
# Handle images with EXIF orientation tags, which PIL will ignore by default
|
|
# https://github.com/python-pillow/Pillow/issues/4703
|
|
ImageOps.exif_transpose(images, in_place=True)
|
|
image_sizes = [images.size]
|
|
images = [images]
|
|
elif isinstance(images, list):
|
|
image_sizes = []
|
|
for i in images:
|
|
ImageOps.exif_transpose(i, in_place=True)
|
|
image_sizes.append(i.size)
|
|
image_tensor = self.image_processor.process(images, image_processor, model_cfg)['pixel_values']
|
|
|
|
text = text.replace(DEFAULT_IMAGE_TOKEN, "").strip()
|
|
if images is not None and len(image_tensor) != 0 and DEFAULT_IMAGE_TOKEN not in text:
|
|
question = DEFAULT_IMAGE_TOKEN + "\n" + text
|
|
else:
|
|
question = text
|
|
conv = conv_templates["instella"].copy()
|
|
conv.append_message(conv.roles[0], question)
|
|
conv.append_message(conv.roles[1], None)
|
|
prompt_question = conv.get_prompt()
|
|
|
|
|
|
input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0)
|
|
keywords = [conv.sep]
|
|
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
|
terminators = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("|||IP_ADDRESS|||")]
|
|
|
|
out = {
|
|
"input_ids": input_ids,
|
|
"stopping_criteria": [stopping_criteria],
|
|
"eos_token_id": terminators,
|
|
}
|
|
if images is not None:
|
|
out = {
|
|
"image_tensor": image_tensor,
|
|
"image_sizes": image_sizes,
|
|
**out,
|
|
}
|
|
self.tokenizer = tokenizer
|
|
return out
|
|
|
|
def batch_encode(self,
|
|
texts: List[TextInput] = None,
|
|
images: List[ImageInput] = None,
|
|
image_processor: CLIPImageProcessor = None,
|
|
tokenizer: AutoTokenizer = None,
|
|
model_cfg: dict = None,
|
|
):
|
|
|
|
if texts is None:
|
|
raise ValueError("Text must be provided for batch encoding.")
|
|
|
|
if images is None:
|
|
images = [None] * len(text)
|
|
|
|
assert isinstance(texts, list), "Since batch encoding happening, provide batch of texts in a list."
|
|
|
|
assert len(texts) == len(images), "The number of texts and images must be equal."
|
|
|
|
batch_outs = []
|
|
for txt, img in tqdm(zip(texts, images), total=len(texts), desc="Total Samples to encode"):
|
|
batch_outs.append(self.encode(txt, img, image_processor, tokenizer, model_cfg))
|
|
|
|
return batch_outs
|
|
# batched_image_tensors = []
|
|
# batched_text_tokens = []
|
|
# stopping_criterias = []
|
|
# image_sizes = []
|
|
# for t, img in tqdm(zip(text, images), desc="Total Samples to encode"):
|
|
# if img is not None:
|
|
# if isinstance(img, Image):
|
|
# ImageOps.exif_transpose(img, in_place=True)
|
|
# image_sizes.append(img.size)
|
|
# img = [img]
|
|
|
|
# elif isinstance(img, list):
|
|
# tmp_img_sizes = []
|
|
# for i in img:
|
|
# ImageOps.exif_transpose(i, in_place=True)
|
|
# tmp_img_sizes.append(i.size)
|
|
# image_sizes.append(tmp_img_sizes)
|
|
# batched_image_tensors.append(self.image_processor.process(img, image_processor, model_cfg)['pixel_values'].squeeze(0))
|
|
|
|
# t = t.replace(DEFAULT_IMAGE_TOKEN, "").strip()
|
|
# if img is not None and len(batched_image_tensors[-1]) != 0 and DEFAULT_IMAGE_TOKEN not in t:
|
|
# question = DEFAULT_IMAGE_TOKEN + "\n" + t
|
|
# else:
|
|
# question = t
|
|
# conv = conv_templates["instella"].copy()
|
|
# conv.append_message(conv.roles[0], question)
|
|
# conv.append_message(conv.roles[1], None)
|
|
# prompt_question = conv.get_prompt()
|
|
|
|
# input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
|
|
# stopping_criterias.append(KeywordsStoppingCriteria([conv.sep], tokenizer, input_ids.unsqueeze(0)))
|
|
# batched_text_tokens.append(input_ids)
|
|
# terminators = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("|||IP_ADDRESS|||")]
|
|
|
|
# # Pad the text tokens.
|
|
# pad_token_ids = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
|
|
# input_ids = self.pad_sequence(batched_text_tokens, batch_first=True, padding_value=pad_token_ids, tokenizer=tokenizer)
|
|
# attention_masks = input_ids.ne(pad_token_ids)
|
|
# batch_outs = {
|
|
# "input_ids": input_ids,
|
|
# "attention_mask": attention_masks,
|
|
# "pad_token_id": pad_token_ids,
|
|
# "stopping_criteria": stopping_criterias,
|
|
# "eos_token_id": terminators,
|
|
# }
|
|
# if images is not None:
|
|
# batch_outs = {
|
|
# "image_tensor": batched_image_tensors,
|
|
# "image_sizes": image_sizes,
|
|
# **batch_outs
|
|
# }
|
|
# self.tokenizer = tokenizer
|
|
# return batch_outs
|
|
|
|
def decode(self, output_ids: torch.Tensor)->str:
|
|
return self.tokenizer.decode(output_ids[0, :], skip_special_tokens=True).strip()
|
|
|
|
def batch_decode(self, output_ids_lst: List[torch.Tensor])->List[str]:
|
|
raise NotImplementedError("Batch decode is not implemented for InstellaVLProcessor")
|
|
# text_decoded_outs = []
|
|
# for out_ids in output_ids_lst:
|
|
# text_decoded_outs.append(self.decode(out_ids))
|
|
# return text_decoded_outs
|
|
|
|
|
|
|
|
InstellaVLProcessor.register_for_auto_class()
|