first commit

This commit is contained in:
xxl 2024-12-26 10:53:19 +08:00
parent 14e1a26fda
commit 63eb429d0d
16 changed files with 457539 additions and 2 deletions

132
README.md
View File

@ -1,3 +1,131 @@
# mPLUG-Owl3-2B-241014_a14065968674238464724418 ---
license: apache-2.0
language:
- en
pipeline_tag: visual-question-answering
tags:
- chat
---
# mPLUG-Owl3
## Introduction
mPLUG-Owl3 is a state-of-the-art multi-modal large language model designed to tackle the challenges of long image sequence understanding. We propose Hyper Attention, which boosts the speed of long visual sequence understanding in multimodal large language models by sixfold, allowing for processing of visual sequences that are eight times longer. Meanwhile, we maintain excellent performance on single-image, multi-image, and video tasks.
Github: [mPLUG-Owl](https://github.com/X-PLUG/mPLUG-Owl)
## Quickstart
Load the mPLUG-Owl3. We now only support attn_implementation in ```['sdpa', 'flash_attention_2']```.
```Python
import torch
from modelscope import AutoConfig, AutoModel
model_path = 'iic/mPLUG-Owl3-2B-241014'
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
print(config)
# model = mPLUGOwl3Model(config).cuda().half()
model = AutoModel.from_pretrained(model_path, attn_implementation='sdpa', torch_dtype=torch.half, trust_remote_code=True)
model.eval().cuda()
```
Chat with images.
```Python
from PIL import Image
from modelscope import AutoTokenizer
from decord import VideoReader, cpu
model_path = 'iic/mPLUG-Owl3-2B-241014'
tokenizer = AutoTokenizer.from_pretrained(model_path)
processor = model.init_processor(tokenizer)
image = Image.new('RGB', (500, 500), color='red')
messages = [
{"role": "user", "content": """<|image|>
Describe this image."""},
{"role": "assistant", "content": ""}
]
inputs = processor(messages, images=[image], videos=None)
inputs.to('cuda')
inputs.update({
'tokenizer': tokenizer,
'max_new_tokens':100,
'decode_text':True,
})
g = model.generate(**inputs)
print(g)
```
Chat with a video.
```Python
from PIL import Image
from modelscope import AutoTokenizer
from decord import VideoReader, cpu # pip install decord
model_path = 'iic/mPLUG-Owl3-2B-241014'
tokenizer = AutoTokenizer.from_pretrained(model_path)
processor = model.init_processor(tokenizer)
messages = [
{"role": "user", "content": """<|video|>
Describe this video."""},
{"role": "assistant", "content": ""}
]
videos = ['/nas-mmu-data/examples/car_room.mp4']
MAX_NUM_FRAMES=16
def encode_video(video_path):
def uniform_sample(l, n):
gap = len(l) / n
idxs = [int(i * gap + gap / 2) for i in range(n)]
return [l[i] for i in idxs]
vr = VideoReader(video_path, ctx=cpu(0))
sample_fps = round(vr.get_avg_fps() / 1) # FPS
frame_idx = [i for i in range(0, len(vr), sample_fps)]
if len(frame_idx) > MAX_NUM_FRAMES:
frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES)
frames = vr.get_batch(frame_idx).asnumpy()
frames = [Image.fromarray(v.astype('uint8')) for v in frames]
print('num frames:', len(frames))
return frames
video_frames = [encode_video(_) for _ in videos]
inputs = processor(messages, images=None, videos=video_frames)
inputs.to('cuda')
inputs.update({
'tokenizer': tokenizer,
'max_new_tokens':100,
'decode_text':True,
})
g = model.generate(**inputs)
print(g)
```
## Citation
If you find our work helpful, feel free to give us a cite.
```
@misc{ye2024mplugowl3longimagesequenceunderstanding,
title={mPLUG-Owl3: Towards Long Image-Sequence Understanding in Multi-Modal Large Language Models},
author={Jiabo Ye and Haiyang Xu and Haowei Liu and Anwen Hu and Ming Yan and Qi Qian and Ji Zhang and Fei Huang and Jingren Zhou},
year={2024},
eprint={2408.04840},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2408.04840},
}
```
mPLUG-Owl3-2B-241014

47
config.json Normal file
View File

@ -0,0 +1,47 @@
{
"architectures": [
"mPLUGOwl3Model"
],
"auto_map": {
"AutoConfig": "configuration_mplugowl3.mPLUGOwl3Config",
"AutoModel": "modeling_mplugowl3.mPLUGOwl3Model",
"AutoModelForCausalLM": "modeling_mplugowl3.mPLUGOwl3Model"
},
"attention_dropout": 0.0,
"bos_token_id": 151643,
"eos_token_id": 151645,
"hidden_act": "silu",
"hidden_size": 1536,
"initializer_range": 0.02,
"intermediate_size": 8960,
"max_position_embeddings": 32768,
"max_window_layers": 28,
"model_type": "mplugowl3",
"num_attention_heads": 12,
"num_hidden_layers": 28,
"num_key_value_heads": 2,
"rms_norm_eps": 1e-06,
"rope_theta": 1000000.0,
"sliding_window": 32768,
"tie_word_embeddings": true,
"torch_dtype": "bfloat16",
"transformers_version": "4.41.2",
"use_cache": true,
"use_sliding_window": false,
"vocab_size": 151851,
"hyper_layers": [
7,
15,
23,
26
],
"vision_config": {
"hidden_size": 1152,
"image_size": 384,
"intermediate_size": 4304,
"model_type": "siglip_vision_model",
"num_attention_heads": 16,
"num_hidden_layers": 27,
"patch_size": 14
}
}

1
configuration.json Normal file
View File

@ -0,0 +1 @@
{"framework":"Pytorch","task":"image-text-to-text"}

View File

@ -0,0 +1,123 @@
from transformers.configuration_utils import PretrainedConfig
class HyperQwen2Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a
Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of
Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 151936):
Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Qwen2Model`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 22016):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 32):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 32768):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
use_sliding_window (`bool`, *optional*, defaults to `False`):
Whether to use sliding window attention.
sliding_window (`int`, *optional*, defaults to 4096):
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
max_window_layers (`int`, *optional*, defaults to 28):
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
```python
>>> from transformers import Qwen2Model, Qwen2Config
>>> # Initializing a Qwen2 style configuration
>>> configuration = Qwen2Config()
>>> # Initializing a model from the Qwen2-7B style configuration
>>> model = Qwen2Model(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "qwen2"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=151936,
hidden_size=4096,
intermediate_size=22016,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=32,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
tie_word_embeddings=False,
rope_theta=10000.0,
use_sliding_window=False,
sliding_window=4096,
max_window_layers=28,
attention_dropout=0.0,
hyper_layers=[1,9,17,25],
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.use_sliding_window = use_sliding_window
self.sliding_window = sliding_window if use_sliding_window else None
self.max_window_layers = max_window_layers
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_dropout = attention_dropout
self.hyper_layers = hyper_layers
super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

View File

@ -0,0 +1,47 @@
# coding=utf-8
""" mPLUGOwl3 model configuration"""
import os
from typing import Union
from transformers.utils import logging
from .configuration_hyper_qwen2 import HyperQwen2Config
from transformers.models.siglip.configuration_siglip import SiglipVisionConfig
logger = logging.get_logger(__name__)
class mPLUGOwl3Config(HyperQwen2Config):
model_type = "mplugowl3"
keys_to_ignore_at_inference = ["past_key_values"]
default_vision_config = {
"hidden_size": 1152,
"image_size": 384,
"intermediate_size": 4304,
"model_type": "siglip_vision_model",
"num_attention_heads": 16,
"num_hidden_layers": 27,
"patch_size": 14
}
def __init__(
self,
use_cache=True,
vision_config=None,
**kwargs,
):
self.use_cache = use_cache
# same as HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit add tgt_sizes
if vision_config is None:
self.vision_config = SiglipVisionConfig(**self.default_vision_config)
logger.info("vision_config is None, using default vision config")
elif isinstance(vision_config, dict):
self.vision_config = SiglipVisionConfig(**vision_config)
elif isinstance(vision_config, SiglipVisionConfig):
self.vision_config = vision_config
self.image_size = self.vision_config.image_size
self.patch_size = self.vision_config.patch_size
super().__init__(**kwargs)

14
generation_config.json Normal file
View File

@ -0,0 +1,14 @@
{
"bos_token_id": 151643,
"pad_token_id": 151643,
"do_sample": true,
"eos_token_id": [
151645,
151643
],
"repetition_penalty": 1.1,
"temperature": 0.7,
"top_p": 0.8,
"top_k": 20,
"transformers_version": "4.37.0"
}

View File

@ -0,0 +1,416 @@
import random
from typing import Optional, Union, Dict, Any, List
from einops import rearrange, repeat
import torch
import math
import PIL.Image
import PIL.ImageSequence
import numpy as np
import PIL
from PIL import Image
from transformers.utils import TensorType, requires_backends, is_torch_dtype, is_torch_device
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
from transformers import AutoImageProcessor
from transformers.image_transforms import to_channel_dimension_format
from transformers.image_utils import (
ImageInput,
make_list_of_images,
valid_images,
is_torch_tensor,
is_batched,
to_numpy_array,
infer_channel_dimension_format,
ChannelDimension
)
from torchvision.ops.boxes import box_area
from torchvision.transforms import functional as F
from torchvision.transforms.transforms import InterpolationMode
from torchvision import transforms
def recursive_converter(converter, value):
if isinstance(value, list):
new_value = []
for v in value:
new_value += [recursive_converter(converter, v)]
return new_value
else:
return converter(value)
def box_iou(boxes1, area1, boxes2, eps=1e-5):
area2 = box_area(boxes2)
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
wh = (rb - lt).clamp(min=0) # [N,M,2]
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
union = area1[:, None] + area2 - inter
iou = inter / (union+eps)
return iou, union
available_anchor_strategy = ['docowl', 'random', 'highest', 'last', 'llava']
grid_dict = {
'grid_33':[
(1,1),
(1,2),(2,1),
(1,3),(3,1),
(2,2),(1,4),(4,1),
(1,5),(5,1),
(1,6),(6,1),(2,3),(3,2),
(1,7),(7,1),
(4,2),(2,4),(1,8),(8,1),
(3,3),(1,9),(9,1)],
'grid_squ_3x3':[
(1,1),(2,2),(3,3)
],
'grid_squ_4':[
(2,2),(1,3),(1,4),(3,1),(4,1)
],
'grid_squ_6':[
(2,2),(1,3),(1,4),(3,1),(4,1), (2,3),(3,2)
],
'grid_squ_2':[
(2,1)
],
'grid_squ_9':[
(1,1),
(1,2),(2,1),
(1,3),(3,1),
(2,2),(1,4),(4,1),
(1,5),(5,1),
(1,6),(6,1),(2,3),(3,2),
(1,7),(7,1),
(4,2),(2,4),(1,8),(8,1),
(3,3),(1,9),(9,1)],
}
cut_prompt_template_dict = {
'v0': lambda img_token, h, w: f''.join([f"{img_token}" for i in range(h) for j in range(w)]),
'v1': lambda img_token, h, w: f'Cut to {h} rows {w} columns, '+ ' '.join([f"subimg({i},{j}){img_token}"for i in range(h) for j in range(w)]),
'v1_global': lambda img_token, h, w: f'Cut to {h} rows {w} columns with a global view, '+ ' '.join([f"subimg({i},{j}){img_token}"for i in range(h) for j in range(w)]+[f"global_view{img_token}"]),
'v2_global': lambda img_token, h, w: f'Cut to {h} rows {w} columns with a global view\n'+ '\n'.join([' '.join([f"subimg({i},{j}){img_token}" for j in range(w)]) for i in range(h)])+f"\nglobal_view{img_token}",
'v3': lambda img_token, h, w: f'<|start_cut|>{h}*{w}'+ ' '.join([f"{img_token}"for i in range(h) for j in range(w)])+'<|end_cut|>',
'v3_global': lambda img_token, h, w: f'<|start_cut|>{h}*{w}\n'+ '\n'.join([' '.join([f"{img_token}" for j in range(w)]) for i in range(h)])+f'\n{img_token}<|end_cut|>',
}
def anchor_rank(anchors, anchors_areas, input_image_size, eps=1e-5):
# anchors x1 y1 x2 y2
# image_size: (h, w)
# xyxy
input_image_bbox = torch.tensor([0, 0, input_image_size[1], input_image_size[0]]).unsqueeze(0)
boxes1 = anchors
boxes2 = input_image_bbox
boxes3 = anchors.clone()
# y2
boxes3[:,3] = input_image_size[0]/input_image_size[1]*anchors[:,2] # 用于算分辨率无关的iou
area1 = anchors_areas
iou, _ = box_iou(boxes1, area1, boxes2)
iou = iou.squeeze(1)
shape_iou, _ = box_iou(boxes1, area1, boxes3)
shape_iou = shape_iou.diag()
# 优先匹配形状接近 再匹配分辨率接近
index = torch.argmax(shape_iou*100+iou,dim=0)
return index
def select_best_resolution(anchors, anchors_areas, input_image_size): # TODO For a futher check
"""
Selects the best resolution from a list of possible resolutions based on the original size.
Args:
original_size (tuple): The original size of the image in the format (width, height).
possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
Returns:
tuple: The best fit resolution in the format (width, height).
"""
original_size = (input_image_size[1], input_image_size[0])
possible_resolutions = [(_[2], _[3]) for _ in anchors] # xyxy -> w,h
original_width, original_height = original_size
best_fit = None
max_effective_resolution = 0
min_wasted_resolution = float('inf')
index = 0
for i, (width, height) in enumerate(possible_resolutions):
scale = min(width / original_width, height / original_height)
downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
wasted_resolution = (width * height) - effective_resolution
if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
max_effective_resolution = effective_resolution
min_wasted_resolution = wasted_resolution
best_fit = (width, height)
index = i
return index
def build_cut_shape_indices(cut_shape):
# cut_shape: a list of (nh,nw)
cut_shape_indices = []
for shape in cut_shape:
n=shape[0]*shape[1]
indices = torch.cat([
repeat(torch.tensor(shape),'l -> n l',n=n),
torch.arange(n).unsqueeze(1)
], dim=1)
assert indices.shape[0] == n
assert indices.shape[1] == 3 # nh,nw,idx
cut_shape_indices.append(indices)
cut_shape_indices = torch.cat(cut_shape_indices,dim=0).long()
return cut_shape_indices
class AnchorResize(torch.nn.Module):
def __init__(self, image_size, anchors, interpolation=InterpolationMode.BILINEAR, antialias=None, anchor_strategy='docowl'):
super().__init__()
self.image_size = image_size
# xyxy
self.anchors = torch.tensor(
[[0, 0, _[1]*image_size[1], _[0]*image_size[0]]
for _ in anchors], requires_grad=False
)
self.anchor_areas = box_area(self.anchors)
self.interpolation = interpolation
self.antialias = antialias
self.anchor_strategy = anchor_strategy
assert self.anchor_strategy in available_anchor_strategy
def resize_global(self, img):
return F.resize(img, self.image_size, self.interpolation, max_size=None, antialias=self.antialias)
def forward(self, img, skip_resize=False):
"""
Args:
img (PIL Image or Tensor): Image to be scaled.
Returns:
PIL Image or Tensor: Rescaled image.
"""
if self.anchor_strategy == 'docowl':
selected_anchor = anchor_rank(self.anchors, self.anchor_areas, (img.size[1], img.size[0]))
elif self.anchor_strategy == 'random':
selected_anchor = random.randint(0,len(self.anchors)-1)
elif self.anchor_strategy == 'highest':
# 选面积最大的 在这个基础上 尽可能选最方正的
selected_anchor = torch.argmax(self.anchors[:,2]*self.anchors[:,3]*100-torch.abs(self.anchors[:,2]-self.anchors[:,3]))
elif self.anchor_strategy == 'last':
selected_anchor = len(self.anchors)-1
elif self.anchor_strategy == 'llava':
selected_anchor = select_best_resolution(self.anchors, self.anchor_areas, (img.size[1], img.size[0]))
else:
selected_anchor = None
assert selected_anchor is not None
target_size = self.anchors[selected_anchor][2:].tolist() # w,h
if skip_resize:
# for debug
return selected_anchor
return F.resize(img, [target_size[1],target_size[0]], self.interpolation, max_size=None, antialias=self.antialias), selected_anchor
def __repr__(self) -> str:
detail = f"(size={self.image_size}, anchor={self.anchors}, interpolation={self.interpolation.value}, antialias={self.antialias})"
return f"{self.__class__.__name__}{detail}"
class CutMixin:
def __init__(self, cut_cfg={"anchors": "grid_squ_6", "anchor_strategy": "docowl", "cut_prompt": "v3", "add_global": True, "cut_prob": 1.0}) -> None:
if cut_cfg is None:
self.cut_enable = False
return
else:
self.cut_enable = True
image_size = self.image_size
anchors = cut_cfg.get('anchors','grid_33')
anchor_strategy = cut_cfg.get('anchor_strategy','docowl')
cut_prompt = cut_cfg.get('cut_prompt','v0')
self.cut_prob = cut_cfg.get('cut_prob', 1.0)
self.force_shape_cut = cut_cfg.get('force_shape_cut', False)
force_shape_cut_anchors = cut_cfg.get('force_shape_cut_anchors', 'force_shape_cut_anchors')
self.add_global = cut_cfg.get('add_global', False)
# h,w
if isinstance(image_size, int):
image_size = (image_size, image_size)
self.image_size = image_size
if anchors in grid_dict:
anchors = grid_dict[anchors]
else:
anchors = eval(anchors)
self.anchors = [tuple(_) for _ in anchors]
self.anchor_max = max([max(_) for _ in self.anchors])
self.resizer = AnchorResize(image_size=image_size, anchors=anchors, interpolation=InterpolationMode.BICUBIC, anchor_strategy=anchor_strategy)
if force_shape_cut_anchors in grid_dict:
force_shape_cut_anchors = grid_dict[force_shape_cut_anchors]
else:
force_shape_cut_anchors = eval(force_shape_cut_anchors)
self.force_shape_cut_anchors = [tuple(_) for _ in force_shape_cut_anchors]
self.force_shape_cut_anchors_max = max([max(_) for _ in self.force_shape_cut_anchors])
self.old_resizer = transforms.Resize(image_size,interpolation=InterpolationMode.BICUBIC)
# 把image processor的缩放去掉 只保留后面的变换
self.image_transform = transforms.Compose(self.image_transform.transforms[1:])
if self.add_global:
self.cut_prompt_template = cut_prompt_template_dict[cut_prompt+'_global']
else:
self.cut_prompt_template = cut_prompt_template_dict[cut_prompt]
self.media_tokens = ["<|image|>", "<|video|>"]
def _process_image(self, images):
new_images = []
cut_shape = []
for image in images:
raw_image = image
image, selected_anchor = self.resizer(image)
image_input = self.image_transform(image) # h,w,3 -> 3,h,w
cut_shape.append((image_input.shape[1]//self.image_size[0], image_input.shape[2]//self.image_size[1])) # cut_h, cut_w
image_input = rearrange(image_input, 'C (num_h h) (num_w w) -> (num_h num_w) C h w', h=self.image_size[0], w=self.image_size[1])
new_images.append(image_input)
if self.add_global:
new_images.append(self.image_transform(self.resizer.resize_global(raw_image)).unsqueeze(0))
cut_shape.append((1,1))
new_images = torch.cat(new_images,dim=0)
cut_shape_indices = build_cut_shape_indices(cut_shape)
return new_images, cut_shape, cut_shape_indices
class mPLUGOwl3BatchFeature(BatchFeature):
r"""
Extend from BatchFeature for supporting various image size
"""
def __init__(self, data: Optional[Dict[str, Any]] = None, tensor_type: Union[None, str, TensorType] = None):
super().__init__(data)
self.convert_to_tensors(tensor_type=tensor_type)
def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None):
if tensor_type is None:
return self
is_tensor, as_tensor = self._get_is_as_tensor_fns(tensor_type)
def converter(value):
try:
if not is_tensor(value):
tensor = as_tensor(value)
return tensor
except: # noqa E722
if key == "overflowing_values":
raise ValueError("Unable to create tensor returning overflowing values of different lengths. ")
raise ValueError(
"Unable to create tensor, you should probably activate padding "
"with 'padding=True' to have batched tensors with the same length."
)
for key, value in self.items():
self[key] = recursive_converter(converter, value)
return self
def to(self, *args, **kwargs) -> "mPLUGOwl3BatchFeature":
requires_backends(self, ["torch"])
import torch
def cast_tensor(v):
# check if v is a floating point
if torch.is_floating_point(v):
# cast and send to device
return v.to(*args, **kwargs)
elif device is not None:
return v.to(device=device)
else:
return v
new_data = {}
device = kwargs.get("device")
# Check if the args are a device or a dtype
if device is None and len(args) > 0:
# device should be always the first argument
arg = args[0]
if is_torch_dtype(arg):
# The first argument is a dtype
pass
elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int):
device = arg
else:
# it's something else
raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
for k, v in self.items():
new_data[k] = recursive_converter(cast_tensor, v)
self.data = new_data
return self
class mPLUGOwl3ImageProcessor(BaseImageProcessor, CutMixin):
model_input_names = ["pixel_values"]
def __init__(
self,
image_size,
mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5],
**kwargs):
super().__init__(**kwargs)
self.image_size = image_size
self.image_transform = transforms.Compose([
transforms.Resize((image_size, image_size), interpolation=Image.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean, std),
])
CutMixin.__init__(self)
def preprocess(
self,
images: Union[Image.Image, List[Image.Image]],
cut_enable=True,
**kwargs
) -> mPLUGOwl3BatchFeature:
if isinstance(images, Image.Image):
images_list = [images]
else:
images_list = images
if self.cut_enable and cut_enable:
image_data, cut_shape, cut_shape_indices = self._process_image(images_list)
else:
image_data = [self.image_transform(self.resizer.resize_global(image)) for image in images_list]
image_data = torch.stack(image_data, dim=0)
cut_shape = cut_shape_indices = None
return mPLUGOwl3BatchFeature(data={'pixel_values': image_data, 'cut_shape':cut_shape, 'cut_shape_indices':cut_shape_indices})
def to_dict(self):
encoder_dict = super().to_dict()
pop_keys = ['image_transform', 'resizer', 'old_resizer', 'cut_prompt_template']
for pk in pop_keys:
encoder_dict.pop(pk, None)
return encoder_dict
AutoImageProcessor.register("mPLUGOwl3ImageProcessor", mPLUGOwl3ImageProcessor)

151387
merges.txt Normal file

File diff suppressed because it is too large Load Diff

BIN
model.safetensors (Stored with Git LFS) Normal file

Binary file not shown.

1532
modeling_hyper_qwen2.py Normal file

File diff suppressed because it is too large Load Diff

231
modeling_mplugowl3.py Normal file
View File

@ -0,0 +1,231 @@
import math
from typing import List, Optional
import json
import torch
import torchvision
from threading import Thread
from copy import deepcopy
from PIL import Image
from transformers import AutoProcessor, Qwen2PreTrainedModel, Qwen2ForCausalLM, TextIteratorStreamer
from .processing_mplugowl3 import mPLUGOwl3Processor
from .image_processing_mplugowl3 import mPLUGOwl3ImageProcessor
from .configuration_mplugowl3 import mPLUGOwl3Config
# from .modeling_navit_siglip import SiglipVisionTransformer
from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer
from .x_sdpa import ScaleDotProductAttention
from .modeling_hyper_qwen2 import HyperQwen2ForCausalLM
from torch import nn
class mPLUGOwl3PreTrainedModel(Qwen2PreTrainedModel):
config_class = mPLUGOwl3Config
class mPLUGOwl3Model(mPLUGOwl3PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.language_model = HyperQwen2ForCausalLM(config)
self.vision_model = self.init_vision_module()
self.vision_dim = self.vision_model.embed_dim
self.embed_dim = self.language_model.config.hidden_size
self.vision2text_model = nn.Linear(self.vision_dim, self.embed_dim)
self.processor = None
self.terminators = ['<|im_end|>', '<|endoftext|>']
def init_vision_module(self):
self.config.vision_config._attn_implementation = self.config.vision_config._attn_implementation
model = SiglipVisionTransformer(self.config.vision_config)
setattr(model, 'embed_dim', model.embeddings.embed_dim)
setattr(model, 'patch_size', model.embeddings.patch_size)
return model
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.language_model.embed_tokens = value
def get_output_embeddings(self):
return self.language_model.lm_head
def set_output_embeddings(self, new_embeddings):
self.language_model.lm_head = new_embeddings
def set_decoder(self, decoder):
self.language_model = decoder
def get_decoder(self):
return self.language_model
def forward_image(self, pixel_values):
if pixel_values is None:
return None
dtype = self.language_model.model.embed_tokens.weight.dtype
with torch.inference_mode():
image_embeds = self.vision_model(pixel_values.to(dtype), output_hidden_states=True).hidden_states[-2]
if self.vision2text_model is not None:
image_embeds = self.vision2text_model(image_embeds)
else:
pass
return image_embeds
def forward(self, pixel_values=None, **kwargs):
image_embeds = self.forward_image(pixel_values)
return self.language_model(
image_embeds=image_embeds,
**kwargs
)
def _decode(self, input_ids, image_embeds, media_offset, tokenizer, attention_mask, decode_text=False, **kwargs):
terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
output = self.language_model.generate(
input_ids=input_ids,
image_embeds=image_embeds,
media_offset=media_offset,
pad_token_id=0,
eos_token_id=terminators,
attention_mask=attention_mask,
**kwargs
)
output = output[:,input_ids.shape[1]:]
if decode_text:
return self._decode_text(output, tokenizer)
return output
def _decode_stream(self, input_ids, image_embeds, media_offset, tokenizer, **kwargs):
terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
streamer = TextIteratorStreamer(tokenizer=tokenizer)
generation_kwargs = {
'input_ids': input_ids,
'image_embeds': image_embeds,
'media_offset': media_offset,
'pad_token_id': 0,
'eos_token_id': terminators,
'streamer': streamer
}
generation_kwargs.update(kwargs)
thread = Thread(target=self.language_model.generate, kwargs=generation_kwargs)
thread.start()
return streamer
def _decode_text(self, result_ids, tokenizer):
terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
result_text = []
for result in result_ids:
result = result[result != 0]
if result[-1] in terminators:
result = result[:-1]
result_text.append(tokenizer.decode(result).strip())
return result_text
def init_processor(self, tokenizer):
ip = mPLUGOwl3ImageProcessor(image_size=384)
self.processor = mPLUGOwl3Processor(image_processor=ip, tokenizer=tokenizer)
processor = self.processor
return processor
def generate(
self,
input_ids=None,
pixel_values=None,
media_offset=None,
attention_mask=None,
tokenizer=None,
stream=False,
decode_text=False,
**kwargs
):
assert input_ids is not None
with torch.inference_mode():
image_embeds = self.forward_image(pixel_values)
if stream:
result = self._decode_stream(input_ids=input_ids, image_embeds=image_embeds, media_offset=media_offset, tokenizer=tokenizer, **kwargs)
else:
result = self._decode(input_ids=input_ids, image_embeds=image_embeds, media_offset=media_offset, tokenizer=tokenizer, attention_mask=attention_mask, decode_text=decode_text, **kwargs)
return result
def chat(
self,
images,
videos,
messages,
tokenizer,
processor=None,
max_new_tokens=2048,
min_new_tokens=0,
sampling=True,
max_inp_length=8192,
system_prompt='',
stream=False,
max_slice_nums=None,
use_image_id=None,
**kwargs
):
cut_flag = kwargs.get('kwargs', True)
if processor is None:
if self.processor is None:
processor = self.init_processor(tokenizer)
else:
processor = self.processor
inputs = processor(messages, images=images, videos=videos, cut_enable=cut_flag)
inputs.to('cuda')
inputs.update({
'tokenizer': tokenizer,
'max_new_tokens': max_new_tokens,
# 'stream':True,
})
if sampling:
generation_config = {
"top_p": 0.8,
"top_k": 100,
"temperature": 0.7,
"do_sample": True,
# "repetition_penalty": 1.05
}
else:
generation_config = {
"num_beams": 3,
# "repetition_penalty": 1.2,
}
if min_new_tokens > 0:
generation_config['min_new_tokens'] = min_new_tokens
generation_config.update(
(k, kwargs[k]) for k in generation_config.keys() & kwargs.keys()
)
with torch.inference_mode():
res = self.generate(
**inputs,
stream=stream,
decode_text=True,
**generation_config
)
if stream:
def stream_gen():
for text in res:
for term in self.terminators:
text = text.replace(term, '')
yield text
return stream_gen()
else:
answer = res[0]
return answer

396
processing_mplugowl3.py Normal file
View File

@ -0,0 +1,396 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Processor class for mPLUGOwl3.
"""
from typing import List, Optional, Union, Dict, Any
import warnings
import torch
import re
from transformers.image_processing_utils import BatchFeature
from transformers.image_utils import ImageInput
from transformers.processing_utils import ProcessorMixin
from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from transformers.utils import TensorType, requires_backends, is_torch_dtype, is_torch_device
from .image_processing_mplugowl3 import mPLUGOwl3BatchFeature, mPLUGOwl3ImageProcessor
OWL_MEDIA_TOKEN=['<|image|>']
class MediaIndicesHelper():
def __init__(self, tokenizer) -> None:
self.media_position = []
self.tokenizer = tokenizer
def has_media(self, text, media_tokens=None):
if media_tokens is None:
media_tokens = OWL_MEDIA_TOKEN
has_media_flag = any([media_token == text for media_token in media_tokens])
if any([media_token in text for media_token in media_tokens]):
# 不允许出现text中包含media token但是不仅仅是media token。 media token必须单独为一个chunk
assert has_media_flag, text
return has_media_flag
def add_media(self, text_chunk, text=None, tokenize_fn=None):
# cross
assert tokenize_fn is not None
assert text is not None
assert text in OWL_MEDIA_TOKEN
media_token_ids = tokenize_fn(text)
start = len(text_chunk)
end = start + len(media_token_ids)
self.media_position.append([start, end])
text_chunk.extend(media_token_ids)
return len(media_token_ids)
def cal_media_offset(self, input_ids):
if len(self.media_position) == 0:
return torch.ones_like(input_ids)*(-1000000)
media_starts = torch.tensor([_[0] for _ in self.media_position]).reshape(1,-1)
rng = torch.arange(input_ids.shape[0]).reshape(-1,1)
matrix = (rng > media_starts).sum(dim=1)
return matrix
def len_images(self,):
return len(self.media_position)
class mPLUGOwl3Processor(ProcessorMixin):
r"""
Args:
image_processor ([`mPLUGOwl3ImageProcessor`], *optional*):
The image processor is a required input.
tokenizer ([`LlamaTokenizerWrapper`], *optional*):
The tokenizer is a required input.
"""
attributes = ["image_processor", "tokenizer"]
image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer"
def __init__(self, image_processor: mPLUGOwl3ImageProcessor = None, tokenizer=None, prompt_style='chatml', inference_mode=True, addition_eod="<|endoftext|>"):
super().__init__(image_processor, tokenizer)
self.image_processor: mPLUGOwl3ImageProcessor
self.prompt_style = prompt_style
self.inference_mode = inference_mode
self.media_tokens = ["<|image|>"]
self.addition_eod = addition_eod
def build_text_qwen(self, messages):
# role should be within ['system', 'user', 'assistant']
im_start, im_end = '<|im_start|>', '<|im_end|>'
text = []
for num_turn, message in enumerate(messages):
if num_turn == 0 and message['role'] != 'system':
if self.prompt_style != 'plain':
text.append({
"text": f"{im_start}system\n{im_end}",
"label": 0
})
if message['role'] == 'system':
if self.prompt_style != 'plain':
text.append({
"text": f"{im_start}system\n{message['content']}{im_end}",
"label": 0
})
elif message['role'] == 'user':
if self.prompt_style != 'plain':
content = f"\n{im_start}user\n{message['content']}{im_end}"
else:
content = message['content']
pattern = '|'.join(map(re.escape, self.media_tokens))
chunk_strs = re.split(f'({pattern})', content)
for chunk_str in chunk_strs:
text.append({
"text": chunk_str,
"label": 0
})
elif message['role'] == 'assistant':
if self.prompt_style != 'plain':
text.append({"text": f"\n{im_start}assistant\n", "label": 0})
text.append({"text": f"{message['content']}{im_end}", "label": 1})
else:
text.append({"text": f"{message['content']}", "label": 1})
text.append({"text": self.addition_eod, "label": 1})
else:
raise NotImplementedError
if self.inference_mode:
while text and text[-1]['label']==1: # 只要列表非空且最后一个元素满足条件
text.pop() # 就移除最后一个元素
return text
def wrapped_tokenize(self, text):
return self.tokenizer(text).input_ids
def encode_text_sft(self, texts):
# output enc_chunk
enc_chunk = []
label_chunk = []
enc_length = 0
num_images = 0
media_helper = MediaIndicesHelper(tokenizer=self.tokenizer)
for current_ti, text_chunk in enumerate(texts):
text = text_chunk["text"]
label = text_chunk["label"]
if not media_helper.has_media(text):
curr_chunk=self.wrapped_tokenize(text)
if label == 1:
enc_length += len(curr_chunk)
enc_chunk += curr_chunk
label_chunk += [label] * len(curr_chunk)
else:
enc_length += len(curr_chunk)
enc_chunk += curr_chunk
label_chunk += [label] * len(curr_chunk)
# For media tokens
else:
add_length = media_helper.add_media(
enc_chunk,
text=text,
tokenize_fn=self.wrapped_tokenize)
enc_length += add_length
label_chunk += [label] * add_length
# enc_chunk.extend([self.media_tokens[text]] * self.media_lengths[text])
# enc_length += self.media_lengths[text]
# label_chunk += [label] * self.media_lengths[text]
num_images += 1
enc_chunk = torch.tensor(enc_chunk).long()
media_offset = []
media_before = 0
for i,_ in enumerate([media_helper]):
mo = _.cal_media_offset(enc_chunk)
media_offset.append(torch.cat([(torch.ones(mo.shape[0],1)*media_before).long().to(mo.device), (mo+media_before).unsqueeze(1)], dim=1)) # L 2
media_before += _.len_images()
media_offset = torch.stack(media_offset, dim=0)
return {
'input_ids': enc_chunk.unsqueeze(0),
'media_offset': media_offset,
}
def __call__(
self,
messages,
images = None,
videos = None,
max_length: Optional[int] = None,
cut_enable=True,
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
**kwargs
) -> mPLUGOwl3BatchFeature:
medias = []
if videos is not None:
medias.extend([{'type': 'video', 'content': video, 'use_video_span': True} for video in videos])
if images is not None:
medias.extend([{'type':'image', 'content': image} for image in images])
if len(medias):
image_tensor_list = []
pattern = r"(<\|image\|>|<\|video\|>)"
# 存在媒体
image_token_ptr = 0
media_layout = []
for message in messages:
text_list = re.split(pattern, message['content'])
text = ''
for text_content in text_list:
if text_content in ['<|image|>', '<|video|>']:
media_item = medias[image_token_ptr]
image_token_ptr += 1
if text_content == '<|image|>':
assert media_item['type'] == 'image'
image = media_item['content']
image_inputs = self.image_processor([image], cut_enable=cut_enable, return_tensors=return_tensors)
if image_inputs.get('cut_shape',None) is not None:
cut_shape = image_inputs['cut_shape']
cut_text = self.image_processor.cut_prompt_template(img_token='<|image|>', h=cut_shape[0][0], w=cut_shape[0][1])
text += cut_text
image_tensor_list.append(image_inputs['pixel_values'])
else:
text += text_content
elif text_content == '<|video|>':
assert media_item['type'] == 'video'
video = media_item['content']
use_video_span = media_item['use_video_span']
image_tensor = self.image_processor(video, cut_enable=False)['pixel_values']
image_tensor_list.append(image_tensor)
num_video_frame = image_tensor.shape[0]
if use_video_span:
text_content = '<|start_video_frame|>'+'<|image|>'*num_video_frame+'<|end_video_frame|>'
else:
text_content = '<|image|>'*num_video_frame
text += text_content
else:
text += text_content
message['content'] = text
assert image_token_ptr == len(medias), (image_token_ptr,len(medias)) # 保证图和token数目一致
assert all(len(_.shape) == 4 for _ in image_tensor_list), [_.shape for _ in image_tensor_list]
num_image_tokens = sum([_['content'].count('<|image|>')for _ in messages])
num_image_shapes = sum([_.shape[0] for _ in image_tensor_list])
assert num_image_tokens == num_image_shapes, (messages, [_.shape for _ in image_tensor_list])
image_tensor_list = torch.cat(image_tensor_list, dim=0)
# text = ''.join([_['text'] for _ in text])
text = self.build_text_qwen(messages)
model_inputs = self.encode_text_sft(text)
if len(medias) is not None:
model_inputs.update({'pixel_values': image_tensor_list})
# if 'cut_shape' in model_inputs:
# model_inputs.pop('cut_shape')
# if 'cut_shape_indices' in model_inputs:
# model_inputs.pop('cut_shape_indices')
return mPLUGOwl3BatchFeature(model_inputs)
def check_media(self, images, messages):
media_num = 0 if images is None else len(images)
media_count = sum([message['content'].count('<|image|>') for message in messages])
assert media_num == media_count
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
output_ids = args[0]
result_text = []
for result in output_ids:
result = result[result != 0]
if result[0] == self.tokenizer.bos_id:
result = result[1:]
if result[-1] == self.tokenizer.eos_id:
result = result[:-1]
result_text.append(self.tokenizer.decode(result, *args[1:], **kwargs).strip())
return result_text
# return self.tokenizer.batch_decode(*args, **kwargs)
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
result = args[0]
result = result[result != 0]
if result[0] == self.tokenizer.bos_id:
result = result[1:]
if result[-1] == self.tokenizer.eos_id or (hasattr(self.tokenizer, "eot_id") and result[-1] == self.tokenizer.eot_id):
result = result[:-1]
return self.tokenizer.decode(result, *args[1:], **kwargs).strip()
def _convert(
self, input_str, max_inp_length: Optional[int] = None
):
if self.version > 2.5 or not getattr(self.tokenizer, "add_bos_token", False):
input_ids = self.tokenizer.encode(input_str)
else:
input_ids = [self.tokenizer.bos_id] + self.tokenizer.encode(input_str)
if max_inp_length is not None:
input_ids = input_ids[:max_inp_length]
input_ids = torch.tensor(input_ids, dtype=torch.int32)
start_cond = (input_ids == self.tokenizer.im_start_id) | (input_ids == self.tokenizer.slice_start_id)
end_cond = (input_ids == self.tokenizer.im_end_id) | (input_ids == self.tokenizer.slice_end_id)
image_start_tokens = torch.where(start_cond)[0]
image_start_tokens += 1
image_end_tokens = torch.where(end_cond)[0]
valid_image_nums = max(len(image_start_tokens), len(image_end_tokens))
image_bounds = torch.hstack(
[
image_start_tokens[:valid_image_nums].unsqueeze(-1),
image_end_tokens[:valid_image_nums].unsqueeze(-1),
]
)
return input_ids, image_bounds
@property
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
def pad(self, inputs, max_length=None, padding_value=0, padding_side="left"):
items = []
if isinstance(inputs[0], list):
assert isinstance(inputs[0][0], torch.Tensor)
for it in inputs:
for tr in it:
items.append(tr)
else:
assert isinstance(inputs[0], torch.Tensor)
items = inputs
batch_size = len(items)
shape = items[0].shape
dim = len(shape)
assert dim <= 2
if max_length is None:
max_length = 0
max_length = max(max_length, max(item.shape[-1] for item in items))
min_length = min(item.shape[-1] for item in items)
dtype = items[0].dtype
if dim == 0:
return torch.stack([item for item in items], dim=0), [0]
elif dim == 1:
if max_length == min_length:
return torch.stack([item for item in items], dim=0), [0] * batch_size
tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value
else:
tensor = (
torch.zeros((batch_size, max_length, shape[-1]), dtype=dtype)
+ padding_value
)
padding_length = []
for i, item in enumerate(items):
if dim == 1:
if padding_side == "left":
tensor[i, -len(item) :] = item.clone()
else:
tensor[i, : len(item)] = item.clone()
elif dim == 2:
if padding_side == "left":
tensor[i, -len(item) :, :] = item.clone()
else:
tensor[i, : len(item), :] = item.clone()
padding_length.append(tensor.shape[-1] - len(item))
return tensor, padding_length

303111
tokenizer.json Normal file

File diff suppressed because it is too large Load Diff

40
tokenizer_config.json Normal file
View File

@ -0,0 +1,40 @@
{
"add_prefix_space": false,
"added_tokens_decoder": {
"151643": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151644": {
"content": "<|im_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151645": {
"content": "<|im_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
}
},
"additional_special_tokens": ["<|im_start|>", "<|im_end|>"],
"bos_token": null,
"chat_template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"clean_up_tokenization_spaces": false,
"eos_token": "<|im_end|>",
"errors": "replace",
"model_max_length": 32768,
"pad_token": "<|endoftext|>",
"split_special_tokens": false,
"tokenizer_class": "Qwen2Tokenizer",
"unk_token": null
}

1
vocab.json Normal file

File diff suppressed because one or more lines are too long

60
x_sdpa.py Normal file
View File

@ -0,0 +1,60 @@
from torch import nn
from einops import rearrange
class ScaleDotProductAttention(nn.Module):
def __init__(self, layer_number, causal=False, softmax_scale=None, attention_dropout=0.0):
super().__init__()
self.layer_number = layer_number
self.causal = causal
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
# Qwen 不需要scale
def forward(self, q, k, v, attn_mask=None, order='sbhd'):
"""Implements the multihead softmax attention.
Arguments
---------
q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
"""
# (N,...,L,E)
import torch
import torch.nn as nn
import torch.nn.functional as F
if order == 'sbhd':
q, k, v = [rearrange(x, 's b h d -> b h s d').contiguous()
for x in (q, k, v)]
elif order == 'bhsd':
pass
if attn_mask is not None:
attn_mask = (~attn_mask.clone().bool()).contiguous()
else:
attn_mask = None
# attention mask, True means it will take part in attention B H s_q s_k
if self.training:
# during training q,k,v always have same seqlen
if self.causal:
assert q.shape[-2] == k.shape[-2]
is_causal = self.causal
dropout_p = self.dropout_p
else:
# turn off FA causal mask after first inference autoregressive iteration
# only on first autoregressive step q,k,v have same seqlen
if self.causal:
is_causal = q.shape[-2] == k.shape[-2]
else:
is_causal = self.causal
dropout_p = 0.0
# 如果is_causal则无视输入的mask 反之会使用输入的mask
o = F.scaled_dot_product_attention(q, k, v,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=self.softmax_scale
)
# B Head L D -> L B (Head D)
o = rearrange(o, 'B Head L D -> L B (Head D)').contiguous()
return o