Spaces:
Runtime error
Runtime error
File size: 6,817 Bytes
be1ec96 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
from typing import Optional, List
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from open_clip.transformer import _expand_token, to_2tuple
def resample_abs_pos_embed(
posemb,
new_size: List[int],
old_size: Optional[List[int]] = None,
num_prefix_tokens: int = 1,
interpolation: str = 'bicubic',
antialias: bool = True
):
# sort out sizes, assume square if old size not provided
new_size = to_2tuple(new_size)
new_ntok = new_size[0] * new_size[1]
if not old_size:
old_size = int(math.sqrt(posemb.shape[1] - num_prefix_tokens))
old_size = to_2tuple(old_size)
if new_size == old_size: # might not both be same container type
return posemb
if num_prefix_tokens:
posemb_prefix, posemb = posemb[:, :num_prefix_tokens], posemb[:, num_prefix_tokens:]
else:
posemb_prefix, posemb = None, posemb
# do the interpolation
posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2)
posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
posemb = posemb.permute(0, 2, 3, 1).reshape(1, new_ntok, -1)
# add back extra (class, etc) prefix tokens
if posemb_prefix is not None:
posemb = torch.cat([posemb_prefix, posemb], dim=1)
return posemb
class SelfSelfAttention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., ss_attn_iter=1,
ss_attn_temp=None):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.ss_attn_iter = ss_attn_iter
self.ss_attn_temp = ss_attn_temp
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, attn_bias=None, prev_attn=None):
x = x.transpose(0, 1)
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
self.v_values = v
# original self-attention for the original path
attn_ori_return = (q @ k.transpose(-2, -1)) * self.scale
attn_ori = attn_ori_return.softmax(dim=-1)
attn_ori = self.attn_drop(attn_ori)
x_ori = (attn_ori @ v).transpose(1, 2).reshape(B, N, C)
x_ori = self.proj_drop(self.proj(x_ori))
# GEM
xs1 = v
xs2 = k
xs3 = q
if self.ss_attn_temp is None:
pre_norm = torch.norm(x, dim=-1).mean(dim=-1, keepdim=True).unsqueeze(1).unsqueeze(-1)
inv_temp = pre_norm * self.scale
else:
inv_temp = self.ss_attn_temp
for it in range(self.ss_attn_iter):
xs1 = F.normalize(xs1, dim=-1)
xs2 = F.normalize(xs2, dim=-1)
xs3 = F.normalize(xs3, dim=-1)
attn_return1 = (xs1 @ xs1.transpose(-2, -1)) * inv_temp
attn_return2 = (xs2 @ xs2.transpose(-2, -1)) * inv_temp
attn_return3 = (xs3 @ xs3.transpose(-2, -1)) * inv_temp
attn1 = (attn_return1).softmax(dim=-1)
attn2 = (attn_return2).softmax(dim=-1)
attn3 = (attn_return3).softmax(dim=-1)
xs1 = attn1 @ xs1
xs2 = attn2 @ xs2
xs3 = attn3 @ xs3
# Assigment to V
xs1 = F.normalize(xs1, dim=-1)
xs2 = F.normalize(xs2, dim=-1)
xs3 = F.normalize(xs3, dim=-1)
attn_return1 = (xs1 @ xs1.transpose(-2, -1)) * inv_temp
attn_return2 = (xs2 @ xs2.transpose(-2, -1)) * inv_temp
attn_return3 = (xs3 @ xs3.transpose(-2, -1)) * inv_temp
attn1 = (attn_return1).softmax(dim=-1)
attn2 = (attn_return2).softmax(dim=-1)
attn3 = (attn_return3).softmax(dim=-1)
xs1 = attn1 @ v
xs2 = attn2 @ v
xs3 = attn3 @ v
xs = (xs1 + xs2 + xs3) / 3
x = xs.transpose(1, 2).reshape(B, N, C)
x = self.proj_drop(self.proj(x))
return [x.transpose(0, 1), x_ori.transpose(0, 1)]
class GEMResidualBlock(nn.Module):
def __init__(self, res_block):
super(GEMResidualBlock, self).__init__()
self.res_block = res_block
def forward(self,
q_x: torch.Tensor,
k_x: Optional[torch.Tensor] = None,
v_x: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
):
if isinstance(q_x, list):
x_gem, q_x = q_x
else:
x_gem = q_x
x_gem_res, x_ori_res = self.res_block.attn(x=self.res_block.ln_1(q_x))
x_gem_res, x_ori_res = self.res_block.ls_1(x_gem_res), self.res_block.ls_1(x_ori_res)
# Original
x_ori = q_x + x_ori_res
x_ori = x_ori + self.res_block.ls_2(self.res_block.mlp(self.res_block.ln_2(x_ori)))
# GEM
x_gem = x_gem + x_gem_res
return [x_gem, x_ori]
class GEMViT(nn.Module):
def __init__(self, vit):
self.vit = vit
def modified_vit_forward(self, x: torch.Tensor):
x = self.conv1(x) # shape = [*, width, grid, grid]
grid_h, grid_w = x.shape[2:]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
# class embeddings and positional embeddings
x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1)
# shape = [*, grid ** 2 + 1, width]
if x.shape[1] != self.positional_embedding.shape[1]:
pos_emb = resample_abs_pos_embed(self.positional_embedding.unsqueeze(0),
new_size=[grid_h, grid_w],
# old_size=list(self.grid_size),
num_prefix_tokens=1,
interpolation='bicubic',
antialias=True)
else:
pos_emb = self.positional_embedding
x = x + pos_emb.to(x.dtype)
# x = x + self.positional_embedding.to(x.dtype)
x = self.patch_dropout(x)
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x_gem, x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x_gem = x_gem.permute(1, 0, 2) # LND -> NLD
# Apply proj
x = self.ln_post(x)
x_gem = self.ln_post(x_gem)
if self.proj is not None:
x = x @ self.proj
x_gem = x_gem @ self.proj
return [x_gem, x]
|