LINC-BIT's picture
Upload 1912 files
b84549f verified
raw
history blame contribute delete
86.7 kB
import timm
from timm.models._factory import load_checkpoint
import torch
import os
from typing import List, Union, Optional, Tuple
from torch import nn
from torch.jit import Final
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from utils.dl.common.model import get_model_device, set_module, get_module, get_model_latency, get_model_size, LayerActivation3
import torch.nn.functional as F
from utils.common.log import logger
from transformers import AutoTokenizer
import torch.nn.functional as F
from maskrcnn_benchmark.modeling.detector.generalized_vl_rcnn import GeneralizedVLRCNN
from maskrcnn_benchmark.config import cfg
from maskrcnn_benchmark.structures.bounding_box import BoxList
from torchvision import transforms as T
import matplotlib.pyplot as plt
import nltk
import re
from copy import deepcopy
from abc import ABC, abstractmethod
from methods.elasticdnn.pipeline.offline.fm_to_md.base import FM_to_MD_Util
from methods.elasticdnn.pipeline.offline.fm_lora.base import FMLoRA_Util, LoRA
from new_impl.cv.elasticdnn.api.model import ElasticDNN_OfflineFMModel, ElasticDNN_OfflineMDModel
from methods.elasticdnn.model.base import Abs, KTakesAll, ElasticDNNUtil, Layer_WrappedWithFBS
from transformers.models.bert.modeling_bert import BertSelfAttention
from transformers import BertConfig
import math
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
def collect_mm_fn(batch):
if len(batch[0]) == 2:
dict = {'images' : [], 'targets' : []}
else:
dict = {'images' : [], 'targets' : [], "info_imgs" : [], "ids" : []}
for item in batch:
if len(item) == 2:
img, new_target = item
if len(new_target) == 0:
continue
dict['images'].append(img)
dict['targets'].append(new_target)
else:
img, new_target, info_imgs, ids = item
if len(new_target) == 0:
continue
dict['images'].append(img)
dict['targets'].append(new_target)
dict['info_imgs'].append(info_imgs)
dict['ids'].append(ids)
return dict, torch.Tensor([0])
def run_ner(caption):
noun_phrases = find_noun_phrases(caption)
noun_phrases = [remove_punctuation(phrase) for phrase in noun_phrases]
noun_phrases = [phrase for phrase in noun_phrases if phrase != '']
relevant_phrases = noun_phrases
labels = noun_phrases
tokens_positive = []
for entity, label in zip(relevant_phrases, labels):
try:
# search all occurrences and mark them as different entities
for m in re.finditer(entity, caption.lower()):
tokens_positive.append([[m.start(), m.end()]])
except:
print("noun entities:", noun_phrases)
print("entity:", entity)
print("caption:", caption.lower())
return tokens_positive
def build_transform(cfg, min_image_size):
"""
Creates a basic transformation that was used to train the models
"""
# we are loading images with OpenCV, so we don't need to convert them
# to BGR, they are already! So all we need to do is to normalize
# by 255 if we want to convert to BGR255 format, or flip the channels
# if we want it to be in RGB in [0-1] range.
if cfg.INPUT.TO_BGR255:
to_bgr_transform = T.Lambda(lambda x: x * 255)
else:
to_bgr_transform = T.Lambda(lambda x: x[[2, 1, 0]])
normalize_transform = T.Normalize(
mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD
)
transform = T.Compose(
[
T.ToPILImage(),
T.Resize(min_image_size) if min_image_size is not None else lambda x: x,
T.ToTensor(),
to_bgr_transform,
normalize_transform,
]
)
return transform
def remove_punctuation(text: str) -> str:
punct = ['|', ':', ';', '@', '(', ')', '[', ']', '{', '}', '^',
'\'', '\"', '’', '`', '?', '$', '%', '#', '!', '&', '*', '+', ',', '.'
]
for p in punct:
text = text.replace(p, '')
return text.strip()
def create_positive_map_label_to_token_from_positive_map(positive_map, plus=0):
positive_map_label_to_token = {}
for i in range(len(positive_map)):
positive_map_label_to_token[i + plus] = torch.nonzero(positive_map[i], as_tuple=True)[0].tolist()
return positive_map_label_to_token
def create_positive_map(tokenized, tokens_positive):
"""construct a map such that positive_map[i,j] = True iff box i is associated to token j"""
positive_map = torch.zeros((len(tokens_positive), 256), dtype=torch.float)
for j, tok_list in enumerate(tokens_positive):
for (beg, end) in tok_list:
try:
beg_pos = tokenized.char_to_token(beg)
end_pos = tokenized.char_to_token(end - 1)
except Exception as e:
print("beg:", beg, "end:", end)
print("token_positive:", tokens_positive)
# print("beg_pos:", beg_pos, "end_pos:", end_pos)
raise e
if beg_pos is None:
try:
beg_pos = tokenized.char_to_token(beg + 1)
if beg_pos is None:
beg_pos = tokenized.char_to_token(beg + 2)
except:
beg_pos = None
if end_pos is None:
try:
end_pos = tokenized.char_to_token(end - 2)
if end_pos is None:
end_pos = tokenized.char_to_token(end - 3)
except:
end_pos = None
if beg_pos is None or end_pos is None:
continue
assert beg_pos is not None and end_pos is not None
positive_map[j, beg_pos: end_pos + 1].fill_(1)
return positive_map / (positive_map.sum(-1)[:, None] + 1e-6)
def find_noun_phrases(caption: str) -> List[str]:
caption = caption.lower()
tokens = nltk.word_tokenize(caption)
pos_tags = nltk.pos_tag(tokens)
grammar = "NP: {<DT>?<JJ.*>*<NN.*>+}"
cp = nltk.RegexpParser(grammar)
result = cp.parse(pos_tags)
noun_phrases = list()
for subtree in result.subtrees():
if subtree.label() == 'NP':
noun_phrases.append(' '.join(t[0] for t in subtree.leaves()))
return noun_phrases
class Glip(nn.Module):
def __init__(self, config, pretrain_path, min_image_size=None,confidence_threshold=0.7):
super(Glip, self).__init__()
state_dict = torch.load(pretrain_path)['model']
self.min_image_size = min_image_size
self.cfg = config
self.confidence_threshold = confidence_threshold
self.tokenizer = AutoTokenizer.from_pretrained(cfg.MODEL.LANGUAGE_BACKBONE.MODEL_PATH)
self.device = torch.device(cfg.MODEL.DEVICE)
for k in list(state_dict.keys()):
if k.startswith('module'):
new_k = k.replace('module.', '')
state_dict[new_k] = state_dict.pop(k)
self.model = GeneralizedVLRCNN(config)
self.model.load_state_dict(state_dict, strict=False)
# self.transform = build_transform(config, min_image_size)
def forward(self, images, targets, for_training=None):
# img_list = []
# for image in images:
# img_list.append(self.transform(image).to(self.device))
# if isinstance(texts, list):
# # we directly provided a list of category names
# caption_string = ""
# tokens_positive = []
# seperation_tokens = " . "
# for word in texts:
# tokens_positive.append([len(caption_string), len(caption_string) + len(word)])
# caption_string += word
# caption_string += seperation_tokens
# tokenized = self.tokenizer([caption_string], return_tensors="pt")
# tokens_positive = [tokens_positive]
# texts = [caption_string]
# print(tokens_positive)
# else:
device = torch.device(cfg.MODEL.DEVICE)
images = [image.to(device) for image in images]
targets = [target.to(device) for target in targets]
texts = [t.get_field("caption") for t in targets if "caption" in t.fields()]
positive_map = []
# if custom_entity is None:
# tokens_positive = self.run_ner(texts)
# print(tokens_positive)
# process positive map
if self.training == False:
try:
tokens_positive = run_ner(texts[0])
except:
print('a')
tokenized = self.tokenizer(texts, return_tensors="pt")
positive_map = create_positive_map(tokenized, tokens_positive)
if self.cfg.MODEL.RPN_ARCHITECTURE == "VLDYHEAD":
plus = 1
else:
plus = 0
positive_map = create_positive_map_label_to_token_from_positive_map(positive_map, plus=plus)
else:
for i, text in enumerate(texts):
tokenized = self.tokenizer(text, return_tensors="pt")
tokens_positive = targets[i].get_field('tokens_positive')
positive_map.append(create_positive_map(tokenized, tokens_positive))
positive_map = torch.cat(positive_map, dim=0).to(device)
if self.training:
proposal_losses = self.model(images, targets, texts, positive_map=positive_map)
return proposal_losses
else:
proposals, token_logits, dot_product_logits = self.model(images, targets, texts, positive_map=positive_map)
proposal = self._post_process(proposals[0])
return proposal, token_logits, dot_product_logits
def _post_process_fixed_thresh(self, predictions):
scores = predictions.get_field("scores")
labels = predictions.get_field("labels").tolist()
thresh = scores.clone()
for i, lb in enumerate(labels):
if isinstance(self.confidence_threshold, float):
thresh[i] = self.confidence_threshold
elif len(self.confidence_threshold) == 1:
thresh[i] = self.confidence_threshold[0]
else:
thresh[i] = self.confidence_threshold[lb - 1]
keep = torch.nonzero(scores > thresh).squeeze(1)
predictions = predictions[keep]
scores = predictions.get_field("scores")
_, idx = scores.sort(0, descending=True)
return predictions[idx]
def _post_process(self, predictions, threshold=0.5):
scores = predictions.get_field("scores")
labels = predictions.get_field("labels").tolist()
thresh = scores.clone()
for i, lb in enumerate(labels):
if isinstance(self.confidence_threshold, float):
thresh[i] = threshold
elif len(self.confidence_threshold) == 1:
thresh[i] = threshold
else:
thresh[i] = self.confidence_threshold[lb - 1]
keep = torch.nonzero(scores > thresh).squeeze(1)
predictions = predictions[keep]
scores = predictions.get_field("scores")
_, idx = scores.sort(0, descending=True)
return predictions[idx]
# @torch.no_grad()
# def clip_vit_b_16():
# # https://huggingface.co/openai/clip-vit-base-patch16
# model = CLIPModelCanReceiveTextEmbeds.from_pretrained("openai/clip-vit-base-patch16")
# processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
# print(model)
# from PIL import Image
# import requests
# image = Image.open('/data/zql/datasets/Caltech-256/data/caltech256/256_ObjectCategories/003.backpack/003_0001.jpg')
# inputs = processor(text=["a photo of a dog", "a photo of a backpack", "a photo of a cat"], images=image, return_tensors="pt", padding=True)
# print(inputs)
# from utils.dl.common.model import LayerActivation2, get_module
# input_embed_hook = LayerActivation2(get_module(model, 'text_model.embeddings'))
# outputs = model(**inputs)
# logits_per_image = outputs.logits_per_image # this is the image-text similarity score
# probs = logits_per_image.softmax(dim=1)
# print(probs)
# input_embed = input_embed_hook.output
# input_embed_hook.remove()
# torch.save(input_embed, os.path.join(os.path.dirname(__file__), './test_input_embed.pth'))
# print('embed', input_embed.size())
# del inputs['input_ids']
# inputs['input_embeds'] = input_embed
# outputs = model(**inputs)
# logits_per_image = outputs.logits_per_image # this is the image-text similarity score
# probs = logits_per_image.softmax(dim=1)
# print(probs)
@torch.no_grad()
def glip_model(config_path, pretrain_path):
# https://huggingface.co/openai/clip-vit-base-patch16
cfg.merge_from_file(config_path)
return cfg, Glip(cfg, pretrain_path)
class ToQKV_WrappedWithLoRA(nn.Module):
def __init__(self, fc: nn.Linear, ab_r: int):
super(ToQKV_WrappedWithLoRA, self).__init__()
self.fc = fc
self.ab = self.create_ab_as_linear(fc.weight.data, ab_r)
def create_ab_as_linear(self, fc_weight: torch.Tensor, ab_r: int):
res = nn.Sequential(
LoRA(fc_weight.size(1), fc_weight.size(0) // ab_r, bias=False),
LoRA(fc_weight.size(0) // ab_r, fc_weight.size(0), bias=False)
).to(fc_weight.device)
nn.init.kaiming_uniform_(res[0].weight, a=5 ** 0.5)
nn.init.zeros_(res[1].weight)
return res
def forward(self, x):
x1 = self.fc(x)
x2 = self.ab(x)
return x1 + x2
def get_model_latency_2(model: torch.nn.Module, sample: dict, sample_num: int,
device: str, warmup_sample_num: int, return_detail=False):
"""Get the latency (inference time) of a PyTorch model.
Reference: https://deci.ai/resources/blog/measure-inference-time-deep-neural-networks/
Args:
model (torch.nn.Module): A PyTorch model.
model_input_size (Tuple[int]): Typically be `(1, 3, 32, 32)` or `(1, 3, 224, 224)`.
sample_num (int): How many inputs which size is :attr:`model_input_size` will be tested and compute the average latency as result.
device (str): Typically be 'cpu' or 'cuda'.
warmup_sample_num (int): Let model perform some dummy inference to warm up the test environment to avoid measurement loss.
return_detail (bool, optional): Beside the average latency, return all result measured. Defaults to False.
Returns:
Union[float, Tuple[float, List[float]]]: The average latency (and all lantecy data) of :attr:`model`.
"""
# if isinstance(model_input_size, tuple):
# dummy_input = torch.rand(model_input_size).to(device)
# else:
# dummy_input = model_input_size
model = model.to(device)
model.eval()
# warm up
with torch.no_grad():
for _ in range(warmup_sample_num):
model(**sample)
infer_time_list = []
if device == 'cuda' or 'cuda' in str(device):
with torch.no_grad():
for _ in range(sample_num):
s, e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
s.record()
model(**sample)
e.record()
torch.cuda.synchronize()
cur_model_infer_time = s.elapsed_time(e) / 1000.
infer_time_list += [cur_model_infer_time]
else:
with torch.no_grad():
for _ in range(sample_num):
start = time.time()
model(**sample)
cur_model_infer_time = time.time() - start
infer_time_list += [cur_model_infer_time]
avg_infer_time = sum(infer_time_list) / sample_num
if return_detail:
return avg_infer_time, infer_time_list
return avg_infer_time
class WindowAttention(nn.Module):
""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
""" Forward function.
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class BiMultiHeadAttention(nn.Module):
def __init__(self, v_dim, l_dim, embed_dim, num_heads, dropout=0.1, cfg=None):
super(BiMultiHeadAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.v_dim = v_dim
self.l_dim = l_dim
assert (
self.head_dim * self.num_heads == self.embed_dim
), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
self.scale = self.head_dim ** (-0.5)
self.dropout = dropout
self.v_proj = nn.Linear(self.v_dim, self.embed_dim)
self.l_proj = nn.Linear(self.l_dim, self.embed_dim)
self.values_v_proj = nn.Linear(self.v_dim, self.embed_dim)
self.values_l_proj = nn.Linear(self.l_dim, self.embed_dim)
self.out_v_proj = nn.Linear(self.embed_dim, self.v_dim)
self.out_l_proj = nn.Linear(self.embed_dim, self.l_dim)
self.stable_softmax_2d = cfg.MODEL.DYHEAD.FUSE_CONFIG.STABLE_SOFTMAX_2D
self.clamp_min_for_underflow = cfg.MODEL.DYHEAD.FUSE_CONFIG.CLAMP_MIN_FOR_UNDERFLOW
self.clamp_max_for_overflow = cfg.MODEL.DYHEAD.FUSE_CONFIG.CLAMP_MAX_FOR_OVERFLOW
self._reset_parameters()
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def _reset_parameters(self):
nn.init.xavier_uniform_(self.v_proj.weight)
self.v_proj.bias.data.fill_(0)
nn.init.xavier_uniform_(self.l_proj.weight)
self.l_proj.bias.data.fill_(0)
nn.init.xavier_uniform_(self.values_v_proj.weight)
self.values_v_proj.bias.data.fill_(0)
nn.init.xavier_uniform_(self.values_l_proj.weight)
self.values_l_proj.bias.data.fill_(0)
nn.init.xavier_uniform_(self.out_v_proj.weight)
self.out_v_proj.bias.data.fill_(0)
nn.init.xavier_uniform_(self.out_l_proj.weight)
self.out_l_proj.bias.data.fill_(0)
def forward(self, v, l, attention_mask_l=None):
bsz, tgt_len, embed_dim = v.size()
query_states = self.v_proj(v) * self.scale
key_states = self._shape(self.l_proj(l), -1, bsz)
value_v_states = self._shape(self.values_v_proj(v), -1, bsz)
value_l_states = self._shape(self.values_l_proj(l), -1, bsz)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_v_states = value_v_states.view(*proj_shape)
value_l_states = value_l_states.view(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
)
# attn_weights_l = nn.functional.softmax(attn_weights.transpose(1, 2), dim=-1)
if self.stable_softmax_2d:
attn_weights = attn_weights - attn_weights.max()
if self.clamp_min_for_underflow:
attn_weights = torch.clamp(attn_weights, min=-50000) # Do not increase -50000, data type half has quite limited range
if self.clamp_max_for_overflow:
attn_weights = torch.clamp(attn_weights, max=50000) # Do not increase 50000, data type half has quite limited range
attn_weights_T = attn_weights.transpose(1, 2)
attn_weights_l = (attn_weights_T - torch.max(attn_weights_T, dim=-1, keepdim=True)[
0])
if self.clamp_min_for_underflow:
attn_weights_l = torch.clamp(attn_weights_l, min=-50000) # Do not increase -50000, data type half has quite limited range
if self.clamp_max_for_overflow:
attn_weights_l = torch.clamp(attn_weights_l, max=50000) # Do not increase 50000, data type half has quite limited range
attn_weights_l = attn_weights_l.softmax(dim=-1)
if attention_mask_l is not None:
assert (attention_mask_l.dim() == 2)
attention_mask = attention_mask_l.unsqueeze(1).unsqueeze(1)
attention_mask = attention_mask.expand(bsz, 1, tgt_len, src_len)
attention_mask = attention_mask.masked_fill(attention_mask == 0, -9e15)
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights_v = nn.functional.softmax(attn_weights, dim=-1)
attn_probs_v = F.dropout(attn_weights_v, p=self.dropout, training=self.training)
attn_probs_l = F.dropout(attn_weights_l, p=self.dropout, training=self.training)
attn_output_v = torch.bmm(attn_probs_v, value_l_states)
attn_output_l = torch.bmm(attn_probs_l, value_v_states)
if attn_output_v.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output_v` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output_v.size()}"
)
if attn_output_l.size() != (bsz * self.num_heads, src_len, self.head_dim):
raise ValueError(
f"`attn_output_l` should be of size {(bsz, self.num_heads, src_len, self.head_dim)}, but is {attn_output_l.size()}"
)
attn_output_v = attn_output_v.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output_v = attn_output_v.transpose(1, 2)
attn_output_v = attn_output_v.reshape(bsz, tgt_len, self.embed_dim)
attn_output_l = attn_output_l.view(bsz, self.num_heads, src_len, self.head_dim)
attn_output_l = attn_output_l.transpose(1, 2)
attn_output_l = attn_output_l.reshape(bsz, src_len, self.embed_dim)
attn_output_v = self.out_v_proj(attn_output_v)
attn_output_l = self.out_l_proj(attn_output_l)
return attn_output_v, attn_output_l
class BertSelfAttentionPrunable(BertSelfAttention):
def __init__(self):
config = BertConfig.from_pretrained('new_impl/cv/glip/object_detection/bert-base-uncased')
super(BertSelfAttentionPrunable, self).__init__(config)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, -1)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_layer = past_key_value[0]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
use_cache = past_key_value is not None
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
if use_cache:
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
-1, 1
)
else:
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.query.out_features,) # NOTE: modified
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs
@staticmethod
def init_from_exist_self_attn(attn: BertSelfAttention):
# print(attn)
res = BertSelfAttentionPrunable()
for attr in dir(attn):
# if str(attr) in ['transpose_for_scores'] or str(attr).startswith('_'):
# continue
# if isinstance(getattr(attn, attr), nn.Module):
# print(attr)
if isinstance(getattr(attn, attr), nn.Module):
try:
# print(attr, 'ok')
setattr(res, attr, getattr(attn, attr))
except Exception as e:
print(attr, str(e))
return res
class FM_to_MD_GLIP_Util(FM_to_MD_Util):
def init_md_from_fm_by_reducing_width_with_perf_test_2(self, fm: nn.Module, reducing_width_ratio: int,
samples: torch.Tensor) -> nn.Module:
fm_size = get_model_size(fm, True)
fm_latency = get_model_latency_2(fm, samples, 20,
get_model_device(fm), 20, False)
master_dnn = self.init_md_from_fm_by_reducing_width(fm, reducing_width_ratio)
master_dnn_size = get_model_size(master_dnn, True)
logger.debug(f'inited master DNN: {master_dnn}')
# from utils.dl.common.model import get_module
# print('after generating')
# get_module(fm, 'head').debug()
# get_module(master_dnn, 'head').debug()
# print('test master latency')
master_dnn_latency = get_model_latency_2(fm, samples, 20,
get_model_device(fm), 20, False)
logger.info(f'init master DNN (w/o FBS yet) by reducing foundation model\'s width (by {reducing_width_ratio:d}x)')
logger.info(f'foundation model ({fm_size:.3f}MB, {fm_latency:.4f}s/sample) -> '
f'master DNN ({master_dnn_size:.3f}MB, {master_dnn_latency:.4f}s/sample)\n'
f'(model size: ↓ {(fm_size / master_dnn_size):.2f}x, '
f'latency: ↓ {(fm_latency / master_dnn_latency):.2f}x)')
return master_dnn
def init_md_from_fm_by_reducing_width(self, fm: nn.Module, reducing_width_ratio: int, sparsity=0.0) -> nn.Module:
#sparsity: it is mainly used to make a distilled model used in the baseline algorithm, and the parameter can ensure that the model has the same size as the model used in the online algorithm.
fm_vit = deepcopy(fm)
def _f(n):
return int(n // reducing_width_ratio)
# def _rand_indexes(n):
# return torch.randperm(n)[0: int(n // reducing_width_ratio)]
def l1_max_indexes(p: torch.Tensor, dim=0):
assert dim in [0, 1]
assert p.dim() in [1, 2, 4]
if dim == 1:
p = p.T
p_norm = p.abs().contiguous().view(p.size(0), -1).sum(dim=1)
n = p.size(0)
t1 = p_norm.argsort(descending=True)[0: int(n // reducing_width_ratio)]
t2 = t1.sort()[0]
return p_norm.argsort(descending=True)[0: int(n // reducing_width_ratio)].sort()[0]
def l1_max_indexes_with_sparsity(p: torch.Tensor, dim=0):
assert dim in [0, 1]
assert p.dim() in [1, 2, 4]
if dim == 1:
p = p.T
p_norm = p.abs().contiguous().view(p.size(0), -1).sum(dim=1)
n = p.size(0)
return p_norm.argsort(descending=True)[0: int(n // reducing_width_ratio * (1 - sparsity))].sort()[0]
for layer_i, layer in enumerate(fm_vit.model.backbone.body.layers):
for block in layer.blocks:
ori_attn = block.attn
new_attn = WindowAttention(ori_attn.dim, ori_attn.window_size, ori_attn.num_heads, True, ori_attn.scale, 0., 0.)
new_attn.relative_position_index = ori_attn.relative_position_index
new_attn.relative_position_bias_table = ori_attn.relative_position_bias_table
new_attn.qkv = ori_attn.qkv
new_attn.attn_drop = ori_attn.attn_drop
new_attn.proj = ori_attn.proj
new_attn.proj_drop = ori_attn.proj_drop
set_module(block, 'attn', new_attn)
# first_attn = True
for layer_i, layer in enumerate(fm_vit.model.backbone.body.layers):
for block_i, block in enumerate(layer.blocks):
qkv = block.attn.qkv
new_qkv = nn.Linear(qkv.in_features, _f(qkv.out_features),
qkv.bias is not None, qkv.weight.device)
indexes = l1_max_indexes(qkv.weight.data, 0)
new_qkv.weight.data.copy_(qkv.weight.data[indexes])
if qkv.bias is not None:
new_qkv.bias.data.copy_(qkv.bias.data[indexes])
# fm_vit.model.backbone.body.layers[0].blocks.0.attn.qkv
set_module(fm_vit, f'model.backbone.body.layers.{layer_i}.blocks.{block_i}.attn.qkv', new_qkv)
proj = block.attn.proj
new_proj = nn.Linear(_f(proj.in_features), proj.out_features,
proj.bias is not None, proj.weight.device)
new_proj.weight.data.copy_(proj.weight.data[:, l1_max_indexes(proj.weight.data, 1)])
if proj.bias is not None:
new_proj.bias.data.copy_(proj.bias.data)
set_module(fm_vit, f'model.backbone.body.layers.{layer_i}.blocks.{block_i}.attn.proj', new_proj)
fc1 = block.mlp.fc1
new_fc1 = nn.Linear(fc1.in_features, int(_f(fc1.out_features) * (1 - sparsity)),
fc1.bias is not None, fc1.weight.device)
indexes = l1_max_indexes_with_sparsity(fc1.weight.data, 0)
new_fc1.weight.data.copy_(fc1.weight.data[indexes])
if fc1.bias is not None:
new_fc1.bias.data.copy_(fc1.bias.data[indexes])
set_module(fm_vit, f'model.backbone.body.layers.{layer_i}.blocks.{block_i}.mlp.fc1', new_fc1)
fc2 = block.mlp.fc2
new_fc2 = nn.Linear(int(_f(fc2.in_features) * (1 - sparsity)), fc2.out_features,
fc2.bias is not None, fc2.weight.device)
new_fc2.weight.data.copy_(fc2.weight.data[:, l1_max_indexes_with_sparsity(fc2.weight.data, 1)])
if fc2.bias is not None:
new_fc2.bias.data.copy_(fc2.bias.data)
set_module(fm_vit, f'model.backbone.body.layers.{layer_i}.blocks.{block_i}.mlp.fc2', new_fc2)
for block in fm_vit.model.language_backbone.body.model.encoder.layer:
set_module(block, 'attention.self', BertSelfAttentionPrunable.init_from_exist_self_attn(block.attention.self))
for block_i, block in enumerate(fm_vit.model.language_backbone.body.model.encoder.layer):
for k in ['query', 'key', 'value']:
qkv = get_module(block, f'attention.self.{k}')
new_qkv = nn.Linear(qkv.in_features, _f(qkv.out_features),
qkv.bias is not None, qkv.weight.device)
indexes = l1_max_indexes(qkv.weight.data, 0)
new_qkv.weight.data.copy_(qkv.weight.data[indexes])
if qkv.bias is not None:
new_qkv.bias.data.copy_(qkv.bias.data[indexes])
set_module(block, f'attention.self.{k}', new_qkv)
proj = get_module(block, f'attention.output.dense')
new_proj = nn.Linear(_f(proj.in_features), proj.out_features,
proj.bias is not None, proj.weight.device)
new_proj.weight.data.copy_(proj.weight.data[:, l1_max_indexes(proj.weight.data, 1)])
if proj.bias is not None:
new_proj.bias.data.copy_(proj.bias.data)
set_module(block, f'attention.output.dense', new_proj)
fc1 = get_module(block, f'intermediate.dense')
new_fc1 = nn.Linear(fc1.in_features, int(_f(fc1.out_features) * (1 - sparsity)),
fc1.bias is not None, fc1.weight.device)
indexes = l1_max_indexes_with_sparsity(fc1.weight.data, 0)
new_fc1.weight.data.copy_(fc1.weight.data[indexes])
if fc1.bias is not None:
new_fc1.bias.data.copy_(fc1.bias.data[indexes])
set_module(block, f'intermediate.dense', new_fc1)
fc2 = get_module(block, f'output.dense')
new_fc2 = nn.Linear(int(_f(fc2.in_features) * (1 - sparsity)), fc2.out_features,
fc2.bias is not None, fc2.weight.device)
new_fc2.weight.data.copy_(fc2.weight.data[:, l1_max_indexes_with_sparsity(fc2.weight.data, 1)])
if fc2.bias is not None:
new_fc2.bias.data.copy_(fc2.bias.data)
set_module(block, f'output.dense', new_fc2)
for block_i, block in enumerate(fm_vit.model.rpn.head.dyhead_tower):
if block_i % 3 == 0:
tmp = block.b_attn.attn
tmp.head_dim = int(tmp.head_dim // reducing_width_ratio)
tmp.embed_dim = int(tmp.embed_dim // reducing_width_ratio)
set_module(block, 'b_attn.attn', tmp)
for k in ['v_proj', 'l_proj', 'values_v_proj', 'values_l_proj']:
qkv = get_module(block, f'b_attn.attn.{k}')
new_qkv = nn.Linear(qkv.in_features, _f(qkv.out_features),
qkv.bias is not None, qkv.weight.device)
indexes = l1_max_indexes(qkv.weight.data, 0)
new_qkv.weight.data.copy_(qkv.weight.data[indexes])
if qkv.bias is not None:
new_qkv.bias.data.copy_(qkv.bias.data[indexes])
set_module(block, f'b_attn.attn.{k}', new_qkv)
for k in ['out_v_proj', 'out_l_proj']:
qkv = get_module(block, f'b_attn.attn.{k}')
new_qkv = nn.Linear(_f(qkv.in_features), qkv.out_features,
qkv.bias is not None, qkv.weight.device)
new_qkv.weight.data.copy_(qkv.weight.data[:, l1_max_indexes(qkv.weight.data, 1)])
if qkv.bias is not None:
new_qkv.bias.data.copy_(qkv.bias.data)
set_module(block, f'b_attn.attn.{k}', new_qkv)
elif block_i % 3 == 1:
set_module(block, 'attention.self', BertSelfAttentionPrunable.init_from_exist_self_attn(block.attention.self))
for k in ['query', 'key', 'value']:
qkv = get_module(block, f'attention.self.{k}')
new_qkv = nn.Linear(qkv.in_features, _f(qkv.out_features),
qkv.bias is not None, qkv.weight.device)
indexes = l1_max_indexes(qkv.weight.data, 0)
new_qkv.weight.data.copy_(qkv.weight.data[indexes])
if qkv.bias is not None:
new_qkv.bias.data.copy_(qkv.bias.data[indexes])
set_module(block, f'attention.self.{k}', new_qkv)
proj = get_module(block, f'attention.output.dense')
new_proj = nn.Linear(_f(proj.in_features), proj.out_features,
proj.bias is not None, proj.weight.device)
new_proj.weight.data.copy_(proj.weight.data[:, l1_max_indexes(proj.weight.data, 1)])
if proj.bias is not None:
new_proj.bias.data.copy_(proj.bias.data)
set_module(block, f'attention.output.dense', new_proj)
fc1 = get_module(block, f'intermediate.dense')
new_fc1 = nn.Linear(fc1.in_features, int(_f(fc1.out_features) * (1 - sparsity)),
fc1.bias is not None, fc1.weight.device)
indexes = l1_max_indexes_with_sparsity(fc1.weight.data, 0)
new_fc1.weight.data.copy_(fc1.weight.data[indexes])
if fc1.bias is not None:
new_fc1.bias.data.copy_(fc1.bias.data[indexes])
set_module(block, f'intermediate.dense', new_fc1)
fc2 = get_module(block, f'output.dense')
new_fc2 = nn.Linear(int(_f(fc2.in_features) * (1 - sparsity)), fc2.out_features,
fc2.bias is not None, fc2.weight.device)
new_fc2.weight.data.copy_(fc2.weight.data[:, l1_max_indexes_with_sparsity(fc2.weight.data, 1)])
if fc2.bias is not None:
new_fc2.bias.data.copy_(fc2.bias.data)
set_module(block, f'output.dense', new_fc2)
# reduce dim_embedding
# if name.endswith('patch_embed.proj'):
# continue
# new_layer = nn.Conv2d(module.in_channels, _f(module.out_channels), module.kernel_size, module.stride,
# module.padding, module.dilation, module.groups, module.bias is not None, module.padding_mode,
# module.weight.device)
# rand_indexes = l1_max_indexes(module.weight.data)
# new_layer.weight.data.copy_(module.weight.data[rand_indexes])
# if new_layer.bias is not None:
# new_layer.bias.data.copy_(module.bias.data[rand_indexes])
# fm_vit.cls_token.data = fm_vit.cls_token.data[:, :, rand_indexes]
# fm_vit.pos_embed.data = fm_vit.pos_embed.data[:, :, rand_indexes]
# elif isinstance(module, nn.Linear):
# if 'head' in name:
# continue
# new_layer = nn.Linear(_f(module.in_features), module.out_features,
# module.bias is not None, module.weight.device)
# new_layer.weight.data.copy_(module.weight.data[:, l1_max_indexes(module.weight.data, 1)])
# if new_layer.bias is not None:
# new_layer.bias.data.copy_(module.bias.data)
# else:
# first_attn = False
# if first_attn:
# first_attn = False
# new_layer = nn.Linear(module.in_features, _f(module.out_features),
# module.bias is not None, module.weight.device)
# rand_indexes = l1_max_indexes(module.weight.data)
# new_layer.weight.data.copy_(module.weight.data[rand_indexes])
# if new_layer.bias is not None:
# new_layer.bias.data.copy_(module.bias.data[rand_indexes])
# else:
# new_layer = nn.Linear(_f(module.in_features), _f(module.out_features),
# module.bias is not None, module.weight.device)
# rand_indexes = l1_max_indexes(module.weight.data)
# new_layer.weight.data.copy_(module.weight.data[rand_indexes][:, l1_max_indexes(module.weight.data, 1)])
# if new_layer.bias is not None:
# new_layer.bias.data.copy_(module.bias.data[rand_indexes])
# elif isinstance(module, nn.LayerNorm) and ('block' in name or name == 'norm' or name == 'norm.0'):
# new_layer = nn.LayerNorm(_f(module.normalized_shape[0]), eps=module.eps, device=module.weight.device)
# rand_indexes = l1_max_indexes(module.weight.data)
# new_layer.weight.data.copy_(module.weight.data[rand_indexes])
# new_layer.bias.data.copy_(module.bias.data[rand_indexes])
# else:
# continue
# original_layer_str = str(module)
# set_module(fm_vit, name, new_layer)
# logger.debug(f'set_module, {name}, {new_layer}')
# logger.debug(f'slim {name} from {original_layer_str} to {new_layer}')
return fm_vit
class FMLoRA_GLIP_Util(FMLoRA_Util):
def train_only_lora_and_conv(self, fm: nn.Module):
res = []
for n, m in fm.named_modules():
if isinstance(m, LoRA) or isinstance(m, nn.Conv2d):
for p in m.parameters():
p.requires_grad = True
res += [p]
else:
for p in m.parameters():
p.requires_grad = False
return res
@torch.no_grad()
def add_lora_ab_to_fm(self, fm: nn.Module, ab_r: int, samples):
fm.eval()
# samples = {'images' : samples[0], 'targets' : samples[1]}
for k, v in samples.items():
if isinstance(v, torch.Tensor) or isinstance(v, BoxList):
samples[k] = v.to(get_model_device(fm))
print(k)
_, o1_token_logits, o1_dot_product_logits = fm(**samples)
mo_list = {k:v for k, v in fm.named_modules()}
for name, module in fm.named_modules():
if '.proj' in name or 'out' in name:
continue
if name.endswith(('k_proj', 'q_proj', 'v_proj', 'qkv', 'attn.proj', 'l_proj', 'query', 'key', 'value')):
set_module(fm, name, ToQKV_WrappedWithLoRA(module, ab_r))
_, o2_token_logits, o2_dot_product_logits = fm(**samples)
output_diff = 0.
for o1, o2 in list(zip(o1_dot_product_logits, o2_dot_product_logits)):
output_diff += ((o1 - o2) ** 2).sum()
if o1_token_logits is not None:
output_diff += ((o1_token_logits - o2_token_logits) ** 2).sum()
assert output_diff < 1e-5
return fm
@torch.no_grad()
def absorb_lora_and_recover_net_structure(self, fm: nn.Module, samples: dict):
fm.eval()
# print('absorb lora before')
for k, v in samples.items():
if isinstance(v, torch.Tensor):
samples[k] = v.to(get_model_device(fm))
print(k)
_, o1_token_logits, o1_dot_product_logits = fm(**samples)
for name, module in fm.named_modules():
if not isinstance(module, ToQKV_WrappedWithLoRA):
continue
fc = module.fc
ab = module.ab
fc.weight.add_(ab[1].weight @ ab[0].weight)
set_module(fm, name, fc)
# print('absorb lora after')
_, o2_token_logits, o2_dot_product_logits = fm(**samples)
output_diff = 0.
for o1, o2 in list(zip(o1_dot_product_logits, o2_dot_product_logits)):
output_diff += ((o1 - o2) ** 2).sum()
if o1_token_logits is not None:
output_diff += ((o1_token_logits - o2_token_logits) ** 2).sum()
assert output_diff < 1e-3, output_diff
return fm
class ElasticDNN_OfflineMMDetFMModel(ElasticDNN_OfflineFMModel):
def __init__(self, name: str, models_dict_path: str, device: str, num_classes=10, collate_fn=None):
super().__init__(name, models_dict_path, device)
self.num_classes = num_classes
self.collate_fn = collate_fn
def get_accuracy(self, test_loader, *args, **kwargs):
# print('DeeplabV3: start test acc')
_d = test_loader.dataset
from data import build_dataloader
if _d.__class__.__name__ == 'MergedDataset':
# print('\neval on merged datasets')
datasets = _d.datasets
if self.collate_fn is None:
test_loaders = [build_dataloader(d, test_loader.batch_size, test_loader.num_workers, False, None, collate_fn=None) for d in datasets]
else:
test_loaders = [build_dataloader(d, test_loader.batch_size, test_loader.num_workers, False, None, collate_fn=self.collate_fn) for d in datasets]
accs = [self.get_accuracy(loader) for loader in test_loaders]
# print(accs)
return sum(accs) / len(accs)
# print('dataset len', len(test_loader.dataset))
model = self.models_dict['main']
device = self.device
model.eval()
# print('# classes', model.num_classes)
model = model.to(device)
from evaluator import COCOEvaluator, MMCOCODecoder
from utils.common.others import HiddenPrints
with torch.no_grad():
with HiddenPrints():
evaluator = COCOEvaluator(
dataloader=test_loader,
img_size=(416, 416),
confthre=0.01,
nmsthre=0.65,
num_classes=len(test_loader.dataset.classes),
testdev=True
)
res = evaluator.evaluate(model, False, False, decoder=MMCOCODecoder)
map50 = res[1]
# print('eval info', res[-1])
return map50
def infer(self, x, *args, **kwargs):
if len(args) > 0:
print(args, len(args))
return self.models_dict['main'](x, *args) # forward(x, label)
return self.models_dict['main'](**x)
class ElasticDNN_OfflineMMDetMDModel(ElasticDNN_OfflineMDModel):
def __init__(self, name: str, models_dict_path: str, device: str, num_classes=10, collate_fn=None):
super().__init__(name, models_dict_path, device)
self.num_classes = num_classes
self.collate_fn = collate_fn
def get_accuracy(self, test_loader, *args, **kwargs):
# print('DeeplabV3: start test acc')
_d = test_loader.dataset
from data import build_dataloader
if _d.__class__.__name__ == 'MergedDataset':
# print('\neval on merged datasets')
datasets = _d.datasets
if self.collate_fn is None:
test_loaders = [build_dataloader(d, test_loader.batch_size, test_loader.num_workers, False, None, collate_fn=None) for d in datasets]
else:
test_loaders = [build_dataloader(d, test_loader.batch_size, test_loader.num_workers, False, None, collate_fn=self.collate_fn) for d in datasets]
accs = [self.get_accuracy(loader) for loader in test_loaders]
# print(accs)
return sum(accs) / len(accs)
# print('dataset len', len(test_loader.dataset))
model = self.models_dict['main']
device = self.device
model.eval()
# print('# classes', model.num_classes)
model = model.to(device)
from evaluator import COCOEvaluator, MMCOCODecoder
from utils.common.others import HiddenPrints
with torch.no_grad():
with HiddenPrints():
evaluator = COCOEvaluator(
dataloader=test_loader,
img_size=(416, 416),
confthre=0.01,
nmsthre=0.65,
num_classes=len(test_loader.dataset.classes),
testdev=True
)
res = evaluator.evaluate(model, False, False, decoder=MMCOCODecoder)
map50 = res[1]
# print('eval info', res[-1])
return map50
def infer(self, x, *args, **kwargs):
if len(args) > 0:
return self.models_dict['main'](x, *args) # forward(x, label)
return self.models_dict['main'](**x)
class SqueezeLast(nn.Module):
def __init__(self):
super(SqueezeLast, self).__init__()
def forward(self, x):
return x.squeeze(-1)
class ProjConv_WrappedWithFBS(Layer_WrappedWithFBS):
def __init__(self, raw_conv2d: nn.Conv2d, r):
super(ProjConv_WrappedWithFBS, self).__init__()
self.fbs = nn.Sequential(
Abs(),
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(raw_conv2d.in_channels, raw_conv2d.out_channels // r),
nn.ReLU(),
nn.Linear(raw_conv2d.out_channels // r, raw_conv2d.out_channels),
nn.ReLU()
)
self.raw_conv2d = raw_conv2d
# self.raw_bn = raw_bn # remember clear the original BNs in the network
nn.init.constant_(self.fbs[5].bias, 1.)
nn.init.kaiming_normal_(self.fbs[5].weight)
def forward(self, x):
raw_x = self.raw_conv2d(x)
if self.use_cached_channel_attention and self.cached_channel_attention is not None:
channel_attention = self.cached_channel_attention
else:
self.cached_raw_channel_attention = self.fbs(x)
self.cached_channel_attention = self.k_takes_all(self.cached_raw_channel_attention)
channel_attention = self.cached_channel_attention
return raw_x * channel_attention.unsqueeze(2).unsqueeze(3)
class Linear_WrappedWithFBS(Layer_WrappedWithFBS):
def __init__(self, linear: nn.Linear, r):
super(Linear_WrappedWithFBS, self).__init__()
self.linear = linear
# for conv: (B, C_in, H, W) -> (B, C_in) -> (B, C_out)
# for mlp in ViT: (B, #patches, D: dim of patches embedding) -> (B, D) -> (B, C_out)
self.fbs = nn.Sequential(
Rearrange('b n d -> b d n'),
Abs(),
nn.AdaptiveAvgPool1d(1),
SqueezeLast(),
nn.Linear(linear.in_features, max(linear.out_features // r, 36)),
nn.ReLU(),
nn.Linear(max(linear.out_features // r, 36), linear.out_features),
nn.ReLU()
)
nn.init.constant_(self.fbs[6].bias, 1.)
nn.init.kaiming_normal_(self.fbs[6].weight)
def forward(self, x):
if self.use_cached_channel_attention and self.cached_channel_attention is not None:
channel_attention = self.cached_channel_attention
else:
self.cached_raw_channel_attention = self.fbs(x)
self.cached_channel_attention = self.k_takes_all(self.cached_raw_channel_attention)
channel_attention = self.cached_channel_attention
raw_res = self.linear(x)
return channel_attention.unsqueeze(1) * raw_res
class ToQKV_WrappedWithFBS(Layer_WrappedWithFBS):
"""
This regards to_q/to_k/to_v as a whole (in fact it consists of multiple heads) and prunes it.
It seems different channels of different heads are pruned according to the input.
This is different from "removing some head" or "removing the same channels in each head".
"""
def __init__(self, to_qkv: nn.Linear, r):
super(ToQKV_WrappedWithFBS, self).__init__()
# self.to_qkv = to_qkv
self.to_qk = nn.Linear(to_qkv.in_features, to_qkv.out_features // 3 * 2, bias=to_qkv.bias is not None)
self.to_v = nn.Linear(to_qkv.in_features, to_qkv.out_features // 3, bias=to_qkv.bias is not None)
self.to_qk.weight.data.copy_(to_qkv.weight.data[0: to_qkv.out_features // 3 * 2])
if to_qkv.bias is not None:
self.to_qk.bias.data.copy_(to_qkv.bias.data[0: to_qkv.out_features // 3 * 2])
self.to_v.weight.data.copy_(to_qkv.weight.data[to_qkv.out_features // 3 * 2: ])
if to_qkv.bias is not None:
self.to_v.bias.data.copy_(to_qkv.bias.data[to_qkv.out_features // 3 * 2: ])
self.fbs = nn.Sequential(
Rearrange('b n d -> b d n'),
Abs(),
nn.AdaptiveAvgPool1d(1),
SqueezeLast(),
nn.Linear(to_qkv.in_features, to_qkv.out_features // 3 // r),
nn.ReLU(),
# nn.Linear(to_qkv.out_features // 3 // r, to_qkv.out_features // 3),
nn.Linear(to_qkv.out_features // 3 // r, self.to_v.out_features),
nn.ReLU()
)
nn.init.constant_(self.fbs[6].bias, 1.)
nn.init.kaiming_normal_(self.fbs[6].weight)
def forward(self, x):
if self.use_cached_channel_attention and self.cached_channel_attention is not None:
channel_attention = self.cached_channel_attention
else:
self.cached_raw_channel_attention = self.fbs(x)
# print()
# for attn in self.cached_raw_channel_attention.chunk(3, dim=1)[0: 1]:
# print(self.cached_raw_channel_attention.size(), attn.size())
# print(self.k_takes_all.k)
# print(attn[0].nonzero(as_tuple=True)[0].size(), attn[0])
self.cached_channel_attention = self.k_takes_all(self.cached_raw_channel_attention)
# for attn in self.cached_channel_attention.chunk(3, dim=1)[0: 1]:
# print(self.cached_channel_attention.size(), attn.size())
# print(self.k_takes_all.k)
# print(attn[0].nonzero(as_tuple=True)[0].size(), attn[0])
# print()
channel_attention = self.cached_channel_attention
qk = self.to_qk(x)
v = channel_attention.unsqueeze(1) * self.to_v(x)
return torch.cat([qk, v], dim=-1)
# qkv = raw_res.chunk(3, dim = -1)
# raw_v = qkv[2]
# print('raw_k, raw_v', qkv[0].sum((0, 1))[0: 10], qkv[0].sum((0, 1)).nonzero(as_tuple=True)[0].size(),
# qkv[1].sum((0, 1))[0: 10], qkv[1].sum((0, 1)).nonzero(as_tuple=True)[0].size(),)
# print('raw_v', raw_v.size(), raw_v.sum((0, 1))[0: 10], raw_v.sum((0, 1)).nonzero(as_tuple=True)[0].size())
# qkv_attn = channel_attention.chunk(3, dim=-1)
# print('attn', [attn[0][0: 10] for attn in qkv_attn])
# print(channel_attention.unsqueeze(1).size(), raw_res.size())
# print('fbs', channel_attention.size(), raw_res.size())
# return channel_attention.unsqueeze(1) * raw_res
class StaticFBS(nn.Module):
def __init__(self, static_channel_attention):
super(StaticFBS, self).__init__()
assert static_channel_attention.dim() == 2 and static_channel_attention.size(0) == 1
self.static_channel_attention = nn.Parameter(static_channel_attention, requires_grad=False) # (1, dim)
def forward(self, x):
# print('staticfbs', x, self.static_channel_attention.unsqueeze(1))
return x * self.static_channel_attention.unsqueeze(1)
class ElasticGLIPUtil(ElasticDNNUtil):
def convert_raw_dnn_to_master_dnn(self, raw_dnn: nn.Module, r: float, ignore_layers=[]):
assert len(ignore_layers) == 0, 'not supported yet'
raw_vit = deepcopy(raw_dnn)
for name, module in raw_vit.named_modules():
# if name.endswith('patch_embed'):
# set_module(module, 'proj', ProjConv_WrappedWithFBS(module.proj, r))
# if name.endswith('attn') and not name.endswith('b_attn.attn') and not name.endswith('b_attn'):
# set_module(module, 'qkv', ToQKV_WrappedWithFBS(module.qkv, r))
if name.endswith('intermediate'):
set_module(module, 'dense', Linear_WrappedWithFBS(module.dense, r))
elif name.endswith('mlp'):
set_module(module, 'fc1', Linear_WrappedWithFBS(module.fc1, r))
return raw_vit
def set_master_dnn_sparsity(self, master_dnn: nn.Module, sparsity: float):
# for name, module in master_dnn.named_modules():
# if not name.endswith('attn'):
# continue
# q_features = module.qkv.to_qk.out_features // 2
# if (q_features - int(q_features * sparsity)) % module.num_heads != 0:
# # tune sparsity to ensure #unpruned channel % num_heads == 0
# # so that the pruning seems to reduce the dim_head of each head
# tuned_sparsity = 1. - int((q_features - int(q_features * sparsity)) / module.num_heads) * module.num_heads / q_features
# logger.debug(f'tune sparsity from {sparsity:.2f} to {tuned_sparsity}')
# sparsity = tuned_sparsity
# break
return super().set_master_dnn_sparsity(master_dnn, sparsity)
def select_most_rep_sample(self, master_dnn: nn.Module, samples: torch.Tensor):
# print(samples)
sample={}
sample['images'] = [samples['images'][0]]
sample['targets'] = [samples['targets'][0]]
# return samples[0].unsqueeze(0)
# res = {k: v[0: 1] for k, v in samples.items()}
return sample
def extract_surrogate_dnn_via_samples(self, master_dnn: nn.Module, samples: torch.Tensor, return_detail=False):#产生小模型的步骤
sample = self.select_most_rep_sample(master_dnn, samples)
# assert sample.dim() == 4 and sample.size(0) == 1
# print('before')
master_dnn.eval()
self.clear_cached_channel_attention_in_master_dnn(master_dnn)
with torch.no_grad():
_, o1_token_logits, o1_dot_product_logits = master_dnn(**sample)
# print('after')
boosted_vit = deepcopy(master_dnn)
def get_unpruned_indexes_from_channel_attn(channel_attn: torch.Tensor, k):
assert channel_attn.size(0) == 1, 'use A representative sample to generate channel attentions'
# print('attn_in_unpruned', channel_attn[0][0: 10])
res = channel_attn[0].nonzero(as_tuple=True)[0] # should be one-dim
# res = channel_attn[0].argsort(descending=True)[0: -int(channel_attn.size(1) * k)].sort()[0]
# g = channel_attn
# k = g.size(1) - int(g.size(1) * k)
# res = g.topk(k, 1)[1][0].sort()[0]
return res
unpruned_indexes_of_layers = {}
# for attn, ff in boosted_vit.transformer.layers:
# for block_i, block in enumerate(boosted_vit.blocks):
for layer_i, layer in enumerate(boosted_vit.model.backbone.body.layers):
for block_i, block in enumerate(layer.blocks):
# attn = block.attn
# ff = block.mlp
ff_0 = get_module(block, f'mlp.fc1')
# ff_0_unpruned_indexes = get_unpruned_indexes_from_channel_attn(ff_0.cached_channel_attention, k)
ff_0_pruned_indexes = ff_0.k_takes_all.cached_i[0].sort()[0]
ff_0_unpruned_indexes = torch.LongTensor([ii for ii in range(ff_0.cached_channel_attention.size(1)) if ii not in ff_0_pruned_indexes])
new_ff_0 = nn.Linear(ff_0.linear.in_features, ff_0_unpruned_indexes.size(0), ff_0.linear.bias is not None)
new_ff_0.weight.data.copy_(ff_0.linear.weight.data[ff_0_unpruned_indexes])
if ff_0.linear.bias is not None:
new_ff_0.bias.data.copy_(ff_0.linear.bias.data[ff_0_unpruned_indexes])
set_module(block, 'mlp.fc1', nn.Sequential(new_ff_0, StaticFBS(ff_0.cached_channel_attention[:, ff_0_unpruned_indexes])))
ff_1 = get_module(block, f'mlp.fc2')
new_ff_1 = nn.Linear(ff_0_unpruned_indexes.size(0), ff_1.out_features, ff_1.bias is not None)
new_ff_1.weight.data.copy_(ff_1.weight.data[:, ff_0_unpruned_indexes])
if ff_1.bias is not None:
new_ff_1.bias.data.copy_(ff_1.bias.data)
set_module(block, 'mlp.fc2', new_ff_1)
unpruned_indexes_of_layers[f'model.backbone.body.layers.{layer_i}.blocks.{block_i}.mlp.fc1.0.weight'] = ff_0_unpruned_indexes
# for block_i,block in enumerate(boosted_vit.vision_model.encoder.layers):
# attn = block.self_attn
# ff = block.mlp
# ff_0 = ff.fc1
# # ff_0_unpruned_indexes = get_unpruned_indexes_from_channel_attn(ff_0.cached_channel_attention, k)
# ff_0_pruned_indexes = ff_0.k_takes_all.cached_i[0].sort()[0]
# ff_0_unpruned_indexes = torch.LongTensor([ii for ii in range(ff_0.cached_channel_attention.size(1)) if ii not in ff_0_pruned_indexes])
# new_ff_0 = nn.Linear(ff_0.linear.in_features, ff_0_unpruned_indexes.size(0), ff_0.linear.bias is not None)
# new_ff_0.weight.data.copy_(ff_0.linear.weight.data[ff_0_unpruned_indexes])
# if ff_0.linear.bias is not None:
# new_ff_0.bias.data.copy_(ff_0.linear.bias.data[ff_0_unpruned_indexes])
# set_module(ff, 'fc1', nn.Sequential(new_ff_0, StaticFBS(ff_0.cached_channel_attention[:, ff_0_unpruned_indexes])))
# ff_1 = ff.fc2
# new_ff_1 = nn.Linear(ff_0_unpruned_indexes.size(0), ff_1.out_features, ff_1.bias is not None)
# new_ff_1.weight.data.copy_(ff_1.weight.data[:, ff_0_unpruned_indexes])
# if ff_1.bias is not None:
# new_ff_1.bias.data.copy_(ff_1.bias.data)
# set_module(ff, 'fc2', new_ff_1)
# unpruned_indexes_of_layers[f'vision_model.encoder.layers.{block_i}.mlp.fc1.0.weight'] = ff_0_unpruned_indexes
# for block_i, block in enumerate(boosted_vit.text_decoder.bert.encoder.layer):
# # attn = block.attn
# # ff = block.mlp
# ff_0 = get_module(block, f'intermediate.dense')
# # ff_0_unpruned_indexes = get_unpruned_indexes_from_channel_attn(ff_0.cached_channel_attention, k)
# ff_0_pruned_indexes = ff_0.k_takes_all.cached_i[0].sort()[0]
# ff_0_unpruned_indexes = torch.LongTensor([ii for ii in range(ff_0.cached_channel_attention.size(1)) if ii not in ff_0_pruned_indexes])
# new_ff_0 = nn.Linear(ff_0.linear.in_features, ff_0_unpruned_indexes.size(0), ff_0.linear.bias is not None)
# new_ff_0.weight.data.copy_(ff_0.linear.weight.data[ff_0_unpruned_indexes])
# if ff_0.linear.bias is not None:
# new_ff_0.bias.data.copy_(ff_0.linear.bias.data[ff_0_unpruned_indexes])
# set_module(block, 'intermediate.dense', nn.Sequential(new_ff_0, StaticFBS(ff_0.cached_channel_attention[:, ff_0_unpruned_indexes])))
# ff_1 = get_module(block, f'output.dense')
# new_ff_1 = nn.Linear(ff_0_unpruned_indexes.size(0), ff_1.out_features, ff_1.bias is not None)
# new_ff_1.weight.data.copy_(ff_1.weight.data[:, ff_0_unpruned_indexes])
# if ff_1.bias is not None:
# new_ff_1.bias.data.copy_(ff_1.bias.data)
# set_module(block, 'output.dense', new_ff_1)
# unpruned_indexes_of_layers[f'text_decoder.bert.encoder.layer.{block_i}.intermediate.dense.0.weight'] = ff_0_unpruned_indexes
surrogate_dnn = boosted_vit
surrogate_dnn.eval()
surrogate_dnn = surrogate_dnn.to(get_model_device(master_dnn))
# logger.debug(surrogate_dnn)
with torch.no_grad():
_, o2_token_logits, o2_dot_product_logits = surrogate_dnn(**sample)
output_diff = 0.
for o1, o2 in list(zip(o1_dot_product_logits, o2_dot_product_logits)):
output_diff += ((o1 - o2) ** 2).sum()
if o1_token_logits is not None:
output_diff += ((o1_token_logits - o2_token_logits) ** 2).sum()
# assert output_diff < 1e-4, output_diff
logger.info(f'output diff of master and surrogate DNN: {output_diff}')
# logger.debug(f'example output of master/surrogate: {master_dnn_output.sum(0)[0: 10]}, {surrogate_dnn_output.sum(0)[0: 10]}')
# logger.info(f'\nonly prune mlp!!!!\n')
# logger.info(f'\nonly prune mlp!!!!\n')
if return_detail:
return boosted_vit, unpruned_indexes_of_layers
return boosted_vit
def extract_surrogate_dnn_via_samples_with_perf_test(self, master_dnn: nn.Module, samples, return_detail=False):
master_dnn_size = get_model_size(master_dnn, True)
sample = {}
sample['images'] = [samples['images'][0]]
sample['targets'] = [samples['targets'][0]]
master_dnn_latency = self._get_model_latency(master_dnn, sample, 50,
get_model_device(master_dnn), 50, False)
res = self.extract_surrogate_dnn_via_samples(master_dnn, samples, return_detail)
if not return_detail:
surrogate_dnn = res
else:
surrogate_dnn, unpruned_indexes_of_layers = res
surrogate_dnn_size = get_model_size(surrogate_dnn, True)
surrogate_dnn_latency = self._get_model_latency(master_dnn, samples, 50,
get_model_device(master_dnn), 50, False)
logger.info(f'master DNN ({master_dnn_size:.3f}MB, {master_dnn_latency:.4f}s/sample) -> '
f'surrogate DNN ({surrogate_dnn_size:.3f}MB, {surrogate_dnn_latency:.4f}s/sample)\n'
f'(model size: ↓ {(master_dnn_size / surrogate_dnn_size):.2f}x, '
f'latency: ↓ {(master_dnn_latency / surrogate_dnn_latency):.2f}x)')
return res
def _get_model_latency(self, model: torch.nn.Module, sample, sample_num: int,
device: str, warmup_sample_num: int, return_detail=False):
import time
model = model.to(device)
model.eval()
sample['images'] = [sample['images'][0]]
sample['targets'] = [sample['targets'][0]]
# warm up
with torch.no_grad():
for _ in range(warmup_sample_num):
model(**sample)
infer_time_list = []
if device == 'cuda' or 'cuda' in str(device):
with torch.no_grad():
for _ in range(sample_num):
s, e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
s.record()
model(**sample)
e.record()
torch.cuda.synchronize()
cur_model_infer_time = s.elapsed_time(e) / 1000.
infer_time_list += [cur_model_infer_time]
else:
with torch.no_grad():
for _ in range(sample_num):
start = time.time()
model(**sample)
cur_model_infer_time = time.time() - start
infer_time_list += [cur_model_infer_time]
avg_infer_time = sum(infer_time_list) / sample_num
if return_detail:
return avg_infer_time, infer_time_list
return avg_infer_time
# from typing import Any, Dict
# from schema import Schema, Or
# import schema
# from data import Scenario, MergedDataset
# from methods.base.alg import BaseAlg
# from data import build_dataloader
# from ..model import ElasticDNN_OfflineFMModel, ElasticDNN_OfflineMDModel
# from ...model.base import ElasticDNNUtil
# import torch.optim
# import tqdm
# import torch.nn.functional as F
# from torch import nn
# from utils.dl.common.env import create_tbwriter
# import os
# import random
# import numpy as np
# from copy import deepcopy
# from utils.dl.common.model import LayerActivation2, get_module
# from utils.common.log import logger
# class ElasticDNN_Det_MDPretrainingWoFBSAlg(BaseAlg):
# """
# TODO: fine-tuned FM -> init MD -> trained MD -> construct indexes (only between similar weights) and fine-tune
# """
# def get_required_models_schema(self) -> Schema:
# return Schema({
# 'fm': ElasticDNN_OfflineFMModel,
# 'md': ElasticDNN_OfflineMDModel
# })
# def get_required_hyp_schema(self) -> Schema:
# return Schema({
# 'launch_tbboard': bool,
# 'samples_size': any,
# 'generate_md_width_ratio': int,
# 'train_batch_size': int,
# 'val_batch_size': int,
# 'num_workers': int,
# 'optimizer': str,
# 'optimizer_args': dict,
# 'scheduler': str,
# 'scheduler_args': dict,
# 'num_iters': int,
# 'val_freq': int,
# 'distill_loss_weight': float
# })
# def run(self, scenario: Scenario, hyps: Dict) -> Dict[str, Any]:
# super().run(scenario, hyps)
# assert isinstance(self.models['md'], ElasticDNN_OfflineMDModel) # for auto completion
# assert isinstance(self.models['fm'], ElasticDNN_OfflineFMModel) # for auto completion
# # 1. add FBS
# device = self.models['md'].device
# if self.models['md'].models_dict['main'] == -1:
# logger.info(f'init master DNN by reducing width of an adapted foundation model (already tuned by LoRA)...')
# before_fm_model = deepcopy(self.models['fm'].models_dict['main'])
# lora_util = self.models['fm'].get_lora_util()
# sample = hyps['samples_size']
# if isinstance(sample, (tuple, list)) and isinstance(sample[0], int):
# sample = torch.rand(hyps['samples_size']).to(device)
# lora_absorbed_fm_model = lora_util.absorb_lora_and_recover_net_structure(self.models['fm'].models_dict['main'],
# sample)
# self.models['fm'].models_dict['main'] = lora_absorbed_fm_model
# master_dnn = self.models['fm'].generate_md_by_reducing_width(hyps['generate_md_width_ratio'],
# sample)
# self.models['fm'].models_dict['main'] = before_fm_model
# self.models['md'].models_dict['main'] = master_dnn
# self.models['md'].to(device)
# # 2. train (knowledge distillation, index relationship)
# offline_datasets = scenario.get_offline_datasets()
# train_dataset = MergedDataset([d['train'] for d in offline_datasets.values()])
# val_dataset = MergedDataset([d['val'] for d in offline_datasets.values()])
# train_loader = iter(build_dataloader(train_dataset, hyps['train_batch_size'], hyps['num_workers'],
# True, None))
# val_loader = build_dataloader(val_dataset, hyps['val_batch_size'], hyps['num_workers'],
# False, False)
# # logger.info(f'FM acc: {self.models["fm"].get_accuracy(val_loader):.4f}')
# # 2.1 train whole master DNN (knowledge distillation)
# for p in master_dnn.parameters():
# p.requires_grad = True
# self.models['md'].to_train_mode()
# optimizer = torch.optim.__dict__[hyps['optimizer']]([
# {'params': self.models['md'].models_dict['main'].parameters(), **hyps['optimizer_args']}
# ])
# scheduler = torch.optim.lr_scheduler.__dict__[hyps['scheduler']](optimizer, **hyps['scheduler_args'])
# tb_writer = create_tbwriter(os.path.join(self.res_save_dir, 'tb_log'), launch_tbboard=hyps['launch_tbboard'])
# pbar = tqdm.tqdm(range(hyps['num_iters']), dynamic_ncols=True)
# best_avg_val_acc = 0.
# md_output_hook = None
# for iter_index in pbar:
# self.models['md'].to_train_mode()
# self.models['fm'].to_eval_mode()
# # rand_sparsity = random.random() * (hyps['max_sparsity'] - hyps['min_sparsity']) + hyps['min_sparsity']
# # elastic_dnn_util.set_master_dnn_sparsity(self.models['md'].models_dict['main'], rand_sparsity)
# if md_output_hook is None:
# md_output_hook = self.models['md'].get_feature_hook()
# fm_output_hook = self.models['fm'].get_feature_hook()
# x, y = next(train_loader)
# if isinstance(x, dict):
# for k, v in x.items():
# if isinstance(v, torch.Tensor):
# x[k] = v.to(device)
# y = y.to(device)
# else:
# x, y = x.to(device), y.to(device)
# with torch.no_grad():
# fm_output = self.models['fm'].infer(x)
# task_loss = self.models['md'].forward_to_get_task_loss(x, y)
# md_output = md_output_hook.output
# fm_output = fm_output_hook.output
# distill_loss = hyps['distill_loss_weight'] * self.models['md'].get_distill_loss(md_output, fm_output)
# total_loss = task_loss + distill_loss
# optimizer.zero_grad()
# total_loss.backward()
# optimizer.step()
# scheduler.step()
# if (iter_index + 1) % hyps['val_freq'] == 0:
# # elastic_dnn_util.clear_cached_channel_attention_in_master_dnn(self.models['md'].models_dict['main'])
# md_output_hook.remove()
# md_output_hook = None
# fm_output_hook.remove()
# fm_output_hook = None
# cur_md = self.models['md'].models_dict['main']
# md_for_test = deepcopy(self.models['md'].models_dict['main'])
# val_acc = 0.
# self.models['md'].models_dict['main'] = md_for_test
# self.models['md'].to_eval_mode()
# val_acc = self.models['md'].get_accuracy(val_loader)
# self.models['md'].models_dict['main'] = cur_md
# self.models['md'].save_model(os.path.join(self.res_save_dir, 'models/md_last.pt'))
# self.models['fm'].save_model(os.path.join(self.res_save_dir, 'models/fm_last.pt'))
# if val_acc > best_avg_val_acc:
# best_avg_val_acc = val_acc
# self.models['md'].save_model(os.path.join(self.res_save_dir, 'models/md_best.pt'))
# self.models['fm'].save_model(os.path.join(self.res_save_dir, 'models/fm_best.pt'))
# tb_writer.add_scalars(f'losses', dict(task=task_loss, distill=distill_loss, total=total_loss), iter_index)
# pbar.set_description(f'loss: {total_loss:.6f}')
# if (iter_index + 1) >= hyps['val_freq']:
# tb_writer.add_scalar(f'accs/val_acc', val_acc, iter_index)
# pbar.set_description(f'loss: {total_loss:.6f}, val_acc: {val_acc:.4f}')
# if __name__ == '__main__':
# model = glip_model('new_impl/cv/glip/object_detection/pretrained_model/glip_Swin_T_O365_GoldG.yaml','new_impl/cv/glip/object_detection/pretrained_model/glip_tiny_model_o365_goldg_cc_sbu.pth').cuda()
# model.eval()
# # print(model)
# # exit()
# # config = CLIPConfig.from_pretrained('openai/clip-vit-base-patch16')
# # print(config)
# # # test 1: single image inference
# from PIL import Image, ImageDraw
# import requests
# import numpy as np
# ori_image = Image.open('new_impl/cv/glip/object_detection/9472793441_b7822c00de_z.jpg').convert("RGB")
# image = [np.asarray(ori_image)[:, :, [2, 1, 0]]]
# text = 'sofa . remote . dog . person . car . sky . plane .'
# target = torch.Tensor()
# o = model(image, text)
# o = model._post_process(o[0])
# print(o)
# bboxes = o.bbox.cpu()
# a = ImageDraw.ImageDraw(ori_image)
# for box in bboxes:
# box = box.int()
# a.rectangle(((box[0], box[1]), (box[2], box[3])), fill=None, outline='red', width=2)
# ori_image.save('test.jpg')
# # print(o.logits_per_image.softmax(dim=1))
# # o = model(image, torch.load('dnns/clip/test_input_embed.pth'), False)
# # # print(o)
# # print(o.logits_per_image.softmax(dim=1))
# # exit()
# # test 2: normal training using clip loss (batch)
# from data import get_dataset, build_dataloader
# from torchvision.transforms import Compose, ToTensor, Resize
# dataset = get_dataset('Caltech256', '/data/zql/datasets/Caltech-256/data/caltech256/256_ObjectCategories/', 'train', transform=Compose([
# Resize((32, 32)), ToTensor()
# ]))
# dataloader = build_dataloader(dataset, 8, 0, True, None)
# from PIL import Image
# import requests
# images, labels = next(iter(dataloader))
# # torch.save(images, 'dnns/clip/test_image.pth')
# classes = dataset.classes
# text = [f"a photo of a {classes[i]}" for i in labels] # should be ground truth
# print(text)
# print(images.size())
# o = model(images, text, True)
# print(o)
# print(o.logits_per_image.softmax(dim=1))
# # o = model(image, torch.load('dnns/clip/test_input_embed.pth'), False)
# # # print(o)
# # print(o.logits_per_image.softmax(dim=1))