File size: 7,231 Bytes
2e04998
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
# Code are mainly borrowed from the official implementation of MoCo (https://github.com/facebookresearch/moco)

import numpy as np
import torch
import torch.nn as nn
from torch_utils import misc
from torch_utils import persistence

#----------------------------------------------------------------------------

# Contrastive head
@persistence.persistent_class
class CLHead(torch.nn.Module):
    def __init__(self,
        inplanes     = 256,   # Number of input features
        temperature  = 0.2,   # Temperature of logits
        queue_size   = 3500,  # Number of stored negative samples
        momentum     = 0.999, # Momentum for updating network
    ):
        super().__init__()
        self.inplanes = inplanes
        self.temperature = temperature
        self.queue_size = queue_size
        self.m = momentum

        self.mlp = nn.Sequential(nn.Linear(inplanes, inplanes), nn.ReLU(), nn.Linear(inplanes, 128))
        self.momentum_mlp = nn.Sequential(nn.Linear(inplanes, inplanes), nn.ReLU(), nn.Linear(inplanes, 128))
        self.momentum_mlp.requires_grad_(False)

        for param_q, param_k in zip(self.mlp.parameters(), self.momentum_mlp.parameters()):
            param_k.data.copy_(param_q.data)  # initialize
            param_k.requires_grad = False  # not update by gradient

        # create the queue
        self.register_buffer("queue", torch.randn(128, self.queue_size))
        self.queue = nn.functional.normalize(self.queue, dim=0)
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))


    @torch.no_grad()
    def _momentum_update_key_encoder(self):
        """
        Momentum update of the key encoder
        """
        for param_q, param_k in zip(self.mlp.parameters(), self.momentum_mlp.parameters()):
            param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)

    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        # gather keys before updating queue
        keys = concat_all_gather(keys)

        batch_size = keys.shape[0]
        keys = keys.T
        ptr = int(self.queue_ptr) 
        if batch_size > self.queue_size:
            self.queue[:, 0:] = keys[:, :self.queue_size]

        elif ptr + batch_size > self.queue_size:
            self.queue[:, ptr:] = keys[:, :self.queue_size - ptr]
            self.queue[:, :batch_size - (self.queue_size - ptr)] = keys[:, self.queue_size-ptr:]
            self.queue_ptr[0] = batch_size - (self.queue_size - ptr)
        else:
            self.queue[:, ptr:ptr + batch_size] = keys
            self.queue_ptr[0] = ptr + batch_size

    @torch.no_grad()
    def _batch_shuffle_ddp(self, x):
        """
        Batch shuffle, for making use of BatchNorm.
        """

        # If non-distributed now, return raw input directly.
        # We have no idea the effect of disabling shuffle BN to MoCo. 
        # Thus, we recommand train InsGen with more than 1 GPU always.
        if not torch.distributed.is_initialized():
            return x, torch.arange(x.shape[0])

        # gather from all gpus
        device = x.device
        batch_size_this = x.shape[0]
        x_gather = concat_all_gather(x)
        batch_size_all = x_gather.shape[0]

        num_gpus = batch_size_all // batch_size_this

        # random shuffle index
        idx_shuffle = torch.randperm(batch_size_all).cuda(device)

        # broadcast to all gpus
        torch.distributed.broadcast(idx_shuffle, src=0)

        # index for restoring
        idx_unshuffle = torch.argsort(idx_shuffle)

        # shuffled index for this gpu
        gpu_idx = torch.distributed.get_rank()
        idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]

        return x_gather[idx_this], idx_unshuffle

    @torch.no_grad()
    def _batch_unshuffle_ddp(self, x, idx_unshuffle):
        """
        Undo batch shuffle.
        """
        # If non-distributed now, return raw input directly.
        # We have no idea the effect of disabling shuffle BN to MoCo. 
        # Thus, we recommand train InsGen with more than 1 GPU always.
        if not torch.distributed.is_initialized():
            return x

        # gather from all gpus
        batch_size_this = x.shape[0]
        x_gather = concat_all_gather(x)
        batch_size_all = x_gather.shape[0]

        num_gpus = batch_size_all // batch_size_this

        # restored index for this gpu
        gpu_idx = torch.distributed.get_rank()
        idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]

        return x_gather[idx_this]


    def forward(self, im_q, im_k, loss_only=False, update_q=False):
        """
        Input:
            im_q: a batch of query images
            im_k: a batch of key images
        Output:
            logits, targets
        """
        device = im_q.device
        im_q = im_q.to(torch.float32)
        im_k = im_k.to(torch.float32)
        # compute query features
        if im_q.ndim > 2:
            im_q = im_q.mean([2,3])
        q = self.mlp(im_q)  # queries: NxC
        q = nn.functional.normalize(q, dim=1)

        # compute key features
        with torch.no_grad():  # no gradient to keys
            self._momentum_update_key_encoder()  # update the key encoder
            if im_k.ndim > 2:
                im_k = im_k.mean([2,3])
            # shuffle for making use of BN
            im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k)
            k = self.momentum_mlp(im_k)  # keys: NxC
            k = nn.functional.normalize(k, dim=1)

            # undo shuffle
            k = self._batch_unshuffle_ddp(k, idx_unshuffle)

        # compute logits
        # Einstein sum is more intuitive
        # positive logits: Nx1
        l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
        # negative logits: NxK
        l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])

        # logits: Nx(1+K)
        logits = torch.cat([l_pos, l_neg], dim=1)

        # apply temperature
        logits /= self.temperature

        # labels: positive key indicators
        labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda(device)

        # dequeue and enqueue
        if not loss_only:
            if update_q:
                with torch.no_grad():
                    temp_im_q, idx_unshuffle = self._batch_shuffle_ddp(im_q)
                    temp_q = self.momentum_mlp(temp_im_q)
                    temp_q = nn.functional.normalize(temp_q, dim=1)
                    temp_q = self._batch_unshuffle_ddp(temp_q, idx_unshuffle)
                    self._dequeue_and_enqueue(temp_q)
            else:
                self._dequeue_and_enqueue(k)

        # calculate loss
        loss = nn.functional.cross_entropy(logits, labels)

        return loss 

@torch.no_grad()
def concat_all_gather(tensor):
    """
    Performs all_gather operation on the provided tensors.
    *** Warning ***: torch.distributed.all_gather has no gradient.
    """
    tensors_gather = [torch.ones_like(tensor)
        for _ in range(torch.distributed.get_world_size())]
    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)

    output = torch.cat(tensors_gather, dim=0)
    return output

#----------------------------------------------------------------------------