first commit

This commit is contained in:
xxl 2024-12-26 10:13:27 +08:00
parent 6ac9748b59
commit 707e9fd50f
22 changed files with 3186 additions and 2 deletions

View File

@ -1,3 +1,67 @@
# DocOwl2_a14065830742978560965625
---
frameworks:
- Pytorch
license: Apache License 2.0
tasks:
- document-understanding
---
DocOwl2
# mPLUG-DocOwl2
## Introduction
mPLUG-DocOwl2 is a state-of-the-art Multimodal LLM for OCR-free Multi-page Document Understanding.
Through a compressing module named High-resolution DocCompressor, each page is encoded with just 324 tokens.
Github: [mPLUG-DocOwl](https://github.com/X-PLUG/mPLUG-DocOwl)
SDK下载
```bash
#安装ModelScope
pip install modelscope
```
```python
#SDK模型下载
from modelscope import snapshot_download
model_dir = snapshot_download('iic/DocOwl2')
```
Git下载
```
#Git模型下载
git clone https://www.modelscope.cn/iic/DocOwl2.git
```
## Quickstart
```python
import torch
import os
from modelscope import AutoTokenizer, AutoModel
from icecream import ic
import time
class DocOwlInfer():
def __init__(self, ckpt_path):
self.tokenizer = AutoTokenizer.from_pretrained(ckpt_path, use_fast=False)
self.model = AutoModel.from_pretrained(ckpt_path, trust_remote_code=True, low_cpu_mem_usage=True, torch_dtype=torch.float16, device_map='auto')
self.model.init_processor(tokenizer=self.tokenizer, basic_image_size=504, crop_anchors='grid_12')
def inference(self, images, query):
messages = [{'role': 'USER', 'content': '<|image|>'*len(images)+query}]
answer = self.model.chat(messages=messages, images=images, tokenizer=self.tokenizer)
return answer
docowl = DocOwlInfer(ckpt_path='$your_model_local_dir')
images = [
'$your_model_local_dir'+'/examples/docowl2_page0.png',
'$your_model_local_dir'+'/examples/docowl2_page1.png',
'$your_model_local_dir'+'/examples/docowl2_page2.png',
'$your_model_local_dir'+'/examples/docowl2_page3.png',
'$your_model_local_dir'+'/examples/docowl2_page4.png',
'$your_model_local_dir'+'/examples/docowl2_page5.png',
]
answer = docowl.inference(images, query='what is this paper about? provide detailed information.')
answer = docowl.inference(images, query='what is the third page about? provide detailed information.')
```

59
config.json Normal file
View File

@ -0,0 +1,59 @@
{
"architectures": [
"mPLUGDocOwl2"
],
"auto_map": {
"AutoConfig": "configuration_mplug_docowl.MPLUGDocOwlConfig",
"AutoModel": "modeling_mplug_docowl.MPLUGDocOwl2",
"AutoModelForCausalLM": "modeling_mplug_docowl.MPLUGDocOwl2"
},
"attention_bias": false,
"bos_token_id": 1,
"eos_token_id": 2,
"hidden_act": "silu",
"hidden_size": 4096,
"initializer_range": 0.02,
"intermediate_size": 11008,
"max_position_embeddings": 2048,
"model_type": "mplug_docowl",
"num_attention_heads": 32,
"num_hidden_layers": 32,
"num_key_value_heads": 32,
"pretraining_tp": 1,
"rms_norm_eps": 1e-06,
"rope_scaling": null,
"rope_theta": 10000.0,
"tie_word_embeddings": false,
"transformers_version": "4.39.3",
"use_cache": true,
"visual_config": {
"visual_hrcompressor": {
"layer": 2,
"high_reso_cross_num_att_heads": 16,
"high_reso_cross_hid_size": 4096,
"high_reso_cross_dropout": 0.0
},
"visual_hreducer": {
"conv_shape": "1x4",
"hidden_size": 1024
},
"visual_model": {
"attention_dropout": 0.0,
"hidden_act": "quick_gelu",
"hidden_size": 1024,
"image_size": 504,
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 4096,
"layer_norm_eps": 1e-06,
"model_type": "mplug_owl_vision_model",
"num_attention_heads": 16,
"num_channels": 3,
"num_hidden_layers": 24,
"patch_size": 14,
"projection_dim": 768,
"use_flash_attn": false
}
},
"vocab_size": 32000
}

1
configuration.json Normal file
View File

@ -0,0 +1 @@
{"framework":"Pytorch","task":"document-understanding"}

View File

@ -0,0 +1,358 @@
# Copyright (c) Alibaba.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import copy
import os
from typing import Union
from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from transformers.utils import logging
from transformers.models.auto import CONFIG_MAPPING
class LlamaConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the LLaMA-7B.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 32000):
Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`LlamaModel`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 11008):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens,
Llama 2 up to 4096, CodeLlama up to 16384.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*):
Padding token id.
bos_token_id (`int`, *optional*, defaults to 1):
Beginning of stream token id.
eos_token_id (`int`, *optional*, defaults to 2):
End of stream token id.
pretraining_tp (`int`, *optional*, defaults to 1):
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
issue](https://github.com/pytorch/pytorch/issues/76232).
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
these scaling strategies behave:
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
experimental feature, subject to breaking API changes in future versions.
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
```python
>>> from transformers import LlamaModel, LlamaConfig
>>> # Initializing a LLaMA llama-7b style configuration
>>> configuration = LlamaConfig()
>>> # Initializing a model from the llama-7b style configuration
>>> model = LlamaModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "llama"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
pretraining_tp=1,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self._rope_scaling_validation()
self.attention_bias = attention_bias
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
raise ValueError(
"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
f"got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
raise ValueError(
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
)
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
class MplugOwlVisionConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`MplugOwlVisionModel`]. It is used to instantiate
a
mPLUG-Owl vision encoder according to the specified arguments, defining the model architecture. Instantiating a
configuration defaults will yield a similar configuration to that of the mPLUG-Owl
[x-plug/x_plug-llama-7b](https://huggingface.co/x-plug/x_plug-llama-7b) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
intermediate_size (`int`, *optional*, defaults to 3072):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
num_hidden_layers (`int`, *optional*, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
image_size (`int`, *optional*, defaults to 224):
The size (resolution) of each image.
patch_size (`int`, *optional*, defaults to 32):
The size (resolution) of each patch.
hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
layer_norm_eps (`float`, *optional*, defaults to 1e-5):
The epsilon used by the layer normalization layers.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
initializer_factor (`float`, *optional*, defaults to 1):
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
testing).
```"""
model_type = "mplug_owl_vision_model"
def __init__(
self,
hidden_size=1024,
intermediate_size=4096,
projection_dim=768,
num_hidden_layers=24,
num_attention_heads=16,
num_channels=3,
image_size=448,
patch_size=14,
hidden_act="quick_gelu",
layer_norm_eps=1e-6,
attention_dropout=0.0,
initializer_range=0.02,
initializer_factor=1.0,
use_flash_attn=False,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.projection_dim = projection_dim
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_channels = num_channels
self.patch_size = patch_size
self.image_size = image_size
self.initializer_range = initializer_range
self.initializer_factor = initializer_factor
self.attention_dropout = attention_dropout
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
self.use_flash_attn = use_flash_attn
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# get the vision config dict if we are loading from MplugOwlConfig
if config_dict.get("model_type") == "mplug-owl":
config_dict = config_dict["vision_config"]
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, **kwargs)
class MplugDocOwlHReducerConfig(PretrainedConfig):
model_type = "mplug_docowl_hreducer"
def __init__(
self,
hidden_size=1024,
initializer_range=0.02,
layer_norm_eps=1e-6,
conv_shape='1x4',
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.conv_shape = conv_shape
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# get the visual_abstractor config dict if we are loading from MplugOwlConfig
if config_dict.get("model_type") == "mplug-docowl":
config_dict = config_dict["hreducer_config"]
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, **kwargs)
class MplugDocOwlHRDocCompressorConfig(PretrainedConfig):
model_type = "mplug_docowl_hrcompressor"
def __init__(
self,
initializer_range=0.02,
layer_norm_eps=1e-6,
layer=2,
high_reso_cross_num_att_heads=16,
high_reso_cross_hid_size=4096,
high_reso_cross_dropout=0.0,
**kwargs,
):
super().__init__(**kwargs)
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.layer = layer
self.high_reso_cross_num_att_heads=high_reso_cross_num_att_heads
self.high_reso_cross_hid_size=high_reso_cross_hid_size
self.high_reso_cross_dropout=high_reso_cross_dropout
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# get the visual_abstractor config dict if we are loading from MplugOwlConfig
if config_dict.get("model_type") == "mplug-docowl":
config_dict = config_dict["hrcompressor_config"]
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, **kwargs)
DEFAULT_VISUAL_CONFIG = {
"visual_model": MplugOwlVisionConfig().to_dict(),
"visual_hreducer": MplugDocOwlHReducerConfig().to_dict(),
"visual_hrcompressor": MplugDocOwlHRDocCompressorConfig().to_dict()
}
class MPLUGDocOwlConfig(LlamaConfig):
model_type = "mplug_docowl"
def __init__(self, visual_config=None, **kwargs):
if visual_config is None:
self.visual_config = DEFAULT_VISUAL_CONFIG
else:
self.visual_config = visual_config
super().__init__(
**kwargs,
)
if __name__ == "__main__":
print(MplugOwlVisionConfig().to_dict())

9
constants.py Normal file
View File

@ -0,0 +1,9 @@
CONTROLLER_HEART_BEAT_EXPIRATION = 30
WORKER_HEART_BEAT_INTERVAL = 15
LOGDIR = "./demo_logs"
# Model Constants
IGNORE_INDEX = -100
IMAGE_TOKEN_INDEX = -200
DEFAULT_IMAGE_TOKEN = "<|image|>"

BIN
examples/docowl2_page0.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 95 KiB

BIN
examples/docowl2_page1.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 65 KiB

BIN
examples/docowl2_page2.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 53 KiB

BIN
examples/docowl2_page3.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 62 KiB

BIN
examples/docowl2_page4.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 58 KiB

BIN
examples/docowl2_page5.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 52 KiB

9
generation_config.json Normal file
View File

@ -0,0 +1,9 @@
{
"bos_token_id": 1,
"eos_token_id": 2,
"max_length": 4096,
"pad_token_id": 0,
"temperature": 0.9,
"top_p": 0.6,
"transformers_version": "4.31.0"
}

BIN
model.safetensors (Stored with Git LFS) Normal file

Binary file not shown.

1048
modeling_llama2_mam.py Normal file

File diff suppressed because it is too large Load Diff

398
modeling_mplug_docowl.py Normal file
View File

@ -0,0 +1,398 @@
# Copyright 2023 Haotian Liu & Qinghao Ye (Modified from LLaVA)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from transformers import AutoConfig, AutoModelForCausalLM
from .modeling_llama2_mam import LlamaConfig, LlamaModel, LlamaForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration_mplug_docowl import (MPLUGDocOwlConfig, MplugOwlVisionConfig, MplugDocOwlHReducerConfig, MplugDocOwlHRDocCompressorConfig)
from .visual_encoder import MplugOwlVisionModel, MplugDocOwlHReducerModel
from .visual_compressor import MplugDocOwlHRDocCompressor
from .processor import DocProcessor
from .constants import IMAGE_TOKEN_INDEX, IGNORE_INDEX
from icecream import ic
from transformers import StoppingCriteria, TextStreamer
class KeywordsStoppingCriteria(StoppingCriteria):
def __init__(self, keywords, tokenizer, input_ids):
self.keywords = keywords
self.keyword_ids = []
self.max_keyword_len = 0
for keyword in keywords:
cur_keyword_ids = tokenizer(keyword).input_ids
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
cur_keyword_ids = cur_keyword_ids[1:]
if len(cur_keyword_ids) > self.max_keyword_len:
self.max_keyword_len = len(cur_keyword_ids)
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
self.tokenizer = tokenizer
self.start_len = input_ids.shape[1]
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
for keyword_id in self.keyword_ids:
if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
return True
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
for keyword in self.keywords:
if keyword in outputs:
return True
return False
class MPLUGDocOwlMetaModel:
_no_split_modules = ["MplugOwlVisionModel", "MplugDocOwlHReducerModel", "MplugDocOwlHRDocCompressor"]
def __init__(self, config):
super(MPLUGDocOwlMetaModel, self).__init__(config)
self.vision_model = MplugOwlVisionModel(
MplugOwlVisionConfig(**config.visual_config["visual_model"])
)
v_img_row_tokens = int((config.visual_config["visual_model"]['image_size']/config.visual_config["visual_model"]['patch_size']))
v_img_col_tokens = v_img_row_tokens
self.vision2text = MplugDocOwlHReducerModel(
MplugDocOwlHReducerConfig(**config.visual_config["visual_hreducer"]), config.hidden_size
)
horizontal_reduce = int(config.visual_config["visual_hreducer"]['conv_shape'].split('x')[1])
v2t_img_col_tokens = int(v_img_row_tokens / horizontal_reduce)
self.hr_compressor = MplugDocOwlHRDocCompressor(
MplugDocOwlHRDocCompressorConfig(**config.visual_config["visual_hrcompressor"]),
config.hidden_size,
v2t_img_col_tokens
)
def get_vision_tower(self):
vision_model = getattr(self, 'vision_model', None)
if type(vision_model) is list:
vision_model = vision_model[0]
return vision_model
def get_vision2text(self):
vision2text = getattr(self, 'vision2text', None)
if type(vision2text) is list:
vision2text = vision2text[0]
return vision2text
def get_hrcompressor(self):
hrcompressor = getattr(self, 'hr_compressor', None)
if type(hrcompressor) is list:
hrcompressor = hrcompressor[0]
return hrcompressor
class MPLUGDocOwlMetaForCausalLM(ABC):
@abstractmethod
def get_model(self):
pass
def encode_images(self, images, patch_positions):
image_features = self.get_model().vision_model(images).last_hidden_state
image_features = self.get_model().vision2text(encoder_hidden_states=image_features)
image_features = self.get_model().hr_compressor(hidden_states=image_features, patch_positions=patch_positions)
return image_features
def prepare_inputs_labels_for_multimodal(
self, input_ids, attention_mask, past_key_values, labels, images, patch_positions
):
# ic(images.shape, patch_positions.shape)
if images is None or input_ids.shape[1] == 1:
if past_key_values is not None and images is not None and input_ids.shape[1] == 1:
attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device)
multiway_indices = torch.zeros_like(input_ids).long().to(self.device)
return input_ids, multiway_indices, attention_mask, past_key_values, None, labels
if type(images) is list or images.ndim == 5:
concat_images = torch.cat([image for image in images], dim=0)
image_features = self.encode_images(concat_images, patch_positions)
split_sizes = [image.shape[0] for image in images]
image_features = torch.split(image_features, split_sizes, dim=0)
image_features = [x.flatten(0, 1) for x in image_features]
else:
image_features = self.encode_images(images, patch_positions) # Sum(Crop Image Number) x L x d
new_input_embeds = []
new_modality_indicators = []
new_labels = [] if labels is not None else None
cur_image_idx = 0
for batch_idx, cur_input_ids in enumerate(input_ids):
if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0:
# multimodal LLM, but the current sample is not multimodal
# FIXME: this is a hacky fix, for deepspeed zero3 to work
half_len = cur_input_ids.shape[0] // 2
cur_image_features = image_features[cur_image_idx]
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids[:half_len])
cur_input_embeds_2 = self.get_model().embed_tokens(cur_input_ids[half_len:])
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0], cur_input_embeds_2], dim=0)
new_input_embeds.append(cur_input_embeds)
cur_modality_indicators = torch.zeros(len(cur_input_embeds)).long().to(self.device)
new_modality_indicators.append(cur_modality_indicators)
if labels is not None:
new_labels.append(labels[batch_idx])
cur_image_idx += 1
continue
image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
cur_new_input_embeds = []
cur_modality_indicators = []
if labels is not None:
cur_labels = labels[batch_idx]
cur_new_labels = []
assert cur_labels.shape == cur_input_ids.shape
while image_token_indices.numel() > 0:
cur_image_features = image_features[cur_image_idx]
image_token_start = image_token_indices[0]
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start]))
cur_new_input_embeds.append(cur_image_features)
# Add modality indicator
assert image_token_start == len(cur_input_ids[:image_token_start])
cur_modality_indicators.append(torch.zeros(len(cur_input_ids[:image_token_start])).long())
cur_modality_indicators.append(torch.ones(len(cur_image_features)).long())
if labels is not None:
cur_new_labels.append(cur_labels[:image_token_start])
cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
cur_labels = cur_labels[image_token_start+1:]
cur_image_idx += 1
cur_input_ids = cur_input_ids[image_token_start+1:]
image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
if cur_input_ids.numel() > 0:
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids))
cur_modality_indicators.append(torch.zeros(len(cur_input_ids)).long())
if labels is not None:
cur_new_labels.append(cur_labels)
cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds]
cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0)
new_input_embeds.append(cur_new_input_embeds)
# Modality
cur_modality_indicators = [x.to(device=self.device) for x in cur_modality_indicators]
cur_modality_indicators = torch.cat(cur_modality_indicators, dim=0)
new_modality_indicators.append(cur_modality_indicators)
if labels is not None:
cur_new_labels = torch.cat(cur_new_labels, dim=0)
new_labels.append(cur_new_labels)
if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds):
max_len = max(x.shape[0] for x in new_input_embeds)
# Embedding
new_input_embeds_align = []
for cur_new_embed in new_input_embeds:
cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)
new_input_embeds_align.append(cur_new_embed)
new_input_embeds = torch.stack(new_input_embeds_align, dim=0)
# Modality
new_modality_indicators_align = []
for cur_modality_indicator in new_modality_indicators:
cur_new_embed = torch.cat((cur_modality_indicator, torch.zeros(max_len - cur_modality_indicator.shape[0], dtype=cur_modality_indicator.dtype, device=cur_modality_indicator.device)), dim=0)
new_modality_indicators_align.append(cur_new_embed)
new_modality_indicators = torch.stack(new_modality_indicators_align, dim=0)
# Label
if labels is not None:
new_labels_align = []
_new_labels = new_labels
for cur_new_label in new_labels:
cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0)
new_labels_align.append(cur_new_label)
new_labels = torch.stack(new_labels_align, dim=0)
# Attention Mask
if attention_mask is not None:
new_attention_mask = []
for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels):
new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device)
new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device)
cur_new_attention_mask = torch.cat((new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0)
new_attention_mask.append(cur_new_attention_mask)
attention_mask = torch.stack(new_attention_mask, dim=0)
assert attention_mask.shape == new_labels.shape
else:
new_input_embeds = torch.stack(new_input_embeds, dim=0)
new_modality_indicators = torch.stack(new_modality_indicators, dim=0)
if labels is not None:
new_labels = torch.stack(new_labels, dim=0)
if attention_mask is not None:
new_attn_mask_pad_left = torch.full((attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device)
attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1)
assert attention_mask.shape == new_input_embeds.shape[:2]
return None, new_modality_indicators, attention_mask, past_key_values, new_input_embeds, new_labels
class MPLUGDocOwlLlamaModel(MPLUGDocOwlMetaModel, LlamaModel):
config_class = MPLUGDocOwlConfig
def __init__(self, config: MPLUGDocOwlConfig):
super(MPLUGDocOwlLlamaModel, self).__init__(config)
class MPLUGDocOwl2(LlamaForCausalLM, MPLUGDocOwlMetaForCausalLM):
config_class = MPLUGDocOwlConfig
def __init__(self, config):
super(LlamaForCausalLM, self).__init__(config)
self.model = MPLUGDocOwlLlamaModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def init_processor(self, tokenizer, basic_image_size, crop_anchors):
self.processor = DocProcessor(tokenizer=tokenizer, image_size=basic_image_size, anchors=crop_anchors)
return self.processor
def get_model(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
# modality_indicators: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
patch_positions: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
# print('modeling_mplug_docow2.py patch_positions:', patch_positions)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
input_ids, modality_indicators, attention_mask, past_key_values, inputs_embeds, labels = \
self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images, patch_positions)
# ic(inputs_embeds.shape, labels.shape)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
modality_indicators=modality_indicators,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict
)
# ic(outputs[0].shape)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model/pipeline parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
# ic(loss.shape)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
if past_key_values:
input_ids = input_ids[:, -1:]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"images": kwargs.get("images", None),
"patch_positions": kwargs.get("patch_positions", None),
}
)
return model_inputs
def chat(self, messages, images, tokenizer):
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
image_tensor, patch_positions, input_ids = self.processor(images=images, messages=messages)
image_tensor = image_tensor.to(self.model.device, dtype=torch.float16)
patch_positions = patch_positions.to(self.model.device)
input_ids = input_ids.unsqueeze(0).to(self.model.device)
stopping_criteria = KeywordsStoppingCriteria(["</s>"], tokenizer, input_ids)
with torch.inference_mode():
output_ids = self.generate(
input_ids,
images=image_tensor,
patch_positions=patch_positions,
do_sample=False,
temperature=1.0,
max_new_tokens=512,
streamer=streamer,
use_cache=True,
stopping_criteria=[stopping_criteria])
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
return outputs.replace('</s>', '')
AutoConfig.register("mplug_docowl", MPLUGDocOwlConfig)
AutoModelForCausalLM.register(MPLUGDocOwlConfig, MPLUGDocOwl2)

20
preprocessor_config.json Normal file
View File

@ -0,0 +1,20 @@
{
"crop_size": 448,
"do_center_crop": true,
"do_normalize": true,
"do_resize": true,
"feature_extractor_type": "CLIPFeatureExtractor",
"image_mean": [
0.48145466,
0.4578275,
0.40821073
],
"image_std": [
0.26862954,
0.26130258,
0.27577711
],
"resample": 3,
"size": 448
}

226
processor.py Normal file
View File

@ -0,0 +1,226 @@
from einops import rearrange, repeat
import torch
from torchvision import transforms
from PIL import Image, ImageFile
import random
from torchvision.ops.boxes import box_area
from torchvision.transforms.transforms import InterpolationMode
from torchvision.transforms import functional as F
import numpy as np
from icecream import ic
import re
ImageFile.LOAD_TRUNCATED_IMAGES = True
ImageFile.MAX_IMAGE_PIXELS = None
Image.MAX_IMAGE_PIXELS = None
from .constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
def box_iou(boxes1, area1, boxes2, eps=1e-5):
area2 = box_area(boxes2)
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
wh = (rb - lt).clamp(min=0) # [N,M,2]
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
union = area1[:, None] + area2 - inter
iou = inter / (union+eps)
return iou, union
def anchor_rank(anchors, anchors_areas, input_image_size, eps=1e-5):
# anchors x1 y1 x2 y2
# image_size: (h, w)
# xyxy
input_image_bbox = torch.tensor([0, 0, input_image_size[1], input_image_size[0]]).unsqueeze(0)
boxes1 = anchors
boxes2 = input_image_bbox
boxes3 = anchors.clone()
# y2
boxes3[:,3] = input_image_size[0]/input_image_size[1]*anchors[:,2] # 用于算分辨率无关的iou
area1 = anchors_areas
iou, _ = box_iou(boxes1, area1, boxes2)
iou = iou.squeeze(1)
shape_iou, _ = box_iou(boxes1, area1, boxes3)
shape_iou = shape_iou.diag()
# 优先匹配形状接近 再匹配分辨率接近
index = torch.argmax(shape_iou*100+iou,dim=0)
return index
class AnchorResize(torch.nn.Module):
def __init__(self, image_size, anchors, interpolation=InterpolationMode.BILINEAR, antialias=None):
super().__init__()
# xyxy
self.anchors = torch.tensor(
[[0, 0, _[1]*image_size[1], _[0]*image_size[0]]
for _ in anchors], requires_grad=False
)
self.anchor_areas = box_area(self.anchors)
self.interpolation = interpolation
self.antialias = antialias
def forward(self, img, skip_resize=False):
"""
Args:
img (PIL Image or Tensor): Image to be scaled.
Returns:
PIL Image or Tensor: Rescaled image.
"""
selected_anchor = anchor_rank(self.anchors, self.anchor_areas, (img.size[1], img.size[0]))
target_size = self.anchors[selected_anchor][2:].tolist() # w,h
if skip_resize:
# for debug
return selected_anchor
return F.resize(img, [target_size[1],target_size[0]], self.interpolation, max_size=None, antialias=self.antialias), selected_anchor
def __repr__(self) -> str:
detail = f"(size={self.image_size}, anchor={self.anchors}, interpolation={self.interpolation.value}, antialias={self.antialias})"
return f"{self.__class__.__name__}{detail}"
class DocProcessor():
def __init__(self, tokenizer=None, image_size=504, anchors='grid_12'):
self.media_token= "<|image|>"
# h,w
if isinstance(image_size, int):
image_size = (image_size, image_size)
self.image_size = image_size
# h,w
# anchors = grid_dict[anchors]
max_crop = int(anchors.split('_')[1])
anchors = [(j, int(i/j)) for i in range(1,max_crop+1) for j in range(1, i+1) if i%j==0]
self.anchors = [tuple(_) for _ in anchors]
self.anchor_max = max([max(_) for _ in self.anchors])
# xywh -> xyxy
self.resizer = AnchorResize(image_size=image_size, anchors=anchors, interpolation=InterpolationMode.BICUBIC)
self.old_resizer = transforms.Resize(image_size,interpolation=InterpolationMode.BICUBIC)
self.image_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
self.tokenizer = tokenizer
def _process_image(self, images):
new_images = []
new_patch_position = []
num_image_mult = []
for image in images:
nocut_image = self.image_transform(self.old_resizer(image)).unsqueeze(0)
image, selected_anchor = self.resizer(image)
image_input = self.image_transform(image) # h,w,3 -> 3,h,w
# rearrange(x,'B C (n1 h) (n2 w) -> (B n1 n2) C h w', n1=self.down_sample[0], n2=self.down_sample[1])
image_input = rearrange(image_input, 'C (num_h h) (num_w w) -> (num_h num_w) C h w', h=self.image_size[0], w=self.image_size[1])
image_input = torch.cat([nocut_image, image_input], dim=0)
anchor = self.anchors[selected_anchor] # w,h
patch_position = torch.cat([
repeat(torch.arange(anchor[0]), 'num_h -> num_h num_w 1', num_w=anchor[1]),
repeat(torch.arange(anchor[1]), 'num_w -> num_h num_w 1', num_h=anchor[0])],dim=2)
patch_position = rearrange(patch_position, 'num_h num_w p-> (num_h num_w) p', p=2) # num_patch, (ph,pw)
patch_position = torch.cat([torch.ones(1,2).long()*self.anchor_max, patch_position], dim=0)
new_images.append(image_input)
new_patch_position.append(patch_position)
num_image_mult.append(patch_position.shape[0])
new_images = torch.cat(new_images,dim=0)
new_patch_position = torch.cat(new_patch_position, dim=0)
return new_images, new_patch_position, num_image_mult
def __call__(self, images=None, messages=None):
assert images is not None
# print(images)
## 1. process images
if not isinstance(images, list):
images = [images]
image_pils = []
for image in images:
if isinstance(image, str):
image = Image.open(image).convert('RGB')
else:
image = image.convert('RGB')
# ic(image.size)
image_pils.append(image)
image_data, patch_position, num_image_mult = self._process_image(image_pils)
## 2. process text
# 2.1 add image ordinal token (e.g. <img 1>) before image placeholder <|image|>
image_index = 1 # start from 1
for m in messages:
try:
assert m['role'] in ['USER', 'ASSISTANT']
except Exception as e:
print("Unexpected role: "+m['role']+", only support 'USER' or 'ASSISTANT'")
exit(0)
if m['role'] == 'USER' and self.media_token in m.get('content', ''):
pattern = '|'.join(map(re.escape, [self.media_token]))
text_list = re.split(f'({pattern})', m['content'])
text = ''
for x in text_list:
if x == '<|image|>':
text += '<img '+str(image_index)+'><|image|>'
image_index += 1
else:
text += x
m['content'] = text
if messages[-1]['role'] == 'USER':
messages.append({'role':'ASSISTANT'})
else:
try:
assert messages[-1].get('content', '') == ''
except Exception as e:
print("Unexpected end message: "+str(messages[-1]), "only (role=='USER') or (role=='ASSISTANT' and content=='') are expected.")
exit(0)
# print('after adding img ordinal token: ', messages)
# 2.2 text tokenize
seps = [' ', '</s>']
prompt = ""
for i, m in enumerate(messages):
if 'content' in m:
prompt += m['role'] + ": " + m['content'] + seps[i % 2]
else:
prompt += m['role'] + ":"
ic(prompt)
assert self.media_token in prompt
input_ids = self.tokenizer_token(prompt)
return image_data, patch_position, input_ids
def tokenizer_token(self, prompt):
prompt_chunks = [self.tokenizer(chunk).input_ids if len(chunk) > 0 else [] for chunk in prompt.split(DEFAULT_IMAGE_TOKEN)]
def insert_separator(X, sep):
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
input_ids = []
offset = 0
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == self.tokenizer.bos_token_id:
offset = 1
input_ids.append(prompt_chunks[0][0])
for x in insert_separator(prompt_chunks, [IMAGE_TOKEN_INDEX] * (offset + 1)):
input_ids.extend(x[offset:])
return torch.tensor(input_ids, dtype=torch.long)

24
special_tokens_map.json Normal file
View File

@ -0,0 +1,24 @@
{
"bos_token": {
"content": "<s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"eos_token": {
"content": "</s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"pad_token": "<unk>",
"unk_token": {
"content": "<unk>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
}
}

BIN
tokenizer.model (Stored with Git LFS) Normal file

Binary file not shown.

35
tokenizer_config.json Normal file
View File

@ -0,0 +1,35 @@
{
"add_bos_token": true,
"add_eos_token": false,
"bos_token": {
"__type": "AddedToken",
"content": "<s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"clean_up_tokenization_spaces": false,
"eos_token": {
"__type": "AddedToken",
"content": "</s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"legacy": false,
"model_max_length": 4096,
"pad_token": null,
"padding_side": "right",
"sp_model_kwargs": {},
"tokenizer_class": "LlamaTokenizer",
"unk_token": {
"__type": "AddedToken",
"content": "<unk>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
}
}

426
visual_compressor.py Normal file
View File

@ -0,0 +1,426 @@
import math
from typing import Any, Optional, Tuple, Union
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, BaseModelOutputWithPastAndCrossAttentions
from transformers.modeling_utils import PreTrainedModel
from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
import numpy as np
import torch
import torch.nn as nn
import torch.utils.checkpoint
from icecream import ic
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_unpadded_func
from einops import rearrange
class MplugDocOwlVisualMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
in_features = config.high_reso_cross_hid_size
self.act = nn.SiLU()
ffn_hidden_size = int(2 * 4 * in_features / 3)
multiple_of = 256
ffn_hidden_size = multiple_of * ((ffn_hidden_size + multiple_of - 1) // multiple_of)
self.w1 = nn.Linear(in_features, ffn_hidden_size)
self.w2 = nn.Linear(ffn_hidden_size, in_features)
self.w3 = nn.Linear(in_features, ffn_hidden_size)
self.ffn_ln = nn.LayerNorm(ffn_hidden_size, eps=config.layer_norm_eps)
torch.nn.init.zeros_(self.w1.bias.data)
torch.nn.init.zeros_(self.w2.bias.data)
torch.nn.init.zeros_(self.w3.bias.data)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.act(self.w1(hidden_states)) * self.w3(hidden_states)
hidden_states = self.ffn_ln(hidden_states)
hidden_states = self.w2(hidden_states)
return hidden_states
class FlashCrossAttention(torch.nn.Module):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0,
device=None, dtype=None):
super().__init__()
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
def forward(self, q, k, v, **kwargs):
"""Implements the multihead softmax attention.
Arguments
---------
q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
or
q: (Sum_q, H, D), k,v : (Sum_k, H, D),
must with batch_size, max_seqlen_q, max_seqlen_k, cu_seqlens_q, cu_seqlens_k in kwargs
"""
assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q,k,v)))
assert all((i.is_cuda for i in (q,k,v)))
if q.dim() == 4:
batch_size, seqlen_q = q.shape[0], q.shape[1]
q = rearrange(q, 'b s ... -> (b s) ...')
cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32,
device=q.device)
else:
batch_size, seqlen_q = kwargs['batch_size'], kwargs['max_seqlen_q']
cu_seqlens_q = kwargs['cu_seqlens_q']
if k.dim() == 4:
seqlen_k = k.shape[1]
k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [k, v]]
cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32,
device=q.device)
else:
seqlen_k = kwargs['max_seqlen_k']
cu_seqlens_k = kwargs['cu_seqlens_k']
# q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]]
# self.dropout_p = 0
"""print('FlashCrossAttention: q.shape:', q.shape)
print('FlashCrossAttention: k.shape:', k.shape)
print('FlashCrossAttention: v.shape:', v.shape)
print('FlashCrossAttention: cu_seqlens_q:', cu_seqlens_q)
print('FlashCrossAttention: cu_seqlens_k:', cu_seqlens_k)"""
# print('visual_compressor.py q.shape:', q.shape, ' k.shape:', k.shape, ' v.shape:', v.shape)
output = flash_attn_unpadded_func(
q, k, v, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=False
)
if q.dim() == 4: # keep the shape of output shape same as the input query
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
return output
class MplugDocOwlVisualMultiHeadAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
if config.high_reso_cross_hid_size % config.high_reso_cross_num_att_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention heads (%d)"
% (config.high_reso_cross_hid_size, config.high_reso_cross_num_att_heads)
)
if config.high_reso_cross_hid_size // config.high_reso_cross_num_att_heads > 256:
raise ValueError(
"The hidden size of each head (%d) > 256 and is illegal for flash attention"
% (config.high_reso_cross_hid_size // config.high_reso_cross_num_att_heads)
)
self.num_attention_heads = config.high_reso_cross_num_att_heads
self.attention_head_size = int(config.high_reso_cross_hid_size / config.high_reso_cross_num_att_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.high_reso_cross_hid_size, self.all_head_size)
self.key = nn.Linear(config.high_reso_cross_hid_size, self.all_head_size)
self.value = nn.Linear(config.high_reso_cross_hid_size, self.all_head_size)
self.core_attention_flash = FlashCrossAttention(attention_dropout=config.high_reso_cross_dropout)
# bias init
torch.nn.init.zeros_(self.query.bias.data)
torch.nn.init.zeros_(self.key.bias.data)
torch.nn.init.zeros_(self.value.bias.data)
def transpose_for_scores(self, x):
# [B, S, D] -> [B, S, H, D] or [Sum_S, D] -> [Sum_S, H, D]
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x
def forward(
self,
hidden_states,
encoder_hidden_states=None,
**kwargs
):
# assert not torch.isnan(hidden_states).any()
# assert not torch.isnan(encoder_hidden_states).any()
key = self.transpose_for_scores(self.key(encoder_hidden_states))
value = self.transpose_for_scores(self.value(encoder_hidden_states))
query = self.transpose_for_scores(self.query(hidden_states))
# print('visual_compressor.py key(after projection): ', key.shape, key)
# print('visual_compressor.py value(after projection): ', value.shape, value)
# print('visual_compressor.py query(after projection): ', query.shape, query)
# assert not torch.isnan(key).any()
# assert not torch.isnan(value).any()
# assert not torch.isnan(query).any()
outputs = self.core_attention_flash(q=query, k=key, v=value, **kwargs)
outputs = rearrange(outputs, 's h d -> s (h d)').contiguous()
# print('visual_compressor.py outputs(after cross_att): ', outputs.shape, outputs)
return outputs
class MplugDocOwlVisualCrossOutput(nn.Module):
def __init__(self, config):
super().__init__()
dim = config.high_reso_cross_hid_size
self.out_proj = nn.Linear(dim, dim, bias=True)
self.norm2 = nn.LayerNorm(dim)
self.mlp = MplugDocOwlVisualMLP(config)
# bias init
torch.nn.init.zeros_(self.out_proj.bias.data)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
input_tensor = input_tensor + self.out_proj(hidden_states)
input_tensor = input_tensor + self.mlp(self.norm2(input_tensor))
return input_tensor
class MplugDocOwlVisualCrossAttentionLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.attention = MplugDocOwlVisualMultiHeadAttention(config)
self.output = MplugDocOwlVisualCrossOutput(config)
self.norm1 = nn.LayerNorm(config.high_reso_cross_hid_size)
self.normk = nn.LayerNorm(config.high_reso_cross_hid_size)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
**kwargs
) -> Tuple[torch.Tensor]:
# print('visual_compressor.py hidden_states: ', hidden_states.shape, hidden_states)
# print('visual_compressor.py encoder_hidden_states: ', encoder_hidden_states.shape, encoder_hidden_states)
# assert not torch.isnan(hidden_states).any()
# assert not torch.isnan(encoder_hidden_states).any()
hidden_states = self.norm1(hidden_states)
encoder_hidden_states = self.normk(encoder_hidden_states)
# print('visual_compressor.py hidden_states(after norm): ', hidden_states.shape, hidden_states)
# print('visual_compressor.py encoder_hidden_states(after norm): ', encoder_hidden_states.shape, encoder_hidden_states)
attention_output = self.attention(
hidden_states,
encoder_hidden_states,
**kwargs
)
outputs = self.output(attention_output, hidden_states)
return outputs
class MplugDocOwlVisualCrossAttentionEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer_num = config.layer
self.layers = nn.ModuleList(
[MplugDocOwlVisualCrossAttentionLayer(config) for layer_idx in range(self.layer_num)]
)
self.gradient_checkpointing = True
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
**kwargs
):
for i in range(self.layer_num):
layer_module = self.layers[i]
layer_outputs = layer_module(
hidden_states,
encoder_hidden_states,
**kwargs
)
hidden_states = layer_outputs
return hidden_states
def ensemble_crop_feats(crop_feats, patch_positions, col_feat_num):
"""
ensemble vision feats from different crops to a feature map according the position of the raw image
crop_feats: [N_crop, Len_feat, D]
patch_positions: [N_crop, 2], 2 == (rowl_index, col_index)
col_feat_num: the feature num of a row in a crop image
"""
assert crop_feats.size(0) == patch_positions.size(0)
row_feats = []
crop_row = torch.max(patch_positions[:,0])+1 #
crop_feats = rearrange(crop_feats, '(R C) L D -> R C L D', R=crop_row) # [N_crop_row, N_crop_col, Len_feat, D]
crop_feats = rearrange(crop_feats, 'R C (X Y) D-> R C X Y D', Y=col_feat_num) # [N_crop_row, N_crop_col, Len_row_feat, Len_col_feat, D]
# 1. concatenate same row feats across crops; 2. ensemble row feats to get 1 feature map
hw_feats = rearrange(crop_feats, 'R C X Y D-> (R X) (C Y) D') # [N_crop_row x Len_row_feat, N_crop_col x Len_col_feat, D]
return hw_feats
def group_window_feats(feats, window):
"""
collect vision feats from a window (win_row, win_col) to 1 group
feats: [H, W, D]
window: (win_row, win_col)
return: [H/win_row, H/win_col, win_row x win_col, D]
"""
group_feats = rearrange(feats, '(X R) (Y C) D -> (X Y) (R C) D', R=window[0], C=window[1]) # [H/win_row x H/win_col, win_row x win_col, D]
return group_feats
def distinguish_global_crop_features(hidden_states, patch_positions, reorganize_crop_feats=True, col_feat_num=None, group_feats_by_crop_shape=False, keep_row_col=False):
"""
distinguish global and crop features with the help of patcg_positions
# hidden_states: [B, s+1, h]
# (B is the sum of cropped num across samples in a micro_batch, s is the visual tokens, +1 means the vit end token)
# patch_positions: [B, 2],
# 2 == (rowl_index, col_index), the first crop is (0,0), global img is (anchor_max, anchor_max)
col_feat_num is used when reorganize_crop_feats == True
outputs:
img_global_features: list of [Len_global_feat, D]
img_crop_features: list of [Len_global_feat, D]
"""
hidden_states = hidden_states[:, :-1, :] # remove the last vit end token emb
# the first crop is (0,0)
first_crop_indices = (patch_positions.sum(dim=-1) == 0).nonzero().squeeze(1) # Num_img
# the global image is before the first crop
global_indices = first_crop_indices - 1 # Num_img
# print('vision2text_model.py patch_positions:', patch_positions)
# print('vision2text_model.py global_indices:', global_indices)
# collect cropped vision features of an identical image
batch_size = hidden_states.size(0)
img_global_features = []
img_crop_features = [] # store list of Num_crop (variable) x Len_feat (fixed)
img_crop_positions = [] # store list of Num_crop (variable) x 2
for i in range(len(global_indices)):
index = global_indices[i]
img_global_features.append(hidden_states[index])
if i == (len(global_indices)-1):
img_crop_features.append(hidden_states[index+1:])
img_crop_positions.append(patch_positions[index+1:])
else:
next_index = global_indices[i+1]
img_crop_features.append(hidden_states[index+1:next_index])
img_crop_positions.append(patch_positions[index+1:next_index])
if reorganize_crop_feats:
for i in range(len(img_crop_features)):
img_crop_features[i] = ensemble_crop_feats(img_crop_features[i], img_crop_positions[i], col_feat_num) # [H W D]
if group_feats_by_crop_shape: # collect vision feats from a window (crop_row, crop_col) to 1 group
crop_row = torch.max(img_crop_positions[i][:,0])+1 #
crop_col = torch.max(img_crop_positions[i][:,1])+1 #
img_crop_features[i] = group_window_feats(img_crop_features[i], window=(crop_row, crop_col)) # [H/crop_row x W/crop_col, crop_row x crop_row, D]
else:
# img_crop_features = [rearrange(x, 'H W D -> (H W) D') for x in img_crop_features]
if not keep_row_col:
img_crop_featuress[i] = rearrange(img_crop_featuress[i], 'H W D -> (H W) D')
else:
img_crop_features = [rearrange(x, 'N L D -> (N L) D') for x in img_crop_features]
return img_global_features, img_crop_features
class MplugDocOwlHRDocCompressor(PreTrainedModel):
"""
After vision-to-text module, use low-resolution global features to select high-resolution crop features with cross-attention
the key/value from high-resolution crop features are contrained in a window size
positions of the features within the window in raw images are the same as the global query features
"""
def __init__(self, config, output_hidden_size, v2t_img_col_tokens):
super().__init__(config)
self.use_flash_attn = True
assert self.use_flash_attn
self.v2t_img_col_tokens = v2t_img_col_tokens
self.compressor_crossatt = MplugDocOwlVisualCrossAttentionEncoder(config)
self.compressor_fc = torch.nn.Linear(output_hidden_size, output_hidden_size)
self.compressor_eos = torch.nn.Parameter(torch.randn(1, 1, output_hidden_size))
def forward(self, hidden_states, patch_positions=None):
# hidden_states: outputs of vision2textmodel: [Sum(crop), s+1, h]
# (Sum(crop) is the sum of cropped num across samples in a micro_batch, s is the visual tokens, +1 is the special vit_eos token added in H-Reducer)
# patch_positions: [Sum(crop), 2]
# print('visual_compressor.py HRDocCompressor hidden_states.shape:', hidden_states.shape)
# print('visual_compressor.py HRDocCompressor patch_positions.shape:', patch_positions.shape)
# N_img x [L_global (fixed), D], N_img x [L_global (fixed), Crop_row x Crop_Col (Variable), D]
img_global_features, img_crop_features = distinguish_global_crop_features(hidden_states,
patch_positions,
reorganize_crop_feats=True,
col_feat_num=self.v2t_img_col_tokens,
group_feats_by_crop_shape=True)
# cross-attention to accumulate high-resolution features
# if self.use_flash_attn: # flash_attn_varlen_func don't need to pad crop_features
img_global_features = torch.stack(img_global_features, dim=0).to(hidden_states.device) # Num_img x Len_global_feat x D
batch_size, global_feat_num, seqlen_q = img_global_features.shape[0], img_global_features.shape[1], 1
img_global_features = rearrange(img_global_features, 'b s ... -> (b s) ...')
cu_seqlens_q = torch.arange(0, batch_size*global_feat_num+1, step=1, dtype=torch.int32, device=img_global_features.device) # # (Num_img x Len_global_feat +1, )
cu_seqlens_k = [0]
max_seqlens_k = 0
for crop_feat in img_crop_features:
for i in range(crop_feat.shape[0]):
cu_seqlens_k.append(cu_seqlens_k[-1]+crop_feat.shape[1]) # same k within a image shares the seq len
max_seqlens_k = max(max_seqlens_k, crop_feat.size(1))
cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32).to(hidden_states.device) # (Num_img x Len_global_feat+1, )
# cu_seqlens_k = torch.arange(0, (batch_size + 1) * max_seqlens_k, step=max_seqlens_k, dtype=torch.int32, device=img_global_features.device) # # (Num_img+1, )
img_crop_features = torch.cat([rearrange(x, 'N L D -> (N L) D') for x in img_crop_features], dim=0).to(hidden_states.device) # Sum(L_hr) x D
flash_kwargs = {
'batch_size': batch_size*global_feat_num, # each feat in global feats use different keys
'max_seqlen_q': seqlen_q, # key are unique for each query
'max_seqlen_k': max_seqlens_k,
'cu_seqlens_q': cu_seqlens_q, # the seq len of each q
'cu_seqlens_k': cu_seqlens_k # the seq len of each k
}
# print('visual_compressor.py HRDocCompressor img_global_features.shape:', img_global_features.shape, img_global_features)
# print('visual_compressor.py HRDocCompressor img_crop_features.shape:', img_crop_features.shape, img_crop_features)
"""print('visual_compressor.py HRDocCompressor cu_seqlens_q, cu_seqlens_q.shape:', cu_seqlens_q, cu_seqlens_q.shape)
print('visual_compressor.py HRDocCompressor cu_seqlens_k, cu_seqlens_k.shape:', cu_seqlens_k, cu_seqlens_k.shape)"""
# assert not torch.isnan(img_global_features).any()
# assert not torch.isnan(img_crop_features).any()
for x_name, x in self.compressor_crossatt.named_parameters():
try:
assert not torch.isnan(x).any()
# print('visual_compressor.py ', x_name, x.shape, x)
except Exception as e:
print(e)
print('visual_compressor.py nan', x_name, x.shape, x)
hidden_states = self.compressor_crossatt(
img_global_features.contiguous(), # Sum(L_global) x D
img_crop_features.contiguous(), # Sum(L_hr) x D
**flash_kwargs
) # Sum(L_global) x D
hidden_states = rearrange(hidden_states, '(B S) D -> S B D', B=batch_size) # L_global x N_img x D
hidden_states = self.compressor_fc(hidden_states) # L_global x N_img x D
hidden_states = hidden_states.transpose(0, 1).contiguous() # N_img x L_global x D
# print('visual_compressor.py hidden_states:', hidden_states.shape)
hidden_states = torch.cat([hidden_states, self.compressor_eos.repeat(hidden_states.shape[0], 1, 1)], dim=1) # N_img x (L_global+1) x D
# print('visual_compressor.py HRDocCompressor hidden_states.shape:', hidden_states.shape)
return hidden_states

501
visual_encoder.py Normal file
View File

@ -0,0 +1,501 @@
import math
from typing import Any, Optional, Tuple, Union
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, BaseModelOutputWithPastAndCrossAttentions
from transformers.modeling_utils import PreTrainedModel
from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
import numpy as np
import torch
import torch.nn as nn
import torch.utils.checkpoint
from icecream import ic
import einops
from einops import rearrange
def get_abs_pos(abs_pos, tgt_size):
# abs_pos: L, C
# tgt_size: M
# return: M, C
src_size = int(math.sqrt(abs_pos.size(0)))
tgt_size = int(math.sqrt(tgt_size))
dtype = abs_pos.dtype
if src_size != tgt_size:
return F.interpolate(
abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
size=(tgt_size, tgt_size),
mode="bicubic",
align_corners=False,
).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)
else:
return abs_pos
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token:
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float32)
omega /= embed_dim / 2.
omega = 1. / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
class MplugOwlVisionEmbeddings(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.cls_token = nn.Parameter(torch.randn(1, 1, self.hidden_size))
self.patch_embed = nn.Conv2d(
in_channels=3,
out_channels=self.hidden_size,
kernel_size=self.patch_size,
stride=self.patch_size,
bias=False,
)
self.num_patches = (self.image_size // self.patch_size) ** 2
self.position_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, self.hidden_size))
self.pre_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.size(0)
image_embeds = self.patch_embed(pixel_values)
image_embeds = image_embeds.flatten(2).transpose(1, 2)
class_embeds = self.cls_token.expand(batch_size, 1, -1).to(image_embeds.dtype)
embeddings = torch.cat([class_embeds, image_embeds], dim=1)
embeddings = embeddings + self.position_embedding[:, : embeddings.size(1)].to(image_embeds.dtype)
embeddings = self.pre_layernorm(embeddings)
return embeddings
class MplugOwlVisionAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
if self.head_dim * self.num_heads != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.dropout = nn.Dropout(config.attention_dropout)
self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size)
self.dense = nn.Linear(self.hidden_size, self.hidden_size)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
bsz, seq_len, embed_dim = hidden_states.size()
mixed_qkv = self.query_key_value(hidden_states)
mixed_qkv = mixed_qkv.reshape(bsz, seq_len, self.num_heads, 3, embed_dim // self.num_heads).permute(
3, 0, 2, 1, 4
) # [3, b, np, sq, hn]
query_states, key_states, value_states = (
mixed_qkv[0],
mixed_qkv[1],
mixed_qkv[2],
)
# if self.config.use_flash_attn and flash_attn_func is not None:
if False:
# [b*sq, np, hn]
query_states = query_states.permute(0, 2, 1, 3).contiguous()
query_states = query_states.view(query_states.size(0) * query_states.size(1), query_states.size(2), -1)
key_states = key_states.permute(0, 2, 1, 3).contiguous()
key_states = key_states.view(key_states.size(0) * key_states.size(1), key_states.size(2), -1)
value_states = value_states.permute(0, 2, 1, 3).contiguous()
value_states = value_states.view(value_states.size(0) * value_states.size(1), value_states.size(2), -1)
cu_seqlens = torch.arange(
0, (bsz + 1) * seq_len, step=seq_len, dtype=torch.int32, device=query_states.device
)
context_layer = flash_attn_func(
query_states,
key_states,
value_states,
cu_seqlens,
cu_seqlens,
seq_len,
seq_len,
self.dropout if self.training else 0.0,
softmax_scale=self.scale,
causal=False,
return_attn_probs=False,
)
# [b*sq, np, hn] => [b, sq, np, hn]
context_layer = context_layer.view(bsz, seq_len, context_layer.size(1), context_layer.size(2))
else:
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
attention_scores = attention_scores * self.scale
# Normalize the attention scores to probabilities.
attention_probs = torch.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_states).permute(0, 2, 1, 3)
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,)
context_layer = context_layer.reshape(new_context_layer_shape)
output = self.dense(context_layer)
outputs = (output, attention_probs) if output_attentions else (output, None)
return outputs
class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class MplugOwlMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.activation_fn = QuickGELU()
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class MplugOwlVisionEncoderLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = MplugOwlVisionAttention(config)
self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
self.mlp = MplugOwlMLP(config)
self.post_attention_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.FloatTensor]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
`(config.encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states,
head_mask=attention_mask,
output_attentions=output_attentions,
)
hidden_states = hidden_states + residual
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = hidden_states + residual
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class MplugOwlVisionEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`MplugOwlVisionEncoderLayer`].
Args:
config (`MplugOwlVisionConfig`):
The corresponding vision configuration for the `MplugOwlEncoder`.
"""
def __init__(self, config):
super().__init__()
self.config = config
self.layers = nn.ModuleList([MplugOwlVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = True
def forward(
self,
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutput]:
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Embedded representation of the inputs. Should be float, not int tokens.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
hidden_states = inputs_embeds
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
hidden_states,
attention_mask,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
)
class MplugOwlVisionModel(PreTrainedModel):
main_input_name = "pixel_values"
def __init__(self, config):
super().__init__(config)
self.config = config
self.hidden_size = config.hidden_size
self.embeddings = MplugOwlVisionEmbeddings(config)
self.encoder = MplugOwlVisionEncoder(config)
self.post_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
self.post_init()
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
hidden_states = self.embeddings(pixel_values)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = encoder_outputs[0]
last_hidden_state = self.post_layernorm(last_hidden_state)
pooled_output = last_hidden_state[:, 0, :]
pooled_output = self.post_layernorm(pooled_output)
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
def get_input_embeddings(self):
return self.embeddings
class MplugDocOwlHReducerModel(PreTrainedModel):
def __init__(self, config, language_hidden_size):
super().__init__(config)
self.config = config
self.ln_q = torch.nn.LayerNorm(self.config.hidden_size, eps=1e-6)
self.conv_shape = (int(self.config.conv_shape.split('x')[0]), int(self.config.conv_shape.split('x')[1])) #
self.conv_patch=self.conv_shape[0]*self.conv_shape[1]
## feature interaction with a conv layer
self.reducer_before = torch.nn.Sequential(
nn.Conv2d(self.config.hidden_size, self.conv_patch*self.config.hidden_size, kernel_size=self.conv_shape, stride=self.conv_shape, bias=True),
nn.GELU()
)
## reduce visual feature length with a conv layer
self.reducer = nn.Conv2d(self.config.hidden_size, self.config.hidden_size, kernel_size=self.conv_shape, stride=self.conv_shape, bias=True)
## align visual features with language embedding with fc
self.visual_fc = torch.nn.Linear(self.config.hidden_size, language_hidden_size)
self.vit_eos = torch.nn.Parameter(torch.randn(1, 1, language_hidden_size))
self.post_init()
def forward(
self,
encoder_hidden_states=None
):
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, `optional`):
batch_size is the number of all images (global+crop) in a batch
Sequence of hidden-states at the output of the last layer of the encoder.
"""
encoder_hidden_states = encoder_hidden_states[:,1:,:] # remove the first cls token
B, L, C = encoder_hidden_states.shape # B, 1024=(448/14)^2, 1024
## feature interaction with a conv layer
encoder_hidden_states = rearrange(encoder_hidden_states, 'B (H W) D -> B D H W', H=int(math.sqrt(L)))
hidden_states = self.reducer_before(encoder_hidden_states) # B 4D H W/4
## reduce seq length with a conv layer
"""hidden_states = hidden_states.flatten(2).transpose(1, 2) # B 4D H W/4 -> B 4D H*W/4 -> B H*W/4 4D
hidden_states = rearrange(hidden_states, 'B L (X D) -> B (L X) D', X=self.conv_patch) # B (H W) D
hidden_states = rearrange(hidden_states, 'B (H W) D -> B D H W', H=int(math.sqrt(L))) # B D H W """
hidden_states = rearrange(hidden_states, 'B (X D) H W -> B D H (W X)', X=self.conv_patch) # B 4D H W/4 -> B D H W
sequence_output = self.reducer(hidden_states) # B,C,H,W -> B,C,H/conv_shape[1],W/(conv_shape[1])
sequence_output = sequence_output.flatten(2).transpose(1, 2) # B,C,H/conv_shape[1],W/(conv_shape[1]) -> B,C,L/conv_patch -> B,L/conv_patch,C
sequence_output = sequence_output.transpose(0, 1).contiguous() # L/conv_patch, B, C
## align visual features with language embedding with fc
sequence_output = self.visual_fc(sequence_output) # L/conv_patch, B, h
sequence_output = sequence_output.transpose(0, 1).contiguous() # B, s/4, h
sequence_output = torch.cat([sequence_output, self.vit_eos.repeat(B, 1, 1)], dim=1)
return sequence_output