Spaces:
Runtime error
Runtime error
from typing import Optional, List | |
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from open_clip.transformer import _expand_token, to_2tuple | |
def resample_abs_pos_embed( | |
posemb, | |
new_size: List[int], | |
old_size: Optional[List[int]] = None, | |
num_prefix_tokens: int = 1, | |
interpolation: str = 'bicubic', | |
antialias: bool = True | |
): | |
# sort out sizes, assume square if old size not provided | |
new_size = to_2tuple(new_size) | |
new_ntok = new_size[0] * new_size[1] | |
if not old_size: | |
old_size = int(math.sqrt(posemb.shape[1] - num_prefix_tokens)) | |
old_size = to_2tuple(old_size) | |
if new_size == old_size: # might not both be same container type | |
return posemb | |
if num_prefix_tokens: | |
posemb_prefix, posemb = posemb[:, :num_prefix_tokens], posemb[:, num_prefix_tokens:] | |
else: | |
posemb_prefix, posemb = None, posemb | |
# do the interpolation | |
posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2) | |
posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias) | |
posemb = posemb.permute(0, 2, 3, 1).reshape(1, new_ntok, -1) | |
# add back extra (class, etc) prefix tokens | |
if posemb_prefix is not None: | |
posemb = torch.cat([posemb_prefix, posemb], dim=1) | |
return posemb | |
class SelfSelfAttention(nn.Module): | |
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., ss_attn_iter=1, | |
ss_attn_temp=None): | |
super().__init__() | |
self.num_heads = num_heads | |
head_dim = dim // num_heads | |
self.scale = qk_scale or head_dim ** -0.5 | |
self.ss_attn_iter = ss_attn_iter | |
self.ss_attn_temp = ss_attn_temp | |
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
self.attn_drop = nn.Dropout(attn_drop) | |
self.proj = nn.Linear(dim, dim) | |
self.proj_drop = nn.Dropout(proj_drop) | |
def forward(self, x, attn_bias=None, prev_attn=None): | |
x = x.transpose(0, 1) | |
B, N, C = x.shape | |
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) | |
q, k, v = qkv[0], qkv[1], qkv[2] | |
self.v_values = v | |
# original self-attention for the original path | |
attn_ori_return = (q @ k.transpose(-2, -1)) * self.scale | |
attn_ori = attn_ori_return.softmax(dim=-1) | |
attn_ori = self.attn_drop(attn_ori) | |
x_ori = (attn_ori @ v).transpose(1, 2).reshape(B, N, C) | |
x_ori = self.proj_drop(self.proj(x_ori)) | |
# GEM | |
xs1 = v | |
xs2 = k | |
xs3 = q | |
if self.ss_attn_temp is None: | |
pre_norm = torch.norm(x, dim=-1).mean(dim=-1, keepdim=True).unsqueeze(1).unsqueeze(-1) | |
inv_temp = pre_norm * self.scale | |
else: | |
inv_temp = self.ss_attn_temp | |
for it in range(self.ss_attn_iter): | |
xs1 = F.normalize(xs1, dim=-1) | |
xs2 = F.normalize(xs2, dim=-1) | |
xs3 = F.normalize(xs3, dim=-1) | |
attn_return1 = (xs1 @ xs1.transpose(-2, -1)) * inv_temp | |
attn_return2 = (xs2 @ xs2.transpose(-2, -1)) * inv_temp | |
attn_return3 = (xs3 @ xs3.transpose(-2, -1)) * inv_temp | |
attn1 = (attn_return1).softmax(dim=-1) | |
attn2 = (attn_return2).softmax(dim=-1) | |
attn3 = (attn_return3).softmax(dim=-1) | |
xs1 = attn1 @ xs1 | |
xs2 = attn2 @ xs2 | |
xs3 = attn3 @ xs3 | |
# Assigment to V | |
xs1 = F.normalize(xs1, dim=-1) | |
xs2 = F.normalize(xs2, dim=-1) | |
xs3 = F.normalize(xs3, dim=-1) | |
attn_return1 = (xs1 @ xs1.transpose(-2, -1)) * inv_temp | |
attn_return2 = (xs2 @ xs2.transpose(-2, -1)) * inv_temp | |
attn_return3 = (xs3 @ xs3.transpose(-2, -1)) * inv_temp | |
attn1 = (attn_return1).softmax(dim=-1) | |
attn2 = (attn_return2).softmax(dim=-1) | |
attn3 = (attn_return3).softmax(dim=-1) | |
xs1 = attn1 @ v | |
xs2 = attn2 @ v | |
xs3 = attn3 @ v | |
xs = (xs1 + xs2 + xs3) / 3 | |
x = xs.transpose(1, 2).reshape(B, N, C) | |
x = self.proj_drop(self.proj(x)) | |
return [x.transpose(0, 1), x_ori.transpose(0, 1)] | |
class GEMResidualBlock(nn.Module): | |
def __init__(self, res_block): | |
super(GEMResidualBlock, self).__init__() | |
self.res_block = res_block | |
def forward(self, | |
q_x: torch.Tensor, | |
k_x: Optional[torch.Tensor] = None, | |
v_x: Optional[torch.Tensor] = None, | |
attn_mask: Optional[torch.Tensor] = None, | |
): | |
if isinstance(q_x, list): | |
x_gem, q_x = q_x | |
else: | |
x_gem = q_x | |
x_gem_res, x_ori_res = self.res_block.attn(x=self.res_block.ln_1(q_x)) | |
x_gem_res, x_ori_res = self.res_block.ls_1(x_gem_res), self.res_block.ls_1(x_ori_res) | |
# Original | |
x_ori = q_x + x_ori_res | |
x_ori = x_ori + self.res_block.ls_2(self.res_block.mlp(self.res_block.ln_2(x_ori))) | |
# GEM | |
x_gem = x_gem + x_gem_res | |
return [x_gem, x_ori] | |
class GEMViT(nn.Module): | |
def __init__(self, vit): | |
self.vit = vit | |
def modified_vit_forward(self, x: torch.Tensor): | |
x = self.conv1(x) # shape = [*, width, grid, grid] | |
grid_h, grid_w = x.shape[2:] | |
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] | |
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] | |
# class embeddings and positional embeddings | |
x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1) | |
# shape = [*, grid ** 2 + 1, width] | |
if x.shape[1] != self.positional_embedding.shape[1]: | |
pos_emb = resample_abs_pos_embed(self.positional_embedding.unsqueeze(0), | |
new_size=[grid_h, grid_w], | |
# old_size=list(self.grid_size), | |
num_prefix_tokens=1, | |
interpolation='bicubic', | |
antialias=True) | |
else: | |
pos_emb = self.positional_embedding | |
x = x + pos_emb.to(x.dtype) | |
# x = x + self.positional_embedding.to(x.dtype) | |
x = self.patch_dropout(x) | |
x = self.ln_pre(x) | |
x = x.permute(1, 0, 2) # NLD -> LND | |
x_gem, x = self.transformer(x) | |
x = x.permute(1, 0, 2) # LND -> NLD | |
x_gem = x_gem.permute(1, 0, 2) # LND -> NLD | |
# Apply proj | |
x = self.ln_post(x) | |
x_gem = self.ln_post(x_gem) | |
if self.proj is not None: | |
x = x @ self.proj | |
x_gem = x_gem @ self.proj | |
return [x_gem, x] | |