|
|
|
|
|
""" |
|
================================================ |
|
@author: Jaron |
|
@time: 2024/08/21 17:41:52 |
|
@email: fjjth98@163.com |
|
@description: Video-CCAM |
|
================================================ |
|
""" |
|
from typing import Optional, Union |
|
|
|
import torch |
|
from PIL import Image |
|
from torch import nn |
|
from torch.nn import functional as F |
|
from transformers import (AutoImageProcessor, AutoModel, AutoModelForCausalLM, |
|
AutoTokenizer, Cache, DynamicCache, GenerationConfig, |
|
PreTrainedModel) |
|
from transformers.activations import ACT2FN |
|
|
|
from .configuration_videoccam import CCAMConfig, VideoCCAMConfig |
|
|
|
|
|
class CCAMMLP(nn.Module): |
|
|
|
def __init__(self, config): |
|
super().__init__() |
|
self.hidden_act = config.hidden_act |
|
self.hidden_size = config.hidden_size |
|
self.intermediate_size = config.intermediate_size |
|
self.output_size = config.output_size |
|
if self.hidden_act == 'swiglu': |
|
self.fc1 = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=config.mlp_bias) |
|
self.act_fn = ACT2FN['silu'] |
|
else: |
|
self.fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) |
|
self.act_fn = ACT2FN[self.hidden_act] |
|
self.fc2 = nn.Linear(self.intermediate_size, self.output_size, bias=config.mlp_bias) |
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
hidden_states = self.fc1(hidden_states) |
|
if self.hidden_act == 'swiglu': |
|
gate, up = hidden_states.chunk(2, dim=-1) |
|
hidden_states = self.act_fn(gate) * up |
|
else: |
|
hidden_states = self.act_fn(hidden_states) |
|
hidden_states = self.fc2(hidden_states) |
|
return hidden_states |
|
|
|
|
|
class CCAMCrossAttention(nn.Module): |
|
"""Cross-attention layer of the CCAM projector. |
|
|
|
Flash Attention 2 is not supported since the mask may be neither full nor causal. Only support `attn_implementation` as `eager` and `sdpa`. |
|
""" |
|
|
|
def __init__(self, config): |
|
super().__init__() |
|
self.num_heads = config.num_heads |
|
self.hidden_size = config.hidden_size |
|
self.attention_bias = config.attention_bias |
|
self.attention_dropout = config.attention_dropout |
|
self.cross_hidden_size = config.cross_hidden_size |
|
self.num_key_value_heads = config.num_key_value_heads |
|
self.attn_implementation = config._attn_implementation |
|
self.head_dim = self.hidden_size // self.num_heads |
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads |
|
|
|
assert self.head_dim * self.num_heads == self.hidden_size, f'hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`: {self.num_heads}).' |
|
|
|
self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=self.attention_bias) |
|
self.k_proj = nn.Linear(self.cross_hidden_size, self.num_key_value_heads * self.head_dim, bias=self.attention_bias) |
|
self.v_proj = nn.Linear(self.cross_hidden_size, self.num_key_value_heads * self.head_dim, bias=self.attention_bias) |
|
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=self.attention_bias) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
cross_hidden_states: torch.Tensor, |
|
attention_mask: torch.Tensor = None |
|
) -> torch.Tensor: |
|
B, Q, C = hidden_states.size() |
|
query_states = self.q_proj(hidden_states) |
|
key_states = self.k_proj(cross_hidden_states) |
|
value_states = self.v_proj(cross_hidden_states) |
|
|
|
L = key_states.size(1) |
|
query_states = query_states.view(B, Q, self.num_heads, self.head_dim).transpose(1, 2) |
|
key_states = key_states.view(B, L, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
value_states = value_states.view(B, L, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
if self.num_key_value_groups > 1: |
|
key_states = key_states.repeat_interleave(repeats=self.num_key_value_groups, dim=1) |
|
value_states = value_states.repeat_interleave(repeats=self.num_key_value_groups, dim=1) |
|
|
|
if self.attn_implementation == 'eager': |
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / self.head_dim ** 0.5 |
|
if attention_mask is not None: |
|
attn_weights = attn_weights + attention_mask.view(1, 1, Q, L) |
|
|
|
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) |
|
attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training) |
|
attn_output = torch.matmul(attn_weights, value_states) |
|
else: |
|
|
|
attn_output = F.scaled_dot_product_attention( |
|
query_states, |
|
key_states, |
|
value_states, |
|
attn_mask=attention_mask, |
|
dropout_p=self.attention_dropout if self.training else 0.0 |
|
) |
|
attn_output = attn_output.transpose(1, 2).reshape(B, Q, C) |
|
attn_output = self.o_proj(attn_output) |
|
|
|
return attn_output |
|
|
|
|
|
class CCAMModel(PreTrainedModel): |
|
config_class = CCAMConfig |
|
_no_split_modules = ['CCAMCrossAttention'] |
|
_supports_flash_attn_2 = True |
|
_supports_sdpa = True |
|
|
|
def __init__(self, config: CCAMConfig): |
|
super().__init__(config) |
|
self.num_query = config.num_query |
|
self.hidden_size = config.hidden_size |
|
self.output_size = config.output_size |
|
self.cross_hidden_size = config.cross_hidden_size |
|
|
|
self.query = nn.Parameter(torch.empty(1, self.num_query, self.hidden_size).normal_(mean=.0, std=.02)) |
|
self.pre_ccam = nn.Sequential( |
|
nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps), |
|
nn.Dropout(config.dropout) |
|
) |
|
self.ccam = CCAMCrossAttention(config) |
|
self.post_ccam = nn.Sequential( |
|
nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps), |
|
nn.Dropout(config.dropout), |
|
CCAMMLP(config) |
|
) |
|
|
|
def get_ccam(self, vision_hidden_state: torch.Tensor) -> torch.Tensor: |
|
"""Compute CCAM Mask for vision hidden state |
|
|
|
Args: |
|
vision_hidden_state (torch.Tensor): (T, L, C) |
|
|
|
Returns: |
|
torch.Tensor: (Q, T*L) -inf means masked |
|
""" |
|
T, L, _ = vision_hidden_state.size() |
|
dtype, device = vision_hidden_state.dtype, vision_hidden_state.device |
|
base_mask = torch.zeros(T, T, dtype=dtype, device=device) |
|
t = torch.arange(T, device=device) |
|
base_mask.masked_fill_(t > t[:, None], float('-inf')) |
|
attention_mask = torch.zeros(self.num_query, T * L, dtype=dtype, device=device) |
|
attention_mask[:self.num_query // T * T] = torch.kron(base_mask, torch.ones(self.num_query // T, L, dtype=dtype, device=device)) |
|
return attention_mask |
|
|
|
def forward(self, vision_hidden_states: list[torch.Tensor]) -> torch.Tensor: |
|
"""Forward function, do not collect batch due to the support of zero3 |
|
|
|
Args: |
|
vision_hidden_states (list[torch.Tensor]): [(t0, L, C), (t1, L, C), ...] |
|
|
|
Returns: |
|
torch.Tensor: (B, Q, C) |
|
""" |
|
output = [] |
|
for hidden_states in vision_hidden_states: |
|
|
|
attention_mask = self.get_ccam(hidden_states) |
|
|
|
x = self.pre_ccam(self.query) |
|
x = self.ccam( |
|
hidden_states=x, |
|
cross_hidden_states=hidden_states.flatten(0, 1)[None], |
|
attention_mask=attention_mask[None] |
|
) + x |
|
x = self.post_ccam(x) |
|
output.append(x) |
|
output = torch.cat(output, dim=0) |
|
return output |
|
|
|
|
|
|
|
class VideoCCAM(PreTrainedModel): |
|
config_class = VideoCCAMConfig |
|
_auto_class = 'AutoModel' |
|
_supports_flash_attn_2 = True |
|
|
|
def __init__(self, config: VideoCCAMConfig): |
|
super().__init__(config) |
|
|
|
self.vision_encoder = AutoModel.from_config(config.vision_config, torch_dtype=config.torch_dtype, attn_implementation=config._attn_implementation) |
|
self.vision_encoder.vision_model.post_layernorm = nn.Identity() |
|
self.projector = CCAMModel._from_config(config.projector_config, torch_dtype=config.torch_dtype, attn_implementation=config._attn_implementation) |
|
self.llm = AutoModelForCausalLM.from_config(config.text_config, torch_dtype=config.torch_dtype, attn_implementation=config._attn_implementation) |
|
self.post_init() |
|
|
|
|
|
def _init_weights(self, module, std=.02): |
|
if isinstance(module, (nn.Linear, nn.Conv2d)): |
|
module.weight.data.normal_(mean=0.0, std=std) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, nn.Embedding): |
|
module.weight.data.normal_(mean=0.0, std=std) |
|
if module.padding_idx is not None: |
|
module.weight.data[module.padding_idx].zero_() |
|
|
|
@property |
|
def _supports_sdpa(self): |
|
""" |
|
Retrieve language_model's attribute to check whether the model supports |
|
SDPA or not. |
|
""" |
|
return self.llm._supports_sdpa |
|
|
|
@property |
|
def _no_split_modules(self): |
|
""" |
|
Retrieve language_model's attribute to check whether the model supports |
|
SDPA or not. |
|
""" |
|
return self.vision_encoder._no_split_modules + self.projector._no_split_modules + self.llm._no_split_modules |
|
|
|
@torch.inference_mode |
|
def generate( |
|
self, |
|
input_ids: list[list[int]] = None, |
|
pixel_values: torch.FloatTensor = None, |
|
vision_split_sizes: list[int] = None, |
|
past_key_values: Union[tuple, Cache] = None, |
|
batch_generation: bool = False, |
|
generation_config: GenerationConfig = None, |
|
**kwargs |
|
) -> tuple[torch.LongTensor, Optional[Cache]]: |
|
"""Generation for multi-modal inputs |
|
|
|
Args: |
|
input_ids (list[list[int]]): input token indices, use list[int] for efficient embeddings concatenation. |
|
pixel_values (torch.FloatTensor): input image/video (processed) pixel values. |
|
vision_split_sizes (list[int]): for each vision token (<image>, <video>), how many frames are required. |
|
past_key_values (Union[tuple, Cache]): past_key_values for efficient generation, only used for multi-turn dialogue and single inputs. If this argument is not None, new past_key_values will also be returned. |
|
batch_generation (bool, optional): whether left padding for batch inputs. Defaults to False. |
|
generation_config (GenerationConfig, optional): _description_. Defaults to None. |
|
|
|
Returns: |
|
torch.LongTensor: _description_ |
|
""" |
|
if past_key_values is not None and len(input_ids) != 1: |
|
raise ValueError(f'`past_key_values` is only supported when there is only 1 `input_ids`.') |
|
|
|
device = self.llm.get_input_embeddings().weight.device |
|
_input_ids, text_split_pos = [], [0] |
|
for ids in input_ids: |
|
_input_ids += ids |
|
text_split_pos.append(text_split_pos[-1] + len(ids)) |
|
_input_ids = torch.tensor(_input_ids, dtype=torch.long, device=device) |
|
vision_pos = torch.where((_input_ids == self.config.image_token_id) | (_input_ids == self.config.video_token_id))[0].tolist() |
|
_inputs_embeds = self.llm.get_input_embeddings()(_input_ids) |
|
|
|
|
|
if pixel_values is not None: |
|
assert len(vision_pos) == len(vision_split_sizes), f'The number of visual tokens ({len(vision_pos)}) should be equal to the number of visual features ({len(vision_split_sizes)}).' |
|
vision_embeds = self.vision_encoder(pixel_values, output_hidden_states=False).last_hidden_state |
|
vision_embeds = self.projector(vision_embeds.split(vision_split_sizes, dim=0)) |
|
|
|
|
|
inputs_embeds_len, inputs_embeds, idx = [], [], 0 |
|
for i in range(1, len(text_split_pos)): |
|
start, cur_inputs_embeds = text_split_pos[i-1], [] |
|
while idx < len(vision_pos) and vision_pos[idx] < text_split_pos[i]: |
|
cur_inputs_embeds.append(_inputs_embeds[start:vision_pos[idx]]) |
|
cur_inputs_embeds.append(vision_embeds[idx]) |
|
start, idx = vision_pos[idx] + 1, idx + 1 |
|
if start < text_split_pos[i]: |
|
cur_inputs_embeds.append(_inputs_embeds[start:text_split_pos[i]]) |
|
inputs_embeds_len.append(sum(i.size(0) for i in cur_inputs_embeds)) |
|
inputs_embeds.append(cur_inputs_embeds) |
|
|
|
|
|
if past_key_values is None: |
|
|
|
if batch_generation: |
|
B, L = len(input_ids), max(inputs_embeds_len) |
|
padded_inputs_embeds, attention_mask = [], [] |
|
pad_embeds = self.llm.get_input_embeddings()(torch.tensor([self.config.text_config.pad_token_id], dtype=torch.long, device=device)) |
|
for l, embeds in zip(inputs_embeds_len, inputs_embeds): |
|
padded_inputs_embeds.append(pad_embeds.expand(L - l, -1)) |
|
padded_inputs_embeds += embeds |
|
attention_mask += [0] * (L- l) + [1] * l |
|
padded_inputs_embeds = torch.cat(padded_inputs_embeds, dim=0).view(B, L, -1) |
|
attention_mask = torch.tensor(attention_mask, dtype=torch.long, device=device).view(B, L) |
|
output_ids = self.llm.generate( |
|
inputs_embeds=padded_inputs_embeds, |
|
attention_mask=attention_mask, |
|
generation_config=generation_config, |
|
**kwargs |
|
) |
|
else: |
|
output_ids = [] |
|
for l, embeds in zip(inputs_embeds_len, inputs_embeds): |
|
output_ids += self.llm.generate( |
|
inputs_embeds=torch.cat(embeds, dim=0)[None], |
|
attention_mask=torch.ones(1, l, dtype=torch.long, device=device), |
|
generation_config=generation_config, |
|
**kwargs |
|
) |
|
return output_ids |
|
else: |
|
inputs_embeds = torch.cat(inputs_embeds[0], dim=0) |
|
if not isinstance(past_key_values, Cache): |
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values) |
|
|
|
past_key_values = self.llm( |
|
inputs_embeds=inputs_embeds[None, :-1], |
|
past_key_values=past_key_values, |
|
return_dict=True |
|
).past_key_values |
|
|
|
pseudo_input_ids_len = past_key_values.get_seq_length() + 1 |
|
pseudo_input_ids = torch.zeros(1, pseudo_input_ids_len, dtype=torch.long, device=device) |
|
pseudo_input_ids[0, -1] = _input_ids[-1] |
|
output = self.llm.generate( |
|
input_ids=pseudo_input_ids, |
|
past_key_values=past_key_values, |
|
generation_config=generation_config, |
|
return_dict_in_generate=True, |
|
**kwargs |
|
) |
|
return output.sequences[0, pseudo_input_ids_len:], output.past_key_values |
|
|
|
def chat( |
|
self, |
|
messages: list[list[dict]], |
|
images: list[list[Image.Image]] = None, |
|
tokenizer: AutoTokenizer = None, |
|
image_processor: AutoImageProcessor = None, |
|
batch_generation: bool = False, |
|
generation_config = None, |
|
**kwargs |
|
) -> list[str]: |
|
|
|
pixel_values, vision_split_sizes = [], [] |
|
for image in images: |
|
pixel_values += image |
|
vision_split_sizes.append(len(image)) |
|
if len(pixel_values) > 0: |
|
pixel_values = image_processor(pixel_values, return_tensors='pt')['pixel_values'].to( |
|
dtype=self.vision_encoder.get_input_embeddings().weight.dtype, |
|
device=self.vision_encoder.get_input_embeddings().weight.device |
|
) |
|
else: |
|
pixel_values = None |
|
|
|
input_ids = tokenizer.apply_chat_template( |
|
messages, |
|
tokenize=True, |
|
add_generation_prompt=True, |
|
return_dict=False |
|
) |
|
|
|
output_ids = self.generate( |
|
input_ids=input_ids, |
|
pixel_values=pixel_values, |
|
vision_split_sizes=vision_split_sizes, |
|
batch_generation=batch_generation, |
|
generation_config=generation_config, |
|
**kwargs |
|
) |
|
|
|
prediction = tokenizer.batch_decode(output_ids, skip_special_tokens=True) |
|
return prediction |
|
|