Spaces:
Running
Running
''' | |
@File : ImageReward.py | |
@Time : 2023/01/28 19:53:00 | |
@Auther : Jiazheng Xu | |
@Contact : [email protected] | |
@Description: ImageReward Reward model. | |
* Based on CLIP code base and improved-aesthetic-predictor code base | |
* https://github.com/openai/CLIP | |
* https://github.com/christophschuhmann/improved-aesthetic-predictor | |
''' | |
import os | |
import torch | |
import torch.nn as nn | |
from PIL import Image | |
from .models.BLIP.blip_pretrain import BLIP_Pretrain | |
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize | |
try: | |
from torchvision.transforms import InterpolationMode | |
BICUBIC = InterpolationMode.BICUBIC | |
except ImportError: | |
BICUBIC = Image.BICUBIC | |
def _convert_image_to_rgb(image): | |
return image.convert("RGB") | |
def _transform(n_px): | |
return Compose([ | |
Resize(n_px, interpolation=BICUBIC), | |
CenterCrop(n_px), | |
_convert_image_to_rgb, | |
ToTensor(), | |
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), | |
]) | |
class MLP(nn.Module): | |
def __init__(self, input_size): | |
super().__init__() | |
self.input_size = input_size | |
self.layers = nn.Sequential( | |
nn.Linear(self.input_size, 1024), | |
# nn.ReLU(), | |
nn.Dropout(0.2), | |
nn.Linear(1024, 128), | |
# nn.ReLU(), | |
nn.Dropout(0.2), | |
nn.Linear(128, 64), | |
# nn.ReLU(), | |
nn.Dropout(0.1), | |
nn.Linear(64, 16), | |
# nn.ReLU(), | |
nn.Linear(16, 1) | |
) | |
# initial MLP param | |
for name, param in self.layers.named_parameters(): | |
if 'weight' in name: | |
nn.init.normal_(param, mean=0.0, std=1.0 / (self.input_size + 1)) | |
if 'bias' in name: | |
nn.init.constant_(param, val=0) | |
def forward(self, input): | |
return self.layers(input) | |
class ImageReward(nn.Module): | |
def __init__(self, med_config, device='cpu'): | |
super().__init__() | |
self.device = device | |
self.blip = BLIP_Pretrain(image_size=224, vit='large', med_config=med_config) | |
self.preprocess = _transform(224) | |
self.mlp = MLP(768) | |
self.mean = 0.16717362830052426 | |
self.std = 1.0333394966054072 | |
def score_gard(self, prompt_ids, prompt_attention_mask, image): | |
image_embeds = self.blip.visual_encoder(image) | |
# text encode cross attention with image | |
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device) | |
text_output = self.blip.text_encoder(prompt_ids, | |
attention_mask=prompt_attention_mask, | |
encoder_hidden_states=image_embeds, | |
encoder_attention_mask=image_atts, | |
return_dict=True, | |
) | |
txt_features = text_output.last_hidden_state[:, 0, :] # (feature_dim) | |
rewards = self.mlp(txt_features) | |
rewards = (rewards - self.mean) / self.std | |
return rewards | |
def score(self, prompt, image): | |
if (type(image).__name__ == 'list'): | |
_, rewards = self.inference_rank(prompt, image) | |
return rewards | |
# text encode | |
text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35, | |
return_tensors="pt").to(self.device) | |
# image encode | |
if isinstance(image, Image.Image): | |
pil_image = image | |
elif isinstance(image, str): | |
if os.path.isfile(image): | |
pil_image = Image.open(image) | |
else: | |
raise TypeError( | |
r'This image parameter type has not been supportted yet. Please pass PIL.Image or file path str.') | |
image = self.preprocess(pil_image).unsqueeze(0).to(self.device) | |
image_embeds = self.blip.visual_encoder(image) | |
# text encode cross attention with image | |
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device) | |
text_output = self.blip.text_encoder(text_input.input_ids, | |
attention_mask=text_input.attention_mask, | |
encoder_hidden_states=image_embeds, | |
encoder_attention_mask=image_atts, | |
return_dict=True, | |
) | |
txt_features = text_output.last_hidden_state[:, 0, :].float() # (feature_dim) | |
rewards = self.mlp(txt_features) | |
rewards = (rewards - self.mean) / self.std | |
return rewards.detach().cpu().numpy().item() | |
def inference_rank(self, prompt, generations_list): | |
text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35, | |
return_tensors="pt").to(self.device) | |
txt_set = [] | |
for generation in generations_list: | |
# image encode | |
if isinstance(generation, Image.Image): | |
pil_image = generation | |
elif isinstance(generation, str): | |
if os.path.isfile(generation): | |
pil_image = Image.open(generation) | |
else: | |
raise TypeError( | |
r'This image parameter type has not been supportted yet. Please pass PIL.Image or file path str.') | |
image = self.preprocess(pil_image).unsqueeze(0).to(self.device) | |
image_embeds = self.blip.visual_encoder(image) | |
# text encode cross attention with image | |
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device) | |
text_output = self.blip.text_encoder(text_input.input_ids, | |
attention_mask=text_input.attention_mask, | |
encoder_hidden_states=image_embeds, | |
encoder_attention_mask=image_atts, | |
return_dict=True) | |
txt_set.append(text_output.last_hidden_state[:, 0, :]) | |
txt_features = torch.cat(txt_set, 0).float() # [image_num, feature_dim] | |
rewards = self.mlp(txt_features) # [image_num, 1] | |
rewards = (rewards - self.mean) / self.std | |
rewards = torch.squeeze(rewards) | |
_, rank = torch.sort(rewards, dim=0, descending=True) | |
_, indices = torch.sort(rank, dim=0) | |
indices = indices + 1 | |
return indices.detach().cpu().numpy().tolist(), rewards.detach().cpu().numpy().tolist() | |