File size: 4,909 Bytes
04f8e39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import torch
import torch.nn as nn
import numpy as np
import math

class lqformerattention(nn.Module):
    def __init__(self, embed_dim, num_heads, down_dim, up_dim):
        super().__init__()
        self.num_heads = num_heads
        self.down_dim = down_dim
        self.embed_dim = embed_dim
        self.down_head_dim = down_dim // num_heads
        self.head_dim = embed_dim // num_heads
        self.up_dim = up_dim
        self.q_proj = nn.Linear(self.down_dim, self.down_dim, bias=True)
        self.k_proj = nn.Linear(self.down_dim, self.down_dim, bias=True)
        self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)

        

    def forward(self, query, key, value, attention_mask=None):
        bsz, q_len, _ = query.size()
        k_len = key.size(1)
        v_len = value.size(1)

        query = self.q_proj(query).view(bsz, q_len, self.num_heads, self.down_head_dim).transpose(1, 2)
        key = self.k_proj(key).view(bsz, k_len, self.num_heads, self.down_head_dim).transpose(1, 2)
        value = self.v_proj(value).view(bsz, v_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        attn_weights = torch.matmul(
            query.to(torch.float32), key.to(torch.float32).transpose(2, 3)
        ) / math.sqrt(self.down_head_dim)

        if attention_mask is not None:
            attention_mask = attention_mask.masked_fill(attention_mask == 0, -1e4)
            attn_weights = attn_weights + attention_mask


        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value.dtype)
        attn_output = torch.matmul(attn_weights, value)

        attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, -1)
        return attn_output, attn_weights
    
class LQFormerLayer(nn.Module):
    def __init__(self, d_model, mm_model, n_heads, down_dim, up_dim):
        super(LQFormerLayer, self).__init__()
        self.t2q_attn = lqformerattention(embed_dim=down_dim, num_heads=n_heads, down_dim=down_dim, up_dim=up_dim)
        self.i2q_attn = lqformerattention(embed_dim=d_model, num_heads=n_heads, down_dim=down_dim, up_dim=up_dim)
        self.ln_text = nn.LayerNorm(down_dim)
        self.ln_q = nn.LayerNorm(down_dim)
        self.ln_kv = nn.LayerNorm(down_dim)
        self.n_heads = n_heads


    def forward(self, learnable_tokens, image_tokens, image_tokens_down, text_tokens, text_mask=None):
        # Down-project learnable tokens and text tokens
        
        # Residual connection for learnable tokens before self-attention
        residual_learnable = learnable_tokens
        
        # Layer norm
        learnable_tokens = self.ln_q(learnable_tokens)
        text_tokens = self.ln_text(text_tokens)
        batch_size = learnable_tokens.size(0)   
        if text_mask is not None:
            attention_mask = text_mask.unsqueeze(1).unsqueeze(2)  # (batch_size, 1, 1, seq_len)
            attention_mask = attention_mask.repeat(1, self.n_heads, learnable_tokens.size(1), 1)
        else:
            attention_mask = None
        attn_output, _ = self.t2q_attn(query=learnable_tokens, key=text_tokens, value=text_tokens, attention_mask=attention_mask)
        
        # Cross-attention: learnable tokens query image tokens
        image_tokens_down = self.ln_kv(image_tokens_down)
        attn_output, attention_map = self.i2q_attn(query=attn_output, key=image_tokens_down, value=image_tokens, attention_mask=None)
        
        attention_map = torch.mean(attention_map, dim=1)
        return attn_output, attention_map

class LQFormer(nn.Module):
    def __init__(self, config, num_layers=1):
        super(LQFormer, self).__init__()
        self.mm_model = config.hidden_size
        self.d_model = 1152
        self.down_dim = 576
        self.down_projector_learnable_text = nn.Linear(self.mm_model, self.down_dim, bias=True)
        self.down_projector_image = nn.Linear(self.d_model, self.down_dim, bias=True)
        self.layers = nn.ModuleList([LQFormerLayer(mm_model=self.mm_model, d_model = 1152, n_heads=config.num_attention_heads, down_dim = 576, up_dim = 2560) for _ in range(num_layers)])
        self.up_projector = nn.Linear(self.d_model, self.mm_model)

    def forward(self, learnable_tokens, image_tokens, text_tokens, text_mask=None):
        learnable_tokens_down = self.down_projector_learnable_text(learnable_tokens)
        text_tokens_down = self.down_projector_learnable_text(text_tokens)
        image_tokens_down = self.down_projector_image(image_tokens)
        # Pass through the layers
        for layer in self.layers:
            residual = learnable_tokens
            learnable_tokens, attention_map = layer(learnable_tokens_down, image_tokens, image_tokens_down, text_tokens_down, text_mask)
            learnable_tokens = self.up_projector(learnable_tokens)
            learnable_tokens = residual + learnable_tokens
        return learnable_tokens