Spaces:
Sleeping
Sleeping
import json | |
import math | |
import os | |
from pathlib import Path | |
from typing import List | |
from urllib.parse import urlparse | |
import torch | |
from models.swin_transformer import interpolate_relative_pos_embed | |
from models.vit import interpolate_pos_embed | |
from timm.models.hub import download_cached_file | |
from torch import nn | |
from transformers import BertTokenizer | |
CONFIG_PATH = Path(__file__).resolve().parents[1] | |
def read_json(rpath): | |
with open(rpath) as f: | |
return json.load(f) | |
def tie_encoder_decoder_weights( | |
encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, skip_key: str | |
): | |
uninitialized_encoder_weights: List[str] = [] | |
if decoder.__class__ != encoder.__class__: | |
logger.info( | |
f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder weights are correctly initialized." | |
) | |
def tie_encoder_to_decoder_recursively( | |
decoder_pointer: nn.Module, | |
encoder_pointer: nn.Module, | |
module_name: str, | |
uninitialized_encoder_weights: List[str], | |
skip_key: str, | |
depth=0, | |
): | |
assert isinstance(decoder_pointer, nn.Module) and isinstance( | |
encoder_pointer, nn.Module | |
), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module" | |
if hasattr(decoder_pointer, "weight") and skip_key not in module_name: | |
assert hasattr(encoder_pointer, "weight") | |
encoder_pointer.weight = decoder_pointer.weight | |
if hasattr(decoder_pointer, "bias"): | |
assert hasattr(encoder_pointer, "bias") | |
encoder_pointer.bias = decoder_pointer.bias | |
print(module_name + " is tied") | |
return | |
encoder_modules = encoder_pointer._modules | |
decoder_modules = decoder_pointer._modules | |
if len(decoder_modules) > 0: | |
assert ( | |
len(encoder_modules) > 0 | |
), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}" | |
all_encoder_weights = { | |
module_name + "/" + sub_name for sub_name in encoder_modules.keys() | |
} | |
encoder_layer_pos = 0 | |
for name, module in decoder_modules.items(): | |
if name.isdigit(): | |
encoder_name = str(int(name) + encoder_layer_pos) | |
decoder_name = name | |
if not isinstance( | |
decoder_modules[decoder_name], | |
type(encoder_modules[encoder_name]), | |
) and len(encoder_modules) != len(decoder_modules): | |
# this can happen if the name corresponds to the position in a list module list of layers | |
# in this case the decoder has added a cross-attention that the encoder does not have | |
# thus skip this step and subtract one layer pos from encoder | |
encoder_layer_pos -= 1 | |
continue | |
elif name not in encoder_modules: | |
continue | |
elif depth > 500: | |
raise ValueError( | |
"Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model." | |
) | |
else: | |
decoder_name = encoder_name = name | |
tie_encoder_to_decoder_recursively( | |
decoder_modules[decoder_name], | |
encoder_modules[encoder_name], | |
module_name + "/" + name, | |
uninitialized_encoder_weights, | |
skip_key, | |
depth=depth + 1, | |
) | |
all_encoder_weights.remove(module_name + "/" + encoder_name) | |
uninitialized_encoder_weights += list(all_encoder_weights) | |
# tie weights recursively | |
tie_encoder_to_decoder_recursively( | |
decoder, encoder, base_model_prefix, uninitialized_encoder_weights, skip_key | |
) | |
class GroupWiseLinear(nn.Module): | |
# could be changed to: | |
# output = torch.einsum('ijk,zjk->ij', x, self.W) | |
# or output = torch.einsum('ijk,jk->ij', x, self.W[0]) | |
def __init__(self, num_class, hidden_dim, bias=True): | |
super().__init__() | |
self.num_class = num_class | |
self.hidden_dim = hidden_dim | |
self.bias = bias | |
self.W = nn.Parameter(torch.Tensor(1, num_class, hidden_dim)) | |
if bias: | |
self.b = nn.Parameter(torch.Tensor(1, num_class)) | |
self.reset_parameters() | |
def reset_parameters(self): | |
stdv = 1.0 / math.sqrt(self.W.size(2)) | |
for i in range(self.num_class): | |
self.W[0][i].data.uniform_(-stdv, stdv) | |
if self.bias: | |
for i in range(self.num_class): | |
self.b[0][i].data.uniform_(-stdv, stdv) | |
def forward(self, x): | |
# x: B,K,d | |
x = (self.W * x).sum(-1) | |
if self.bias: | |
x = x + self.b | |
return x | |
def init_tokenizer(): | |
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | |
tokenizer.add_special_tokens({"bos_token": "[DEC]"}) | |
tokenizer.add_special_tokens({"additional_special_tokens": ["[ENC]"]}) | |
tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0] | |
return tokenizer | |
def create_vit( | |
vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0 | |
): | |
assert vit in ["base", "large"], "vit parameter must be base or large" | |
if vit == "base": | |
vision_width = 768 | |
visual_encoder = VisionTransformer( | |
img_size=image_size, | |
patch_size=16, | |
embed_dim=vision_width, | |
depth=12, | |
num_heads=12, | |
use_grad_checkpointing=use_grad_checkpointing, | |
ckpt_layer=ckpt_layer, | |
drop_path_rate=0 or drop_path_rate, | |
) | |
elif vit == "large": | |
vision_width = 1024 | |
visual_encoder = VisionTransformer( | |
img_size=image_size, | |
patch_size=16, | |
embed_dim=vision_width, | |
depth=24, | |
num_heads=16, | |
use_grad_checkpointing=use_grad_checkpointing, | |
ckpt_layer=ckpt_layer, | |
drop_path_rate=0.1 or drop_path_rate, | |
) | |
return visual_encoder, vision_width | |
def is_url(url_or_filename): | |
parsed = urlparse(url_or_filename) | |
return parsed.scheme in ("http", "https") | |
def load_checkpoint(model, url_or_filename): | |
if is_url(url_or_filename): | |
cached_file = download_cached_file( | |
url_or_filename, check_hash=False, progress=True | |
) | |
checkpoint = torch.load(cached_file, map_location="cpu") | |
elif os.path.isfile(url_or_filename): | |
checkpoint = torch.load(url_or_filename, map_location="cpu") | |
else: | |
raise RuntimeError("checkpoint url or path is invalid") | |
state_dict = checkpoint["model"] | |
state_dict["visual_encoder.pos_embed"] = interpolate_pos_embed( | |
state_dict["visual_encoder.pos_embed"], model.visual_encoder | |
) | |
if "visual_encoder_m.pos_embed" in model.state_dict().keys(): | |
state_dict["visual_encoder_m.pos_embed"] = interpolate_pos_embed( | |
state_dict["visual_encoder_m.pos_embed"], model.visual_encoder_m | |
) | |
for key in model.state_dict().keys(): | |
if key in state_dict.keys(): | |
if state_dict[key].shape != model.state_dict()[key].shape: | |
del state_dict[key] | |
msg = model.load_state_dict(state_dict, strict=False) | |
print("load checkpoint from %s" % url_or_filename) | |
return model, msg | |
def load_checkpoint_swinbase(model, url_or_filename, kwargs): | |
if kwargs["image_size"] == 224: | |
vision_config_path = f"{CONFIG_PATH}/configs/swin/config_swinB_224.json" | |
elif kwargs["image_size"] == 384: | |
vision_config_path = f"{CONFIG_PATH}/configs/swin/config_swinB_384.json" | |
window_size = read_json(vision_config_path)["window_size"] | |
print("--------------") | |
print(url_or_filename) | |
print("--------------") | |
if is_url(url_or_filename): | |
cached_file = download_cached_file( | |
url_or_filename, check_hash=False, progress=True | |
) | |
checkpoint = torch.load(cached_file, map_location="cpu") | |
elif os.path.isfile(url_or_filename): | |
checkpoint = torch.load(url_or_filename, map_location="cpu") | |
else: | |
raise RuntimeError("checkpoint url or path is invalid") | |
state_dict = checkpoint["model"] | |
for k in list(state_dict.keys()): | |
if "relative_position_bias_table" in k: | |
dst_num_pos = (2 * window_size - 1) ** 2 | |
state_dict[k] = interpolate_relative_pos_embed( | |
state_dict[k], dst_num_pos, param_name=k | |
) | |
elif ("relative_position_index" in k) or ("attn_mask" in k): | |
del state_dict[k] | |
elif "vision_multi" in k: | |
state_dict[k.replace("vision_multi", "tagging_head")] = state_dict.pop(k) | |
msg = model.load_state_dict(state_dict, strict=False) | |
print("load checkpoint from %s" % url_or_filename) | |
return model, msg | |