File size: 11,822 Bytes
8ebda9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 |
import sys
sys.path.append('../../')
from data.clip_dataloader.flickr import FlickrDataModule
import pytorch_lightning as pl
import numpy as np
import torch
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
import torch.nn.functional as F
import math
import copy
import argparse
from transformers import CLIPModel, BertForSequenceClassification
class CLIPLightning(pl.LightningModule):
def __init__(self, model_name='ViT-B/32', minibatch_size=2):
"""A lightning wrapper for a CLIP model as specified in the paper.
Args:
model_name (str): A case sensitive visual model name.
config (dict): A dictionary containing the CLIP instantiation parameters.
"""
super().__init__()
self.prepare_data_per_node = True
self.model_name = 'ViT-B/32'
# self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") # NOTE load from openAI
self.text_encoder = BertForSequenceClassification.from_pretrained("IDEA-CCNL/Taiyi-CLIP-Roberta-102M-Chinese")
self.minibatch_size = minibatch_size
self.isViT = 'ViT' in self.model_name
self.automatic_optimization = False
# Training loss: https://github.com/openai/CLIP/issues/83
# Mini-batching thanks to https://github.com/crowsonkb / https://twitter.com/RiversHaveWings
# Multi-GPU support: https://github.com/MicPie/clasp
def training_step(self, train_batch, idx):
# get optimizers and scheduler
optimizer = self.optimizers()
image, text, labels = train_batch
n = math.ceil(len(image) // self.minibatch_size)
image_mbs = torch.chunk(image, n)
text_mbs = torch.chunk(text, n)
with torch.no_grad():
ims = [F.normalize(self.clip_model.get_image_features(im), dim=1) for im in image_mbs]
txt = [F.normalize(self.text_encoder(t).logits, dim=1) for t in text_mbs]
# gather from all GPUs 这里的LOSS要把所有GPU的汇集起来一起算才对
ims = self.all_gather(torch.cat(ims))
txt = self.all_gather(torch.cat(txt))
if len(ims.shape) == 3:
ims = list(ims)
txt = list(txt)
else:
ims = [ims]
txt = [txt]
image_logits = torch.cat(ims) @ torch.cat(txt).t() * self.clip_model.logit_scale.exp()
ground_truth = torch.arange(len(image_logits)).long().to(image_logits.device)
loss = (F.cross_entropy(image_logits, ground_truth) +
F.cross_entropy(image_logits.t(), ground_truth)).div(2)
acc_i = (torch.argmax(image_logits, 1) == ground_truth).sum()
acc_t = (torch.argmax(image_logits, 0) == ground_truth).sum()
self.log_dict({'loss': loss / len(ims), 'acc': (acc_i + acc_t) / 2 / len(image) / len(ims)}, prog_bar=True)
if isinstance(optimizer, list):
optimizer = optimizer[0]
optimizer.zero_grad()
# image loss
for j, mb in enumerate(image_mbs[:-1]):
# 最后一部分样本舍弃。(对齐的bug)
images_tmp = copy.deepcopy(ims)
images_tmp[self.global_rank][j * self.minibatch_size:(j+1)*self.minibatch_size] = \
F.normalize(self.clip_model.get_image_features(mb), dim=1)
image_logits = torch.cat(images_tmp) @ torch.cat(txt).t() * self.clip_model.logit_scale.exp()
ground_truth = torch.arange(len(image_logits)).long().to(image_logits.device)
loss = (F.cross_entropy(image_logits, ground_truth) + F.cross_entropy(image_logits.t(), ground_truth))/2
self.manual_backward(loss)
# text loss
for j, mb in enumerate(text_mbs[:-1]):
text_tmp = copy.deepcopy(txt)
text_tmp[self.global_rank][j * self.minibatch_size:(j+1)*self.minibatch_size] = \
F.normalize(self.text_encoder(mb).logits, dim=1)
image_logits = torch.cat(ims) @ torch.cat(text_tmp).t() * self.clip_model.logit_scale.exp()
loss = (F.cross_entropy(image_logits, ground_truth) + F.cross_entropy(image_logits.t(), ground_truth))/2
self.manual_backward(loss)
optimizer.step()
lr_scheduler = self.lr_schedulers()
lr_scheduler.step()
self.clip_model.logit_scale.data.clamp_(-np.log(100), np.log(100))
def validation_step(self, val_batch, idx):
image, text, labels = val_batch
img_embed = self.clip_model.get_image_features(image)
txt_embed = self.text_encoder(text).logits
# print(img_embed.shape)
image_norm = F.normalize(img_embed, dim=1)
text_norm = F.normalize(txt_embed, dim=1)
image_logits = image_norm @ text_norm.t() * self.clip_model.logit_scale.exp()
text_logits = text_norm @ image_norm.t() * self.clip_model.logit_scale.exp()
# print(image_logits.shape)
# image_logits, text_logits = self.forward(image, text)
ground_truth = torch.arange(len(image_logits)).long().to(image_logits.device)
loss = (F.cross_entropy(image_logits, ground_truth) + F.cross_entropy(text_logits, ground_truth)).div(2)
self.log('val_loss', loss, prog_bar=True)
return [image_norm, text_norm, labels]
def validation_epoch_end(self, outputs):
image_features = torch.cat([x[0] for x in outputs])
text_features = torch.cat([x[1] for x in outputs])
labels = [label for x in outputs for label in x[2]]
print(image_features.shape, text_features.shape, len(labels))
self.get_metrics(image_features, text_features, labels, 100)
def test_step(self, test_batch, idx):
image, text, labels = test_batch
image_features = self.clip_model.get_image_features(image)
text_features = self.text_encoder(text).logits
image_features = image_features / image_features.norm(dim=1, keepdim=True)
text_features = text_features / text_features.norm(dim=1, keepdim=True)
return [image_features, text_features, labels]
def test_epoch_end(self, outputs):
image_features = torch.cat([x[0] for x in outputs])
text_features = torch.cat([x[1] for x in outputs])
labels = [label for x in outputs for label in x[2]]
print(image_features.shape, text_features.shape, len(labels))
self.get_metrics(image_features, text_features, labels, 100)
def get_metrics(self, image_features, text_features, labels, logit_scale):
# 计算相似度,支持多个样本的情况(比如一个图片有多个caption)
# img2txt计算的时候要用到,因为一张图片可能对应多个文本。
# txt2img计算的时候不需要(一般一个text只有一个对应图片)
# metrics = {}
logits_per_image = (logit_scale * image_features @ text_features.t()).detach().cpu()
logits_per_text = logits_per_image.t().detach().cpu()
logits = {"image_to_text": logits_per_image, "text_to_image": logits_per_text}
label2idx = {} # 计算label到idx的映射。
repeat_id = []
for i, label in enumerate(labels):
if label not in label2idx:
label2idx[label] = [i]
else:
# 表示该index的标签出现过,记录这个index,后续算txt2img分数的时候,这些index的权值要降低。
label2idx[label].append(i)
repeat_id.append(i)
# print(label2idx) # 标注了每个label的idx
# print('repeat_id:', repeat_id)
ground_truth = [label2idx[label] for label in labels]
# print(ground_truth)
for name, logit in logits.items():
# print(name, logit.shape)
if name == 'text_to_image':
logit[:, repeat_id] -= 1e8 # 这部分的分数要降低。(重复出现的图片,直接忽略)
r1_stat, r5_stat, r10_stat = [], [], []
ranking = torch.argsort(logit, descending=True) # index of the largest element to the smallest
# print(name, ranking[:, :10])
for i, each_query in enumerate(ranking[:, :10]):
for j, q in enumerate(each_query):
if q in ground_truth[i]:
if j == 0:
r1_stat.append(1)
r5_stat.append(1)
r10_stat.append(1)
break
if j < 5:
r5_stat.append(1)
r10_stat.append(1)
break
if j < 10:
r10_stat.append(1)
break
print(f'{name} r1:{sum(r1_stat)/len(logit)}, r5:{sum(r5_stat)/len(logit)}, r10:{sum(r10_stat)/len(logit)}')
def configure_optimizers(self):
lr = {
"RN50": 5e-4,
"RN101": 5e-4,
"RN50x4": 5e-4,
"RN50x16": 4e-4,
"RN50x64": 3.6e-4,
"ViT-B/32": 5e-4,
"ViT-B/16": 5e-4,
"ViT-L/14": 4e-4,
"ViT-L/14-336px": 2e-5
}[self.model_name]
optimizer = torch.optim.AdamW(
[{'params': self.clip_model.parameters()}, {'params': self.text_encoder.parameters()}],
lr=lr,
betas=(
0.9,
0.98 if self.isViT else 0.999
),
eps=1e-6 if self.isViT else 1e-8,
weight_decay=0.2
)
# Source: https://github.com/openai/CLIP/issues/107
# Use pip install 'git+https://github.com/katsura-jp/pytorch-cosine-annealing-with-warmup'
lr_scheduler = CosineAnnealingWarmRestarts(
optimizer,
T_0=2000
)
# CosineAnnealingWarmupRestarts
return {'optimizer': optimizer, 'lr_scheduler': lr_scheduler}
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# model_name
parser.add_argument('--model', type=str,
default="ViT-B/32",
help='model definition')
# experiment setting
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--num_epoches', type=int, default=1)
parser.add_argument('--num_gpus', type=int, default=2)
# dataset
parser.add_argument('--train_filename', type=str,
help='dir or csv file')
parser.add_argument('--train_root', type=str,
help='image root path')
parser.add_argument('--val_filename', type=str,
help='dir or csv file')
parser.add_argument('--val_root', type=str,
help='image root path')
parser.add_argument('--test_filename', type=str,
help='dir or csv file')
parser.add_argument('--test_root', type=str,
help='image root path')
parser.add_argument('--num_workers', type=int, default=0)
# huggingface pretrain model 定义
parser.add_argument('--pretrain_model', type=str,
default="openai/clip-vit-base-patch32",
help='defalut load from openai') # "wf-genius/TaiYi-CLIP-ViT-B-32" 是我训好的 NOTE
args = parser.parse_args()
dm = FlickrDataModule(args)
model = CLIPLightning(model_name=args.model, minibatch_size=args.batch_size//2)
trainer = pl.Trainer(gpus=args.num_gpus, precision=16, max_epochs=args.num_epoches)
trainer.test(model, dm) # zero-shot test
trainer.fit(model, dm) # finetune on train set
trainer.test(model, dm) # test again
|