File size: 3,311 Bytes
2725f73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87

from typing import List, Optional, Tuple, Union
import torch, math
import torch.utils.checkpoint
from torch import nn
import transformers
from flash_attn import flash_attn_varlen_func
from transformers.activations import ACT2FN
from PIL import Image
import io, fire
from torch.nn import functional as F

class OmniVisualEncoder(transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VisionTransformerPretrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.config_attn_implementation = 'flash_attention_2'
        self.gradient_checkpointing = True  # 强制开启
        self._gradient_checkpointing_func = torch.utils.checkpoint.checkpoint
        self.merge_size = config.merge_size if hasattr(config, 'merge_size') else 2
        del self.merger
        
    def forward(
        self,
        pixel_values: torch.Tensor, 
        grid_thw: torch.Tensor,
    ):
        hidden_states = pixel_values.to(self.get_dtype())
        grid_thw = grid_thw.to(pixel_values.device)
        
        hidden_states = self.patch_embed(hidden_states)
        rotary_pos_emb = self.rot_pos_emb(grid_thw)

        cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
            dim=0, dtype=torch.int32
        )
        cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)

        for blk in self.blocks:
            if self.gradient_checkpointing and self.training:
                hidden_states = self._gradient_checkpointing_func(blk.__call__, hidden_states, cu_seqlens, rotary_pos_emb)
            else:
                hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)

        return hidden_states
    
    @torch.no_grad()
    def fake_input(self, device):
        merge_size = max(self.merge_size, self.config.spatial_merge_size)
        fake_image = torch.zeros([
            1,
            self.config.temporal_patch_size,
            3,
            merge_size // self.config.spatial_merge_size,
            self.config.spatial_merge_size,
            self.config.patch_size,
            merge_size // self.config.spatial_merge_size,
            self.config.spatial_merge_size,
            self.config.patch_size,
        ], dtype=torch.float32, device=device)
        patches = fake_image.permute(0, 3, 6, 4, 7, 2, 1, 5, 8)
        flatten_patches = patches.reshape(
            merge_size * merge_size, 3 * self.config.temporal_patch_size * self.config.patch_size * self.config.patch_size
        )
        return [flatten_patches], [(1, merge_size, merge_size)], [1]


class OmniVisualBridge(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.merge_size = self.config.merge_size if hasattr(self.config, 'merge_size') else 2
        self.hidden_size = config.embed_dim * (self.merge_size**2)
        self.ln_q = nn.LayerNorm(config.embed_dim, eps=1e-6)
        self.mlp = nn.Sequential(
            nn.Linear(self.hidden_size, self.hidden_size),
            nn.GELU(),
            nn.Linear(self.hidden_size, config.hidden_size),
        )
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
        return x


if __name__ == '__main__':
    fire.Fire()