Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import logging | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| # import torch.distributed as dist | |
| logger = logging.getLogger("dinov2") | |
| class KoLeoLoss(nn.Module): | |
| """Kozachenko-Leonenko entropic loss regularizer from Sablayrolles et al. - 2018 - Spreading vectors for similarity search""" | |
| def __init__(self): | |
| super().__init__() | |
| self.pdist = nn.PairwiseDistance(2, eps=1e-8) | |
| def pairwise_NNs_inner(self, x): | |
| """ | |
| Pairwise nearest neighbors for L2-normalized vectors. | |
| Uses Torch rather than Faiss to remain on GPU. | |
| """ | |
| # parwise dot products (= inverse distance) | |
| dots = torch.mm(x, x.t()) | |
| n = x.shape[0] | |
| dots.view(-1)[:: (n + 1)].fill_(-1) # Trick to fill diagonal with -1 | |
| # max inner prod -> min distance | |
| _, I = torch.max(dots, dim=1) # noqa: E741 | |
| return I | |
| def forward(self, student_output, eps=1e-8): | |
| """ | |
| Args: | |
| student_output (BxD): backbone output of student | |
| """ | |
| with torch.cuda.amp.autocast(enabled=False): | |
| student_output = F.normalize(student_output, eps=eps, p=2, dim=-1) | |
| I = self.pairwise_NNs_inner(student_output) # noqa: E741 | |
| distances = self.pdist(student_output, student_output[I]) # BxD, BxD -> B | |
| loss = -torch.log(distances + eps).mean() | |
| return loss | |