Spaces:
Runtime error
Runtime error
import logging | |
from typing import Any, Union, List, Optional, Tuple, Dict | |
import open_clip | |
from open_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD | |
import torch | |
from torchvision import transforms | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
import numpy as np | |
import cv2 as cv2 | |
from .gem_wrapper import GEMWrapper | |
_MODELS = { | |
# B/32 | |
"ViT-B/32": [ | |
"openai", | |
"laion400m_e31", | |
"laion400m_e32", | |
"laion2b_e16", | |
"laion2b_s34b_b79k", | |
], | |
"ViT-B/32-quickgelu": [ | |
"metaclip_400m", | |
"metaclip_fullcc" | |
], | |
# B/16 | |
"ViT-B/16": [ | |
"openai", | |
"laion400m_e31", | |
"laion400m_e32", | |
"laion2b_s34b_b88k", | |
], | |
"ViT-B/16-quickgelu": [ | |
"metaclip_400m", | |
"metaclip_fullcc", | |
], | |
"ViT-B/16-plus-240": [ | |
"laion400m_e31", | |
"laion400m_e32" | |
], | |
# L/14 | |
"ViT-L/14": [ | |
"openai", | |
"laion400m_e31", | |
"laion400m_e32", | |
"laion2b_s32b_b82k", | |
], | |
"ViT-L/14-quickgelu": [ | |
"metaclip_400m", | |
"metaclip_fullcc" | |
], | |
"ViT-L/14-336": [ | |
"openai", | |
] | |
} | |
def available_models() -> List[str]: | |
"""Returns the names of available GEM-VL models""" | |
# _str = "".join([": ".join([key, value]) + "\n" for key, values in _MODELS2.items() for value in values]) | |
_str = "".join([": ".join([key + " "*(20 - len(key)), value]) + "\n" for key, values in _MODELS.items() for value in values]) | |
return _str | |
def get_tokenizer( | |
model_name: str = '', | |
context_length: Optional[int] = None, | |
**kwargs, | |
): | |
""" Wrapper around openclip get_tokenizer function """ | |
return open_clip.get_tokenizer(model_name=model_name, context_length=context_length, **kwargs) | |
def get_gem_img_transform( | |
img_size: Union[int, Tuple[int, int]] = (448, 448), | |
mean: Optional[Tuple[float, ...]] = None, | |
std: Optional[Tuple[float, ...]] = None, | |
): | |
mean = mean or OPENAI_DATASET_MEAN | |
std = std or OPENAI_DATASET_STD | |
transform = transforms.Compose([ | |
transforms.Resize(size=img_size, interpolation=transforms.InterpolationMode.BICUBIC), | |
transforms.ToTensor(), | |
transforms.Normalize(mean, std), | |
]) | |
return transform | |
def create_gem_model( | |
model_name: str, | |
pretrained: Optional[str] = None, | |
gem_depth: int = 7, | |
ss_attn_iter: int = 1, | |
ss_attn_temp: Optional[float] = None, | |
precision: str = 'fp32', | |
device: Union[str, torch.device] = 'cpu', | |
jit: bool = False, | |
force_quick_gelu: bool = False, | |
force_custom_text: bool = False, | |
force_patch_dropout: Optional[float] = None, | |
force_image_size: Optional[Union[int, Tuple[int, int]]] = None, | |
force_preprocess_cfg: Optional[Dict[str, Any]] = None, | |
pretrained_image: bool = False, | |
pretrained_hf: bool = True, | |
cache_dir: Optional[str] = None, | |
output_dict: Optional[bool] = None, | |
require_pretrained: bool = False, | |
**model_kwargs, | |
): | |
model_name = model_name.replace("/", "-") | |
logging.info(f'Loading pretrained {model_name} from pretrained weights {pretrained}...') | |
open_clip_model = open_clip.create_model(model_name, pretrained, precision, device, jit, force_quick_gelu, force_custom_text, | |
force_patch_dropout, force_image_size, force_preprocess_cfg, pretrained_image, | |
pretrained_hf, cache_dir, output_dict, require_pretrained, **model_kwargs) | |
tokenizer = open_clip.get_tokenizer(model_name=model_name) | |
gem_model = GEMWrapper(model=open_clip_model, tokenizer=tokenizer, depth=gem_depth, | |
ss_attn_iter=ss_attn_iter, ss_attn_temp=ss_attn_temp) | |
logging.info(f'Loaded GEM-{model_name} from pretrained weights {pretrained}!') | |
return gem_model | |
def create_model_and_transforms( | |
model_name: str, | |
pretrained: Optional[str] = None, | |
gem_depth: int = 7, | |
precision: str = 'fp32', | |
device: Union[str, torch.device] = 'cpu', | |
jit: bool = False, | |
force_quick_gelu: bool = False, | |
force_custom_text: bool = False, | |
force_patch_dropout: Optional[float] = None, | |
force_image_size: Optional[Union[int, Tuple[int, int]]] = None, | |
force_preprocess_cfg: Optional[Dict[str, Any]] = None, | |
pretrained_image: bool = False, | |
pretrained_hf: bool = True, | |
cache_dir: Optional[str] = None, | |
output_dict: Optional[bool] = None, | |
require_pretrained: bool = False, | |
**model_kwargs, | |
): | |
gem_model = create_gem_model(model_name, pretrained, gem_depth, precision, device, jit, force_quick_gelu, force_custom_text, | |
force_patch_dropout, force_image_size, force_preprocess_cfg, pretrained_image, | |
pretrained_hf, cache_dir, output_dict, require_pretrained, **model_kwargs) | |
transform = get_gem_img_transform(**model_kwargs) | |
return gem_model, transform | |
def visualize(image, text, logits, alpha=0.6, save_path=None): | |
W, H = logits.shape[-2:] | |
if isinstance(image, Image.Image): | |
image = image.resize((W, H)) | |
elif isinstance(image, torch.Tensor): | |
if image.ndim > 3: | |
image = image.squeeze(0) | |
image_unormed = (image.detach().cpu() * torch.Tensor(OPENAI_DATASET_STD)[:, None, None]) \ | |
+ torch.Tensor(OPENAI_DATASET_MEAN)[:, None, None] # undo the normalization | |
image = Image.fromarray((image_unormed.permute(1, 2, 0).numpy() * 255).astype('uint8')) # convert to PIL | |
else: | |
raise f'image should be either of type PIL.Image.Image or torch.Tensor but found {type(image)}' | |
# plot image | |
plt.imshow(image) | |
plt.axis('off') | |
plt.tight_layout() | |
plt.show() | |
if logits.ndim > 3: | |
logits = logits.squeeze(0) | |
logits = logits.detach().cpu().numpy() | |
img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
logits = (logits * 255).astype('uint8') | |
heat_maps = [cv2.applyColorMap(logit, cv2.COLORMAP_JET) for logit in logits] | |
vizs = [(1 - alpha) * img_cv + alpha * heat_map for heat_map in heat_maps] | |
for viz, cls_name in zip(vizs, text): | |
viz = cv2.cvtColor(viz.astype('uint8'), cv2.COLOR_BGR2RGB) | |
plt.imshow(viz) | |
plt.title(cls_name) | |
plt.axis('off') | |
plt.tight_layout() | |
plt.show() | |
if save_path is not None: | |
plt.savefig(f'heatmap_{cls_name}.png') | |