RexSeek-3B / modeling_rexseek.py
Mountchicken's picture
Upload 16 files
692ce93 verified
import logging
import math
import os
import re
from typing import List, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn
from torchvision.ops import roi_align
from transformers import (
AutoConfig,
AutoModel,
AutoModelForCausalLM,
Qwen2Config,
Qwen2ForCausalLM,
StoppingCriteria,
StoppingCriteriaList,
)
from transformers.generation.utils import GenerateOutput
from transformers.utils import logging, strtobool
from .clip import CLIPVisionTower
from .convnext import ConvNextVisionEncoder
logger = logging.get_logger(__name__)
XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper()
XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper()
IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN_INDEX = 0
IMAGE_TOKEN_INDEX = -200
DEFAULT_IMAGE_TOKEN = "<image>"
# For Objects
DEFAULT_OBJECT_TOKEN = "<obj<i>>"
DEFAULT_OBJECT_FEATURE_TOKEN = "<objfeat>"
DEFAULT_OBJECT_INDEX = -300
# For Grounding
DEFAULT_GROUNDING_START = "<ground>"
DEFAULT_GROUNDING_END = "</ground>"
DEFAULT_GROUNDING_OBJECTS_START = "<objects>"
DEFAULT_GROUNDING_OBJECTS_END = "</objects>"
def is_fsdp_enabled():
return (
torch.distributed.is_available()
and torch.distributed.is_initialized()
and strtobool(os.environ.get("ACCELERATE_USE_FSDP", "False")) == 1
and strtobool(os.environ.get("FSDP_CPU_RAM_EFFICIENT_LOADING", "False")) == 1
)
class IdentityMap(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, *args, **kwargs):
return x
@property
def config(self):
return {"mm_projector_type": "identity"}
class SimpleResBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.pre_norm = nn.LayerNorm(channels)
self.proj = nn.Sequential(
nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels)
)
def forward(self, x):
x = self.pre_norm(x)
return x + self.proj(x)
def build_vision_projector(config, start_hidden_size, delay_load=False, **kwargs):
projector_type = "mlp2x_gelu"
mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
if mlp_gelu_match:
mlp_depth = int(mlp_gelu_match.group(1))
modules = [nn.Linear(start_hidden_size, config.hidden_size)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
return nn.Sequential(*modules)
if projector_type == "identity":
return IdentityMap()
raise ValueError(f"Unknown projector type: {projector_type}")
def get_token_slices(input_ids: torch.Tensor):
"""
Get slices of tokens based on special markers in the input tensor.
Args:
input_ids (torch.Tensor): A tensor of token IDs where IMAGE_TOKEN_INDEX represents an image token,
DEFAULT_OBJECT_INDEX represents an object token, and all other values represent text tokens.
Returns:
List[Dict[str, Any]]: A list of dictionaries where each dictionary contains the type of the
token slice ('text', 'image', 'object') and the span as a list of start and end indices.
"""
# define type markers and corresponding types
type_map = {IMAGE_TOKEN_INDEX: "image", DEFAULT_OBJECT_INDEX: "object"}
# find the positions of special markers
image_indices = torch.where(input_ids == IMAGE_TOKEN_INDEX)[0]
object_indices = torch.where(input_ids == DEFAULT_OBJECT_INDEX)[0]
if len(object_indices) > 0:
has_object = True
else:
has_object = False
# merge all the positions of special markers
special_indices = torch.cat((image_indices, object_indices))
special_indices, _ = torch.sort(special_indices)
special_tokens = input_ids[special_indices]
slices = []
start_idx = 0
for i, idx in enumerate(special_indices):
if start_idx < idx:
slices.append({"type": "text", "span": [start_idx, idx.item()]})
token_type = type_map[special_tokens[i].item()]
slices.append({"type": token_type, "span": [idx.item(), idx.item() + 1]})
start_idx = idx.item() + 1
if start_idx < len(input_ids):
slices.append({"type": "text", "span": [start_idx, len(input_ids)]})
return slices, has_object
class StopWordStoppingCriteria(StoppingCriteria):
"""StopWord stopping criteria."""
def __init__(self, tokenizer, stop_word):
self.tokenizer = tokenizer
self.stop_word = stop_word
self.length = len(self.stop_word)
def __call__(self, input_ids, *args, **kwargs) -> bool:
cur_text = self.tokenizer.decode(input_ids[0])
cur_text = cur_text.replace("\r", "").replace("\n", "")
return cur_text[-self.length :] == self.stop_word
def get_stop_criteria(
tokenizer,
stop_words=[],
):
stop_criteria = StoppingCriteriaList()
for word in stop_words:
stop_criteria.append(StopWordStoppingCriteria(tokenizer, word))
return stop_criteria
def gen_sineembed_for_position(pos_tensor, dim_of_pos_feats):
"""Generate sine position embedding from a position tensor.
Args:
pos_tensor (torch.Tensor): shape: [batch_size, N, 4]. the last dimension is [cx, cy, w, h] in
normalized coordinates in range [0, 1].
out_dim (int): the output dimension of the position embedding.
Returns:
pos (torch.Tensor): shape: [batch_size, N, out_dim].
"""
scale = 2 * math.pi
dim_t = torch.arange(
dim_of_pos_feats, dtype=torch.float32, device=pos_tensor.device
)
dim_t = 10000 ** (2 * (dim_t // 2) / dim_of_pos_feats)
x_embed = pos_tensor[:, :, 0] * scale
y_embed = pos_tensor[:, :, 1] * scale
pos_x = x_embed[:, :, None] / dim_t
pos_y = y_embed[:, :, None] / dim_t
pos_x = torch.stack(
(pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3
).flatten(2)
pos_y = torch.stack(
(pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3
).flatten(2)
if pos_tensor.size(-1) == 2:
pos = torch.cat((pos_y, pos_x), dim=2)
elif pos_tensor.size(-1) == 4:
w_embed = pos_tensor[:, :, 2] * scale
pos_w = w_embed[:, :, None] / dim_t
pos_w = torch.stack(
(pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3
).flatten(2)
h_embed = pos_tensor[:, :, 3] * scale
pos_h = h_embed[:, :, None] / dim_t
pos_h = torch.stack(
(pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3
).flatten(2)
pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
else:
raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1)))
return pos
class MultiLevelROIVisualPrompt(nn.Module):
"""Initialize the MultiLevelROIVisualPrompt.
Args:
output_size (Optional[int]): The size of the output. Default is None.
channel_per_level (List[int]): List of channels per level. Default is [192, 384, 768, 1536].
spatial_scale (Optional[float]): The spatial scale factor. Default is None.
with_additional_projection (bool): Whether to use additional projection. Default is False.
visual_prompt_hidden_size (int): The hidden size of the visual prompt. Default is 1024.
add_pos_embedding (bool): Whether to add position embedding. Default is False.
pos_embedding_dim (int): The dimension of the position embedding. Default is 1024.
"""
def __init__(
self,
output_size: int = None,
channel_per_level: List[int] = [192, 384, 768, 1536],
spatail_scale: float = None,
add_pos_embedding: bool = False,
pos_embedding_dim: int = 1024,
):
super(MultiLevelROIVisualPrompt, self).__init__()
self.output_size = output_size
self.channel_per_level = channel_per_level
self.spatail_scale = spatail_scale
self.add_pos_embedding = add_pos_embedding
self.pos_embedding_dim = pos_embedding_dim
def __call__(
self,
multi_level_features: List[torch.Tensor],
boxes: Union[torch.Tensor, List[torch.Tensor]],
) -> torch.Tensor:
"""Performs Region of Interest (RoI) Align operator on multi-level features. The RoI
feature on each scale will go through a different linear layer for projection. Different
RoI features will be summed up and then average pooled.
Args:
multi_level_features (Listp[Tensor[N, C, H, W]]): Feature maps from different levels
boxes (Tensor[K, 5] or List[Tensor[L, 4]]): the box coordinates in (x1, y1, x2, y2)
format where the regions will be taken from.
Returns:
Tensor[1, K, C]: The output tensor that has the shape KxC, where K is the number of RoIs
"""
boxes[0] = boxes[0].float()
concat_multi_level_feature = []
max_height = max([feature.shape[2] for feature in multi_level_features])
max_width = max([feature.shape[3] for feature in multi_level_features])
# interpolate to the same size
for level, feature in enumerate(multi_level_features):
if level != 0:
concat_multi_level_feature.append(
F.interpolate(
feature.float(),
size=(max_height, max_width),
mode="bilinear",
align_corners=False,
)
)
else:
concat_multi_level_feature.append(feature.float())
concat_multi_level_feature = torch.cat(concat_multi_level_feature, dim=1)
out_box_feat = roi_align(
concat_multi_level_feature,
boxes,
output_size=self.output_size,
spatial_scale=self.spatail_scale,
)
# Average Pooling -> n,c -> 1,n,c
out_box_feat = out_box_feat.mean(dim=(2, 3)).reshape(
1, out_box_feat.shape[0], out_box_feat.shape[1]
)
if self.add_pos_embedding:
# note that this boxes is in xyxy, unormalized format, so we need to normalize it first
boxes = boxes[0] # (N, 4)
boxes = boxes.to(out_box_feat.dtype)
original_img_width = max_width / self.spatail_scale
original_img_height = max_height / self.spatail_scale
boxes[:, [0, 2]] = boxes[:, [0, 2]] / original_img_width
boxes[:, [1, 3]] = boxes[:, [1, 3]] / original_img_height
# convert from xyxy to cx, cy, w, h
boxes[:, 2] = boxes[:, 2] - boxes[:, 0]
boxes[:, 3] = boxes[:, 3] - boxes[:, 1]
boxes[:, 0] = boxes[:, 0] + boxes[:, 2] / 2
boxes[:, 1] = boxes[:, 1] + boxes[:, 3] / 2
pos_embed = gen_sineembed_for_position(
boxes.unsqueeze(0), self.pos_embedding_dim // 4
)
out_box_feat = out_box_feat + pos_embed
return out_box_feat
class RexSeekQwenConfig(Qwen2Config):
model_type = "rexseek_qwen"
class RexSeekQwenForCausalLM(Qwen2ForCausalLM):
config_class = RexSeekQwenConfig
def __init__(self, config):
super().__init__(config)
# low resolusion vision encoder
vision_tower = getattr(
config,
"mm_vision_tower",
getattr(config, "vision_tower", None),
)
self.vision_tower = CLIPVisionTower(
vision_tower,
args=config,
)
# high resolusion vision encoder
self.vision_tower_aux = ConvNextVisionEncoder()
# vision projector
self.mm_projector = build_vision_projector(
config, start_hidden_size=2560
) # projector for vision_tower
# projector for object token
self.mm_object_projector = build_vision_projector(
config, start_hidden_size=2880
)
# visual prompt encoder
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.box_encoder = MultiLevelROIVisualPrompt(
output_size=7,
channel_per_level=[192, 384, 768, 1536], # ConvNeXt Large
spatail_scale=192 / 768,
add_pos_embedding=True,
pos_embedding_dim=2880,
)
self.post_init()
print("model initialized")
def get_vision_tower(self):
vision_tower = getattr(self, "vision_tower", None)
if type(vision_tower) is list:
vision_tower = vision_tower[0]
return vision_tower
def get_vision_tower_aux(self):
vision_tower_aux = getattr(self, "vision_tower_aux", None)
if type(vision_tower_aux) is list:
vision_tower_aux = vision_tower_aux[0]
return vision_tower_aux
def get_model(self):
return self.model
def encode_images(self, images, images_aux):
low_res_feat = self.get_vision_tower()(images)
aux_output = self.get_vision_tower_aux()(images_aux)
visual_outputs_aux = aux_output["image_features"]
high_res_feat = aux_output["last_feat"] # (B, 1536, 24, 24)
# concat the low res features with the high res features
b, c, h, w = high_res_feat.shape # (2, 1536, 24, 24)
_, _, d = low_res_feat.shape # (2, 576, 1024)
high_res_feat = high_res_feat.view(b, c, h * w).transpose(1, 2)
image_features = torch.cat((low_res_feat, high_res_feat), dim=-1)
image_features = self.mm_projector(image_features)
return image_features, visual_outputs_aux
def encode_objects(
self, bboxes, visual_outputs_aux, dtype, num_gt_boxes_per_image=None
):
"""Encode object features from bounding boxes.
Args:
bboxes (torch.Tensor): bounding boxes in the shape of (N, 4)
image_features_before_proj (torch.Tensor): image features in the shape of (N, hidden_size)
Returns:
torch.Tensor: object features in the shape of (N, hidden_size)
"""
bbox_visual_outputs = []
for batch_idx, boxes in enumerate(bboxes):
num_box = (
num_gt_boxes_per_image[batch_idx]
if num_gt_boxes_per_image is not None
else len(boxes)
)
boxes = boxes[:num_box]
if len(boxes) == 0:
bbox_visual_outputs.append(None)
continue
multi_level_aux_features = [
visual_output_aux[batch_idx].unsqueeze(0)
for visual_output_aux in visual_outputs_aux
]
out_vp_feat = self.box_encoder(
multi_level_aux_features,
[boxes],
).squeeze(0)
out_vp_feat = out_vp_feat.to(dtype)
out_vp_feat = self.mm_object_projector(out_vp_feat)
bbox_visual_outputs.append(out_vp_feat)
# b,n,c
return bbox_visual_outputs
def prepare_inputs_labels_for_multimodal(
self,
input_ids,
position_ids,
attention_mask,
past_key_values,
labels,
pixel_values=None,
pixel_values_aux=None,
gt_boxes=None,
num_gt_boxes_per_image=None,
):
if pixel_values is None:
return (
input_ids,
position_ids,
attention_mask,
past_key_values,
None,
labels,
)
pixel_values, visual_outputs_aux = self.encode_images(
pixel_values, pixel_values_aux
) # (B, 576, 2048)
if gt_boxes is not None:
bbox_feats = self.encode_objects(
gt_boxes, visual_outputs_aux, pixel_values.dtype, num_gt_boxes_per_image
)
_labels = labels
_position_ids = position_ids
_attention_mask = attention_mask
if attention_mask is None:
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
else:
attention_mask = attention_mask.bool() # padding mask in shaoe (B, L)
if position_ids is None:
position_ids = torch.arange(
0, input_ids.shape[1], dtype=torch.long, device=input_ids.device
)
if labels is None:
labels = torch.full_like(input_ids, IGNORE_INDEX)
input_ids = [
cur_input_ids[cur_attention_mask]
for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)
]
labels = [
cur_labels[cur_attention_mask]
for cur_labels, cur_attention_mask in zip(labels, attention_mask)
]
new_input_embeds = []
new_labels = []
cur_image_idx = 0
cur_object_idx = 0
for batch_idx, cur_input_ids in enumerate(input_ids):
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
if num_images == 0:
cur_image_features = pixel_values[cur_image_idx]
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
cur_input_embeds = torch.cat(
[cur_input_embeds_1, cur_image_features[0:0]], dim=0
)
new_input_embeds.append(cur_input_embeds)
new_labels.append(labels[batch_idx])
cur_image_idx += 1
cur_object_idx += 1
continue
cur_labels = labels[batch_idx]
token_slices, has_object = get_token_slices(cur_input_ids)
result_input_embeddings = []
result_output_labels = []
cur_gt_bnox_indice = 0
cur_object_features = None
for slice in token_slices:
slice_type = slice["type"]
slice_span = slice["span"]
if slice_type == "text":
cur_input_ids_noim = cur_input_ids[slice_span[0] : slice_span[1]]
cur_labels_noim = cur_labels[slice_span[0] : slice_span[1]]
cur_input_embeds = self.get_model().embed_tokens(cur_input_ids_noim)
result_input_embeddings.append(cur_input_embeds)
result_output_labels.append(cur_labels_noim)
elif slice_type == "image":
cur_input_embeds = pixel_values[cur_image_idx]
result_input_embeddings.append(cur_input_embeds)
result_output_labels.append(
torch.full(
(cur_input_embeds.shape[0],),
IGNORE_INDEX,
device=cur_labels.device,
dtype=cur_labels.dtype,
)
)
cur_image_idx += 1
elif slice_type == "object":
try:
result_input_embeddings.append(
bbox_feats[cur_object_idx][cur_gt_bnox_indice].unsqueeze(0)
)
except:
raise ValueError(
f"current boxe_feats.shape: {bbox_feats[cur_object_idx].shape}, "
)
cur_gt_bnox_indice += 1
result_output_labels.append(
torch.full(
(1,),
IGNORE_INDEX,
device=cur_labels.device,
dtype=cur_labels.dtype,
)
)
cur_object_idx += 1
result_input_embeddings = torch.cat(result_input_embeddings)
result_output_labels = torch.cat(result_output_labels)
assert len(result_output_labels) == len(result_input_embeddings)
new_input_embeds.append(result_input_embeddings)
new_labels.append(result_output_labels)
# Truncate sequences to max length as image embeddings can make the sequence longer
tokenizer_model_max_length = getattr(
self.config, "tokenizer_model_max_length", None
)
if tokenizer_model_max_length is not None:
new_input_embeds = [
x[:tokenizer_model_max_length] for x in new_input_embeds
]
new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
# Combine them
max_len = max(x.shape[0] for x in new_input_embeds)
batch_size = len(new_input_embeds)
new_input_embeds_padded = []
new_labels_padded = torch.full(
(batch_size, max_len),
IGNORE_INDEX,
dtype=new_labels[0].dtype,
device=new_labels[0].device,
)
attention_mask = torch.zeros(
(batch_size, max_len),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
position_ids = torch.zeros(
(batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device
)
for i, (cur_new_embed, cur_new_labels) in enumerate(
zip(new_input_embeds, new_labels)
):
cur_len = cur_new_embed.shape[0]
new_input_embeds_padded.append(
torch.cat(
(
cur_new_embed,
torch.zeros(
(max_len - cur_len, cur_new_embed.shape[1]),
dtype=cur_new_embed.dtype,
device=cur_new_embed.device,
),
),
dim=0,
)
)
if cur_len > 0:
new_labels_padded[i, :cur_len] = cur_new_labels
attention_mask[i, :cur_len] = True
position_ids[i, :cur_len] = torch.arange(
0, cur_len, dtype=position_ids.dtype, device=position_ids.device
)
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
if _labels is None:
new_labels = None
else:
new_labels = new_labels_padded
if _attention_mask is None:
attention_mask = None
else:
attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
if _position_ids is None:
position_ids = None
return (
None,
position_ids,
attention_mask,
past_key_values,
new_input_embeds,
new_labels,
)
@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor],
pixel_values: Optional[torch.Tensor],
pixel_values_aux: Optional[torch.Tensor],
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
if inputs_embeds is None:
position_ids = kwargs.pop("position_ids", None)
attention_mask = kwargs.pop("attention_mask", None)
gt_boxes = kwargs.pop("gt_boxes", None)
num_gt_boxes_per_image = kwargs.pop("num_gt_boxes_per_image", None)
if pixel_values is not None:
(inputs, position_ids, attention_mask, _, inputs_embeds, _) = (
self.prepare_inputs_labels_for_multimodal(
inputs,
position_ids,
attention_mask,
past_key_values=None,
labels=None,
pixel_values=pixel_values,
pixel_values_aux=pixel_values_aux,
gt_boxes=gt_boxes,
num_gt_boxes_per_image=num_gt_boxes_per_image,
)
)
else:
inputs_embeds = self.get_model().embed_tokens(inputs)
return super().generate(
position_ids=position_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
**kwargs,
)
AutoConfig.register("rexseek_qwen", RexSeekQwenConfig)
AutoModelForCausalLM.register(RexSeekQwenConfig, RexSeekQwenForCausalLM)