hjc-owo
init repo
966ae59
'''
@File : utils.py
@Time : 2023/04/05 19:18:00
@Auther : Jiazheng Xu
@Contact : [email protected]
* Based on CLIP code base
* https://github.com/openai/CLIP
* Checkpoint of CLIP/BLIP/Aesthetic are from:
* https://github.com/openai/CLIP
* https://github.com/salesforce/BLIP
* https://github.com/christophschuhmann/improved-aesthetic-predictor
'''
import os
import urllib
from typing import Union, List
import pathlib
import torch
from tqdm import tqdm
from huggingface_hub import hf_hub_download
from .ImageReward import ImageReward
from .models.CLIPScore import CLIPScore
from .models.BLIPScore import BLIPScore
from .models.AestheticScore import AestheticScore
_MODELS = {
"ImageReward-v1.0": "https://huggingface.co/THUDM/ImageReward/blob/main/ImageReward.pt",
}
def available_models() -> List[str]:
"""Returns the names of available ImageReward models"""
return list(_MODELS.keys())
def ImageReward_download(url: str, root: str):
os.makedirs(root, exist_ok=True)
filename = os.path.basename(url)
download_target = os.path.join(root, filename)
hf_hub_download(repo_id="THUDM/ImageReward", filename=filename, local_dir=root)
return download_target
def load(name: str = "ImageReward-v1.0",
device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
download_root: str = None,
med_config_path: str = None):
"""Load a ImageReward model
Parameters
----------
name: str
A model name listed by `ImageReward.available_models()`, or the path to a model checkpoint containing the state_dict
device: Union[str, torch.device]
The device to put the loaded model
download_root: str
path to download the model files; by default, it uses "~/.cache/ImageReward"
med_config_path: str
Returns
-------
model : torch.nn.Module
The ImageReward model
"""
if name in _MODELS:
download_root = download_root or "~/.cache/ImageReward"
download_root = pathlib.Path(download_root)
model_path = pathlib.Path(download_root) / 'ImageReward.pt'
if not model_path.exists():
model_path = ImageReward_download(_MODELS[name], root=download_root.as_posix())
elif os.path.isfile(name):
model_path = name
else:
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
print('-> load ImageReward model from %s' % model_path)
state_dict = torch.load(model_path, map_location='cpu')
# med_config
if med_config_path is None:
med_config_root = download_root or "~/.cache/ImageReward"
med_config_root = pathlib.Path(med_config_root)
med_config_path = med_config_root / 'med_config.json'
if not med_config_path.exists():
med_config_path = ImageReward_download("https://huggingface.co/THUDM/ImageReward/blob/main/med_config.json",
root=med_config_root.as_posix())
print('-> load ImageReward med_config from %s' % med_config_path)
model = ImageReward(device=device, med_config=med_config_path).to(device)
msg = model.load_state_dict(state_dict, strict=False)
model.eval()
return model
_SCORES = {
"CLIP": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
"BLIP": "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large.pth",
"Aesthetic": "https://github.com/christophschuhmann/improved-aesthetic-predictor/raw/main/sac%2Blogos%2Bava1-l14-linearMSE.pth",
}
def available_scores() -> List[str]:
"""Returns the names of available ImageReward scores"""
return list(_SCORES.keys())
def _download(url: str, root: str):
os.makedirs(root, exist_ok=True)
filename = os.path.basename(url)
download_target = os.path.join(root, filename)
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")
if os.path.isfile(download_target):
return download_target
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True,
unit_divisor=1024) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
return download_target
def load_score(name: str = "CLIP", device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
download_root: str = None):
"""Load a ImageReward model
Parameters
----------
name : str
A model name listed by `ImageReward.available_models()`
device : Union[str, torch.device]
The device to put the loaded model
download_root: str
path to download the model files; by default, it uses "~/.cache/ImageReward"
Returns
-------
model : torch.nn.Module
The ImageReward model
"""
model_download_root = download_root or os.path.expanduser("~/.cache/ImageReward")
if name in _SCORES:
model_path = _download(_SCORES[name], model_download_root)
else:
raise RuntimeError(f"Score {name} not found; available scores = {available_scores()}")
print('load checkpoint from %s' % model_path)
if name == "BLIP":
state_dict = torch.load(model_path, map_location='cpu')
med_config = ImageReward_download("https://huggingface.co/THUDM/ImageReward/blob/main/med_config.json",
model_download_root)
model = BLIPScore(med_config=med_config, device=device).to(device)
model.blip.load_state_dict(state_dict['model'], strict=False)
elif name == "CLIP":
model = CLIPScore(download_root=model_download_root, device=device).to(device)
elif name == "Aesthetic":
state_dict = torch.load(model_path, map_location='cpu')
model = AestheticScore(download_root=model_download_root, device=device).to(device)
model.mlp.load_state_dict(state_dict, strict=False)
else:
raise RuntimeError(f"Score {name} not found; available scores = {available_scores()}")
print("checkpoint loaded")
model.eval()
return model