2464 lines
106 KiB
Python
2464 lines
106 KiB
Python
# Modification Copyright© 2025 Advanced Micro Devices, Inc. All rights reserved.
|
|
# Copyright 2023 Haotian Liu
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from typing import List, Optional, Tuple, Union
|
|
|
|
from transformers import (AutoConfig, AutoModelForCausalLM,
|
|
OlmoConfig, OlmoModel, OlmoForCausalLM)
|
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
from transformers.generation.utils import GenerateOutput
|
|
from abc import ABC, abstractmethod
|
|
|
|
import re
|
|
import os
|
|
import math
|
|
import random
|
|
import shutil
|
|
from .mm_utils import get_anyres_image_grid_shape, rank0_print
|
|
|
|
from .mm_utils import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
|
|
|
import torch
|
|
from einops import rearrange, repeat
|
|
|
|
try:
|
|
from einops_exts import rearrange_many
|
|
except:
|
|
pass
|
|
|
|
|
|
|
|
from torch import einsum
|
|
|
|
from torch import Tensor, device
|
|
import torch.utils.checkpoint
|
|
from torch.nn import CrossEntropyLoss
|
|
|
|
from transformers.activations import ACT2FN
|
|
from transformers.modeling_outputs import (
|
|
BaseModelOutputWithPastAndCrossAttentions,
|
|
BaseModelOutputWithPoolingAndCrossAttentions,
|
|
CausalLMOutputWithCrossAttentions,
|
|
MaskedLMOutput,
|
|
)
|
|
from transformers.modeling_utils import (
|
|
PreTrainedModel,
|
|
apply_chunking_to_forward,
|
|
find_pruneable_heads_and_indices,
|
|
prune_linear_layer,
|
|
)
|
|
from transformers.utils import logging
|
|
from transformers.models.bert.configuration_bert import BertConfig
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
########## Projector ##############
|
|
class PoolerProjector(nn.Module):
|
|
def __init__(self, config, vision_cfg):
|
|
super().__init__()
|
|
self._config = config
|
|
self.hw = vision_cfg.image_size // vision_cfg.patch_size
|
|
|
|
self.conv_pool = nn.Conv2d(config.mm_hidden_size, config.hidden_size, kernel_size=2, stride=2)
|
|
|
|
self.proj = nn.Sequential(
|
|
nn.GELU(),
|
|
nn.Linear(config.hidden_size, config.hidden_size),
|
|
)
|
|
|
|
def forward(self, x, *args, **kwargs):
|
|
height = width = self.hw
|
|
assert height * width == x.shape[1]
|
|
x = x.view(x.shape[0], height, width, -1).permute(0, 3, 1, 2)
|
|
x = self.conv_pool(x)
|
|
x = x.flatten(2).transpose(1, 2)
|
|
x = self.proj(x)
|
|
return x
|
|
|
|
@property
|
|
def config(self):
|
|
return {"mm_projector_type": "pooler"}
|
|
|
|
class IdentityMap(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x, *args, **kwargs):
|
|
return x
|
|
|
|
@property
|
|
def config(self):
|
|
return {"mm_projector_type": "identity"}
|
|
|
|
|
|
class SimpleResBlock(nn.Module):
|
|
def __init__(self, channels):
|
|
super().__init__()
|
|
self.pre_norm = nn.LayerNorm(channels)
|
|
|
|
self.proj = nn.Sequential(nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels))
|
|
|
|
def forward(self, x):
|
|
x = self.pre_norm(x)
|
|
return x + self.proj(x)
|
|
|
|
|
|
def build_vision_projector(config, delay_load=False, **kwargs):
|
|
projector_type = getattr(config, "mm_projector_type", "linear")
|
|
|
|
if projector_type == "linear":
|
|
return nn.Linear(config.mm_hidden_size, config.hidden_size)
|
|
|
|
if projector_type == "pooler":
|
|
return PoolerProjector(config, kwargs["vision_cfg"])
|
|
|
|
mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
|
|
if mlp_gelu_match:
|
|
mlp_depth = int(mlp_gelu_match.group(1))
|
|
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
|
|
for _ in range(1, mlp_depth):
|
|
modules.append(nn.GELU())
|
|
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
|
|
return nn.Sequential(*modules)
|
|
|
|
mlp_gelu_resnet_match = re.match(r"^mlp(\d+)x_res(\d+)x_gelu$", projector_type)
|
|
if mlp_gelu_resnet_match:
|
|
mlp_depth = int(mlp_gelu_resnet_match.group(1))
|
|
res_depth = int(mlp_gelu_resnet_match.group(2))
|
|
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
|
|
for _ in range(1, mlp_depth):
|
|
modules.append(nn.GELU())
|
|
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
|
|
for _ in range(res_depth):
|
|
modules.append(SimpleResBlock(config.hidden_size))
|
|
return nn.Sequential(*modules)
|
|
|
|
if projector_type == "identity":
|
|
return IdentityMap()
|
|
|
|
raise ValueError(f"Unknown projector type: {projector_type}")
|
|
|
|
################ Resampler: Spatial Pool ####################
|
|
class SpatialPool(nn.Module):
|
|
def __init__(self, model_args, vision_tower):
|
|
super().__init__()
|
|
|
|
self.mode = model_args.mm_spatial_pool_mode
|
|
self.stride = model_args.mm_spatial_pool_stride
|
|
self.out_channels = getattr(model_args, "mm_spatial_pool_out_channels", vision_tower.hidden_size)
|
|
|
|
if self.mode == "average":
|
|
self.pool = nn.AvgPool2d(kernel_size=self.stride, stride=self.stride)
|
|
elif self.mode == "max":
|
|
self.pool = nn.MaxPool2d(kernel_size=self.stride, stride=self.stride)
|
|
elif self.mode == "conv":
|
|
self.pool = nn.Conv2d(in_channels=vision_tower.hidden_size, out_channels=self.out_channels, kernel_size=self.stride, stride=self.stride)
|
|
else:
|
|
raise ValueError(f"Unknown pooling mode: {self.pool}.")
|
|
|
|
def forward(self, image_features, images, *args, **kwargs):
|
|
ori_W = int(math.sqrt(image_features.shape[1] * images.shape[3] // images.shape[2]))
|
|
ori_H = int(ori_W * images.shape[2] // images.shape[3])
|
|
|
|
B, _, F = image_features.shape
|
|
|
|
image_features_spatial = image_features.view(B, ori_H, ori_H, F).permute(0, 3, 1, 2)
|
|
image_features_spatial_pool = self.pool(image_features_spatial)
|
|
|
|
return image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous()
|
|
|
|
@property
|
|
def config(self):
|
|
return {
|
|
"mm_resampler_type": "spatial_pool",
|
|
"mm_spatial_pool_stride": self.stride,
|
|
"mm_spatial_pool_mode": self.mode,
|
|
"mm_spatial_pool_out_channels": self.out_channels,
|
|
}
|
|
|
|
@property
|
|
def hidden_size(self):
|
|
return self.out_channels
|
|
|
|
def disabled_train(self, mode=True):
|
|
"""Overwrite model.train with this function to make sure train/eval mode
|
|
does not change anymore."""
|
|
return self
|
|
|
|
############## Qformer ####################
|
|
class BertEmbeddings(nn.Module):
|
|
"""Construct the embeddings from word and position embeddings."""
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
|
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
|
|
|
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
|
# any TensorFlow checkpoint file
|
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
|
|
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
|
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
|
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
|
|
|
self.config = config
|
|
|
|
def forward(
|
|
self,
|
|
input_ids=None,
|
|
position_ids=None,
|
|
query_embeds=None,
|
|
past_key_values_length=0,
|
|
):
|
|
if input_ids is not None:
|
|
seq_length = input_ids.size()[1]
|
|
else:
|
|
seq_length = 0
|
|
|
|
if position_ids is None:
|
|
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length].clone()
|
|
|
|
if input_ids is not None:
|
|
embeddings = self.word_embeddings(input_ids)
|
|
if self.position_embedding_type == "absolute":
|
|
position_embeddings = self.position_embeddings(position_ids)
|
|
embeddings = embeddings + position_embeddings
|
|
|
|
if query_embeds is not None:
|
|
embeddings = torch.cat((query_embeds, embeddings), dim=1)
|
|
else:
|
|
embeddings = query_embeds
|
|
|
|
embeddings = self.LayerNorm(embeddings)
|
|
embeddings = self.dropout(embeddings)
|
|
return embeddings
|
|
|
|
|
|
class BertSelfAttention(nn.Module):
|
|
def __init__(self, config, is_cross_attention):
|
|
super().__init__()
|
|
self.config = config
|
|
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
|
raise ValueError("The hidden size (%d) is not a multiple of the number of attention " "heads (%d)" % (config.hidden_size, config.num_attention_heads))
|
|
|
|
self.num_attention_heads = config.num_attention_heads
|
|
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
|
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
|
|
|
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
|
if is_cross_attention:
|
|
self.key = nn.Linear(config.encoder_width, self.all_head_size)
|
|
self.value = nn.Linear(config.encoder_width, self.all_head_size)
|
|
else:
|
|
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
|
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
|
|
|
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
|
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
|
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
|
self.max_position_embeddings = config.max_position_embeddings
|
|
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
|
self.save_attention = False
|
|
|
|
def save_attn_gradients(self, attn_gradients):
|
|
self.attn_gradients = attn_gradients
|
|
|
|
def get_attn_gradients(self):
|
|
return self.attn_gradients
|
|
|
|
def save_attention_map(self, attention_map):
|
|
self.attention_map = attention_map
|
|
|
|
def get_attention_map(self):
|
|
return self.attention_map
|
|
|
|
def transpose_for_scores(self, x):
|
|
new_x_shape = x.size()[:-1] + (
|
|
self.num_attention_heads,
|
|
self.attention_head_size,
|
|
)
|
|
x = x.view(*new_x_shape)
|
|
return x.permute(0, 2, 1, 3)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
attention_mask=None,
|
|
head_mask=None,
|
|
encoder_hidden_states=None,
|
|
encoder_attention_mask=None,
|
|
past_key_value=None,
|
|
output_attentions=False,
|
|
):
|
|
|
|
# If this is instantiated as a cross-attention module, the keys
|
|
# and values come from an encoder; the attention mask needs to be
|
|
# such that the encoder's padding tokens are not attended to.
|
|
is_cross_attention = encoder_hidden_states is not None
|
|
|
|
if is_cross_attention:
|
|
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
|
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
|
attention_mask = encoder_attention_mask
|
|
elif past_key_value is not None:
|
|
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
|
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
|
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
|
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
|
else:
|
|
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
|
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
|
|
|
mixed_query_layer = self.query(hidden_states)
|
|
|
|
query_layer = self.transpose_for_scores(mixed_query_layer)
|
|
|
|
past_key_value = (key_layer, value_layer)
|
|
|
|
# Take the dot product between "query" and "key" to get the raw attention scores.
|
|
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
|
|
|
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
|
seq_length = hidden_states.size()[1]
|
|
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
|
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
|
distance = position_ids_l - position_ids_r
|
|
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
|
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
|
|
|
if self.position_embedding_type == "relative_key":
|
|
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
|
attention_scores = attention_scores + relative_position_scores
|
|
elif self.position_embedding_type == "relative_key_query":
|
|
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
|
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
|
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
|
|
|
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
|
if attention_mask is not None:
|
|
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
|
attention_scores = attention_scores + attention_mask
|
|
|
|
# Normalize the attention scores to probabilities.
|
|
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
|
|
|
if is_cross_attention and self.save_attention:
|
|
self.save_attention_map(attention_probs)
|
|
attention_probs.register_hook(self.save_attn_gradients)
|
|
|
|
# This is actually dropping out entire tokens to attend to, which might
|
|
# seem a bit unusual, but is taken from the original Transformer paper.
|
|
attention_probs_dropped = self.dropout(attention_probs)
|
|
|
|
# Mask heads if we want to
|
|
if head_mask is not None:
|
|
attention_probs_dropped = attention_probs_dropped * head_mask
|
|
|
|
context_layer = torch.matmul(attention_probs_dropped, value_layer)
|
|
|
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
|
context_layer = context_layer.view(*new_context_layer_shape)
|
|
|
|
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
|
|
|
outputs = outputs + (past_key_value,)
|
|
return outputs
|
|
|
|
|
|
class BertSelfOutput(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
|
|
def forward(self, hidden_states, input_tensor):
|
|
hidden_states = self.dense(hidden_states)
|
|
hidden_states = self.dropout(hidden_states)
|
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
|
return hidden_states
|
|
|
|
|
|
class BertAttention(nn.Module):
|
|
def __init__(self, config, is_cross_attention=False):
|
|
super().__init__()
|
|
self.self = BertSelfAttention(config, is_cross_attention)
|
|
self.output = BertSelfOutput(config)
|
|
self.pruned_heads = set()
|
|
|
|
def prune_heads(self, heads):
|
|
if len(heads) == 0:
|
|
return
|
|
heads, index = find_pruneable_heads_and_indices(
|
|
heads,
|
|
self.self.num_attention_heads,
|
|
self.self.attention_head_size,
|
|
self.pruned_heads,
|
|
)
|
|
|
|
# Prune linear layers
|
|
self.self.query = prune_linear_layer(self.self.query, index)
|
|
self.self.key = prune_linear_layer(self.self.key, index)
|
|
self.self.value = prune_linear_layer(self.self.value, index)
|
|
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
|
|
|
# Update hyper params and store pruned heads
|
|
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
|
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
|
self.pruned_heads = self.pruned_heads.union(heads)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
attention_mask=None,
|
|
head_mask=None,
|
|
encoder_hidden_states=None,
|
|
encoder_attention_mask=None,
|
|
past_key_value=None,
|
|
output_attentions=False,
|
|
):
|
|
self_outputs = self.self(
|
|
hidden_states,
|
|
attention_mask,
|
|
head_mask,
|
|
encoder_hidden_states,
|
|
encoder_attention_mask,
|
|
past_key_value,
|
|
output_attentions,
|
|
)
|
|
attention_output = self.output(self_outputs[0], hidden_states)
|
|
|
|
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
|
return outputs
|
|
|
|
|
|
class BertIntermediate(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
|
if isinstance(config.hidden_act, str):
|
|
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
|
else:
|
|
self.intermediate_act_fn = config.hidden_act
|
|
|
|
def forward(self, hidden_states):
|
|
hidden_states = self.dense(hidden_states)
|
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class BertOutput(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
|
|
def forward(self, hidden_states, input_tensor):
|
|
hidden_states = self.dense(hidden_states)
|
|
hidden_states = self.dropout(hidden_states)
|
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
|
return hidden_states
|
|
|
|
|
|
class BertLayer(nn.Module):
|
|
def __init__(self, config, layer_num):
|
|
super().__init__()
|
|
self.config = config
|
|
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
|
self.seq_len_dim = 1
|
|
self.attention = BertAttention(config)
|
|
self.layer_num = layer_num
|
|
if self.config.add_cross_attention and layer_num % self.config.cross_attention_freq == 0:
|
|
self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
|
|
self.has_cross_attention = True
|
|
else:
|
|
self.has_cross_attention = False
|
|
self.intermediate = BertIntermediate(config)
|
|
self.output = BertOutput(config)
|
|
|
|
self.intermediate_query = BertIntermediate(config)
|
|
self.output_query = BertOutput(config)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
attention_mask=None,
|
|
head_mask=None,
|
|
encoder_hidden_states=None,
|
|
encoder_attention_mask=None,
|
|
past_key_value=None,
|
|
output_attentions=False,
|
|
query_length=0,
|
|
):
|
|
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
|
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
|
self_attention_outputs = self.attention(
|
|
hidden_states,
|
|
attention_mask,
|
|
head_mask,
|
|
output_attentions=output_attentions,
|
|
past_key_value=self_attn_past_key_value,
|
|
)
|
|
attention_output = self_attention_outputs[0]
|
|
outputs = self_attention_outputs[1:-1]
|
|
|
|
present_key_value = self_attention_outputs[-1]
|
|
|
|
if query_length > 0:
|
|
query_attention_output = attention_output[:, :query_length, :]
|
|
|
|
if self.has_cross_attention:
|
|
assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
|
|
cross_attention_outputs = self.crossattention(
|
|
query_attention_output,
|
|
attention_mask,
|
|
head_mask,
|
|
encoder_hidden_states,
|
|
encoder_attention_mask,
|
|
output_attentions=output_attentions,
|
|
)
|
|
query_attention_output = cross_attention_outputs[0]
|
|
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
|
|
|
layer_output = apply_chunking_to_forward(
|
|
self.feed_forward_chunk_query,
|
|
self.chunk_size_feed_forward,
|
|
self.seq_len_dim,
|
|
query_attention_output,
|
|
)
|
|
if attention_output.shape[1] > query_length:
|
|
layer_output_text = apply_chunking_to_forward(
|
|
self.feed_forward_chunk,
|
|
self.chunk_size_feed_forward,
|
|
self.seq_len_dim,
|
|
attention_output[:, query_length:, :],
|
|
)
|
|
layer_output = torch.cat([layer_output, layer_output_text], dim=1)
|
|
else:
|
|
layer_output = apply_chunking_to_forward(
|
|
self.feed_forward_chunk,
|
|
self.chunk_size_feed_forward,
|
|
self.seq_len_dim,
|
|
attention_output,
|
|
)
|
|
outputs = (layer_output,) + outputs
|
|
|
|
outputs = outputs + (present_key_value,)
|
|
|
|
return outputs
|
|
|
|
def feed_forward_chunk(self, attention_output):
|
|
intermediate_output = self.intermediate(attention_output)
|
|
layer_output = self.output(intermediate_output, attention_output)
|
|
return layer_output
|
|
|
|
def feed_forward_chunk_query(self, attention_output):
|
|
intermediate_output = self.intermediate_query(attention_output)
|
|
layer_output = self.output_query(intermediate_output, attention_output)
|
|
return layer_output
|
|
|
|
|
|
class BertEncoder(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
self.layer = nn.ModuleList([BertLayer(config, i) for i in range(config.num_hidden_layers)])
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
attention_mask=None,
|
|
head_mask=None,
|
|
encoder_hidden_states=None,
|
|
encoder_attention_mask=None,
|
|
past_key_values=None,
|
|
use_cache=None,
|
|
output_attentions=False,
|
|
output_hidden_states=False,
|
|
return_dict=True,
|
|
query_length=0,
|
|
):
|
|
all_hidden_states = () if output_hidden_states else None
|
|
all_self_attentions = () if output_attentions else None
|
|
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
|
|
|
next_decoder_cache = () if use_cache else None
|
|
|
|
for i in range(self.config.num_hidden_layers):
|
|
layer_module = self.layer[i]
|
|
if output_hidden_states:
|
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
layer_head_mask = head_mask[i] if head_mask is not None else None
|
|
past_key_value = past_key_values[i] if past_key_values is not None else None
|
|
|
|
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
|
|
|
if use_cache:
|
|
logger.warn("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
|
use_cache = False
|
|
|
|
def create_custom_forward(module):
|
|
def custom_forward(*inputs):
|
|
return module(*inputs, past_key_value, output_attentions, query_length)
|
|
|
|
return custom_forward
|
|
|
|
layer_outputs = torch.utils.checkpoint.checkpoint(
|
|
create_custom_forward(layer_module),
|
|
hidden_states,
|
|
attention_mask,
|
|
layer_head_mask,
|
|
encoder_hidden_states,
|
|
encoder_attention_mask,
|
|
)
|
|
else:
|
|
layer_outputs = layer_module(
|
|
hidden_states,
|
|
attention_mask,
|
|
layer_head_mask,
|
|
encoder_hidden_states,
|
|
encoder_attention_mask,
|
|
past_key_value,
|
|
output_attentions,
|
|
query_length,
|
|
)
|
|
|
|
hidden_states = layer_outputs[0]
|
|
if use_cache:
|
|
next_decoder_cache += (layer_outputs[-1],)
|
|
if output_attentions:
|
|
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
|
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
|
|
|
if output_hidden_states:
|
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
if not return_dict:
|
|
return tuple(
|
|
v
|
|
for v in [
|
|
hidden_states,
|
|
next_decoder_cache,
|
|
all_hidden_states,
|
|
all_self_attentions,
|
|
all_cross_attentions,
|
|
]
|
|
if v is not None
|
|
)
|
|
return BaseModelOutputWithPastAndCrossAttentions(
|
|
last_hidden_state=hidden_states,
|
|
past_key_values=next_decoder_cache,
|
|
hidden_states=all_hidden_states,
|
|
attentions=all_self_attentions,
|
|
cross_attentions=all_cross_attentions,
|
|
)
|
|
|
|
|
|
class BertPooler(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
|
self.activation = nn.Tanh()
|
|
|
|
def forward(self, hidden_states):
|
|
# We "pool" the model by simply taking the hidden state corresponding
|
|
# to the first token.
|
|
first_token_tensor = hidden_states[:, 0]
|
|
pooled_output = self.dense(first_token_tensor)
|
|
pooled_output = self.activation(pooled_output)
|
|
return pooled_output
|
|
|
|
|
|
class BertPredictionHeadTransform(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
|
if isinstance(config.hidden_act, str):
|
|
self.transform_act_fn = ACT2FN[config.hidden_act]
|
|
else:
|
|
self.transform_act_fn = config.hidden_act
|
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
|
|
def forward(self, hidden_states):
|
|
hidden_states = self.dense(hidden_states)
|
|
hidden_states = self.transform_act_fn(hidden_states)
|
|
hidden_states = self.LayerNorm(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class BertLMPredictionHead(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.transform = BertPredictionHeadTransform(config)
|
|
|
|
# The output weights are the same as the input embeddings, but there is
|
|
# an output-only bias for each token.
|
|
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
|
|
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
|
|
|
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
|
self.decoder.bias = self.bias
|
|
|
|
def forward(self, hidden_states):
|
|
hidden_states = self.transform(hidden_states)
|
|
hidden_states = self.decoder(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class BertOnlyMLMHead(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.predictions = BertLMPredictionHead(config)
|
|
|
|
def forward(self, sequence_output):
|
|
prediction_scores = self.predictions(sequence_output)
|
|
return prediction_scores
|
|
|
|
|
|
class BertPreTrainedModel(PreTrainedModel):
|
|
"""
|
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
|
models.
|
|
"""
|
|
|
|
config_class = BertConfig
|
|
base_model_prefix = "bert"
|
|
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
|
|
|
def _init_weights(self, module):
|
|
"""Initialize the weights"""
|
|
if isinstance(module, (nn.Linear, nn.Embedding)):
|
|
# Slightly different from the TF version which uses truncated_normal for initialization
|
|
# cf https://github.com/pytorch/pytorch/pull/5617
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
|
elif isinstance(module, nn.LayerNorm):
|
|
module.bias.data.zero_()
|
|
module.weight.data.fill_(1.0)
|
|
if isinstance(module, nn.Linear) and module.bias is not None:
|
|
module.bias.data.zero_()
|
|
|
|
|
|
class BertModel(BertPreTrainedModel):
|
|
"""
|
|
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
|
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
|
|
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
|
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
|
argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
|
|
input to the forward pass.
|
|
"""
|
|
|
|
def __init__(self, config, add_pooling_layer=False):
|
|
super().__init__(config)
|
|
self.config = config
|
|
|
|
self.embeddings = BertEmbeddings(config)
|
|
|
|
self.encoder = BertEncoder(config)
|
|
|
|
self.pooler = BertPooler(config) if add_pooling_layer else None
|
|
|
|
self.init_weights()
|
|
|
|
def get_input_embeddings(self):
|
|
return self.embeddings.word_embeddings
|
|
|
|
def set_input_embeddings(self, value):
|
|
self.embeddings.word_embeddings = value
|
|
|
|
def _prune_heads(self, heads_to_prune):
|
|
"""
|
|
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
|
class PreTrainedModel
|
|
"""
|
|
for layer, heads in heads_to_prune.items():
|
|
self.encoder.layer[layer].attention.prune_heads(heads)
|
|
|
|
def get_extended_attention_mask(
|
|
self,
|
|
attention_mask: Tensor,
|
|
input_shape: Tuple[int],
|
|
device: device,
|
|
is_decoder: bool,
|
|
has_query: bool = False,
|
|
) -> Tensor:
|
|
"""
|
|
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
|
|
|
|
Arguments:
|
|
attention_mask (:obj:`torch.Tensor`):
|
|
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
|
|
input_shape (:obj:`Tuple[int]`):
|
|
The shape of the input to the model.
|
|
device: (:obj:`torch.device`):
|
|
The device of the input to the model.
|
|
|
|
Returns:
|
|
:obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
|
|
"""
|
|
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
|
# ourselves in which case we just need to make it broadcastable to all heads.
|
|
if attention_mask.dim() == 3:
|
|
extended_attention_mask = attention_mask[:, None, :, :]
|
|
elif attention_mask.dim() == 2:
|
|
# Provided a padding mask of dimensions [batch_size, seq_length]
|
|
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
|
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
|
if is_decoder:
|
|
batch_size, seq_length = input_shape
|
|
|
|
seq_ids = torch.arange(seq_length, device=device)
|
|
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
|
|
|
# add a prefix ones mask to the causal mask
|
|
# causal and attention masks must have same type with pytorch version < 1.3
|
|
causal_mask = causal_mask.to(attention_mask.dtype)
|
|
|
|
if causal_mask.shape[1] < attention_mask.shape[1]:
|
|
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
|
|
if has_query: # UniLM style attention mask
|
|
causal_mask = torch.cat(
|
|
[
|
|
torch.zeros(
|
|
(batch_size, prefix_seq_len, seq_length),
|
|
device=device,
|
|
dtype=causal_mask.dtype,
|
|
),
|
|
causal_mask,
|
|
],
|
|
axis=1,
|
|
)
|
|
causal_mask = torch.cat(
|
|
[
|
|
torch.ones(
|
|
(batch_size, causal_mask.shape[1], prefix_seq_len),
|
|
device=device,
|
|
dtype=causal_mask.dtype,
|
|
),
|
|
causal_mask,
|
|
],
|
|
axis=-1,
|
|
)
|
|
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
|
else:
|
|
extended_attention_mask = attention_mask[:, None, None, :]
|
|
else:
|
|
raise ValueError("Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(input_shape, attention_mask.shape))
|
|
|
|
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
|
# masked positions, this operation will create a tensor which is 0.0 for
|
|
# positions we want to attend and -10000.0 for masked positions.
|
|
# Since we are adding it to the raw scores before the softmax, this is
|
|
# effectively the same as removing these entirely.
|
|
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
|
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
|
return extended_attention_mask
|
|
|
|
def forward(
|
|
self,
|
|
input_ids=None,
|
|
attention_mask=None,
|
|
position_ids=None,
|
|
head_mask=None,
|
|
query_embeds=None,
|
|
encoder_hidden_states=None,
|
|
encoder_attention_mask=None,
|
|
past_key_values=None,
|
|
use_cache=None,
|
|
output_attentions=None,
|
|
output_hidden_states=None,
|
|
return_dict=None,
|
|
is_decoder=False,
|
|
):
|
|
r"""
|
|
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
|
the model is configured as a decoder.
|
|
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
|
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
|
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
|
- 1 for tokens that are **not masked**,
|
|
- 0 for tokens that are **masked**.
|
|
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
|
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
|
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
|
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
|
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
|
use_cache (:obj:`bool`, `optional`):
|
|
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
|
decoding (see :obj:`past_key_values`).
|
|
"""
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
# use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
|
|
if input_ids is None:
|
|
assert query_embeds is not None, "You have to specify query_embeds when input_ids is None"
|
|
|
|
# past_key_values_length
|
|
past_key_values_length = past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0
|
|
|
|
query_length = query_embeds.shape[1] if query_embeds is not None else 0
|
|
|
|
embedding_output = self.embeddings(
|
|
input_ids=input_ids,
|
|
position_ids=position_ids,
|
|
query_embeds=query_embeds,
|
|
past_key_values_length=past_key_values_length,
|
|
)
|
|
|
|
input_shape = embedding_output.size()[:-1]
|
|
batch_size, seq_length = input_shape
|
|
device = embedding_output.device
|
|
|
|
if attention_mask is None:
|
|
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
|
|
|
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
|
# ourselves in which case we just need to make it broadcastable to all heads.
|
|
if is_decoder:
|
|
extended_attention_mask = self.get_extended_attention_mask(
|
|
attention_mask,
|
|
input_ids.shape,
|
|
device,
|
|
is_decoder,
|
|
has_query=(query_embeds is not None),
|
|
)
|
|
else:
|
|
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device, is_decoder)
|
|
|
|
# If a 2D or 3D attention mask is provided for the cross-attention
|
|
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
|
if encoder_hidden_states is not None:
|
|
if type(encoder_hidden_states) == list:
|
|
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
|
|
else:
|
|
(
|
|
encoder_batch_size,
|
|
encoder_sequence_length,
|
|
_,
|
|
) = encoder_hidden_states.size()
|
|
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
|
|
|
if type(encoder_attention_mask) == list:
|
|
encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
|
|
elif encoder_attention_mask is None:
|
|
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
|
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
|
else:
|
|
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
|
else:
|
|
encoder_extended_attention_mask = None
|
|
|
|
# Prepare head mask if needed
|
|
# 1.0 in head_mask indicate we keep the head
|
|
# attention_probs has shape bsz x n_heads x N x N
|
|
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
|
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
|
|
|
encoder_outputs = self.encoder(
|
|
embedding_output,
|
|
attention_mask=extended_attention_mask,
|
|
head_mask=head_mask,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_extended_attention_mask,
|
|
past_key_values=past_key_values,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
query_length=query_length,
|
|
)
|
|
sequence_output = encoder_outputs[0]
|
|
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
|
|
|
if not return_dict:
|
|
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
|
|
|
return BaseModelOutputWithPoolingAndCrossAttentions(
|
|
last_hidden_state=sequence_output,
|
|
pooler_output=pooled_output,
|
|
past_key_values=encoder_outputs.past_key_values,
|
|
hidden_states=encoder_outputs.hidden_states,
|
|
attentions=encoder_outputs.attentions,
|
|
cross_attentions=encoder_outputs.cross_attentions,
|
|
)
|
|
|
|
|
|
class BertLMHeadModel(BertPreTrainedModel):
|
|
|
|
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
|
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
|
|
self.bert = BertModel(config, add_pooling_layer=False)
|
|
self.cls = BertOnlyMLMHead(config)
|
|
|
|
self.init_weights()
|
|
|
|
def get_output_embeddings(self):
|
|
return self.cls.predictions.decoder
|
|
|
|
def set_output_embeddings(self, new_embeddings):
|
|
self.cls.predictions.decoder = new_embeddings
|
|
|
|
def forward(
|
|
self,
|
|
input_ids=None,
|
|
attention_mask=None,
|
|
position_ids=None,
|
|
head_mask=None,
|
|
query_embeds=None,
|
|
encoder_hidden_states=None,
|
|
encoder_attention_mask=None,
|
|
labels=None,
|
|
past_key_values=None,
|
|
use_cache=True,
|
|
output_attentions=None,
|
|
output_hidden_states=None,
|
|
return_dict=None,
|
|
return_logits=False,
|
|
is_decoder=True,
|
|
reduction="mean",
|
|
):
|
|
r"""
|
|
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
|
the model is configured as a decoder.
|
|
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
|
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
|
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
|
- 1 for tokens that are **not masked**,
|
|
- 0 for tokens that are **masked**.
|
|
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
|
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
|
``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
|
|
ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
|
|
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
|
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
|
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
|
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
|
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
|
use_cache (:obj:`bool`, `optional`):
|
|
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
|
decoding (see :obj:`past_key_values`).
|
|
Returns:
|
|
Example::
|
|
>>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
|
|
>>> import torch
|
|
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
|
|
>>> config = BertConfig.from_pretrained("bert-base-cased")
|
|
>>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
|
|
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
|
>>> outputs = model(**inputs)
|
|
>>> prediction_logits = outputs.logits
|
|
"""
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
if labels is not None:
|
|
use_cache = False
|
|
if past_key_values is not None:
|
|
query_embeds = None
|
|
|
|
outputs = self.bert(
|
|
input_ids,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
head_mask=head_mask,
|
|
query_embeds=query_embeds,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
past_key_values=past_key_values,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
is_decoder=is_decoder,
|
|
)
|
|
|
|
sequence_output = outputs[0]
|
|
if query_embeds is not None:
|
|
sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
|
|
|
|
prediction_scores = self.cls(sequence_output)
|
|
|
|
if return_logits:
|
|
return prediction_scores[:, :-1, :].contiguous()
|
|
|
|
lm_loss = None
|
|
if labels is not None:
|
|
# we are doing next-token prediction; shift prediction scores and input ids by one
|
|
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
|
|
labels = labels[:, 1:].contiguous()
|
|
loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
|
|
lm_loss = loss_fct(
|
|
shifted_prediction_scores.view(-1, self.config.vocab_size),
|
|
labels.view(-1),
|
|
)
|
|
if reduction == "none":
|
|
lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
|
|
|
|
if not return_dict:
|
|
output = (prediction_scores,) + outputs[2:]
|
|
return ((lm_loss,) + output) if lm_loss is not None else output
|
|
|
|
return CausalLMOutputWithCrossAttentions(
|
|
loss=lm_loss,
|
|
logits=prediction_scores,
|
|
past_key_values=outputs.past_key_values,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
cross_attentions=outputs.cross_attentions,
|
|
)
|
|
|
|
def prepare_inputs_for_generation(self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs):
|
|
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
|
if attention_mask is None:
|
|
attention_mask = input_ids.new_ones(input_ids.shape)
|
|
query_mask = input_ids.new_ones(query_embeds.shape[:-1])
|
|
attention_mask = torch.cat([query_mask, attention_mask], dim=-1)
|
|
|
|
# cut decoder_input_ids if past is used
|
|
if past is not None:
|
|
input_ids = input_ids[:, -1:]
|
|
|
|
return {
|
|
"input_ids": input_ids,
|
|
"query_embeds": query_embeds,
|
|
"attention_mask": attention_mask,
|
|
"past_key_values": past,
|
|
"encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
|
|
"encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
|
|
"is_decoder": True,
|
|
}
|
|
|
|
def _reorder_cache(self, past, beam_idx):
|
|
reordered_past = ()
|
|
for layer_past in past:
|
|
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
|
return reordered_past
|
|
|
|
|
|
class BertForMaskedLM(BertPreTrainedModel):
|
|
|
|
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
|
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
|
|
self.bert = BertModel(config, add_pooling_layer=False)
|
|
self.cls = BertOnlyMLMHead(config)
|
|
|
|
self.init_weights()
|
|
|
|
def get_output_embeddings(self):
|
|
return self.cls.predictions.decoder
|
|
|
|
def set_output_embeddings(self, new_embeddings):
|
|
self.cls.predictions.decoder = new_embeddings
|
|
|
|
def forward(
|
|
self,
|
|
input_ids=None,
|
|
attention_mask=None,
|
|
position_ids=None,
|
|
head_mask=None,
|
|
query_embeds=None,
|
|
encoder_hidden_states=None,
|
|
encoder_attention_mask=None,
|
|
labels=None,
|
|
output_attentions=None,
|
|
output_hidden_states=None,
|
|
return_dict=None,
|
|
return_logits=False,
|
|
is_decoder=False,
|
|
):
|
|
r"""
|
|
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
|
Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
|
|
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
|
|
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
|
|
"""
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
outputs = self.bert(
|
|
input_ids,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
head_mask=head_mask,
|
|
query_embeds=query_embeds,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
is_decoder=is_decoder,
|
|
)
|
|
|
|
if query_embeds is not None:
|
|
sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
|
|
prediction_scores = self.cls(sequence_output)
|
|
|
|
if return_logits:
|
|
return prediction_scores
|
|
|
|
masked_lm_loss = None
|
|
if labels is not None:
|
|
loss_fct = CrossEntropyLoss() # -100 index = padding token
|
|
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
|
|
|
if not return_dict:
|
|
output = (prediction_scores,) + outputs[2:]
|
|
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
|
|
|
return MaskedLMOutput(
|
|
loss=masked_lm_loss,
|
|
logits=prediction_scores,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
|
|
class Qformer(nn.Module):
|
|
def __init__(self, model_args, vision_tower):
|
|
super().__init__()
|
|
|
|
self.depth = model_args.mm_qformer_depth
|
|
self.num_latents = model_args.mm_qformer_latents
|
|
self.pretrained = model_args.mm_qformer_pretrained
|
|
|
|
self.Qformer, self.query_tokens, self.ln_vision = self.build_Qformer(vision_tower.hidden_size, self.depth, self.num_latents)
|
|
|
|
if self.pretrained is not None:
|
|
pretrained_dict = torch.load(self.pretrained, map_location="cpu")["model"]
|
|
pretrained_dict = {k: v for k, v in pretrained_dict.items() if not k.startswith("t5_proj")}
|
|
self.load_state_dict(pretrained_dict)
|
|
|
|
def build_Qformer(self, vision_width, cross_attention_freq, num_query_token):
|
|
encoder_config = BertConfig.from_pretrained("bert-base-uncased")
|
|
encoder_config.encoder_width = vision_width
|
|
# insert cross-attention layer every other block
|
|
encoder_config.add_cross_attention = True
|
|
encoder_config.cross_attention_freq = cross_attention_freq
|
|
encoder_config.query_length = num_query_token
|
|
Qformer = BertLMHeadModel(config=encoder_config)
|
|
query_tokens = nn.Parameter(torch.zeros(1, num_query_token, encoder_config.hidden_size))
|
|
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
|
|
Qformer.cls = None
|
|
Qformer.bert.embeddings.word_embeddings = None
|
|
Qformer.bert.embeddings.position_embeddings = None
|
|
for layer in Qformer.bert.encoder.layer:
|
|
layer.output = None
|
|
layer.intermediate = None
|
|
return Qformer, query_tokens, nn.LayerNorm(vision_width)
|
|
|
|
def forward(self, image_features, *args, **kwargs):
|
|
x = self.ln_vision(image_features)
|
|
image_atts = torch.ones(x.size()[:-1], dtype=torch.long).to(x.device)
|
|
|
|
query_tokens = self.query_tokens.expand(x.shape[0], -1, -1)
|
|
query_output = self.Qformer.bert(
|
|
query_embeds=query_tokens,
|
|
encoder_hidden_states=x,
|
|
encoder_attention_mask=image_atts,
|
|
return_dict=True,
|
|
)
|
|
|
|
return query_output.last_hidden_state
|
|
|
|
@property
|
|
def hidden_size(self):
|
|
return 768
|
|
|
|
@property
|
|
def config(self):
|
|
return {
|
|
"mm_resampler_type": "qformer",
|
|
"mm_qformer_depth": self.depth,
|
|
"mm_qformer_latents": self.num_latents,
|
|
"mm_qformer_pretrained": self.pretrained,
|
|
}
|
|
|
|
|
|
################### Resampler: Perciever ###################
|
|
def exists(val):
|
|
return val is not None
|
|
|
|
|
|
def FeedForward(dim, mult=4):
|
|
inner_dim = int(dim * mult)
|
|
return nn.Sequential(
|
|
nn.LayerNorm(dim),
|
|
nn.Linear(dim, inner_dim, bias=False),
|
|
nn.GELU(),
|
|
nn.Linear(inner_dim, dim, bias=False),
|
|
)
|
|
|
|
|
|
class PerceiverAttention(nn.Module):
|
|
def __init__(self, *, dim, dim_head=64, heads=8):
|
|
super().__init__()
|
|
self.scale = dim_head**-0.5
|
|
self.heads = heads
|
|
inner_dim = dim_head * heads
|
|
|
|
self.norm_media = nn.LayerNorm(dim)
|
|
self.norm_latents = nn.LayerNorm(dim)
|
|
|
|
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
|
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
|
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
|
|
|
def forward(self, x, latents):
|
|
"""
|
|
Args:
|
|
x (torch.Tensor): image features
|
|
shape (b, T, n1, D)
|
|
latent (torch.Tensor): latent features
|
|
shape (b, T, n2, D)
|
|
"""
|
|
x = self.norm_media(x)
|
|
latents = self.norm_latents(latents)
|
|
|
|
h = self.heads
|
|
|
|
q = self.to_q(latents)
|
|
kv_input = torch.cat((x, latents), dim=-2)
|
|
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
|
q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h)
|
|
q = q * self.scale
|
|
|
|
# attention
|
|
sim = einsum("... i d, ... j d -> ... i j", q, k)
|
|
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
|
attn = sim.softmax(dim=-1)
|
|
|
|
out = einsum("... i j, ... j d -> ... i d", attn, v)
|
|
out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
|
|
return self.to_out(out)
|
|
|
|
|
|
class PerceiverResamplerModule(nn.Module):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
dim,
|
|
depth=6,
|
|
dim_head=64,
|
|
heads=8,
|
|
num_latents=64,
|
|
max_num_media=None,
|
|
max_num_frames=None,
|
|
ff_mult=4,
|
|
):
|
|
super().__init__()
|
|
self.latents = nn.Parameter(torch.randn(num_latents, dim))
|
|
self.frame_embs = nn.Parameter(torch.randn(max_num_frames, dim)) if exists(max_num_frames) else None
|
|
self.media_time_embs = nn.Parameter(torch.randn(max_num_media, 1, dim)) if exists(max_num_media) else None
|
|
|
|
self.layers = nn.ModuleList([])
|
|
for _ in range(depth):
|
|
self.layers.append(
|
|
nn.ModuleList(
|
|
[
|
|
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
|
FeedForward(dim=dim, mult=ff_mult) if ff_mult > 0 else nn.Identity(),
|
|
]
|
|
)
|
|
)
|
|
|
|
self.norm = nn.LayerNorm(dim)
|
|
|
|
def forward(self, x):
|
|
"""
|
|
Args:
|
|
x (torch.Tensor): image features
|
|
shape (b, T, F, v, D)
|
|
Returns:
|
|
shape (b, T, n, D) where n is self.num_latents
|
|
"""
|
|
b, T, F, v = x.shape[:4]
|
|
|
|
# frame and media time embeddings
|
|
if exists(self.frame_embs):
|
|
frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
|
|
x = x + frame_embs
|
|
x = rearrange(x, "b T F v d -> b T (F v) d") # flatten the frame and spatial dimensions
|
|
if exists(self.media_time_embs):
|
|
x = x + self.media_time_embs[:T]
|
|
|
|
# blocks
|
|
latents = repeat(self.latents, "n d -> b T n d", b=b, T=T)
|
|
for attn, ff in self.layers:
|
|
latents = attn(x, latents) + latents
|
|
latents = ff(latents) + latents
|
|
return self.norm(latents)
|
|
|
|
|
|
class PerceiverResampler(nn.Module):
|
|
def __init__(self, model_args, vision_tower):
|
|
super().__init__()
|
|
|
|
self.depth = model_args.mm_perceiver_depth
|
|
self.num_latents = model_args.mm_perceiver_latents
|
|
self.ff_mult = model_args.mm_perceiver_ff_mult
|
|
self.pretrained = model_args.mm_perceiver_pretrained
|
|
|
|
self.perceiver = PerceiverResamplerModule(dim=vision_tower.hidden_size, depth=self.depth, num_latents=self.num_latents, ff_mult=self.ff_mult)
|
|
|
|
if self.pretrained is not None:
|
|
self.load_state_dict(torch.load(self.pretrained))
|
|
|
|
def forward(self, image_features, *args, **kwargs):
|
|
return self.perceiver(image_features[:, None, None]).squeeze(1)
|
|
|
|
@property
|
|
def config(self):
|
|
return {
|
|
"mm_resampler_type": "perceiver",
|
|
"mm_perceiver_depth": self.depth,
|
|
"mm_perceiver_latents": self.num_latents,
|
|
"mm_perceiver_ff_mult": self.ff_mult,
|
|
"mm_perceiver_pretrained": self.pretrained,
|
|
}
|
|
|
|
######################### Resampler: Masker Drop #########################
|
|
class MaskedDrop(nn.Module):
|
|
def __init__(self, model_args):
|
|
super().__init__()
|
|
|
|
self.mode = model_args.mm_mask_drop_mode
|
|
self.skip_percentage = model_args.mm_mask_drop_skip_percentage
|
|
self.ratio = model_args.mm_mask_drop_ratio
|
|
self.ratio_upper = model_args.mm_mask_drop_ratio_upper
|
|
self.ratio_lower = model_args.mm_mask_drop_ratio_lower
|
|
|
|
def forward(self, image_features, *args, **kwargs):
|
|
|
|
if not self.training:
|
|
return image_features
|
|
|
|
if self.skip_percentage > random.random():
|
|
return image_features
|
|
|
|
masked_features = []
|
|
|
|
for image_feature in image_features:
|
|
num_tokens = image_feature.shape[0]
|
|
if self.mode == "fixed":
|
|
num_keep = int(num_tokens * self.ratio)
|
|
masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0][0])
|
|
elif self.mode == "range":
|
|
num_keep = int(num_tokens * random.uniform(self.ratio_lower, self.ratio_upper))
|
|
masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0])
|
|
elif self.mode == "cls_only":
|
|
masked_features.append(image_feature[0:1])
|
|
else:
|
|
raise ValueError(f"Unexpected masked drop mode: {self.mode}")
|
|
|
|
if self.mode not in ["range"] and (type(image_features) is not list or self.mode in ["cls_only"]):
|
|
masked_features = torch.stack(masked_features, dim=0)
|
|
|
|
return masked_features
|
|
|
|
@property
|
|
def config(self):
|
|
return {
|
|
"mm_resampler_type": "masked_drop",
|
|
"mm_mask_drop_mode": self.mode,
|
|
"mm_mask_drop_skip_percentage": self.skip_percentage,
|
|
"mm_mask_drop_ratio": self.ratio,
|
|
"mm_mask_drop_ratio_upper": self.ratio_upper,
|
|
"mm_mask_drop_ratio_lower": self.ratio_lower,
|
|
}
|
|
|
|
def random_masking(self, x, len_keep):
|
|
"""
|
|
Perform per-sample random masking by per-sample shuffling.
|
|
Per-sample shuffling is done by argsort random noise.
|
|
x: [N, L, D], sequence
|
|
"""
|
|
N, L, D = x.shape # batch, length, dim
|
|
|
|
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
|
|
|
|
# sort noise for each sample
|
|
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
|
|
ids_restore = torch.argsort(ids_shuffle, dim=1)
|
|
|
|
# keep the first subset
|
|
ids_keep = ids_shuffle[:, :len_keep]
|
|
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
|
|
|
|
# generate the binary mask: 0 is keep, 1 is remove
|
|
mask = torch.ones([N, L], device=x.device)
|
|
mask[:, :len_keep] = 0
|
|
# unshuffle to get the binary mask
|
|
mask = torch.gather(mask, dim=1, index=ids_restore)
|
|
|
|
return x_masked, mask, ids_restore
|
|
|
|
class IdentityMap(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x, *args, **kwargs):
|
|
return x
|
|
|
|
@property
|
|
def config(self):
|
|
return {"mm_resampler_type": None}
|
|
|
|
###################### Resampler - Builder ######################
|
|
def build_vision_resampler(model_args, delay_load=False, **kwargs):
|
|
resampler_type = getattr(model_args, "mm_resampler_type", None)
|
|
if resampler_type == "masked_drop":
|
|
return MaskedDrop(model_args)
|
|
elif resampler_type == "spatial_pool":
|
|
return SpatialPool(model_args, **kwargs)
|
|
elif resampler_type == "perceiver":
|
|
return PerceiverResampler(model_args, **kwargs)
|
|
elif resampler_type == "qformer":
|
|
return Qformer(model_args, **kwargs)
|
|
elif resampler_type is None:
|
|
return IdentityMap()
|
|
|
|
raise ValueError(f"Unknown resampler type: {resampler_type}")
|
|
|
|
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
|
|
|
|
######################## Vision Tower ######################
|
|
class CLIPVisionTower(nn.Module):
|
|
r"""
|
|
A class to represent the CLIP Vision Tower model.
|
|
|
|
Attributes :
|
|
------------
|
|
- is_loaded (bool): A flag indicating whether the model is loaded.
|
|
- vision_tower_name (str): The name of the vision tower model.
|
|
- select_layer (int): The layer to select features from.
|
|
- select_feature (str): The type of feature to select.
|
|
|
|
Methods :
|
|
------------
|
|
- `__init__(vision_tower: str, args: Namespace, delay_load: bool = False)`: Initializes the CLIPVisionTower with the given vision tower name and arguments.
|
|
- `load_model(device_map: Optional[dict] = None)`: Loads the vision tower model and image processor.
|
|
- `feature_select(image_forward_outs: Any) -> torch.Tensor`: Selects features from the image forward outputs based on the specified feature type.
|
|
- `forward(images: Union[torch.Tensor, List[torch.Tensor]]) -> torch.Tensor`: Forward pass for the vision tower model.
|
|
- `dummy_feature() -> torch.Tensor`: Returns a dummy feature tensor.
|
|
- `dtype() -> torch.dtype`: Returns the data type of the vision tower model.
|
|
- `device() -> torch.device`: Returns the device of the vision tower model.
|
|
- `config() -> Any`: Returns the configuration of the vision tower model.
|
|
- `hidden_size() -> int`: Returns the hidden size of the vision tower model.
|
|
- `num_patches_per_side() -> int`: Returns the number of patches per side of the image.
|
|
- `num_patches() -> int`: Returns the total number of patches in the image.
|
|
- `image_size() -> int`: Returns the size of the image.
|
|
"""
|
|
|
|
def __init__(self, vision_tower, args, delay_load=False):
|
|
super().__init__()
|
|
|
|
self.is_loaded = False
|
|
|
|
self.vision_tower_name = vision_tower
|
|
self.select_layer = args.mm_vision_select_layer
|
|
self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
|
|
|
|
if not delay_load:
|
|
rank0_print(f"Loading vision tower: {vision_tower}")
|
|
self.load_model()
|
|
elif getattr(args, "unfreeze_mm_vision_tower", False):
|
|
# TODO: better detector is needed.
|
|
rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.")
|
|
self.load_model()
|
|
elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts:
|
|
rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.")
|
|
self.load_model()
|
|
else:
|
|
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
|
|
|
|
def load_model(self, device_map=None):
|
|
if self.is_loaded:
|
|
rank0_print("{} is already loaded, `load_model` called again, skipping.".format(self.vision_tower_name))
|
|
return
|
|
|
|
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
|
|
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
|
|
self.vision_tower.requires_grad_(False)
|
|
|
|
self.is_loaded = True
|
|
|
|
def feature_select(self, image_forward_outs):
|
|
select_feature_type = self.select_feature
|
|
|
|
if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]:
|
|
select_every_k_layer = len(image_forward_outs.hidden_states) // 4
|
|
image_features = torch.cat([image_forward_outs.hidden_states[i] for i in range(select_every_k_layer + self.select_layer, len(image_forward_outs.hidden_states), select_every_k_layer)], dim=-1)
|
|
select_feature_type = select_feature_type.replace("slicefour_", "")
|
|
elif self.select_feature in ["slice_m25811_f6_patch", "slice_m25811_f6_cls_patch"]:
|
|
select_layers = [-2, -5, -8, -11, 6]
|
|
image_features = torch.cat([image_forward_outs.hidden_states[i] for i in select_layers], dim=-1)
|
|
select_feature_type = select_feature_type.replace("slice_m25811_f6_", "")
|
|
else:
|
|
image_features = image_forward_outs.hidden_states[self.select_layer]
|
|
|
|
if select_feature_type == "patch":
|
|
image_features = image_features[:, 1:]
|
|
elif select_feature_type == "cls_patch":
|
|
image_features = image_features
|
|
else:
|
|
raise ValueError(f"Unexpected select feature: {select_feature_type}")
|
|
return image_features
|
|
|
|
def forward(self, images):
|
|
if type(images) is list:
|
|
image_features = []
|
|
for image in images:
|
|
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
|
|
image_feature = self.feature_select(image_forward_out).to(image.dtype)
|
|
image_features.append(image_feature)
|
|
else:
|
|
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
|
|
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
|
|
|
return image_features
|
|
|
|
@property
|
|
def dummy_feature(self):
|
|
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
|
|
|
@property
|
|
def dtype(self):
|
|
return self.vision_tower.dtype
|
|
|
|
@property
|
|
def device(self):
|
|
return self.vision_tower.device
|
|
|
|
@property
|
|
def config(self):
|
|
if self.is_loaded:
|
|
return self.vision_tower.config
|
|
else:
|
|
return self.cfg_only
|
|
|
|
@property
|
|
def hidden_size(self):
|
|
_hidden_size = self.config.hidden_size
|
|
if "slicefour" in self.select_feature:
|
|
_hidden_size *= 4
|
|
if "slice_m25811_f6" in self.select_feature:
|
|
_hidden_size *= 5
|
|
return _hidden_size
|
|
|
|
@property
|
|
def num_patches_per_side(self):
|
|
return self.config.image_size // self.config.patch_size
|
|
|
|
@property
|
|
def num_patches(self):
|
|
_num_patches = (self.config.image_size // self.config.patch_size) ** 2
|
|
if "cls_patch" in self.select_feature:
|
|
_num_patches += 1
|
|
return _num_patches
|
|
|
|
@property
|
|
def image_size(self):
|
|
return self.config.image_size
|
|
|
|
def build_vision_tower(vision_tower_cfg, **kwargs):
|
|
vision_tower = getattr(vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None))
|
|
is_absolute_path_exists = os.path.exists(vision_tower)
|
|
if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower:
|
|
return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
|
|
|
|
raise ValueError(f"Unknown vision tower: {vision_tower}")
|
|
|
|
class InstellaVLMetaModel:
|
|
|
|
def __init__(self, config):
|
|
super(InstellaVLMetaModel, self).__init__(config)
|
|
|
|
if hasattr(config, "mm_vision_tower"):
|
|
delay_load = getattr(config, "delay_load", False)
|
|
self.vision_tower = build_vision_tower(config, delay_load=delay_load)
|
|
self.vision_resampler = build_vision_resampler(config, vision_tower=self.vision_tower)
|
|
self.mm_projector = build_vision_projector(config, vision_cfg=self.vision_tower.config)
|
|
|
|
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
|
self.image_newline = nn.Parameter(torch.empty(config.hidden_size, dtype=self.dtype))
|
|
|
|
def get_vision_tower(self):
|
|
vision_tower = getattr(self, "vision_tower", None)
|
|
if type(vision_tower) is list:
|
|
vision_tower = vision_tower[0]
|
|
return vision_tower
|
|
|
|
def initialize_vision_modules(self, model_args, fsdp=None):
|
|
vision_tower = model_args.vision_tower
|
|
mm_vision_select_layer = model_args.mm_vision_select_layer
|
|
mm_vision_select_feature = model_args.mm_vision_select_feature
|
|
pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
|
|
mm_patch_merge_type = model_args.mm_patch_merge_type
|
|
|
|
self.config.mm_vision_tower = vision_tower
|
|
self.config.vision_tower_pretrained = getattr(model_args, "vision_tower_pretrained", "")
|
|
|
|
if self.get_vision_tower() is None:
|
|
vision_tower = build_vision_tower(model_args)
|
|
vision_resampler = build_vision_resampler(model_args, vision_tower=vision_tower)
|
|
for k, v in vision_resampler.config.items():
|
|
setattr(self.config, k, v)
|
|
|
|
if fsdp is not None and len(fsdp) > 0:
|
|
self.vision_tower = [vision_tower]
|
|
self.vision_resampler = [vision_resampler]
|
|
else:
|
|
self.vision_tower = vision_tower
|
|
self.vision_resampler = vision_resampler
|
|
else:
|
|
if fsdp is not None and len(fsdp) > 0:
|
|
vision_resampler = self.vision_resampler[0]
|
|
vision_tower = self.vision_tower[0]
|
|
else:
|
|
vision_resampler = self.vision_resampler
|
|
vision_tower = self.vision_tower
|
|
vision_tower.load_model()
|
|
|
|
# In case it is frozen by LoRA
|
|
for p in self.vision_resampler.parameters():
|
|
p.requires_grad = True
|
|
|
|
self.config.use_mm_proj = True
|
|
self.config.mm_projector_type = getattr(model_args, "mm_projector_type", "linear")
|
|
self.config.mm_hidden_size = getattr(vision_resampler, "hidden_size", vision_tower.hidden_size)
|
|
self.config.mm_vision_select_layer = mm_vision_select_layer
|
|
self.config.mm_vision_select_feature = mm_vision_select_feature
|
|
self.config.mm_patch_merge_type = mm_patch_merge_type
|
|
self.config.online_training = model_args.online_training
|
|
|
|
if getattr(self, "mm_projector", None) is None:
|
|
self.mm_projector = build_vision_projector(self.config, vision_cfg=vision_tower.config)
|
|
|
|
if "unpad" in mm_patch_merge_type:
|
|
embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
|
|
self.image_newline = nn.Parameter(torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std)
|
|
else:
|
|
# In case it is frozen by LoRA
|
|
for p in self.mm_projector.parameters():
|
|
p.requires_grad = True
|
|
|
|
if pretrain_mm_mlp_adapter is not None:
|
|
mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location="cpu")
|
|
|
|
def get_w(weights, keyword):
|
|
return {k.split(keyword + ".")[1]: v for k, v in weights.items() if keyword in k}
|
|
|
|
incompatible_keys = self.mm_projector.load_state_dict(get_w(mm_projector_weights, "mm_projector"))
|
|
rank0_print(f"Loaded mm projector weights from {pretrain_mm_mlp_adapter}. Incompatible keys: {incompatible_keys}")
|
|
incompatible_keys = self.vision_resampler.load_state_dict(get_w(mm_projector_weights, "vision_resampler"), strict=False)
|
|
rank0_print(f"Loaded vision resampler weights from {pretrain_mm_mlp_adapter}. Incompatible keys: {incompatible_keys}")
|
|
|
|
if 'tmp-' in pretrain_mm_mlp_adapter:
|
|
pretrain_mm_mlp_adapter_folder = os.path.dirname(pretrain_mm_mlp_adapter)
|
|
shutil.rmtree(pretrain_mm_mlp_adapter_folder, ignore_errors=True)
|
|
|
|
|
|
|
|
def unpad_image(tensor, original_size):
|
|
"""
|
|
Unpads a PyTorch tensor of a padded and resized image.
|
|
|
|
Args:
|
|
tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
|
|
original_size (tuple): The original size of the image (height, width).
|
|
|
|
Returns:
|
|
torch.Tensor: The unpadded image tensor.
|
|
"""
|
|
original_width, original_height = original_size
|
|
current_height, current_width = tensor.shape[1:]
|
|
|
|
# Compute aspect ratios
|
|
original_aspect_ratio = original_width / original_height
|
|
current_aspect_ratio = current_width / current_height
|
|
|
|
# Determine padding size and direction
|
|
if original_aspect_ratio > current_aspect_ratio:
|
|
# Padding was added to the height
|
|
scale_factor = current_width / original_width
|
|
new_height = int(original_height * scale_factor)
|
|
padding = (current_height - new_height) // 2
|
|
unpadded_tensor = tensor[:, padding : current_height - padding, :]
|
|
else:
|
|
# Padding was added to the width
|
|
scale_factor = current_height / original_height
|
|
new_width = int(original_width * scale_factor)
|
|
padding = (current_width - new_width) // 2
|
|
unpadded_tensor = tensor[:, :, padding : current_width - padding]
|
|
|
|
return unpadded_tensor
|
|
|
|
|
|
class InstellaVLMetaForCausalLM(ABC):
|
|
|
|
@abstractmethod
|
|
def get_model(self):
|
|
pass
|
|
|
|
def get_vision_tower(self):
|
|
return self.get_model().get_vision_tower()
|
|
|
|
def get_2dPool(self, image_feature):
|
|
height = width = self.get_vision_tower().num_patches_per_side
|
|
num_frames, num_tokens, num_dim = image_feature.shape
|
|
image_feature = image_feature.view(num_frames, height, width, -1)
|
|
image_feature = image_feature.permute(0, 3, 1, 2).contiguous()
|
|
# image_feature = nn.functional.max_pool2d(image_feature, self.config.mm_spatial_pool_stride)
|
|
if self.config.mm_spatial_pool_mode == "average":
|
|
image_feature = nn.functional.avg_pool2d(image_feature, self.config.mm_spatial_pool_stride)
|
|
elif self.config.mm_spatial_pool_mode == "max":
|
|
image_feature = nn.functional.max_pool2d(image_feature, self.config.mm_spatial_pool_stride)
|
|
elif self.config.mm_spatial_pool_mode == "bilinear":
|
|
height, weight = image_feature.shape[2:]
|
|
scaled_shape = [math.ceil(height / 2), math.ceil(weight / 2)]
|
|
image_feature = nn.functional.interpolate(image_feature, size=scaled_shape, mode='bilinear')
|
|
|
|
else:
|
|
raise ValueError(f"Unexpected mm_spatial_pool_mode: {self.config.mm_spatial_pool_mode}")
|
|
image_feature = image_feature.permute(0, 2, 3, 1)
|
|
image_feature = image_feature.view(num_frames, -1, num_dim)
|
|
return image_feature
|
|
|
|
def encode_images(self, images):
|
|
image_features = self.get_model().get_vision_tower()(images)
|
|
# image_features = self.get_model().vision_resampler(image_features, images=images)
|
|
image_features = self.get_model().mm_projector(image_features)
|
|
return image_features
|
|
|
|
def encode_multimodals(self, videos_or_images, video_idx_in_batch, split_sizes=None):
|
|
videos_or_images_features = self.get_model().get_vision_tower()(videos_or_images)
|
|
per_videos_or_images_features = torch.split(videos_or_images_features, split_sizes, dim=0) # tuple, (dim_1, 576, 4096)
|
|
all_videos_or_images_features = []
|
|
|
|
for idx, feat in enumerate(per_videos_or_images_features):
|
|
feat = self.get_model().mm_projector(feat)
|
|
if idx in video_idx_in_batch:
|
|
feat = self.get_2dPool(feat)
|
|
all_videos_or_images_features.append(feat)
|
|
return all_videos_or_images_features
|
|
|
|
def add_token_per_grid(self, image_feature):
|
|
resize_h = int(math.sqrt(image_feature.shape[1]))
|
|
num_frames = image_feature.shape[0]
|
|
image_feature = image_feature.view(num_frames, 1, resize_h, resize_h, -1)
|
|
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
|
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
|
image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1)
|
|
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
|
return image_feature
|
|
|
|
def add_token_per_frame(self, image_feature):
|
|
image_feature = image_feature.permute(2, 0, 1).contiguous()
|
|
image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1)
|
|
image_feature = image_feature.permute(1, 2, 0).contiguous()
|
|
return image_feature
|
|
|
|
def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities=["image"], image_sizes=None):
|
|
vision_tower = self.get_vision_tower()
|
|
# rank_print(modalities)
|
|
if vision_tower is None or images is None or input_ids.shape[1] == 1:
|
|
return input_ids, position_ids, attention_mask, past_key_values, None, labels
|
|
|
|
if isinstance(modalities, str):
|
|
modalities = [modalities]
|
|
|
|
if type(images) is list or images.ndim == 5:
|
|
if type(images) is list:
|
|
images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
|
|
|
|
video_idx_in_batch = []
|
|
for _ in range(len(modalities)):
|
|
if modalities[_] == "video":
|
|
video_idx_in_batch.append(_)
|
|
|
|
# print(video_idx_in_batch)
|
|
|
|
images_list = []
|
|
for image in images:
|
|
if image.ndim == 4:
|
|
images_list.append(image)
|
|
else:
|
|
images_list.append(image.unsqueeze(0))
|
|
|
|
# import pdb;pdb.set_trace()
|
|
concat_images = torch.cat([image for image in images_list], dim=0)
|
|
split_sizes = [image.shape[0] for image in images_list]
|
|
encoded_image_features = self.encode_images(concat_images)
|
|
# import pdb
|
|
# pdb.set_trace()
|
|
|
|
# This is a list, each element is [num_images, patch * patch, dim]
|
|
# rank_print(f"Concat images : {concat_images.shape}")
|
|
encoded_image_features = torch.split(encoded_image_features, split_sizes)
|
|
image_features = []
|
|
for idx, image_feat in enumerate(encoded_image_features):
|
|
if idx in video_idx_in_batch:
|
|
image_features.append(self.get_2dPool(image_feat))
|
|
else:
|
|
image_features.append(image_feat)
|
|
# image_features = self.encode_multimodals(concat_images, video_idx_in_batch, split_sizes)
|
|
# rank_print(f"Encoded image feats : {[x.shape for x in image_features]}")
|
|
# image_features = torch.split(image_features, split_sizes, dim=0)
|
|
mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
|
|
image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
|
|
|
|
if mm_patch_merge_type == "flat":
|
|
image_features = [x.flatten(0, 1) for x in image_features]
|
|
|
|
elif mm_patch_merge_type.startswith("spatial"):
|
|
new_image_features = []
|
|
for image_idx, image_feature in enumerate(image_features):
|
|
# FIXME: now assume the image is square, and split to 2x2 patches
|
|
# num_patches = h * w, where h = w = sqrt(num_patches)
|
|
# currently image_feature is a tensor of shape (4, num_patches, hidden_size)
|
|
# we want to first unflatten it to (2, 2, h, w, hidden_size)
|
|
# rank0_print("At least we are reaching here")
|
|
if image_idx in video_idx_in_batch: # video operations
|
|
# rank0_print("Video")
|
|
if self.config.mm_newline_position == "grid":
|
|
# Grid-wise
|
|
image_feature = self.add_token_per_grid(image_feature)
|
|
|
|
new_image_features.append(image_feature)
|
|
elif self.config.mm_newline_position == "frame":
|
|
# Frame-wise
|
|
image_feature = self.add_token_per_frame(image_feature)
|
|
|
|
new_image_features.append(image_feature.flatten(0, 1))
|
|
|
|
elif self.config.mm_newline_position == "one_token":
|
|
# one-token
|
|
image_feature = image_feature.flatten(0, 1)
|
|
if 'unpad' in mm_patch_merge_type:
|
|
image_feature = torch.cat((
|
|
image_feature,
|
|
self.model.image_newline[None].to(image_feature.device)
|
|
), dim=0)
|
|
new_image_features.append(image_feature)
|
|
elif self.config.mm_newline_position == "no_token":
|
|
new_image_features.append(image_feature.flatten(0, 1))
|
|
else:
|
|
raise ValueError(f"Unexpected mm_newline_position: {self.config.mm_newline_position}")
|
|
|
|
|
|
elif image_feature.shape[0] > 1: # multi patches and multi images operations
|
|
base_image_feature = image_feature[0]
|
|
image_feature = image_feature[1:]
|
|
height = width = self.get_vision_tower().num_patches_per_side
|
|
|
|
assert height * width == base_image_feature.shape[0]
|
|
|
|
if "anyres_max" in image_aspect_ratio:
|
|
matched_anyres_max_num_patches = re.match(r"anyres_max_(\d+)", image_aspect_ratio)
|
|
if matched_anyres_max_num_patches:
|
|
max_num_patches = int(matched_anyres_max_num_patches.group(1))
|
|
|
|
if image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
|
|
if hasattr(self.get_vision_tower(), "image_size"):
|
|
vision_tower_image_size = self.get_vision_tower().image_size
|
|
else:
|
|
raise ValueError("vision_tower_image_size is not found in the vision tower.")
|
|
try:
|
|
num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, vision_tower_image_size)
|
|
except Exception as e:
|
|
rank0_print(f"Error: {e}")
|
|
num_patch_width, num_patch_height = 2, 2
|
|
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
|
|
else:
|
|
image_feature = image_feature.view(2, 2, height, width, -1)
|
|
|
|
if "maxpool2x2" in mm_patch_merge_type:
|
|
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
|
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
|
image_feature = nn.functional.max_pool2d(image_feature, 2)
|
|
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
|
elif "unpad" in mm_patch_merge_type and "anyres_max" in image_aspect_ratio and matched_anyres_max_num_patches:
|
|
unit = image_feature.shape[2]
|
|
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
|
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
|
image_feature = unpad_image(image_feature, image_sizes[image_idx])
|
|
c, h, w = image_feature.shape
|
|
times = math.sqrt(h * w / (max_num_patches * unit**2))
|
|
if times > 1.1:
|
|
image_feature = image_feature[None]
|
|
image_feature = nn.functional.interpolate(image_feature, [int(h // times), int(w // times)], mode="bilinear")[0]
|
|
image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1)
|
|
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
|
elif "unpad" in mm_patch_merge_type:
|
|
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
|
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
|
image_feature = unpad_image(image_feature, image_sizes[image_idx])
|
|
image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1)
|
|
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
|
else:
|
|
image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
|
|
image_feature = image_feature.flatten(0, 3)
|
|
if "nobase" in mm_patch_merge_type:
|
|
pass
|
|
else:
|
|
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
|
|
else: # single image operations
|
|
image_feature = image_feature[0]
|
|
if "unpad" in mm_patch_merge_type:
|
|
image_feature = torch.cat((image_feature, self.model.image_newline[None]), dim=0)
|
|
|
|
new_image_features.append(image_feature)
|
|
image_features = new_image_features
|
|
else:
|
|
raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
|
|
else:
|
|
image_features = self.encode_images(images)
|
|
|
|
# TODO: image start / end is not implemented here to support pretraining.
|
|
if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(self.config, "mm_use_im_start_end", False):
|
|
raise NotImplementedError
|
|
# rank_print(f"Total images : {len(image_features)}")
|
|
|
|
# Let's just add dummy tensors if they do not exist,
|
|
# it is a headache to deal with None all the time.
|
|
# But it is not ideal, and if you have a better idea,
|
|
# please open an issue / submit a PR, thanks.
|
|
_labels = labels
|
|
_position_ids = position_ids
|
|
_attention_mask = attention_mask
|
|
if attention_mask is None:
|
|
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
|
|
else:
|
|
attention_mask = attention_mask.bool()
|
|
if position_ids is None:
|
|
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
|
|
if labels is None:
|
|
labels = torch.full_like(input_ids, IGNORE_INDEX)
|
|
|
|
# remove the padding using attention_mask -- FIXME
|
|
_input_ids = input_ids
|
|
input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
|
|
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
|
|
|
|
new_input_embeds = []
|
|
new_labels = []
|
|
cur_image_idx = 0
|
|
# rank_print("Inserting Images embedding")
|
|
for batch_idx, cur_input_ids in enumerate(input_ids):
|
|
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
|
|
# rank0_print(num_images)
|
|
if num_images == 0:
|
|
try:
|
|
cur_image_features = image_features[cur_image_idx]
|
|
except IndexError:
|
|
try:
|
|
cur_image_features = image_features[cur_image_idx - 1]
|
|
except IndexError:
|
|
pass
|
|
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
|
|
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
|
|
new_input_embeds.append(cur_input_embeds)
|
|
new_labels.append(labels[batch_idx])
|
|
cur_image_idx += 1
|
|
continue
|
|
|
|
image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
|
|
cur_input_ids_noim = []
|
|
cur_labels = labels[batch_idx]
|
|
cur_labels_noim = []
|
|
for i in range(len(image_token_indices) - 1):
|
|
cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]])
|
|
cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]])
|
|
split_sizes = [x.shape[0] for x in cur_labels_noim]
|
|
cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
|
|
cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
|
|
cur_new_input_embeds = []
|
|
cur_new_labels = []
|
|
|
|
for i in range(num_images + 1):
|
|
cur_new_input_embeds.append(cur_input_embeds_no_im[i])
|
|
cur_new_labels.append(cur_labels_noim[i])
|
|
if i < num_images:
|
|
try:
|
|
cur_image_features = image_features[cur_image_idx]
|
|
except IndexError:
|
|
cur_image_features = image_features[cur_image_idx - 1]
|
|
cur_image_idx += 1
|
|
cur_new_input_embeds.append(cur_image_features)
|
|
cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
|
|
|
|
cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
|
|
|
|
# import pdb; pdb.set_trace()
|
|
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
|
|
cur_new_labels = torch.cat(cur_new_labels)
|
|
|
|
new_input_embeds.append(cur_new_input_embeds)
|
|
new_labels.append(cur_new_labels)
|
|
|
|
# Truncate sequences to max length as image embeddings can make the sequence longer
|
|
tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None)
|
|
# rank_print("Finishing Inserting")
|
|
|
|
new_input_embeds = [x[:tokenizer_model_max_length] for x, modality in zip(new_input_embeds, modalities)]
|
|
new_labels = [x[:tokenizer_model_max_length] for x, modality in zip(new_labels, modalities)]
|
|
# TODO: Hard code for control loss spike
|
|
# if tokenizer_model_max_length is not None:
|
|
# new_input_embeds = [x[:4096] if modality != "video" else x[:tokenizer_model_max_length] for x, modality in zip(new_input_embeds, modalities)]
|
|
# new_labels = [x[:4096] if modality != "video" else x[:tokenizer_model_max_length] for x, modality in zip(new_labels, modalities)]
|
|
|
|
# Combine them
|
|
max_len = max(x.shape[0] for x in new_input_embeds)
|
|
batch_size = len(new_input_embeds)
|
|
|
|
new_input_embeds_padded = []
|
|
new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
|
|
attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
|
|
position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
|
|
# rank0_print("Prepare pos id")
|
|
|
|
for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
|
|
cur_len = cur_new_embed.shape[0]
|
|
if getattr(self.config, "tokenizer_padding_side", "right") == "left":
|
|
new_input_embeds_padded.append(torch.cat((torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), cur_new_embed), dim=0))
|
|
if cur_len > 0:
|
|
new_labels_padded[i, -cur_len:] = cur_new_labels
|
|
attention_mask[i, -cur_len:] = True
|
|
position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
|
|
else:
|
|
new_input_embeds_padded.append(torch.cat((cur_new_embed, torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0))
|
|
if cur_len > 0:
|
|
new_labels_padded[i, :cur_len] = cur_new_labels
|
|
attention_mask[i, :cur_len] = True
|
|
position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
|
|
|
|
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
|
|
# rank0_print("tokenizer padding")
|
|
|
|
if _labels is None:
|
|
new_labels = None
|
|
else:
|
|
new_labels = new_labels_padded
|
|
|
|
if _attention_mask is None:
|
|
attention_mask = None
|
|
else:
|
|
attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
|
|
|
|
if _position_ids is None:
|
|
position_ids = None
|
|
if getattr(self.config, "use_pos_skipping", False) and self.training:
|
|
position_ids = torch.arange(new_input_embeds.size(1), device=new_input_embeds.device).unsqueeze(0).to(new_input_embeds.device)
|
|
split_position = random.randint(0, new_input_embeds.size(1))
|
|
left_add = random.randint(0, self.config.pos_skipping_range)
|
|
right_add = random.randint(left_add, self.config.pos_skipping_range)
|
|
position_ids[:, :split_position] += left_add
|
|
position_ids[:, split_position:] += right_add
|
|
# rank0_print("Finish preparing")
|
|
return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
|
|
|
|
def initialize_vision_tokenizer(self, model_args, tokenizer):
|
|
if model_args.mm_use_im_patch_token:
|
|
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
|
self.resize_token_embeddings(len(tokenizer))
|
|
|
|
if model_args.mm_use_im_start_end:
|
|
num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
|
self.resize_token_embeddings(len(tokenizer))
|
|
|
|
if num_new_tokens > 0:
|
|
input_embeddings = self.get_input_embeddings().weight.data
|
|
output_embeddings = self.get_output_embeddings().weight.data
|
|
|
|
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
|
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
|
|
|
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
|
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
|
|
|
if model_args.tune_mm_mlp_adapter:
|
|
for p in self.get_input_embeddings().parameters():
|
|
p.requires_grad = True
|
|
for p in self.get_output_embeddings().parameters():
|
|
p.requires_grad = False
|
|
|
|
if model_args.pretrain_mm_mlp_adapter:
|
|
mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location="cpu")
|
|
embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"]
|
|
assert num_new_tokens == 2
|
|
if input_embeddings.shape == embed_tokens_weight.shape:
|
|
input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
|
|
elif embed_tokens_weight.shape[0] == num_new_tokens:
|
|
input_embeddings[-num_new_tokens:] = embed_tokens_weight
|
|
else:
|
|
raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
|
|
elif model_args.mm_use_im_patch_token:
|
|
if model_args.tune_mm_mlp_adapter:
|
|
for p in self.get_input_embeddings().parameters():
|
|
p.requires_grad = False
|
|
for p in self.get_output_embeddings().parameters():
|
|
p.requires_grad = False
|
|
|
|
class InstellaVLConfig(OlmoConfig):
|
|
"""
|
|
Configuration class for the InstellaVL model.
|
|
Attributes:
|
|
model_type (str): The type of the model, set to "instellavl".
|
|
"""
|
|
|
|
model_type = "instellavl"
|
|
|
|
|
|
def disable_torch_init():
|
|
r"""
|
|
Disable the redundant torch default initialization to accelerate model creation.
|
|
"""
|
|
import torch
|
|
|
|
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
|
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
|
|
|
|
|
class InstellaVLModel(InstellaVLMetaModel, OlmoModel):
|
|
config_class = InstellaVLConfig
|
|
|
|
def __init__(self, config: OlmoConfig):
|
|
super(InstellaVLModel, self).__init__(config)
|
|
|
|
|
|
class InstellaVLForCausalLM(OlmoForCausalLM, InstellaVLMetaForCausalLM):
|
|
r"""
|
|
InstellaVLForCausalLM is a class that extends OlmoForCausalLM and InstellaVLMetaForCausalLM to provide
|
|
a language model with multimodal capabilities, specifically for handling images along with text.
|
|
|
|
1. Attributes:
|
|
- config_class (type): The configuration class to use for this model.
|
|
- model (InstellaVLModel): The underlying model.
|
|
- lm_head (nn.Linear): The linear layer for language modeling head.
|
|
|
|
2. Methods:
|
|
|
|
1. `__init__(config: InstellaVLConfig)`:
|
|
Initializes the InstellaVLForCausalLM model with the given configuration.
|
|
|
|
2. `get_model() -> InstellaVLModel`:
|
|
Returns the underlying model.
|
|
|
|
3. `forward() -> Union[Tuple, CausalLMOutputWithPast]`:
|
|
Performs a forward pass through the model.
|
|
|
|
4. `generate() -> Union[GenerateOutput, torch.LongTensor]`:
|
|
Generates text based on the input.
|
|
|
|
5. `prepare_inputs_for_generation(input_ids: torch.LongTensor,) -> dict`:
|
|
Prepares inputs for text generation.
|
|
|
|
"""
|
|
|
|
config_class = InstellaVLConfig
|
|
|
|
def __init__(self, config: OlmoConfig):
|
|
r"""
|
|
Initializes the InstellaVLForCausalLM model.
|
|
|
|
Args:
|
|
- config (OlmoConfig): Configuration object for the model.
|
|
|
|
Attributes:
|
|
- model (InstellaVLModel): The main model instance.
|
|
- lm_head (torch.nn.Linear): Linear layer that maps hidden states to vocabulary size.
|
|
"""
|
|
super(OlmoForCausalLM, self).__init__(config)
|
|
disable_torch_init()
|
|
config.model_type = "instellavl"
|
|
self.model = InstellaVLModel(config)
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
|
|
def get_model(self):
|
|
return self.model
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
labels: Optional[torch.LongTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
images: Optional[torch.FloatTensor] = None,
|
|
image_sizes: Optional[List[List[int]]] = None,
|
|
return_dict: Optional[bool] = None,
|
|
modalities: Optional[List[str]] = ["image"],
|
|
cache_position=None,
|
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
r"""
|
|
Args:
|
|
- input_ids (torch.LongTensor, optional): Input token IDs.
|
|
- attention_mask (torch.Tensor, optional): Attention mask.
|
|
- position_ids (torch.LongTensor, optional): Position IDs.
|
|
- past_key_values (List[torch.FloatTensor], optional): Past key values for caching.
|
|
- inputs_embeds (torch.FloatTensor, optional): Input embeddings.
|
|
- labels (torch.LongTensor, optional): Labels for language modeling.
|
|
- use_cache (bool, optional): Whether to use cache.
|
|
- output_attentions (bool, optional): Whether to output attentions.
|
|
- output_hidden_states (bool, optional): Whether to output hidden states.
|
|
- images (torch.FloatTensor, optional): Input images.
|
|
- image_sizes (List[List[int]], optional): Sizes of input images.
|
|
- return_dict (bool, optional): Whether to return a dictionary.
|
|
- modalities (List[str], optional): List of modalities.
|
|
- cache_position (optional): Cache position.
|
|
|
|
Returns:
|
|
Union[Tuple, CausalLMOutputWithPast]: The output of the forward pass.
|
|
"""
|
|
if inputs_embeds is None:
|
|
(
|
|
input_ids,
|
|
position_ids,
|
|
attention_mask,
|
|
past_key_values,
|
|
inputs_embeds,
|
|
labels
|
|
) = self.prepare_inputs_labels_for_multimodal(
|
|
input_ids,
|
|
position_ids,
|
|
attention_mask,
|
|
past_key_values,
|
|
labels,
|
|
images,
|
|
modalities,
|
|
image_sizes
|
|
)
|
|
|
|
return super().forward(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_values=past_key_values,
|
|
inputs_embeds=inputs_embeds,
|
|
labels=labels,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict
|
|
)
|
|
|
|
@torch.no_grad()
|
|
def generate(
|
|
self,
|
|
inputs: Optional[torch.Tensor] = None,
|
|
images: Optional[torch.Tensor] = None,
|
|
image_sizes: Optional[torch.Tensor] = None,
|
|
modalities: Optional[List[str]] = ["image"],
|
|
**kwargs,
|
|
) -> Union[GenerateOutput, torch.LongTensor]:
|
|
r"""
|
|
Args:
|
|
- inputs (torch.Tensor, optional): Input tensor.
|
|
- images (torch.Tensor, optional): Input images.
|
|
- image_sizes (torch.Tensor, optional): Sizes of input images.
|
|
- modalities (List[str], optional): List of modalities.
|
|
- **kwargs: Additional arguments.
|
|
|
|
Returns:
|
|
Union[GenerateOutput, torch.LongTensor]: The generated text.
|
|
"""
|
|
modalities = kwargs.pop("modalities", None) if "modalities" in kwargs and modalities is None else modalities
|
|
position_ids = kwargs.pop("position_ids", None)
|
|
attention_mask = kwargs.pop("attention_mask", None)
|
|
if "inputs_embeds" in kwargs:
|
|
raise NotImplementedError("`inputs_embeds` is not supported")
|
|
|
|
if images is not None:
|
|
(
|
|
inputs,
|
|
position_ids,
|
|
attention_mask,
|
|
_,
|
|
inputs_embeds,
|
|
_
|
|
) = self.prepare_inputs_labels_for_multimodal(
|
|
inputs,
|
|
position_ids,
|
|
attention_mask,
|
|
None,
|
|
None,
|
|
images,
|
|
image_sizes=image_sizes
|
|
)
|
|
else:
|
|
inputs_embeds = self.get_model().embed_tokens(inputs)
|
|
return super().generate(
|
|
position_ids=position_ids,
|
|
attention_mask=attention_mask,
|
|
inputs_embeds=inputs_embeds,
|
|
**kwargs
|
|
)
|
|
|
|
def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
|
|
inputs_embeds=None, **kwargs):
|
|
r"""
|
|
Args:
|
|
- input_ids (torch.LongTensor): Input token IDs.
|
|
- past_key_values (List[torch.FloatTensor], optional): Past key values for caching.
|
|
- inputs_embeds (torch.FloatTensor, optional): Input embeddings.
|
|
- **kwargs: Additional arguments.
|
|
|
|
Returns:
|
|
dict: Prepared inputs for generation.
|
|
"""
|
|
images = kwargs.pop("images", None)
|
|
image_sizes = kwargs.pop("image_sizes", None)
|
|
inputs = super().prepare_inputs_for_generation(
|
|
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
|
|
)
|
|
if images is not None:
|
|
inputs['images'] = images
|
|
if image_sizes is not None:
|
|
inputs['image_sizes'] = image_sizes
|
|
return inputs
|
|
|
|
AutoConfig.register("instellavl", InstellaVLConfig)
|
|
AutoModelForCausalLM.register(InstellaVLConfig, InstellaVLForCausalLM)
|