File size: 11,822 Bytes
8ebda9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
import sys
sys.path.append('../../')
from data.clip_dataloader.flickr import FlickrDataModule
import pytorch_lightning as pl
import numpy as np
import torch
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
import torch.nn.functional as F
import math
import copy
import argparse
from transformers import CLIPModel, BertForSequenceClassification

class CLIPLightning(pl.LightningModule):
    def __init__(self, model_name='ViT-B/32', minibatch_size=2):
        """A lightning wrapper for a CLIP model as specified in the paper.

        Args:
            model_name (str): A case sensitive visual model name.
            config (dict): A dictionary containing the CLIP instantiation parameters.
        """
        super().__init__()

        self.prepare_data_per_node = True
        self.model_name = 'ViT-B/32'
        # self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")  # NOTE load from openAI
        self.text_encoder = BertForSequenceClassification.from_pretrained("IDEA-CCNL/Taiyi-CLIP-Roberta-102M-Chinese")
        self.minibatch_size = minibatch_size
        self.isViT = 'ViT' in self.model_name
        self.automatic_optimization = False

    # Training loss: https://github.com/openai/CLIP/issues/83
    # Mini-batching thanks to https://github.com/crowsonkb / https://twitter.com/RiversHaveWings
    # Multi-GPU support: https://github.com/MicPie/clasp

    def training_step(self, train_batch, idx):
        # get optimizers and scheduler
        optimizer = self.optimizers()

        image, text, labels = train_batch
        n = math.ceil(len(image) // self.minibatch_size)
        image_mbs = torch.chunk(image, n)
        text_mbs = torch.chunk(text, n)

        with torch.no_grad():
            ims = [F.normalize(self.clip_model.get_image_features(im), dim=1) for im in image_mbs]
            txt = [F.normalize(self.text_encoder(t).logits, dim=1) for t in text_mbs]
            # gather from all GPUs 这里的LOSS要把所有GPU的汇集起来一起算才对
            ims = self.all_gather(torch.cat(ims))
            txt = self.all_gather(torch.cat(txt))

            if len(ims.shape) == 3:
                ims = list(ims)
                txt = list(txt)
            else:
                ims = [ims]
                txt = [txt]

            image_logits = torch.cat(ims) @ torch.cat(txt).t() * self.clip_model.logit_scale.exp()
            ground_truth = torch.arange(len(image_logits)).long().to(image_logits.device)
            loss = (F.cross_entropy(image_logits, ground_truth) +
                    F.cross_entropy(image_logits.t(), ground_truth)).div(2)
            acc_i = (torch.argmax(image_logits, 1) == ground_truth).sum()
            acc_t = (torch.argmax(image_logits, 0) == ground_truth).sum()
            self.log_dict({'loss': loss / len(ims), 'acc': (acc_i + acc_t) / 2 / len(image) / len(ims)}, prog_bar=True)

        if isinstance(optimizer, list):
            optimizer = optimizer[0]
        optimizer.zero_grad()

        # image loss
        for j, mb in enumerate(image_mbs[:-1]):
            # 最后一部分样本舍弃。(对齐的bug)
            images_tmp = copy.deepcopy(ims)
            images_tmp[self.global_rank][j * self.minibatch_size:(j+1)*self.minibatch_size] = \
                F.normalize(self.clip_model.get_image_features(mb), dim=1)
            image_logits = torch.cat(images_tmp) @ torch.cat(txt).t() * self.clip_model.logit_scale.exp()
            ground_truth = torch.arange(len(image_logits)).long().to(image_logits.device)
            loss = (F.cross_entropy(image_logits, ground_truth) + F.cross_entropy(image_logits.t(), ground_truth))/2
            self.manual_backward(loss)

        # text loss
        for j, mb in enumerate(text_mbs[:-1]):
            text_tmp = copy.deepcopy(txt)
            text_tmp[self.global_rank][j * self.minibatch_size:(j+1)*self.minibatch_size] = \
                F.normalize(self.text_encoder(mb).logits, dim=1)
            image_logits = torch.cat(ims) @ torch.cat(text_tmp).t() * self.clip_model.logit_scale.exp()
            loss = (F.cross_entropy(image_logits, ground_truth) + F.cross_entropy(image_logits.t(), ground_truth))/2
            self.manual_backward(loss)

        optimizer.step()
        lr_scheduler = self.lr_schedulers()
        lr_scheduler.step()
        self.clip_model.logit_scale.data.clamp_(-np.log(100), np.log(100))

    def validation_step(self, val_batch, idx):
        image, text, labels = val_batch
        img_embed = self.clip_model.get_image_features(image)
        txt_embed = self.text_encoder(text).logits
        # print(img_embed.shape)
        image_norm = F.normalize(img_embed, dim=1)
        text_norm = F.normalize(txt_embed, dim=1)
        image_logits = image_norm @ text_norm.t() * self.clip_model.logit_scale.exp()
        text_logits = text_norm @ image_norm.t() * self.clip_model.logit_scale.exp()
        # print(image_logits.shape)
        # image_logits, text_logits = self.forward(image, text)
        ground_truth = torch.arange(len(image_logits)).long().to(image_logits.device)
        loss = (F.cross_entropy(image_logits, ground_truth) + F.cross_entropy(text_logits, ground_truth)).div(2)
        self.log('val_loss', loss, prog_bar=True)
        return [image_norm, text_norm, labels]

    def validation_epoch_end(self, outputs):
        image_features = torch.cat([x[0] for x in outputs])
        text_features = torch.cat([x[1] for x in outputs])
        labels = [label for x in outputs for label in x[2]]
        print(image_features.shape, text_features.shape, len(labels))
        self.get_metrics(image_features, text_features, labels, 100)

    def test_step(self, test_batch, idx):
        image, text, labels = test_batch
        image_features = self.clip_model.get_image_features(image)
        text_features = self.text_encoder(text).logits
        image_features = image_features / image_features.norm(dim=1, keepdim=True)
        text_features = text_features / text_features.norm(dim=1, keepdim=True)
        return [image_features, text_features, labels]

    def test_epoch_end(self, outputs):
        image_features = torch.cat([x[0] for x in outputs])
        text_features = torch.cat([x[1] for x in outputs])
        labels = [label for x in outputs for label in x[2]]
        print(image_features.shape, text_features.shape, len(labels))
        self.get_metrics(image_features, text_features, labels, 100)

    def get_metrics(self, image_features, text_features, labels, logit_scale):
        # 计算相似度,支持多个样本的情况(比如一个图片有多个caption)
        # img2txt计算的时候要用到,因为一张图片可能对应多个文本。
        # txt2img计算的时候不需要(一般一个text只有一个对应图片)
        # metrics = {}
        logits_per_image = (logit_scale * image_features @ text_features.t()).detach().cpu()
        logits_per_text = logits_per_image.t().detach().cpu()

        logits = {"image_to_text": logits_per_image, "text_to_image": logits_per_text}

        label2idx = {}  # 计算label到idx的映射。
        repeat_id = []
        for i, label in enumerate(labels):
            if label not in label2idx:
                label2idx[label] = [i]
            else:
                # 表示该index的标签出现过,记录这个index,后续算txt2img分数的时候,这些index的权值要降低。
                label2idx[label].append(i)
                repeat_id.append(i)
        # print(label2idx)    # 标注了每个label的idx

        # print('repeat_id:', repeat_id)
        ground_truth = [label2idx[label] for label in labels]
        # print(ground_truth)

        for name, logit in logits.items():
            # print(name, logit.shape)
            if name == 'text_to_image':
                logit[:, repeat_id] -= 1e8   # 这部分的分数要降低。(重复出现的图片,直接忽略)
            r1_stat, r5_stat, r10_stat = [], [], []
            ranking = torch.argsort(logit, descending=True)  # index of the largest element to the smallest
            # print(name, ranking[:, :10])
            for i, each_query in enumerate(ranking[:, :10]):
                for j, q in enumerate(each_query):
                    if q in ground_truth[i]:
                        if j == 0:
                            r1_stat.append(1)
                            r5_stat.append(1)
                            r10_stat.append(1)
                            break
                        if j < 5:
                            r5_stat.append(1)
                            r10_stat.append(1)
                            break
                        if j < 10:
                            r10_stat.append(1)
                            break
            print(f'{name} r1:{sum(r1_stat)/len(logit)}, r5:{sum(r5_stat)/len(logit)}, r10:{sum(r10_stat)/len(logit)}')

    def configure_optimizers(self):
        lr = {
            "RN50": 5e-4,
            "RN101": 5e-4,
            "RN50x4": 5e-4,
            "RN50x16": 4e-4,
            "RN50x64": 3.6e-4,
            "ViT-B/32": 5e-4,
            "ViT-B/16": 5e-4,
            "ViT-L/14": 4e-4,
            "ViT-L/14-336px": 2e-5
        }[self.model_name]

        optimizer = torch.optim.AdamW(
            [{'params': self.clip_model.parameters()}, {'params': self.text_encoder.parameters()}],
            lr=lr,
            betas=(
                0.9,
                0.98 if self.isViT else 0.999
            ),
            eps=1e-6 if self.isViT else 1e-8,
            weight_decay=0.2
        )

        # Source: https://github.com/openai/CLIP/issues/107
        # Use pip install 'git+https://github.com/katsura-jp/pytorch-cosine-annealing-with-warmup'
        lr_scheduler = CosineAnnealingWarmRestarts(
            optimizer,
            T_0=2000
        )
        # CosineAnnealingWarmupRestarts
        return {'optimizer': optimizer, 'lr_scheduler': lr_scheduler}


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    # model_name
    parser.add_argument('--model', type=str,
                        default="ViT-B/32",
                        help='model definition')

    # experiment setting
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--num_epoches', type=int, default=1)
    parser.add_argument('--num_gpus', type=int, default=2)

    # dataset
    parser.add_argument('--train_filename', type=str,
                        help='dir or csv file')
    parser.add_argument('--train_root', type=str,
                        help='image root path')
    parser.add_argument('--val_filename', type=str,
                        help='dir or csv file')
    parser.add_argument('--val_root', type=str,
                        help='image root path')
    parser.add_argument('--test_filename', type=str,
                        help='dir or csv file')
    parser.add_argument('--test_root', type=str,
                        help='image root path')
    parser.add_argument('--num_workers', type=int, default=0)

    # huggingface pretrain model 定义
    parser.add_argument('--pretrain_model', type=str,
                        default="openai/clip-vit-base-patch32",
                        help='defalut load from openai')    # "wf-genius/TaiYi-CLIP-ViT-B-32" 是我训好的 NOTE

    args = parser.parse_args()
    dm = FlickrDataModule(args)

    model = CLIPLightning(model_name=args.model, minibatch_size=args.batch_size//2)
    trainer = pl.Trainer(gpus=args.num_gpus, precision=16, max_epochs=args.num_epoches)
    trainer.test(model, dm)  # zero-shot test
    trainer.fit(model, dm)  # finetune on train set
    trainer.test(model, dm)  # test again