''' @File : ImageReward.py @Time : 2023/01/28 19:53:00 @Auther : Jiazheng Xu @Contact : xjz22@mails.tsinghua.edu.cn @Description: ImageReward Reward model. * Based on CLIP code base and improved-aesthetic-predictor code base * https://github.com/openai/CLIP * https://github.com/christophschuhmann/improved-aesthetic-predictor ''' import os import torch import torch.nn as nn from PIL import Image from .models.BLIP.blip_pretrain import BLIP_Pretrain from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize try: from torchvision.transforms import InterpolationMode BICUBIC = InterpolationMode.BICUBIC except ImportError: BICUBIC = Image.BICUBIC def _convert_image_to_rgb(image): return image.convert("RGB") def _transform(n_px): return Compose([ Resize(n_px, interpolation=BICUBIC), CenterCrop(n_px), _convert_image_to_rgb, ToTensor(), Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), ]) class MLP(nn.Module): def __init__(self, input_size): super().__init__() self.input_size = input_size self.layers = nn.Sequential( nn.Linear(self.input_size, 1024), # nn.ReLU(), nn.Dropout(0.2), nn.Linear(1024, 128), # nn.ReLU(), nn.Dropout(0.2), nn.Linear(128, 64), # nn.ReLU(), nn.Dropout(0.1), nn.Linear(64, 16), # nn.ReLU(), nn.Linear(16, 1) ) # initial MLP param for name, param in self.layers.named_parameters(): if 'weight' in name: nn.init.normal_(param, mean=0.0, std=1.0 / (self.input_size + 1)) if 'bias' in name: nn.init.constant_(param, val=0) def forward(self, input): return self.layers(input) class ImageReward(nn.Module): def __init__(self, med_config, device='cpu'): super().__init__() self.device = device self.blip = BLIP_Pretrain(image_size=224, vit='large', med_config=med_config) self.preprocess = _transform(224) self.mlp = MLP(768) self.mean = 0.16717362830052426 self.std = 1.0333394966054072 def score_gard(self, prompt_ids, prompt_attention_mask, image): image_embeds = self.blip.visual_encoder(image) # text encode cross attention with image image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device) text_output = self.blip.text_encoder(prompt_ids, attention_mask=prompt_attention_mask, encoder_hidden_states=image_embeds, encoder_attention_mask=image_atts, return_dict=True, ) txt_features = text_output.last_hidden_state[:, 0, :] # (feature_dim) rewards = self.mlp(txt_features) rewards = (rewards - self.mean) / self.std return rewards def score(self, prompt, image): if (type(image).__name__ == 'list'): _, rewards = self.inference_rank(prompt, image) return rewards # text encode text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device) # image encode if isinstance(image, Image.Image): pil_image = image elif isinstance(image, str): if os.path.isfile(image): pil_image = Image.open(image) else: raise TypeError( r'This image parameter type has not been supportted yet. Please pass PIL.Image or file path str.') image = self.preprocess(pil_image).unsqueeze(0).to(self.device) image_embeds = self.blip.visual_encoder(image) # text encode cross attention with image image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device) text_output = self.blip.text_encoder(text_input.input_ids, attention_mask=text_input.attention_mask, encoder_hidden_states=image_embeds, encoder_attention_mask=image_atts, return_dict=True, ) txt_features = text_output.last_hidden_state[:, 0, :].float() # (feature_dim) rewards = self.mlp(txt_features) rewards = (rewards - self.mean) / self.std return rewards.detach().cpu().numpy().item() def inference_rank(self, prompt, generations_list): text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device) txt_set = [] for generation in generations_list: # image encode if isinstance(generation, Image.Image): pil_image = generation elif isinstance(generation, str): if os.path.isfile(generation): pil_image = Image.open(generation) else: raise TypeError( r'This image parameter type has not been supportted yet. Please pass PIL.Image or file path str.') image = self.preprocess(pil_image).unsqueeze(0).to(self.device) image_embeds = self.blip.visual_encoder(image) # text encode cross attention with image image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device) text_output = self.blip.text_encoder(text_input.input_ids, attention_mask=text_input.attention_mask, encoder_hidden_states=image_embeds, encoder_attention_mask=image_atts, return_dict=True) txt_set.append(text_output.last_hidden_state[:, 0, :]) txt_features = torch.cat(txt_set, 0).float() # [image_num, feature_dim] rewards = self.mlp(txt_features) # [image_num, 1] rewards = (rewards - self.mean) / self.std rewards = torch.squeeze(rewards) _, rank = torch.sort(rewards, dim=0, descending=True) _, indices = torch.sort(rank, dim=0) indices = indices + 1 return indices.detach().cpu().numpy().tolist(), rewards.detach().cpu().numpy().tolist()