GEM / gem /gem.py
WalidBouss's picture
Initial commit :tada:
be1ec96
import logging
from typing import Any, Union, List, Optional, Tuple, Dict
import open_clip
from open_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
import torch
from torchvision import transforms
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import cv2 as cv2
from .gem_wrapper import GEMWrapper
_MODELS = {
# B/32
"ViT-B/32": [
"openai",
"laion400m_e31",
"laion400m_e32",
"laion2b_e16",
"laion2b_s34b_b79k",
],
"ViT-B/32-quickgelu": [
"metaclip_400m",
"metaclip_fullcc"
],
# B/16
"ViT-B/16": [
"openai",
"laion400m_e31",
"laion400m_e32",
"laion2b_s34b_b88k",
],
"ViT-B/16-quickgelu": [
"metaclip_400m",
"metaclip_fullcc",
],
"ViT-B/16-plus-240": [
"laion400m_e31",
"laion400m_e32"
],
# L/14
"ViT-L/14": [
"openai",
"laion400m_e31",
"laion400m_e32",
"laion2b_s32b_b82k",
],
"ViT-L/14-quickgelu": [
"metaclip_400m",
"metaclip_fullcc"
],
"ViT-L/14-336": [
"openai",
]
}
def available_models() -> List[str]:
"""Returns the names of available GEM-VL models"""
# _str = "".join([": ".join([key, value]) + "\n" for key, values in _MODELS2.items() for value in values])
_str = "".join([": ".join([key + " "*(20 - len(key)), value]) + "\n" for key, values in _MODELS.items() for value in values])
return _str
def get_tokenizer(
model_name: str = '',
context_length: Optional[int] = None,
**kwargs,
):
""" Wrapper around openclip get_tokenizer function """
return open_clip.get_tokenizer(model_name=model_name, context_length=context_length, **kwargs)
def get_gem_img_transform(
img_size: Union[int, Tuple[int, int]] = (448, 448),
mean: Optional[Tuple[float, ...]] = None,
std: Optional[Tuple[float, ...]] = None,
):
mean = mean or OPENAI_DATASET_MEAN
std = std or OPENAI_DATASET_STD
transform = transforms.Compose([
transforms.Resize(size=img_size, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean, std),
])
return transform
def create_gem_model(
model_name: str,
pretrained: Optional[str] = None,
gem_depth: int = 7,
ss_attn_iter: int = 1,
ss_attn_temp: Optional[float] = None,
precision: str = 'fp32',
device: Union[str, torch.device] = 'cpu',
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_text: bool = False,
force_patch_dropout: Optional[float] = None,
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
force_preprocess_cfg: Optional[Dict[str, Any]] = None,
pretrained_image: bool = False,
pretrained_hf: bool = True,
cache_dir: Optional[str] = None,
output_dict: Optional[bool] = None,
require_pretrained: bool = False,
**model_kwargs,
):
model_name = model_name.replace("/", "-")
logging.info(f'Loading pretrained {model_name} from pretrained weights {pretrained}...')
open_clip_model = open_clip.create_model(model_name, pretrained, precision, device, jit, force_quick_gelu, force_custom_text,
force_patch_dropout, force_image_size, force_preprocess_cfg, pretrained_image,
pretrained_hf, cache_dir, output_dict, require_pretrained, **model_kwargs)
tokenizer = open_clip.get_tokenizer(model_name=model_name)
gem_model = GEMWrapper(model=open_clip_model, tokenizer=tokenizer, depth=gem_depth,
ss_attn_iter=ss_attn_iter, ss_attn_temp=ss_attn_temp)
logging.info(f'Loaded GEM-{model_name} from pretrained weights {pretrained}!')
return gem_model
def create_model_and_transforms(
model_name: str,
pretrained: Optional[str] = None,
gem_depth: int = 7,
precision: str = 'fp32',
device: Union[str, torch.device] = 'cpu',
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_text: bool = False,
force_patch_dropout: Optional[float] = None,
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
force_preprocess_cfg: Optional[Dict[str, Any]] = None,
pretrained_image: bool = False,
pretrained_hf: bool = True,
cache_dir: Optional[str] = None,
output_dict: Optional[bool] = None,
require_pretrained: bool = False,
**model_kwargs,
):
gem_model = create_gem_model(model_name, pretrained, gem_depth, precision, device, jit, force_quick_gelu, force_custom_text,
force_patch_dropout, force_image_size, force_preprocess_cfg, pretrained_image,
pretrained_hf, cache_dir, output_dict, require_pretrained, **model_kwargs)
transform = get_gem_img_transform(**model_kwargs)
return gem_model, transform
def visualize(image, text, logits, alpha=0.6, save_path=None):
W, H = logits.shape[-2:]
if isinstance(image, Image.Image):
image = image.resize((W, H))
elif isinstance(image, torch.Tensor):
if image.ndim > 3:
image = image.squeeze(0)
image_unormed = (image.detach().cpu() * torch.Tensor(OPENAI_DATASET_STD)[:, None, None]) \
+ torch.Tensor(OPENAI_DATASET_MEAN)[:, None, None] # undo the normalization
image = Image.fromarray((image_unormed.permute(1, 2, 0).numpy() * 255).astype('uint8')) # convert to PIL
else:
raise f'image should be either of type PIL.Image.Image or torch.Tensor but found {type(image)}'
# plot image
plt.imshow(image)
plt.axis('off')
plt.tight_layout()
plt.show()
if logits.ndim > 3:
logits = logits.squeeze(0)
logits = logits.detach().cpu().numpy()
img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
logits = (logits * 255).astype('uint8')
heat_maps = [cv2.applyColorMap(logit, cv2.COLORMAP_JET) for logit in logits]
vizs = [(1 - alpha) * img_cv + alpha * heat_map for heat_map in heat_maps]
for viz, cls_name in zip(vizs, text):
viz = cv2.cvtColor(viz.astype('uint8'), cv2.COLOR_BGR2RGB)
plt.imshow(viz)
plt.title(cls_name)
plt.axis('off')
plt.tight_layout()
plt.show()
if save_path is not None:
plt.savefig(f'heatmap_{cls_name}.png')