Spaces:
Running
Running
File size: 2,928 Bytes
966ae59 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
'''
@File : CLIPScore.py
@Time : 2023/02/12 13:14:00
@Auther : Jiazheng Xu
@Contact : [email protected]
@Description: CLIPScore.
* Based on CLIP code base
* https://github.com/openai/CLIP
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import clip
class CLIPScore(nn.Module):
def __init__(self, download_root, device='cpu'):
super().__init__()
self.device = device
self.clip_model, self.preprocess = clip.load("ViT-L/14", device=self.device, jit=False,
download_root=download_root)
if device == "cpu":
self.clip_model.float()
else:
clip.model.convert_weights(self.clip_model) # Actually this line is unnecessary since clip by default already on float16
# have clip.logit_scale require no grad.
self.clip_model.logit_scale.requires_grad_(False)
def score(self, prompt, image_path):
if (type(image_path).__name__=='list'):
_, rewards = self.inference_rank(prompt, image_path)
return rewards
# text encode
text = clip.tokenize(prompt, truncate=True).to(self.device)
txt_features = F.normalize(self.clip_model.encode_text(text))
# image encode
pil_image = Image.open(image_path)
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
image_features = F.normalize(self.clip_model.encode_image(image))
# score
rewards = torch.sum(torch.mul(txt_features, image_features), dim=1, keepdim=True)
return rewards.detach().cpu().numpy().item()
def inference_rank(self, prompt, generations_list):
text = clip.tokenize(prompt, truncate=True).to(self.device)
txt_feature = F.normalize(self.clip_model.encode_text(text))
txt_set = []
img_set = []
for generations in generations_list:
# image encode
img_path = generations
pil_image = Image.open(img_path)
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
image_features = F.normalize(self.clip_model.encode_image(image))
img_set.append(image_features)
txt_set.append(txt_feature)
txt_features = torch.cat(txt_set, 0).float() # [image_num, feature_dim]
img_features = torch.cat(img_set, 0).float() # [image_num, feature_dim]
rewards = torch.sum(torch.mul(txt_features, img_features), dim=1, keepdim=True)
rewards = torch.squeeze(rewards)
_, rank = torch.sort(rewards, dim=0, descending=True)
_, indices = torch.sort(rank, dim=0)
indices = indices + 1
return indices.detach().cpu().numpy().tolist(), rewards.detach().cpu().numpy().tolist() |