File size: 7,231 Bytes
2e04998 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 |
# Code are mainly borrowed from the official implementation of MoCo (https://github.com/facebookresearch/moco)
import numpy as np
import torch
import torch.nn as nn
from torch_utils import misc
from torch_utils import persistence
#----------------------------------------------------------------------------
# Contrastive head
@persistence.persistent_class
class CLHead(torch.nn.Module):
def __init__(self,
inplanes = 256, # Number of input features
temperature = 0.2, # Temperature of logits
queue_size = 3500, # Number of stored negative samples
momentum = 0.999, # Momentum for updating network
):
super().__init__()
self.inplanes = inplanes
self.temperature = temperature
self.queue_size = queue_size
self.m = momentum
self.mlp = nn.Sequential(nn.Linear(inplanes, inplanes), nn.ReLU(), nn.Linear(inplanes, 128))
self.momentum_mlp = nn.Sequential(nn.Linear(inplanes, inplanes), nn.ReLU(), nn.Linear(inplanes, 128))
self.momentum_mlp.requires_grad_(False)
for param_q, param_k in zip(self.mlp.parameters(), self.momentum_mlp.parameters()):
param_k.data.copy_(param_q.data) # initialize
param_k.requires_grad = False # not update by gradient
# create the queue
self.register_buffer("queue", torch.randn(128, self.queue_size))
self.queue = nn.functional.normalize(self.queue, dim=0)
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
@torch.no_grad()
def _momentum_update_key_encoder(self):
"""
Momentum update of the key encoder
"""
for param_q, param_k in zip(self.mlp.parameters(), self.momentum_mlp.parameters()):
param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
@torch.no_grad()
def _dequeue_and_enqueue(self, keys):
# gather keys before updating queue
keys = concat_all_gather(keys)
batch_size = keys.shape[0]
keys = keys.T
ptr = int(self.queue_ptr)
if batch_size > self.queue_size:
self.queue[:, 0:] = keys[:, :self.queue_size]
elif ptr + batch_size > self.queue_size:
self.queue[:, ptr:] = keys[:, :self.queue_size - ptr]
self.queue[:, :batch_size - (self.queue_size - ptr)] = keys[:, self.queue_size-ptr:]
self.queue_ptr[0] = batch_size - (self.queue_size - ptr)
else:
self.queue[:, ptr:ptr + batch_size] = keys
self.queue_ptr[0] = ptr + batch_size
@torch.no_grad()
def _batch_shuffle_ddp(self, x):
"""
Batch shuffle, for making use of BatchNorm.
"""
# If non-distributed now, return raw input directly.
# We have no idea the effect of disabling shuffle BN to MoCo.
# Thus, we recommand train InsGen with more than 1 GPU always.
if not torch.distributed.is_initialized():
return x, torch.arange(x.shape[0])
# gather from all gpus
device = x.device
batch_size_this = x.shape[0]
x_gather = concat_all_gather(x)
batch_size_all = x_gather.shape[0]
num_gpus = batch_size_all // batch_size_this
# random shuffle index
idx_shuffle = torch.randperm(batch_size_all).cuda(device)
# broadcast to all gpus
torch.distributed.broadcast(idx_shuffle, src=0)
# index for restoring
idx_unshuffle = torch.argsort(idx_shuffle)
# shuffled index for this gpu
gpu_idx = torch.distributed.get_rank()
idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]
return x_gather[idx_this], idx_unshuffle
@torch.no_grad()
def _batch_unshuffle_ddp(self, x, idx_unshuffle):
"""
Undo batch shuffle.
"""
# If non-distributed now, return raw input directly.
# We have no idea the effect of disabling shuffle BN to MoCo.
# Thus, we recommand train InsGen with more than 1 GPU always.
if not torch.distributed.is_initialized():
return x
# gather from all gpus
batch_size_this = x.shape[0]
x_gather = concat_all_gather(x)
batch_size_all = x_gather.shape[0]
num_gpus = batch_size_all // batch_size_this
# restored index for this gpu
gpu_idx = torch.distributed.get_rank()
idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]
return x_gather[idx_this]
def forward(self, im_q, im_k, loss_only=False, update_q=False):
"""
Input:
im_q: a batch of query images
im_k: a batch of key images
Output:
logits, targets
"""
device = im_q.device
im_q = im_q.to(torch.float32)
im_k = im_k.to(torch.float32)
# compute query features
if im_q.ndim > 2:
im_q = im_q.mean([2,3])
q = self.mlp(im_q) # queries: NxC
q = nn.functional.normalize(q, dim=1)
# compute key features
with torch.no_grad(): # no gradient to keys
self._momentum_update_key_encoder() # update the key encoder
if im_k.ndim > 2:
im_k = im_k.mean([2,3])
# shuffle for making use of BN
im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k)
k = self.momentum_mlp(im_k) # keys: NxC
k = nn.functional.normalize(k, dim=1)
# undo shuffle
k = self._batch_unshuffle_ddp(k, idx_unshuffle)
# compute logits
# Einstein sum is more intuitive
# positive logits: Nx1
l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
# negative logits: NxK
l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
# logits: Nx(1+K)
logits = torch.cat([l_pos, l_neg], dim=1)
# apply temperature
logits /= self.temperature
# labels: positive key indicators
labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda(device)
# dequeue and enqueue
if not loss_only:
if update_q:
with torch.no_grad():
temp_im_q, idx_unshuffle = self._batch_shuffle_ddp(im_q)
temp_q = self.momentum_mlp(temp_im_q)
temp_q = nn.functional.normalize(temp_q, dim=1)
temp_q = self._batch_unshuffle_ddp(temp_q, idx_unshuffle)
self._dequeue_and_enqueue(temp_q)
else:
self._dequeue_and_enqueue(k)
# calculate loss
loss = nn.functional.cross_entropy(logits, labels)
return loss
@torch.no_grad()
def concat_all_gather(tensor):
"""
Performs all_gather operation on the provided tensors.
*** Warning ***: torch.distributed.all_gather has no gradient.
"""
tensors_gather = [torch.ones_like(tensor)
for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
output = torch.cat(tensors_gather, dim=0)
return output
#----------------------------------------------------------------------------
|