first commit

This commit is contained in:
xxl 2024-11-01 17:30:48 +08:00
parent 715a1c61d5
commit 3d7978286f
12 changed files with 153533 additions and 2 deletions

112
README.md
View File

@ -1,3 +1,111 @@
# GOT-OCR2_0_a13446664330014720201572
---
pipeline_tag: image-text-to-text
library_name: transformers
language:
- multilingual
tags:
- got
- vision-language
- ocr2.0
- custom_code
license: apache-2.0
studios:
- stepfun-ai/GOT_official_online_demo
---
General OCR Theory: Towards OCR-2.0 via a Unified End-to-end Model
<h1>General OCR Theory: Towards OCR-2.0 via a Unified End-to-end Model
</h1>
[🔋Online Demo](https://modelscope.cn/studios/stepfun-ai/GOT_official_online_demo) | [🌟GitHub](https://github.com/Ucas-HaoranWei/GOT-OCR2.0/) | [📜Paper](https://arxiv.org/abs/2409.01704)</a>
[Haoran Wei*](https://scholar.google.com/citations?user=J4naK0MAAAAJ&hl=en), Chenglong Liu*, Jinyue Chen, Jia Wang, Lingyu Kong, Yanming Xu, [Zheng Ge](https://joker316701882.github.io/), Liang Zhao, [Jianjian Sun](https://scholar.google.com/citations?user=MVZrGkYAAAAJ&hl=en), [Yuang Peng](https://scholar.google.com.hk/citations?user=J0ko04IAAAAJ&hl=zh-CN&oi=ao), Chunrui Han, [Xiangyu Zhang](https://scholar.google.com/citations?user=yuB-cfoAAAAJ&hl=en)
![image/jpeg](https://cdn-uploads.huggingface.co/production/uploads/6653eee7a2d7a882a805ab95/QCEFY-M_YG3Bp5fn1GQ8X.jpeg)
## Usage
Inference using Huggingface transformers on NVIDIA GPUs. Requirements tested on python 3.10
```
torch==2.0.1
torchvision==0.15.2
transformers==4.37.2
tiktoken==0.6.0
verovio==4.3.1
accelerate==0.28.0
```
```python
from modelscope import AutoModel, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('stepfun-ai/GOT-OCR2_0', trust_remote_code=True)
model = AutoModel.from_pretrained('stepfun-ai/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
model = model.eval().cuda()
# input your test image
image_file = '/mnt/workspace/58F3EF14-E073-4BBE-B9D9-53CCFE6AE183.png'
# plain texts OCR
res = model.chat(tokenizer, image_file, ocr_type='ocr')
# format texts OCR:
# res = model.chat(tokenizer, image_file, ocr_type='format')
# fine-grained OCR:
# res = model.chat(tokenizer, image_file, ocr_type='ocr', ocr_box='')
# res = model.chat(tokenizer, image_file, ocr_type='format', ocr_box='')
# res = model.chat(tokenizer, image_file, ocr_type='ocr', ocr_color='')
# res = model.chat(tokenizer, image_file, ocr_type='format', ocr_color='')
# multi-crop OCR:
# res = model.chat_crop(tokenizer, image_file, ocr_type='ocr')
# res = model.chat_crop(tokenizer, image_file, ocr_type='format')
# render the formatted OCR results:
# res = model.chat(tokenizer, image_file, ocr_type='format', render=True, save_render_file = './demo.html')
print(res)
```
More details about 'ocr_type', 'ocr_box', 'ocr_color', and 'render' can be found at our GitHub.
Our training codes are available at our [GitHub](https://github.com/Ucas-HaoranWei/GOT-OCR2.0/).
## More Multimodal Projects
👏 Welcome to explore more multimodal projects of our team:
[Vary](https://github.com/Ucas-HaoranWei/Vary) | [Fox](https://github.com/ucaslcl/Fox) | [OneChart](https://github.com/LingyvKong/OneChart)
## Citation
If you find our work helpful, please consider citing our papers 📝 and liking this project ❤️!
```bib
@article{wei2024general,
title={General OCR Theory: Towards OCR-2.0 via a Unified End-to-end Model},
author={Wei, Haoran and Liu, Chenglong and Chen, Jinyue and Wang, Jia and Kong, Lingyu and Xu, Yanming and Ge, Zheng and Zhao, Liang and Sun, Jianjian and Peng, Yuang and others},
journal={arXiv preprint arXiv:2409.01704},
year={2024}
}
@article{liu2024focus,
title={Focus Anywhere for Fine-grained Multi-page Document Understanding},
author={Liu, Chenglong and Wei, Haoran and Chen, Jinyue and Kong, Lingyu and Ge, Zheng and Zhu, Zining and Zhao, Liang and Sun, Jianjian and Han, Chunrui and Zhang, Xiangyu},
journal={arXiv preprint arXiv:2405.14295},
year={2024}
}
@article{wei2023vary,
title={Vary: Scaling up the Vision Vocabulary for Large Vision-Language Models},
author={Wei, Haoran and Kong, Lingyu and Chen, Jinyue and Zhao, Liang and Ge, Zheng and Yang, Jinrong and Sun, Jianjian and Han, Chunrui and Zhang, Xiangyu},
journal={arXiv preprint arXiv:2312.06109},
year={2023}
}
```

38
config.json Normal file
View File

@ -0,0 +1,38 @@
{
"_name_or_path": "ucaslcl/GOT-OCR2_0",
"architectures": [
"GOTQwenForCausalLM"
],
"auto_map": {
"AutoConfig": "modeling_GOT.GOTConfig",
"AutoModel": "modeling_GOT.GOTQwenForCausalLM"
},
"attention_dropout": 0.0,
"bos_token_id": 151643,
"eos_token_id": 151643,
"freeze_vision_tower": false,
"hidden_act": "silu",
"hidden_size": 1024,
"im_end_token": 151858,
"im_patch_token": 151859,
"im_start_token": 151857,
"image_token_len": 256,
"initializer_range": 0.02,
"intermediate_size": 2816,
"max_position_embeddings": 32768,
"max_window_layers": 21,
"model_type": "GOT",
"num_attention_heads": 16,
"num_hidden_layers": 24,
"num_key_value_heads": 16,
"rms_norm_eps": 1e-06,
"rope_theta": 1000000.0,
"sliding_window": 32768,
"tie_word_embeddings": true,
"torch_dtype": "bfloat16",
"transformers_version": "4.37.2",
"use_cache": true,
"use_im_start_end": true,
"use_sliding_window": false,
"vocab_size": 151860
}

1
configuration.json Normal file
View File

@ -0,0 +1 @@
{}

6
generation_config.json Normal file
View File

@ -0,0 +1,6 @@
{
"bos_token_id": 151643,
"eos_token_id": 151643,
"max_new_tokens": 2048,
"transformers_version": "4.37.2"
}

468
got_vision_b.py Normal file
View File

@ -0,0 +1,468 @@
import torch
import torch.nn.functional as F
from typing import Optional, Tuple, Type
from functools import partial
import torch.nn as nn
from typing import Type
class MLPBlock(nn.Module):
def __init__(
self,
embedding_dim: int,
mlp_dim: int,
act: Type[nn.Module] = nn.GELU,
) -> None:
super().__init__()
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
self.act = act()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.lin2(self.act(self.lin1(x)))
class LayerNorm2d(nn.Module):
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(num_channels))
self.bias = nn.Parameter(torch.zeros(num_channels))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
class ImageEncoderViT(nn.Module):
def __init__(
self,
img_size: int = 1024,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.0,
out_chans: int = 256,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_abs_pos: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
global_attn_indexes: Tuple[int, ...] = (),
) -> None:
"""
Args:
img_size (int): Input image size.
patch_size (int): Patch size.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
depth (int): Depth of ViT.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_abs_pos (bool): If True, use absolute positional embeddings.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks.
global_attn_indexes (list): Indexes for blocks using global attention.
"""
super().__init__()
self.img_size = img_size
self.patch_embed = PatchEmbed(
kernel_size=(patch_size, patch_size),
stride=(patch_size, patch_size),
in_chans=in_chans,
embed_dim=embed_dim,
)
self.pos_embed: Optional[nn.Parameter] = None
if use_abs_pos:
# Initialize absolute positional embedding with pretrain image size.
self.pos_embed = nn.Parameter(
torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
)
self.blocks = nn.ModuleList()
for i in range(depth):
block = Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
act_layer=act_layer,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
window_size=window_size if i not in global_attn_indexes else 0,
input_size=(img_size // patch_size, img_size // patch_size),
)
self.blocks.append(block)
self.neck = nn.Sequential(
nn.Conv2d(
embed_dim,
out_chans,
kernel_size=1,
bias=False,
),
LayerNorm2d(out_chans),
nn.Conv2d(
out_chans,
out_chans,
kernel_size=3,
padding=1,
bias=False,
),
LayerNorm2d(out_chans),
)
self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)
self.net_3 = nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x)
if self.pos_embed is not None:
x = x + self.pos_embed
for blk in self.blocks:
x = blk(x)
x = self.neck(x.permute(0, 3, 1, 2))
x = self.net_2(x)
x = self.net_3(x)
return x
class Block(nn.Module):
"""Transformer blocks with support of window attention and residual propagation blocks"""
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
input_size: Optional[Tuple[int, int]] = None,
) -> None:
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks. If it equals 0, then
use global attention.
input_size (tuple(int, int) or None): Input resolution for calculating the relative
positional parameter size.
"""
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
input_size=input_size if window_size == 0 else (window_size, window_size),
)
self.norm2 = norm_layer(dim)
self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
self.window_size = window_size
def forward(self, x: torch.Tensor) -> torch.Tensor:
shortcut = x
x = self.norm1(x)
# Window partition
if self.window_size > 0:
H, W = x.shape[1], x.shape[2]
x, pad_hw = window_partition(x, self.window_size)
x = self.attn(x)
# Reverse window partition
if self.window_size > 0:
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
x = shortcut + x
x = x + self.mlp(self.norm2(x))
return x
class Attention(nn.Module):
"""Multi-head Attention block with relative position embeddings."""
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
input_size: Optional[Tuple[int, int]] = None,
) -> None:
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
input_size (tuple(int, int) or None): Input resolution for calculating the relative
positional parameter size.
"""
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.use_rel_pos = use_rel_pos
if self.use_rel_pos:
assert (
input_size is not None
), "Input size must be provided if using relative positional encoding."
# initialize relative positional embeddings
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, H, W, _ = x.shape
# qkv with shape (3, B, nHead, H * W, C)
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
# q, k, v with shape (B * nHead, H * W, C)
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
attn = (q * self.scale) @ k.transpose(-2, -1)
if self.use_rel_pos:
attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
attn = attn.softmax(dim=-1)
x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
x = self.proj(x)
return x
def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
"""
Partition into non-overlapping windows with padding if needed.
Args:
x (tensor): input tokens with [B, H, W, C].
window_size (int): window size.
Returns:
windows: windows after partition with [B * num_windows, window_size, window_size, C].
(Hp, Wp): padded height and width before partition
"""
B, H, W, C = x.shape
pad_h = (window_size - H % window_size) % window_size
pad_w = (window_size - W % window_size) % window_size
if pad_h > 0 or pad_w > 0:
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
Hp, Wp = H + pad_h, W + pad_w
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows, (Hp, Wp)
def window_unpartition(
windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
) -> torch.Tensor:
"""
Window unpartition into original sequences and removing padding.
Args:
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
window_size (int): window size.
pad_hw (Tuple): padded height and width (Hp, Wp).
hw (Tuple): original height and width (H, W) before padding.
Returns:
x: unpartitioned sequences with [B, H, W, C].
"""
Hp, Wp = pad_hw
H, W = hw
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
if Hp > H or Wp > W:
x = x[:, :H, :W, :].contiguous()
return x
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
"""
Get relative positional embeddings according to the relative positions of
query and key sizes.
Args:
q_size (int): size of query q.
k_size (int): size of key k.
rel_pos (Tensor): relative position embeddings (L, C).
Returns:
Extracted positional embeddings according to relative positions.
"""
max_rel_dist = int(2 * max(q_size, k_size) - 1)
# Interpolate rel pos if needed.
if rel_pos.shape[0] != max_rel_dist:
# Interpolate rel pos.
rel_pos_resized = F.interpolate(
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
size=max_rel_dist,
mode="linear",
)
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
else:
rel_pos_resized = rel_pos
# Scale the coords with short length if shapes for q and k are different.
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
return rel_pos_resized[relative_coords.long()]
def add_decomposed_rel_pos(
attn: torch.Tensor,
q: torch.Tensor,
rel_pos_h: torch.Tensor,
rel_pos_w: torch.Tensor,
q_size: Tuple[int, int],
k_size: Tuple[int, int],
) -> torch.Tensor:
"""
Args:
attn (Tensor): attention map.
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
Returns:
attn (Tensor): attention map with added relative positional embeddings.
"""
q_h, q_w = q_size
k_h, k_w = k_size
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
B, _, dim = q.shape
r_q = q.reshape(B, q_h, q_w, dim)
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
attn = (
attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
).view(B, q_h * q_w, k_h * k_w)
return attn
class PatchEmbed(nn.Module):
"""
Image to Patch Embedding.
"""
def __init__(
self,
kernel_size: Tuple[int, int] = (16, 16),
stride: Tuple[int, int] = (16, 16),
padding: Tuple[int, int] = (0, 0),
in_chans: int = 3,
embed_dim: int = 768,
) -> None:
"""
Args:
kernel_size (Tuple): kernel size of the projection layer.
stride (Tuple): stride of the projection layer.
padding (Tuple): padding size of the projection layer.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
"""
super().__init__()
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
# B C H W -> B H W C
x = x.permute(0, 2, 3, 1)
return x
def build_GOT_vit_b(checkpoint=None):
return _build_GOT_vision(
encoder_embed_dim=768,
encoder_depth=12,
encoder_num_heads=12,
encoder_global_attn_indexes=[2, 5, 8, 11],
checkpoint=checkpoint,
)
def _build_GOT_vision(
encoder_embed_dim,
encoder_depth,
encoder_num_heads,
encoder_global_attn_indexes,
checkpoint=None,
):
prompt_embed_dim = 256
image_size = 1024
vit_patch_size = 16
image_embedding_size = image_size // vit_patch_size
image_encoder=ImageEncoderViT(
depth=encoder_depth,
embed_dim=encoder_embed_dim,
img_size=image_size,
mlp_ratio=4,
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
num_heads=encoder_num_heads,
patch_size=vit_patch_size,
qkv_bias=True,
use_rel_pos=True,
global_attn_indexes=encoder_global_attn_indexes,
window_size=14,
out_chans=prompt_embed_dim,
)
return image_encoder

BIN
model.safetensors (Stored with Git LFS) Normal file

Binary file not shown.

881
modeling_GOT.py Normal file
View File

@ -0,0 +1,881 @@
from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM, StoppingCriteria, TextStreamer
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from typing import List, Optional, Tuple, Union
from transformers.cache_utils import Cache
import requests
from PIL import Image
from io import BytesIO
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from .got_vision_b import build_GOT_vit_b
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
import dataclasses
###
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
DEFAULT_IM_START_TOKEN = '<img>'
DEFAULT_IM_END_TOKEN = '</img>'
from enum import auto, Enum
class SeparatorStyle(Enum):
"""Different separator style."""
SINGLE = auto()
TWO = auto()
MPT = auto()
@dataclasses.dataclass
class Conversation:
"""A class that keeps all conversation history."""
system: str
roles: List[str]
messages: List[List[str]]
offset: int
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
sep: str = "<|im_end|>"
sep2: str = None
version: str = "Unknown"
skip_next: bool = False
def get_prompt(self):
if self.sep_style == SeparatorStyle.SINGLE:
ret = self.system + self.sep + '\n'
for role, message in self.messages:
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + ": " + message + self.sep
else:
ret += role + ":"
return ret
elif self.sep_style == SeparatorStyle.TWO:
seps = [self.sep, self.sep2]
ret = self.system + seps[0]
for i, (role, message) in enumerate(self.messages):
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + ": " + message + seps[i % 2]
else:
ret += role + ":"
return ret
if self.sep_style == SeparatorStyle.MPT:
if self.system:
ret = self.system + self.sep
else:
ret = ''
for role, message in self.messages:
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + message + self.sep
else:
ret += role
return ret
else:
raise ValueError(f"Invalid style: {self.sep_style}")
def append_message(self, role, message):
self.messages.append([role, message])
def copy(self):
return Conversation(
system=self.system,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2)
class KeywordsStoppingCriteria(StoppingCriteria):
def __init__(self, keywords, tokenizer, input_ids):
self.keywords = keywords
self.keyword_ids = [tokenizer(keyword).input_ids for keyword in keywords]
self.keyword_ids = [keyword_id[0] for keyword_id in self.keyword_ids if type(keyword_id) is list and len(keyword_id) == 1]
self.tokenizer = tokenizer
self.start_len = None
self.input_ids = input_ids
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
if self.start_len is None:
self.start_len = self.input_ids.shape[1]
else:
for keyword_id in self.keyword_ids:
if output_ids[0, -1] == keyword_id:
return True
outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0]
for keyword in self.keywords:
if keyword in outputs:
return True
return False
class GOTImageEvalProcessor:
def __init__(self, image_size=384, mean=None, std=None):
if mean is None:
mean = (0.48145466, 0.4578275, 0.40821073)
if std is None:
std = (0.26862954, 0.26130258, 0.27577711)
self.normalize = transforms.Normalize(mean, std)
self.transform = transforms.Compose(
[
transforms.Resize(
(image_size, image_size), interpolation=InterpolationMode.BICUBIC
),
transforms.ToTensor(),
self.normalize,
]
)
def __call__(self, item):
return self.transform(item)
class GOTConfig(Qwen2Config):
model_type = "GOT"
class GOTQwenModel(Qwen2Model):
config_class = GOTConfig
def __init__(self, config: Qwen2Config):
super(GOTQwenModel, self).__init__(config)
self.vision_tower_high = build_GOT_vit_b()
self.mm_projector_vary = nn.Linear(1024, 1024)
def initialize_vision_modules(
self,
vision_tower,
pretrained_stage1_model=None,
freeze_vision_tower=False,
use_im_start_end=False,
vision_select_layer=-1,
dtype=torch.float16,
device="cuda"
):
image_processor_high = GOTImageEvalProcessor(image_size=1024)
self.vision_tower_high = self.vision_tower_high.to(dtype=dtype, device=device)
self.mm_projector_vary = self.mm_projector_vary.to(dtype=dtype, device=device)
image_token_len = 256
self.config.vision_tower = vision_tower
self.config.image_token_len = image_token_len
self.config.use_im_start_end = True
self.config.vision_select_layer = vision_select_layer
self.config.freeze_vision_tower = freeze_vision_tower
return dict(
image_processor_high=image_processor_high,
image_token_len=image_token_len,
)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
# HACK: replace back original embeddings for LLaVA pretraining
orig_embeds_params = getattr(self, 'orig_embeds_params', None)
if orig_embeds_params is not None:
with torch.no_grad():
self.get_input_embeddings().weight[:-self.num_new_tokens] = orig_embeds_params[:-self.num_new_tokens].data
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
vision_tower_high = getattr(self, 'vision_tower_high', None)
if vision_tower_high is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:
use_im_start_end = getattr(self.config, "use_im_start_end", -1)
vision_select_layer = getattr(self.config, "vision_select_layer", -1)
im_patch_token = getattr(self.config, "im_patch_token", -1)
im_start_token = getattr(self.config, "im_start_token", -1)
im_end_token = getattr(self.config, "im_end_token", -1)
freeze_vision_tower = getattr(self.config, "freeze_vision_tower", False)
im_patch_token = 151859
im_start_token = 151857
im_end_token = 151858
image_features = []
for image in images:
P, C, H, W = image.shape
if P == 1:
with torch.set_grad_enabled(False):
cnn_feature = vision_tower_high(image)
cnn_feature = cnn_feature.flatten(2).permute(0, 2, 1) # 256*1024
image_feature = self.mm_projector_vary(cnn_feature)
image_features.append(image_feature)
else:
image_patches = torch.unbind(image)
image_patches_features = []
for image_patch in image_patches:
image_p = torch.stack([image_patch])
with torch.set_grad_enabled(False):
cnn_feature_p = vision_tower_high(image_p)
cnn_feature_p = cnn_feature_p.flatten(2).permute(0, 2, 1)
image_feature_p = self.mm_projector_vary(cnn_feature_p)
image_patches_features.append(image_feature_p)
image_feature = torch.cat(image_patches_features, dim=1)
image_features.append(image_feature)
dummy_image_features_2 = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
dummy_image_features = dummy_image_features_2
use_im_start_end = True
new_input_embeds = []
for cur_input_ids, cur_input_embeds, cur_image_features in zip(input_ids, inputs_embeds, image_features):
if (cur_input_ids == im_patch_token).sum() == 0:
cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum()
new_input_embeds.append(cur_input_embeds)
continue
if use_im_start_end:
if (cur_input_ids == im_start_token).sum() != (cur_input_ids == im_end_token).sum():
raise ValueError("The number of image start tokens and image end tokens should be the same.")
image_start_tokens = torch.where(cur_input_ids == im_start_token)[0]
for image_start_token_pos, per_cur_image_features in zip(image_start_tokens, cur_image_features):
per_cur_image_features = per_cur_image_features.to(device=cur_input_embeds.device)
num_patches = per_cur_image_features.shape[0]
if cur_input_ids[image_start_token_pos + num_patches + 1] != im_end_token:
raise ValueError("The image end token should follow the image start token.")
cur_input_embeds = torch.cat(
(
cur_input_embeds[:image_start_token_pos+1],
per_cur_image_features,
cur_input_embeds[image_start_token_pos + num_patches + 1:]
),
dim=0
)
new_input_embeds.append(cur_input_embeds)
else:
raise NotImplementedError
inputs_embeds = torch.stack(new_input_embeds, dim=0)
return super(GOTQwenModel, self).forward(
input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, use_cache=use_cache, position_ids = position_ids,
output_attentions=output_attentions, output_hidden_states=output_hidden_states,
return_dict=return_dict
)
class GOTQwenForCausalLM(Qwen2ForCausalLM):
config_class = GOTConfig
# supports_gradient_checkpointing = True
def __init__(self, config):
super(Qwen2ForCausalLM, self).__init__(config)
self.model = GOTQwenModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_model(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
images=images,
return_dict=return_dict
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
# logits
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
# Omit tokens covered by past_key_values
if past_key_values is not None:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length()
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
# input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
if (
max_cache_length is not None
and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"images": kwargs.get("images", None),
}
)
return model_inputs
def initialize_vision_tokenizer(
self,
tokenizer,
freeze_lm_model=False,
pretrained_stage1_model=None,
device="cuda"
):
config = self.get_model().config
self.resize_token_embeddings(len(tokenizer))
config.im_patch_token = 151859
config.use_im_start_end = True
if config.use_im_start_end:
self.resize_token_embeddings(len(tokenizer))
config.im_start_token, config.im_end_token = 151857, 151858
def load_image(self, image_file):
if image_file.startswith('http') or image_file.startswith('https'):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert('RGB')
else:
image = Image.open(image_file).convert('RGB')
return image
def disable_torch_init(self):
"""
Disable the redundant torch default initialization to accelerate model creation.
"""
import torch
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
def chat(self, tokenizer, image_file, ocr_type, ocr_box='', ocr_color='', render=False, save_render_file=None, print_prompt=False, gradio_input=False, stream_flag = False):
self.disable_torch_init()
image_processor_high = GOTImageEvalProcessor(image_size=1024)
use_im_start_end = True
image_token_len = 256
if gradio_input:
image = image_file.copy()
else:
image = self.load_image(image_file)
w, h = image.size
if ocr_type == 'format':
qs = 'OCR with format: '
else:
qs = 'OCR: '
if ocr_box:
bbox = eval(ocr_box)
if len(bbox) == 2:
bbox[0] = int(bbox[0]/w*1000)
bbox[1] = int(bbox[1]/h*1000)
if len(bbox) == 4:
bbox[0] = int(bbox[0]/w*1000)
bbox[1] = int(bbox[1]/h*1000)
bbox[2] = int(bbox[2]/w*1000)
bbox[3] = int(bbox[3]/h*1000)
if ocr_type == 'format':
qs = str(bbox) + ' ' + 'OCR with format: '
else:
qs = str(bbox) + ' ' + 'OCR: '
if ocr_color:
if ocr_type == 'format':
qs = '[' + ocr_color + ']' + ' ' + 'OCR with format: '
else:
qs = '[' + ocr_color + ']' + ' ' + 'OCR: '
if use_im_start_end:
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len + DEFAULT_IM_END_TOKEN + '\n' + qs
else:
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
conv_mpt = Conversation(
system="""<|im_start|>system
You should follow the instructions carefully and explain your answers in detail.""",
# system = None,
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
version="mpt",
messages=(),
offset=0,
sep_style=SeparatorStyle.MPT,
sep="<|im_end|>",
)
conv = conv_mpt.copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
if print_prompt:
print(prompt)
inputs = tokenizer([prompt])
image_tensor_1 = image_processor_high(image)
input_ids = torch.as_tensor(inputs.input_ids).cuda()
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
if stream_flag:
with torch.autocast("cuda", dtype=torch.bfloat16):
output_ids = self.generate(
input_ids,
images=[image_tensor_1.unsqueeze(0).half().cuda()],
do_sample=False,
num_beams = 1,
no_repeat_ngram_size = 20,
streamer=streamer,
max_new_tokens=4096,
stopping_criteria=[stopping_criteria]
)
else:
with torch.autocast("cuda", dtype=torch.bfloat16):
output_ids = self.generate(
input_ids,
images=[image_tensor_1.unsqueeze(0).half().cuda()],
do_sample=False,
num_beams = 1,
no_repeat_ngram_size = 20,
# streamer=streamer,
max_new_tokens=4096,
stopping_criteria=[stopping_criteria]
)
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
if outputs.endswith(stop_str):
outputs = outputs[:-len(stop_str)]
outputs = outputs.strip()
response_str = outputs
if render:
print('==============rendering===============')
from .render_tools import svg_to_html, content_mmd_to_html, tik_html, translation_table
if '**kern' in outputs:
import verovio
tk = verovio.toolkit()
tk.loadData(outputs)
tk.setOptions({"pageWidth": 2100, "footer": 'none',
'barLineWidth': 0.5, 'beamMaxSlope': 15,
'staffLineWidth': 0.2, 'spacingStaff': 6})
tk.getPageCount()
svg = tk.renderToSVG()
svg = svg.replace("overflow=\"inherit\"", "overflow=\"visible\"")
svg_to_html(svg, save_render_file)
if ocr_type == 'format' and '**kern' not in outputs:
if '\\begin{tikzpicture}' not in outputs:
html_path_2 = save_render_file
right_num = outputs.count('\\right')
left_num = outputs.count('\left')
if right_num != left_num:
outputs = outputs.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.')
outputs = outputs.replace('"', '``').replace('$', '')
outputs_list = outputs.split('\n')
gt= ''
for out in outputs_list:
gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
gt = gt[:-2]
lines = content_mmd_to_html
lines = lines.split("const text =")
new_web = lines[0] + 'const text =' + gt + lines[1]
else:
html_path_2 = save_render_file
outputs = outputs.translate(translation_table)
outputs_list = outputs.split('\n')
gt= ''
for out in outputs_list:
if out:
if '\\begin{tikzpicture}' not in out and '\\end{tikzpicture}' not in out:
while out[-1] == ' ':
out = out[:-1]
if out is None:
break
if out:
if out[-1] != ';':
gt += out[:-1] + ';\n'
else:
gt += out + '\n'
else:
gt += out + '\n'
lines = tik_html
lines = lines.split("const text =")
new_web = lines[0] + gt + lines[1]
with open(html_path_2, 'w') as web_f_new:
web_f_new.write(new_web)
return response_str
def dynamic_preprocess(self, image, min_num=1, max_num=6, image_size=1024, use_thumbnail=True):
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float('inf')
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
# print(f'width: {width}, height: {height}, best_ratio: {best_ratio}')
return best_ratio
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set(
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
i * j <= max_num and i * j >= min_num)
# print(target_ratios)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
# print(target_aspect_ratio)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images
def chat_crop(self, tokenizer, image_file, ocr_type, render=False, save_render_file=None, print_prompt=False, gradio_input=False, stream_flag = False):
# Model
self.disable_torch_init()
multi_page=False
image_processor_high = GOTImageEvalProcessor(image_size=1024)
use_im_start_end = True
image_token_len = 256
image_list = []
# if len(image_file_list)>1:
# multi_page = True
if multi_page:
qs = 'OCR with format across multi pages: '
# only for png files
# import glob
# from natsort import natsorted
# patches = glob.glob(image_file + '/*png')
patches = image_file
# patches = natsorted(patches)
sub_images = []
for sub_image in patches:
sub_images.append(self.load_image(sub_image))
ll = len(patches)
# print(patches)
# print("len ll: ", ll)
else:
if ocr_type == 'format':
qs = 'OCR with format upon the patch reference: '
else:
qs = 'OCR upon the patch reference: '
if gradio_input:
img = image_file.copy()
else:
img = self.load_image(image_file)
sub_images = self.dynamic_preprocess(img)
ll = len(sub_images)
for image in sub_images:
image_tensor_1 = image_processor_high(image)
image_list.append(image_tensor_1)
image_list = torch.stack(image_list)
print('====new images batch size======: \n',image_list.shape)
if use_im_start_end:
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len*ll + DEFAULT_IM_END_TOKEN + '\n' + qs
else:
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
conv_mpt = Conversation(
system="""<|im_start|>system
You should follow the instructions carefully and explain your answers in detail.""",
# system = None,
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
version="mpt",
messages=(),
offset=0,
sep_style=SeparatorStyle.MPT,
sep="<|im_end|>",
)
conv = conv_mpt.copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
if print_prompt:
print(prompt)
inputs = tokenizer([prompt])
input_ids = torch.as_tensor(inputs.input_ids).cuda()
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
if stream_flag:
with torch.autocast("cuda", dtype=torch.bfloat16):
output_ids = self.generate(
input_ids,
images=[image_list.half().cuda()],
do_sample=False,
num_beams = 1,
# no_repeat_ngram_size = 20,
streamer=streamer,
max_new_tokens=4096,
stopping_criteria=[stopping_criteria]
)
else:
with torch.autocast("cuda", dtype=torch.bfloat16):
output_ids = self.generate(
input_ids,
images=[image_list.half().cuda()],
do_sample=False,
num_beams = 1,
# no_repeat_ngram_size = 20,
# streamer=streamer,
max_new_tokens=4096,
stopping_criteria=[stopping_criteria]
)
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
if outputs.endswith(stop_str):
outputs = outputs[:-len(stop_str)]
outputs = outputs.strip()
response_str = outputs
if render:
print('==============rendering===============')
from .render_tools import content_mmd_to_html
html_path_2 = save_render_file
right_num = outputs.count('\\right')
left_num = outputs.count('\left')
if right_num != left_num:
outputs = outputs.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.')
outputs = outputs.replace('"', '``').replace('$', '')
outputs_list = outputs.split('\n')
gt= ''
for out in outputs_list:
gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
gt = gt[:-2]
lines = content_mmd_to_html
lines = lines.split("const text =")
new_web = lines[0] + 'const text =' + gt + lines[1]
with open(html_path_2, 'w') as web_f_new:
web_f_new.write(new_web)
return response_str

151643
qwen.tiktoken Normal file

File diff suppressed because it is too large Load Diff

96
render_tools.py Normal file
View File

@ -0,0 +1,96 @@
punctuation_dict = {
"": ",",
"": ".",
}
translation_table = str.maketrans(punctuation_dict)
def svg_to_html(svg_content, output_filename):
html_content = f"""
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>SVG Embedded in HTML</title>
</head>
<body>
<svg width="2100" height="15000" xmlns="http://www.w3.org/2000/svg">
{svg_content}
</svg>
</body>
</html>
"""
with open(output_filename, 'w') as file:
file.write(html_content)
content_mmd_to_html = """<!DOCTYPE html>
<html lang="en" data-lt-installed="true"><head>
<meta charset="UTF-8">
<title>Title</title>
<script>
const text =
</script>
<style>
#content {
max-width: 800px;
margin: auto;
}
</style>
<script>
let script = document.createElement('script');
script.src = "https://cdn.jsdelivr.net/npm/mathpix-markdown-it@1.3.6/es5/bundle.js";
document.head.append(script);
script.onload = function() {
const isLoaded = window.loadMathJax();
if (isLoaded) {
console.log('Styles loaded!')
}
const el = window.document.getElementById('content-text');
if (el) {
const options = {
htmlTags: true
};
const html = window.render(text, options);
el.outerHTML = html;
}
};
</script>
</head>
<body>
<div id="content"><div id="content-text"></div></div>
</body>
</html>
"""
tik_html = """
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Document</title>
<link rel="stylesheet" type="text/css" href="https://tikzjax.com/v1/fonts.css">
<script src="https://tikzjax.com/v1/tikzjax.js"></script>
</head>
<body>
<script type="text/tikz">
const text =
</script>
</body>
</html>"""
# print(tik_html)

9
special_tokens_map.json Normal file
View File

@ -0,0 +1,9 @@
{
"pad_token": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
}
}

264
tokenization_qwen.py Normal file
View File

@ -0,0 +1,264 @@
# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Tokenization classes for QWen."""
import base64
import logging
import os
import unicodedata
from typing import Collection, Dict, List, Set, Tuple, Union
import tiktoken
from transformers import PreTrainedTokenizer, AddedToken
logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "qwen.tiktoken"}
PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
ENDOFTEXT = "<|endoftext|>"
IMSTART = "<|im_start|>"
IMEND = "<|im_end|>"
# as the default behavior is changed to allow special tokens in
# regular texts, the surface forms of special tokens need to be
# as different as possible to minimize the impact
EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205)))
SPECIAL_TOKENS = (
ENDOFTEXT,
IMSTART,
IMEND,
) + EXTRAS
def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
with open(tiktoken_bpe_file, "rb") as f:
contents = f.read()
return {
base64.b64decode(token): int(rank)
for token, rank in (line.split() for line in contents.splitlines() if line)
}
class QWenTokenizer(PreTrainedTokenizer):
"""QWen tokenizer."""
vocab_files_names = VOCAB_FILES_NAMES
def __init__(
self,
vocab_file,
errors="replace",
image_start_tag='<img>',
image_end_tag='</img>',
image_pad_tag='<imgpad>',
ref_start_tag='<ref>',
ref_end_tag='</ref>',
box_start_tag='<box>',
box_end_tag='</box>',
quad_start_tag='<quad>',
quad_end_tag='</quad>',
**kwargs,
):
super().__init__(**kwargs)
self.image_start_tag = image_start_tag
self.image_end_tag = image_end_tag
self.image_pad_tag = image_pad_tag
self.ref_start_tag = ref_start_tag
self.ref_end_tag = ref_end_tag
self.box_start_tag = box_start_tag
self.box_end_tag = box_end_tag
self.quad_start_tag = quad_start_tag
self.quad_end_tag = quad_end_tag
self.IMAGE_ST = (
ref_start_tag, ref_end_tag,
box_start_tag, box_end_tag,
quad_start_tag, quad_end_tag,
image_start_tag, image_end_tag,
image_pad_tag
)
self.errors = errors # how to handle errors in decoding
self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) # type: dict[bytes, int]
self.special_tokens = {
token: index
for index, token in enumerate(
SPECIAL_TOKENS + self.IMAGE_ST, start=len(self.mergeable_ranks)
)
}
self.img_start_id = self.special_tokens[self.image_start_tag]
self.img_end_id = self.special_tokens[self.image_end_tag]
self.img_pad_id = self.special_tokens[self.image_pad_tag]
self.ref_start_id = self.special_tokens[self.ref_start_tag]
self.ref_end_id = self.special_tokens[self.ref_end_tag]
self.box_start_id = self.special_tokens[self.box_start_tag]
self.box_end_id = self.special_tokens[self.box_end_tag]
self.quad_start_id = self.special_tokens[self.quad_start_tag]
self.quad_end_id = self.special_tokens[self.quad_end_tag]
enc = tiktoken.Encoding(
"Qwen",
pat_str=PAT_STR,
mergeable_ranks=self.mergeable_ranks,
special_tokens=self.special_tokens,
)
assert (
len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab
), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding"
self.decoder = {
v: k for k, v in self.mergeable_ranks.items()
} # type: dict[int, bytes|str]
self.decoder.update({v: k for k, v in self.special_tokens.items()})
self.tokenizer = enc # type: tiktoken.Encoding
self.eod_id = self.tokenizer.eot_token
self.im_start_id = self.special_tokens[IMSTART]
self.im_end_id = self.special_tokens[IMEND]
def __len__(self) -> int:
return self.tokenizer.n_vocab
def get_vocab(self) -> Dict[bytes, int]:
return self.mergeable_ranks
def convert_tokens_to_ids(
self, tokens: Union[bytes, str, List[Union[bytes, str]]]
) -> List[int]:
ids = []
if isinstance(tokens, (str, bytes)):
if tokens in self.special_tokens:
return self.special_tokens[tokens]
else:
return self.mergeable_ranks.get(tokens)
for token in tokens:
if token in self.special_tokens:
ids.append(self.special_tokens[token])
else:
ids.append(self.mergeable_ranks.get(token))
return ids
def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
if not special_tokens and new_tokens:
raise ValueError('Adding regular tokens is not supported')
for token in new_tokens:
surface_form = token.content if isinstance(token, AddedToken) else token
if surface_form not in SPECIAL_TOKENS:
raise ValueError('Adding unknown special tokens is not supported')
return 0
def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:
"""
Save only the vocabulary of the tokenizer (vocabulary).
Returns:
`Tuple(str)`: Paths to the files saved.
"""
file_path = os.path.join(save_directory, "qwen.tiktoken")
with open(file_path, "w", encoding="utf8") as w:
for k, v in self.mergeable_ranks.items():
line = base64.b64encode(k).decode("utf8") + " " + str(v) + "\n"
w.write(line)
return (file_path,)
def tokenize(
self,
text: str,
allowed_special: Union[Set, str] = "all",
disallowed_special: Union[Collection, str] = (),
**kwargs,
) -> List[Union[bytes, str]]:
"""
Converts a string in a sequence of tokens.
Args:
text (`str`):
The sequence to be encoded.
allowed_special (`Literal["all"]` or `set`):
The surface forms of the tokens to be encoded as special tokens in regular texts.
Default to "all".
disallowed_special (`Literal["all"]` or `Collection`):
The surface forms of the tokens that should not be in regular texts and trigger errors.
Default to an empty tuple.
kwargs (additional keyword arguments, *optional*):
Will be passed to the underlying model specific encode method.
Returns:
`List[bytes|str]`: The list of tokens.
"""
tokens = []
text = unicodedata.normalize("NFC", text)
# this implementation takes a detour: text -> token id -> token surface forms
for t in self.tokenizer.encode(
text, allowed_special=allowed_special, disallowed_special=disallowed_special
):
tokens.append(self.decoder[t])
return tokens
def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:
"""
Converts a sequence of tokens in a single string.
"""
text = ""
temp = b""
for t in tokens:
if isinstance(t, str):
if temp:
text += temp.decode("utf-8", errors=self.errors)
temp = b""
text += t
elif isinstance(t, bytes):
temp += t
else:
raise TypeError("token should only be of type types or str")
if temp:
text += temp.decode("utf-8", errors=self.errors)
return text
@property
def vocab_size(self):
return self.tokenizer.n_vocab
def _convert_id_to_token(self, index: int) -> Union[bytes, str]:
"""Converts an id to a token, special tokens included"""
if index in self.decoder:
return self.decoder[index]
raise ValueError("unknown ids")
def _convert_token_to_id(self, token: Union[bytes, str]) -> int:
"""Converts a token to an id using the vocab, special tokens included"""
if token in self.special_tokens:
return self.special_tokens[token]
if token in self.mergeable_ranks:
return self.mergeable_ranks[token]
raise ValueError("unknown token")
def _tokenize(self, text: str, **kwargs):
"""
Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based
vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).
Do NOT take care of added tokens.
"""
raise NotImplementedError
def _decode(
self,
token_ids: Union[int, List[int]],
skip_special_tokens: bool = False,
errors: str = None,
**kwargs,
) -> str:
if isinstance(token_ids, int):
token_ids = [token_ids]
if skip_special_tokens:
token_ids = [i for i in token_ids if i < self.eod_id]
return self.tokenizer.decode(token_ids, errors=errors or self.errors)

14
tokenizer_config.json Normal file
View File

@ -0,0 +1,14 @@
{
"added_tokens_decoder": {},
"auto_map": {
"AutoTokenizer": [
"tokenization_qwen.QWenTokenizer",
null
]
},
"clean_up_tokenization_spaces": true,
"model_max_length": 8000,
"pad_token": "<|endoftext|>",
"padding_side": "right",
"tokenizer_class": "QWenTokenizer"
}