summary / fengshen /examples /clip_finetune /clip_finetune_flickr.py
fclong's picture
Upload 396 files
8ebda9e
import sys
sys.path.append('../../')
from data.clip_dataloader.flickr import FlickrDataModule
import pytorch_lightning as pl
import numpy as np
import torch
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
import torch.nn.functional as F
import math
import copy
import argparse
from transformers import CLIPModel, BertForSequenceClassification
class CLIPLightning(pl.LightningModule):
def __init__(self, model_name='ViT-B/32', minibatch_size=2):
"""A lightning wrapper for a CLIP model as specified in the paper.
Args:
model_name (str): A case sensitive visual model name.
config (dict): A dictionary containing the CLIP instantiation parameters.
"""
super().__init__()
self.prepare_data_per_node = True
self.model_name = 'ViT-B/32'
# self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") # NOTE load from openAI
self.text_encoder = BertForSequenceClassification.from_pretrained("IDEA-CCNL/Taiyi-CLIP-Roberta-102M-Chinese")
self.minibatch_size = minibatch_size
self.isViT = 'ViT' in self.model_name
self.automatic_optimization = False
# Training loss: https://github.com/openai/CLIP/issues/83
# Mini-batching thanks to https://github.com/crowsonkb / https://twitter.com/RiversHaveWings
# Multi-GPU support: https://github.com/MicPie/clasp
def training_step(self, train_batch, idx):
# get optimizers and scheduler
optimizer = self.optimizers()
image, text, labels = train_batch
n = math.ceil(len(image) // self.minibatch_size)
image_mbs = torch.chunk(image, n)
text_mbs = torch.chunk(text, n)
with torch.no_grad():
ims = [F.normalize(self.clip_model.get_image_features(im), dim=1) for im in image_mbs]
txt = [F.normalize(self.text_encoder(t).logits, dim=1) for t in text_mbs]
# gather from all GPUs 这里的LOSS要把所有GPU的汇集起来一起算才对
ims = self.all_gather(torch.cat(ims))
txt = self.all_gather(torch.cat(txt))
if len(ims.shape) == 3:
ims = list(ims)
txt = list(txt)
else:
ims = [ims]
txt = [txt]
image_logits = torch.cat(ims) @ torch.cat(txt).t() * self.clip_model.logit_scale.exp()
ground_truth = torch.arange(len(image_logits)).long().to(image_logits.device)
loss = (F.cross_entropy(image_logits, ground_truth) +
F.cross_entropy(image_logits.t(), ground_truth)).div(2)
acc_i = (torch.argmax(image_logits, 1) == ground_truth).sum()
acc_t = (torch.argmax(image_logits, 0) == ground_truth).sum()
self.log_dict({'loss': loss / len(ims), 'acc': (acc_i + acc_t) / 2 / len(image) / len(ims)}, prog_bar=True)
if isinstance(optimizer, list):
optimizer = optimizer[0]
optimizer.zero_grad()
# image loss
for j, mb in enumerate(image_mbs[:-1]):
# 最后一部分样本舍弃。(对齐的bug)
images_tmp = copy.deepcopy(ims)
images_tmp[self.global_rank][j * self.minibatch_size:(j+1)*self.minibatch_size] = \
F.normalize(self.clip_model.get_image_features(mb), dim=1)
image_logits = torch.cat(images_tmp) @ torch.cat(txt).t() * self.clip_model.logit_scale.exp()
ground_truth = torch.arange(len(image_logits)).long().to(image_logits.device)
loss = (F.cross_entropy(image_logits, ground_truth) + F.cross_entropy(image_logits.t(), ground_truth))/2
self.manual_backward(loss)
# text loss
for j, mb in enumerate(text_mbs[:-1]):
text_tmp = copy.deepcopy(txt)
text_tmp[self.global_rank][j * self.minibatch_size:(j+1)*self.minibatch_size] = \
F.normalize(self.text_encoder(mb).logits, dim=1)
image_logits = torch.cat(ims) @ torch.cat(text_tmp).t() * self.clip_model.logit_scale.exp()
loss = (F.cross_entropy(image_logits, ground_truth) + F.cross_entropy(image_logits.t(), ground_truth))/2
self.manual_backward(loss)
optimizer.step()
lr_scheduler = self.lr_schedulers()
lr_scheduler.step()
self.clip_model.logit_scale.data.clamp_(-np.log(100), np.log(100))
def validation_step(self, val_batch, idx):
image, text, labels = val_batch
img_embed = self.clip_model.get_image_features(image)
txt_embed = self.text_encoder(text).logits
# print(img_embed.shape)
image_norm = F.normalize(img_embed, dim=1)
text_norm = F.normalize(txt_embed, dim=1)
image_logits = image_norm @ text_norm.t() * self.clip_model.logit_scale.exp()
text_logits = text_norm @ image_norm.t() * self.clip_model.logit_scale.exp()
# print(image_logits.shape)
# image_logits, text_logits = self.forward(image, text)
ground_truth = torch.arange(len(image_logits)).long().to(image_logits.device)
loss = (F.cross_entropy(image_logits, ground_truth) + F.cross_entropy(text_logits, ground_truth)).div(2)
self.log('val_loss', loss, prog_bar=True)
return [image_norm, text_norm, labels]
def validation_epoch_end(self, outputs):
image_features = torch.cat([x[0] for x in outputs])
text_features = torch.cat([x[1] for x in outputs])
labels = [label for x in outputs for label in x[2]]
print(image_features.shape, text_features.shape, len(labels))
self.get_metrics(image_features, text_features, labels, 100)
def test_step(self, test_batch, idx):
image, text, labels = test_batch
image_features = self.clip_model.get_image_features(image)
text_features = self.text_encoder(text).logits
image_features = image_features / image_features.norm(dim=1, keepdim=True)
text_features = text_features / text_features.norm(dim=1, keepdim=True)
return [image_features, text_features, labels]
def test_epoch_end(self, outputs):
image_features = torch.cat([x[0] for x in outputs])
text_features = torch.cat([x[1] for x in outputs])
labels = [label for x in outputs for label in x[2]]
print(image_features.shape, text_features.shape, len(labels))
self.get_metrics(image_features, text_features, labels, 100)
def get_metrics(self, image_features, text_features, labels, logit_scale):
# 计算相似度,支持多个样本的情况(比如一个图片有多个caption)
# img2txt计算的时候要用到,因为一张图片可能对应多个文本。
# txt2img计算的时候不需要(一般一个text只有一个对应图片)
# metrics = {}
logits_per_image = (logit_scale * image_features @ text_features.t()).detach().cpu()
logits_per_text = logits_per_image.t().detach().cpu()
logits = {"image_to_text": logits_per_image, "text_to_image": logits_per_text}
label2idx = {} # 计算label到idx的映射。
repeat_id = []
for i, label in enumerate(labels):
if label not in label2idx:
label2idx[label] = [i]
else:
# 表示该index的标签出现过,记录这个index,后续算txt2img分数的时候,这些index的权值要降低。
label2idx[label].append(i)
repeat_id.append(i)
# print(label2idx) # 标注了每个label的idx
# print('repeat_id:', repeat_id)
ground_truth = [label2idx[label] for label in labels]
# print(ground_truth)
for name, logit in logits.items():
# print(name, logit.shape)
if name == 'text_to_image':
logit[:, repeat_id] -= 1e8 # 这部分的分数要降低。(重复出现的图片,直接忽略)
r1_stat, r5_stat, r10_stat = [], [], []
ranking = torch.argsort(logit, descending=True) # index of the largest element to the smallest
# print(name, ranking[:, :10])
for i, each_query in enumerate(ranking[:, :10]):
for j, q in enumerate(each_query):
if q in ground_truth[i]:
if j == 0:
r1_stat.append(1)
r5_stat.append(1)
r10_stat.append(1)
break
if j < 5:
r5_stat.append(1)
r10_stat.append(1)
break
if j < 10:
r10_stat.append(1)
break
print(f'{name} r1:{sum(r1_stat)/len(logit)}, r5:{sum(r5_stat)/len(logit)}, r10:{sum(r10_stat)/len(logit)}')
def configure_optimizers(self):
lr = {
"RN50": 5e-4,
"RN101": 5e-4,
"RN50x4": 5e-4,
"RN50x16": 4e-4,
"RN50x64": 3.6e-4,
"ViT-B/32": 5e-4,
"ViT-B/16": 5e-4,
"ViT-L/14": 4e-4,
"ViT-L/14-336px": 2e-5
}[self.model_name]
optimizer = torch.optim.AdamW(
[{'params': self.clip_model.parameters()}, {'params': self.text_encoder.parameters()}],
lr=lr,
betas=(
0.9,
0.98 if self.isViT else 0.999
),
eps=1e-6 if self.isViT else 1e-8,
weight_decay=0.2
)
# Source: https://github.com/openai/CLIP/issues/107
# Use pip install 'git+https://github.com/katsura-jp/pytorch-cosine-annealing-with-warmup'
lr_scheduler = CosineAnnealingWarmRestarts(
optimizer,
T_0=2000
)
# CosineAnnealingWarmupRestarts
return {'optimizer': optimizer, 'lr_scheduler': lr_scheduler}
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# model_name
parser.add_argument('--model', type=str,
default="ViT-B/32",
help='model definition')
# experiment setting
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--num_epoches', type=int, default=1)
parser.add_argument('--num_gpus', type=int, default=2)
# dataset
parser.add_argument('--train_filename', type=str,
help='dir or csv file')
parser.add_argument('--train_root', type=str,
help='image root path')
parser.add_argument('--val_filename', type=str,
help='dir or csv file')
parser.add_argument('--val_root', type=str,
help='image root path')
parser.add_argument('--test_filename', type=str,
help='dir or csv file')
parser.add_argument('--test_root', type=str,
help='image root path')
parser.add_argument('--num_workers', type=int, default=0)
# huggingface pretrain model 定义
parser.add_argument('--pretrain_model', type=str,
default="openai/clip-vit-base-patch32",
help='defalut load from openai') # "wf-genius/TaiYi-CLIP-ViT-B-32" 是我训好的 NOTE
args = parser.parse_args()
dm = FlickrDataModule(args)
model = CLIPLightning(model_name=args.model, minibatch_size=args.batch_size//2)
trainer = pl.Trainer(gpus=args.num_gpus, precision=16, max_epochs=args.num_epoches)
trainer.test(model, dm) # zero-shot test
trainer.fit(model, dm) # finetune on train set
trainer.test(model, dm) # test again