31 lines
910 B
Python
31 lines
910 B
Python
|
from typing import List
|
||
|
from PIL.Image import Image
|
||
|
from transformers import CLIPImageProcessor
|
||
|
from transformers.image_processing_utils import BaseImageProcessor
|
||
|
from .mm_utils import process_images
|
||
|
|
||
|
# TODO can inherit from CLIPImageProcessor instead and use the process function directly.
|
||
|
class InstellaVLImageProcessor(BaseImageProcessor):
|
||
|
r"""
|
||
|
Pre-process images
|
||
|
"""
|
||
|
def __init__(self, **kwargs):
|
||
|
super().__init__(**kwargs)
|
||
|
|
||
|
def process(self,
|
||
|
images: List[Image],
|
||
|
processor: CLIPImageProcessor,
|
||
|
model_cfg: dict
|
||
|
):
|
||
|
image_tensors = process_images(images, processor, model_cfg)
|
||
|
if images is None:
|
||
|
return {
|
||
|
"pixel_values": None,
|
||
|
}
|
||
|
else:
|
||
|
return{
|
||
|
"pixel_values": image_tensors,
|
||
|
}
|
||
|
|
||
|
InstellaVLImageProcessor.register_for_auto_class()
|