Spaces:
Running
Running
''' | |
@File : utils.py | |
@Time : 2023/04/05 19:18:00 | |
@Auther : Jiazheng Xu | |
@Contact : [email protected] | |
* Based on CLIP code base | |
* https://github.com/openai/CLIP | |
* Checkpoint of CLIP/BLIP/Aesthetic are from: | |
* https://github.com/openai/CLIP | |
* https://github.com/salesforce/BLIP | |
* https://github.com/christophschuhmann/improved-aesthetic-predictor | |
''' | |
import os | |
import urllib | |
from typing import Union, List | |
import pathlib | |
import torch | |
from tqdm import tqdm | |
from huggingface_hub import hf_hub_download | |
from .ImageReward import ImageReward | |
from .models.CLIPScore import CLIPScore | |
from .models.BLIPScore import BLIPScore | |
from .models.AestheticScore import AestheticScore | |
_MODELS = { | |
"ImageReward-v1.0": "https://huggingface.co/THUDM/ImageReward/blob/main/ImageReward.pt", | |
} | |
def available_models() -> List[str]: | |
"""Returns the names of available ImageReward models""" | |
return list(_MODELS.keys()) | |
def ImageReward_download(url: str, root: str): | |
os.makedirs(root, exist_ok=True) | |
filename = os.path.basename(url) | |
download_target = os.path.join(root, filename) | |
hf_hub_download(repo_id="THUDM/ImageReward", filename=filename, local_dir=root) | |
return download_target | |
def load(name: str = "ImageReward-v1.0", | |
device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", | |
download_root: str = None, | |
med_config_path: str = None): | |
"""Load a ImageReward model | |
Parameters | |
---------- | |
name: str | |
A model name listed by `ImageReward.available_models()`, or the path to a model checkpoint containing the state_dict | |
device: Union[str, torch.device] | |
The device to put the loaded model | |
download_root: str | |
path to download the model files; by default, it uses "~/.cache/ImageReward" | |
med_config_path: str | |
Returns | |
------- | |
model : torch.nn.Module | |
The ImageReward model | |
""" | |
if name in _MODELS: | |
download_root = download_root or "~/.cache/ImageReward" | |
download_root = pathlib.Path(download_root) | |
model_path = pathlib.Path(download_root) / 'ImageReward.pt' | |
if not model_path.exists(): | |
model_path = ImageReward_download(_MODELS[name], root=download_root.as_posix()) | |
elif os.path.isfile(name): | |
model_path = name | |
else: | |
raise RuntimeError(f"Model {name} not found; available models = {available_models()}") | |
print('-> load ImageReward model from %s' % model_path) | |
state_dict = torch.load(model_path, map_location='cpu') | |
# med_config | |
if med_config_path is None: | |
med_config_root = download_root or "~/.cache/ImageReward" | |
med_config_root = pathlib.Path(med_config_root) | |
med_config_path = med_config_root / 'med_config.json' | |
if not med_config_path.exists(): | |
med_config_path = ImageReward_download("https://huggingface.co/THUDM/ImageReward/blob/main/med_config.json", | |
root=med_config_root.as_posix()) | |
print('-> load ImageReward med_config from %s' % med_config_path) | |
model = ImageReward(device=device, med_config=med_config_path).to(device) | |
msg = model.load_state_dict(state_dict, strict=False) | |
model.eval() | |
return model | |
_SCORES = { | |
"CLIP": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", | |
"BLIP": "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large.pth", | |
"Aesthetic": "https://github.com/christophschuhmann/improved-aesthetic-predictor/raw/main/sac%2Blogos%2Bava1-l14-linearMSE.pth", | |
} | |
def available_scores() -> List[str]: | |
"""Returns the names of available ImageReward scores""" | |
return list(_SCORES.keys()) | |
def _download(url: str, root: str): | |
os.makedirs(root, exist_ok=True) | |
filename = os.path.basename(url) | |
download_target = os.path.join(root, filename) | |
if os.path.exists(download_target) and not os.path.isfile(download_target): | |
raise RuntimeError(f"{download_target} exists and is not a regular file") | |
if os.path.isfile(download_target): | |
return download_target | |
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: | |
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, | |
unit_divisor=1024) as loop: | |
while True: | |
buffer = source.read(8192) | |
if not buffer: | |
break | |
output.write(buffer) | |
loop.update(len(buffer)) | |
return download_target | |
def load_score(name: str = "CLIP", device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", | |
download_root: str = None): | |
"""Load a ImageReward model | |
Parameters | |
---------- | |
name : str | |
A model name listed by `ImageReward.available_models()` | |
device : Union[str, torch.device] | |
The device to put the loaded model | |
download_root: str | |
path to download the model files; by default, it uses "~/.cache/ImageReward" | |
Returns | |
------- | |
model : torch.nn.Module | |
The ImageReward model | |
""" | |
model_download_root = download_root or os.path.expanduser("~/.cache/ImageReward") | |
if name in _SCORES: | |
model_path = _download(_SCORES[name], model_download_root) | |
else: | |
raise RuntimeError(f"Score {name} not found; available scores = {available_scores()}") | |
print('load checkpoint from %s' % model_path) | |
if name == "BLIP": | |
state_dict = torch.load(model_path, map_location='cpu') | |
med_config = ImageReward_download("https://huggingface.co/THUDM/ImageReward/blob/main/med_config.json", | |
model_download_root) | |
model = BLIPScore(med_config=med_config, device=device).to(device) | |
model.blip.load_state_dict(state_dict['model'], strict=False) | |
elif name == "CLIP": | |
model = CLIPScore(download_root=model_download_root, device=device).to(device) | |
elif name == "Aesthetic": | |
state_dict = torch.load(model_path, map_location='cpu') | |
model = AestheticScore(download_root=model_download_root, device=device).to(device) | |
model.mlp.load_state_dict(state_dict, strict=False) | |
else: | |
raise RuntimeError(f"Score {name} not found; available scores = {available_scores()}") | |
print("checkpoint loaded") | |
model.eval() | |
return model | |