File size: 3,311 Bytes
2725f73 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
from typing import List, Optional, Tuple, Union
import torch, math
import torch.utils.checkpoint
from torch import nn
import transformers
from flash_attn import flash_attn_varlen_func
from transformers.activations import ACT2FN
from PIL import Image
import io, fire
from torch.nn import functional as F
class OmniVisualEncoder(transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VisionTransformerPretrainedModel):
def __init__(self, config):
super().__init__(config)
self.config_attn_implementation = 'flash_attention_2'
self.gradient_checkpointing = True # 强制开启
self._gradient_checkpointing_func = torch.utils.checkpoint.checkpoint
self.merge_size = config.merge_size if hasattr(config, 'merge_size') else 2
del self.merger
def forward(
self,
pixel_values: torch.Tensor,
grid_thw: torch.Tensor,
):
hidden_states = pixel_values.to(self.get_dtype())
grid_thw = grid_thw.to(pixel_values.device)
hidden_states = self.patch_embed(hidden_states)
rotary_pos_emb = self.rot_pos_emb(grid_thw)
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
dim=0, dtype=torch.int32
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
for blk in self.blocks:
if self.gradient_checkpointing and self.training:
hidden_states = self._gradient_checkpointing_func(blk.__call__, hidden_states, cu_seqlens, rotary_pos_emb)
else:
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
return hidden_states
@torch.no_grad()
def fake_input(self, device):
merge_size = max(self.merge_size, self.config.spatial_merge_size)
fake_image = torch.zeros([
1,
self.config.temporal_patch_size,
3,
merge_size // self.config.spatial_merge_size,
self.config.spatial_merge_size,
self.config.patch_size,
merge_size // self.config.spatial_merge_size,
self.config.spatial_merge_size,
self.config.patch_size,
], dtype=torch.float32, device=device)
patches = fake_image.permute(0, 3, 6, 4, 7, 2, 1, 5, 8)
flatten_patches = patches.reshape(
merge_size * merge_size, 3 * self.config.temporal_patch_size * self.config.patch_size * self.config.patch_size
)
return [flatten_patches], [(1, merge_size, merge_size)], [1]
class OmniVisualBridge(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.merge_size = self.config.merge_size if hasattr(self.config, 'merge_size') else 2
self.hidden_size = config.embed_dim * (self.merge_size**2)
self.ln_q = nn.LayerNorm(config.embed_dim, eps=1e-6)
self.mlp = nn.Sequential(
nn.Linear(self.hidden_size, self.hidden_size),
nn.GELU(),
nn.Linear(self.hidden_size, config.hidden_size),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
return x
if __name__ == '__main__':
fire.Fire()
|