|
import sys |
|
sys.path.append('../../') |
|
from data.clip_dataloader.flickr import FlickrDataModule |
|
import pytorch_lightning as pl |
|
import numpy as np |
|
import torch |
|
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts |
|
import torch.nn.functional as F |
|
import math |
|
import copy |
|
import argparse |
|
from transformers import CLIPModel, BertForSequenceClassification |
|
|
|
class CLIPLightning(pl.LightningModule): |
|
def __init__(self, model_name='ViT-B/32', minibatch_size=2): |
|
"""A lightning wrapper for a CLIP model as specified in the paper. |
|
|
|
Args: |
|
model_name (str): A case sensitive visual model name. |
|
config (dict): A dictionary containing the CLIP instantiation parameters. |
|
""" |
|
super().__init__() |
|
|
|
self.prepare_data_per_node = True |
|
self.model_name = 'ViT-B/32' |
|
|
|
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") |
|
self.text_encoder = BertForSequenceClassification.from_pretrained("IDEA-CCNL/Taiyi-CLIP-Roberta-102M-Chinese") |
|
self.minibatch_size = minibatch_size |
|
self.isViT = 'ViT' in self.model_name |
|
self.automatic_optimization = False |
|
|
|
|
|
|
|
|
|
|
|
def training_step(self, train_batch, idx): |
|
|
|
optimizer = self.optimizers() |
|
|
|
image, text, labels = train_batch |
|
n = math.ceil(len(image) // self.minibatch_size) |
|
image_mbs = torch.chunk(image, n) |
|
text_mbs = torch.chunk(text, n) |
|
|
|
with torch.no_grad(): |
|
ims = [F.normalize(self.clip_model.get_image_features(im), dim=1) for im in image_mbs] |
|
txt = [F.normalize(self.text_encoder(t).logits, dim=1) for t in text_mbs] |
|
|
|
ims = self.all_gather(torch.cat(ims)) |
|
txt = self.all_gather(torch.cat(txt)) |
|
|
|
if len(ims.shape) == 3: |
|
ims = list(ims) |
|
txt = list(txt) |
|
else: |
|
ims = [ims] |
|
txt = [txt] |
|
|
|
image_logits = torch.cat(ims) @ torch.cat(txt).t() * self.clip_model.logit_scale.exp() |
|
ground_truth = torch.arange(len(image_logits)).long().to(image_logits.device) |
|
loss = (F.cross_entropy(image_logits, ground_truth) + |
|
F.cross_entropy(image_logits.t(), ground_truth)).div(2) |
|
acc_i = (torch.argmax(image_logits, 1) == ground_truth).sum() |
|
acc_t = (torch.argmax(image_logits, 0) == ground_truth).sum() |
|
self.log_dict({'loss': loss / len(ims), 'acc': (acc_i + acc_t) / 2 / len(image) / len(ims)}, prog_bar=True) |
|
|
|
if isinstance(optimizer, list): |
|
optimizer = optimizer[0] |
|
optimizer.zero_grad() |
|
|
|
|
|
for j, mb in enumerate(image_mbs[:-1]): |
|
|
|
images_tmp = copy.deepcopy(ims) |
|
images_tmp[self.global_rank][j * self.minibatch_size:(j+1)*self.minibatch_size] = \ |
|
F.normalize(self.clip_model.get_image_features(mb), dim=1) |
|
image_logits = torch.cat(images_tmp) @ torch.cat(txt).t() * self.clip_model.logit_scale.exp() |
|
ground_truth = torch.arange(len(image_logits)).long().to(image_logits.device) |
|
loss = (F.cross_entropy(image_logits, ground_truth) + F.cross_entropy(image_logits.t(), ground_truth))/2 |
|
self.manual_backward(loss) |
|
|
|
|
|
for j, mb in enumerate(text_mbs[:-1]): |
|
text_tmp = copy.deepcopy(txt) |
|
text_tmp[self.global_rank][j * self.minibatch_size:(j+1)*self.minibatch_size] = \ |
|
F.normalize(self.text_encoder(mb).logits, dim=1) |
|
image_logits = torch.cat(ims) @ torch.cat(text_tmp).t() * self.clip_model.logit_scale.exp() |
|
loss = (F.cross_entropy(image_logits, ground_truth) + F.cross_entropy(image_logits.t(), ground_truth))/2 |
|
self.manual_backward(loss) |
|
|
|
optimizer.step() |
|
lr_scheduler = self.lr_schedulers() |
|
lr_scheduler.step() |
|
self.clip_model.logit_scale.data.clamp_(-np.log(100), np.log(100)) |
|
|
|
def validation_step(self, val_batch, idx): |
|
image, text, labels = val_batch |
|
img_embed = self.clip_model.get_image_features(image) |
|
txt_embed = self.text_encoder(text).logits |
|
|
|
image_norm = F.normalize(img_embed, dim=1) |
|
text_norm = F.normalize(txt_embed, dim=1) |
|
image_logits = image_norm @ text_norm.t() * self.clip_model.logit_scale.exp() |
|
text_logits = text_norm @ image_norm.t() * self.clip_model.logit_scale.exp() |
|
|
|
|
|
ground_truth = torch.arange(len(image_logits)).long().to(image_logits.device) |
|
loss = (F.cross_entropy(image_logits, ground_truth) + F.cross_entropy(text_logits, ground_truth)).div(2) |
|
self.log('val_loss', loss, prog_bar=True) |
|
return [image_norm, text_norm, labels] |
|
|
|
def validation_epoch_end(self, outputs): |
|
image_features = torch.cat([x[0] for x in outputs]) |
|
text_features = torch.cat([x[1] for x in outputs]) |
|
labels = [label for x in outputs for label in x[2]] |
|
print(image_features.shape, text_features.shape, len(labels)) |
|
self.get_metrics(image_features, text_features, labels, 100) |
|
|
|
def test_step(self, test_batch, idx): |
|
image, text, labels = test_batch |
|
image_features = self.clip_model.get_image_features(image) |
|
text_features = self.text_encoder(text).logits |
|
image_features = image_features / image_features.norm(dim=1, keepdim=True) |
|
text_features = text_features / text_features.norm(dim=1, keepdim=True) |
|
return [image_features, text_features, labels] |
|
|
|
def test_epoch_end(self, outputs): |
|
image_features = torch.cat([x[0] for x in outputs]) |
|
text_features = torch.cat([x[1] for x in outputs]) |
|
labels = [label for x in outputs for label in x[2]] |
|
print(image_features.shape, text_features.shape, len(labels)) |
|
self.get_metrics(image_features, text_features, labels, 100) |
|
|
|
def get_metrics(self, image_features, text_features, labels, logit_scale): |
|
|
|
|
|
|
|
|
|
logits_per_image = (logit_scale * image_features @ text_features.t()).detach().cpu() |
|
logits_per_text = logits_per_image.t().detach().cpu() |
|
|
|
logits = {"image_to_text": logits_per_image, "text_to_image": logits_per_text} |
|
|
|
label2idx = {} |
|
repeat_id = [] |
|
for i, label in enumerate(labels): |
|
if label not in label2idx: |
|
label2idx[label] = [i] |
|
else: |
|
|
|
label2idx[label].append(i) |
|
repeat_id.append(i) |
|
|
|
|
|
|
|
ground_truth = [label2idx[label] for label in labels] |
|
|
|
|
|
for name, logit in logits.items(): |
|
|
|
if name == 'text_to_image': |
|
logit[:, repeat_id] -= 1e8 |
|
r1_stat, r5_stat, r10_stat = [], [], [] |
|
ranking = torch.argsort(logit, descending=True) |
|
|
|
for i, each_query in enumerate(ranking[:, :10]): |
|
for j, q in enumerate(each_query): |
|
if q in ground_truth[i]: |
|
if j == 0: |
|
r1_stat.append(1) |
|
r5_stat.append(1) |
|
r10_stat.append(1) |
|
break |
|
if j < 5: |
|
r5_stat.append(1) |
|
r10_stat.append(1) |
|
break |
|
if j < 10: |
|
r10_stat.append(1) |
|
break |
|
print(f'{name} r1:{sum(r1_stat)/len(logit)}, r5:{sum(r5_stat)/len(logit)}, r10:{sum(r10_stat)/len(logit)}') |
|
|
|
def configure_optimizers(self): |
|
lr = { |
|
"RN50": 5e-4, |
|
"RN101": 5e-4, |
|
"RN50x4": 5e-4, |
|
"RN50x16": 4e-4, |
|
"RN50x64": 3.6e-4, |
|
"ViT-B/32": 5e-4, |
|
"ViT-B/16": 5e-4, |
|
"ViT-L/14": 4e-4, |
|
"ViT-L/14-336px": 2e-5 |
|
}[self.model_name] |
|
|
|
optimizer = torch.optim.AdamW( |
|
[{'params': self.clip_model.parameters()}, {'params': self.text_encoder.parameters()}], |
|
lr=lr, |
|
betas=( |
|
0.9, |
|
0.98 if self.isViT else 0.999 |
|
), |
|
eps=1e-6 if self.isViT else 1e-8, |
|
weight_decay=0.2 |
|
) |
|
|
|
|
|
|
|
lr_scheduler = CosineAnnealingWarmRestarts( |
|
optimizer, |
|
T_0=2000 |
|
) |
|
|
|
return {'optimizer': optimizer, 'lr_scheduler': lr_scheduler} |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
|
|
|
|
parser.add_argument('--model', type=str, |
|
default="ViT-B/32", |
|
help='model definition') |
|
|
|
|
|
parser.add_argument('--batch_size', type=int, default=128) |
|
parser.add_argument('--num_epoches', type=int, default=1) |
|
parser.add_argument('--num_gpus', type=int, default=2) |
|
|
|
|
|
parser.add_argument('--train_filename', type=str, |
|
help='dir or csv file') |
|
parser.add_argument('--train_root', type=str, |
|
help='image root path') |
|
parser.add_argument('--val_filename', type=str, |
|
help='dir or csv file') |
|
parser.add_argument('--val_root', type=str, |
|
help='image root path') |
|
parser.add_argument('--test_filename', type=str, |
|
help='dir or csv file') |
|
parser.add_argument('--test_root', type=str, |
|
help='image root path') |
|
parser.add_argument('--num_workers', type=int, default=0) |
|
|
|
|
|
parser.add_argument('--pretrain_model', type=str, |
|
default="openai/clip-vit-base-patch32", |
|
help='defalut load from openai') |
|
|
|
args = parser.parse_args() |
|
dm = FlickrDataModule(args) |
|
|
|
model = CLIPLightning(model_name=args.model, minibatch_size=args.batch_size//2) |
|
trainer = pl.Trainer(gpus=args.num_gpus, precision=16, max_epochs=args.num_epoches) |
|
trainer.test(model, dm) |
|
trainer.fit(model, dm) |
|
trainer.test(model, dm) |
|
|
|
|