|
|
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from torch_utils import misc |
|
from torch_utils import persistence |
|
|
|
|
|
|
|
|
|
@persistence.persistent_class |
|
class CLHead(torch.nn.Module): |
|
def __init__(self, |
|
inplanes = 256, |
|
temperature = 0.2, |
|
queue_size = 3500, |
|
momentum = 0.999, |
|
): |
|
super().__init__() |
|
self.inplanes = inplanes |
|
self.temperature = temperature |
|
self.queue_size = queue_size |
|
self.m = momentum |
|
|
|
self.mlp = nn.Sequential(nn.Linear(inplanes, inplanes), nn.ReLU(), nn.Linear(inplanes, 128)) |
|
self.momentum_mlp = nn.Sequential(nn.Linear(inplanes, inplanes), nn.ReLU(), nn.Linear(inplanes, 128)) |
|
self.momentum_mlp.requires_grad_(False) |
|
|
|
for param_q, param_k in zip(self.mlp.parameters(), self.momentum_mlp.parameters()): |
|
param_k.data.copy_(param_q.data) |
|
param_k.requires_grad = False |
|
|
|
|
|
self.register_buffer("queue", torch.randn(128, self.queue_size)) |
|
self.queue = nn.functional.normalize(self.queue, dim=0) |
|
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) |
|
|
|
|
|
@torch.no_grad() |
|
def _momentum_update_key_encoder(self): |
|
""" |
|
Momentum update of the key encoder |
|
""" |
|
for param_q, param_k in zip(self.mlp.parameters(), self.momentum_mlp.parameters()): |
|
param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) |
|
|
|
@torch.no_grad() |
|
def _dequeue_and_enqueue(self, keys): |
|
|
|
keys = concat_all_gather(keys) |
|
|
|
batch_size = keys.shape[0] |
|
keys = keys.T |
|
ptr = int(self.queue_ptr) |
|
if batch_size > self.queue_size: |
|
self.queue[:, 0:] = keys[:, :self.queue_size] |
|
|
|
elif ptr + batch_size > self.queue_size: |
|
self.queue[:, ptr:] = keys[:, :self.queue_size - ptr] |
|
self.queue[:, :batch_size - (self.queue_size - ptr)] = keys[:, self.queue_size-ptr:] |
|
self.queue_ptr[0] = batch_size - (self.queue_size - ptr) |
|
else: |
|
self.queue[:, ptr:ptr + batch_size] = keys |
|
self.queue_ptr[0] = ptr + batch_size |
|
|
|
@torch.no_grad() |
|
def _batch_shuffle_ddp(self, x): |
|
""" |
|
Batch shuffle, for making use of BatchNorm. |
|
""" |
|
|
|
|
|
|
|
|
|
if not torch.distributed.is_initialized(): |
|
return x, torch.arange(x.shape[0]) |
|
|
|
|
|
device = x.device |
|
batch_size_this = x.shape[0] |
|
x_gather = concat_all_gather(x) |
|
batch_size_all = x_gather.shape[0] |
|
|
|
num_gpus = batch_size_all // batch_size_this |
|
|
|
|
|
idx_shuffle = torch.randperm(batch_size_all).cuda(device) |
|
|
|
|
|
torch.distributed.broadcast(idx_shuffle, src=0) |
|
|
|
|
|
idx_unshuffle = torch.argsort(idx_shuffle) |
|
|
|
|
|
gpu_idx = torch.distributed.get_rank() |
|
idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx] |
|
|
|
return x_gather[idx_this], idx_unshuffle |
|
|
|
@torch.no_grad() |
|
def _batch_unshuffle_ddp(self, x, idx_unshuffle): |
|
""" |
|
Undo batch shuffle. |
|
""" |
|
|
|
|
|
|
|
if not torch.distributed.is_initialized(): |
|
return x |
|
|
|
|
|
batch_size_this = x.shape[0] |
|
x_gather = concat_all_gather(x) |
|
batch_size_all = x_gather.shape[0] |
|
|
|
num_gpus = batch_size_all // batch_size_this |
|
|
|
|
|
gpu_idx = torch.distributed.get_rank() |
|
idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx] |
|
|
|
return x_gather[idx_this] |
|
|
|
|
|
def forward(self, im_q, im_k, loss_only=False, update_q=False): |
|
""" |
|
Input: |
|
im_q: a batch of query images |
|
im_k: a batch of key images |
|
Output: |
|
logits, targets |
|
""" |
|
device = im_q.device |
|
im_q = im_q.to(torch.float32) |
|
im_k = im_k.to(torch.float32) |
|
|
|
if im_q.ndim > 2: |
|
im_q = im_q.mean([2,3]) |
|
q = self.mlp(im_q) |
|
q = nn.functional.normalize(q, dim=1) |
|
|
|
|
|
with torch.no_grad(): |
|
self._momentum_update_key_encoder() |
|
if im_k.ndim > 2: |
|
im_k = im_k.mean([2,3]) |
|
|
|
im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k) |
|
k = self.momentum_mlp(im_k) |
|
k = nn.functional.normalize(k, dim=1) |
|
|
|
|
|
k = self._batch_unshuffle_ddp(k, idx_unshuffle) |
|
|
|
|
|
|
|
|
|
l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) |
|
|
|
l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) |
|
|
|
|
|
logits = torch.cat([l_pos, l_neg], dim=1) |
|
|
|
|
|
logits /= self.temperature |
|
|
|
|
|
labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda(device) |
|
|
|
|
|
if not loss_only: |
|
if update_q: |
|
with torch.no_grad(): |
|
temp_im_q, idx_unshuffle = self._batch_shuffle_ddp(im_q) |
|
temp_q = self.momentum_mlp(temp_im_q) |
|
temp_q = nn.functional.normalize(temp_q, dim=1) |
|
temp_q = self._batch_unshuffle_ddp(temp_q, idx_unshuffle) |
|
self._dequeue_and_enqueue(temp_q) |
|
else: |
|
self._dequeue_and_enqueue(k) |
|
|
|
|
|
loss = nn.functional.cross_entropy(logits, labels) |
|
|
|
return loss |
|
|
|
@torch.no_grad() |
|
def concat_all_gather(tensor): |
|
""" |
|
Performs all_gather operation on the provided tensors. |
|
*** Warning ***: torch.distributed.all_gather has no gradient. |
|
""" |
|
tensors_gather = [torch.ones_like(tensor) |
|
for _ in range(torch.distributed.get_world_size())] |
|
torch.distributed.all_gather(tensors_gather, tensor, async_op=False) |
|
|
|
output = torch.cat(tensors_gather, dim=0) |
|
return output |
|
|
|
|
|
|