import math import torch import torch.nn as nn import torch.nn.functional as F class GatedEmbeddingUnit(nn.Module): def __init__(self, input_dimension, output_dimension): super(GatedEmbeddingUnit, self).__init__() self.fc = nn.Linear(input_dimension, output_dimension) self.cg = ContextGating(output_dimension) def forward(self, x): x = self.fc(x) x = self.cg(x) x = F.normalize(x) return x class ContextGating(nn.Module): def __init__(self, dimension, add_batch_norm=True): super(ContextGating, self).__init__() self.fc = nn.Linear(dimension, dimension) self.add_batch_norm = add_batch_norm self.batch_norm = nn.BatchNorm1d(dimension) def forward(self, x): x1 = self.fc(x) if self.add_batch_norm: x1 = self.batch_norm(x1) x = torch.cat((x, x1), 1) return F.glu(x, 1) class MaxMarginRankingLoss(nn.Module): def __init__(self, margin=1): super(MaxMarginRankingLoss, self).__init__() self.margin = margin def forward(self, x): n = x.size()[0] x1 = torch.diag(x) x1 = x1.unsqueeze(1) x1 = x1.expand(n, n) x1 = x1.contiguous().view(-1, 1) x1 = torch.cat((x1, x1), 0) x2 = x.view(-1, 1) x3 = x.transpose(0, 1).contiguous().view(-1, 1) x2 = torch.cat((x2, x3), 0) max_margin = F.relu(self.margin - (x1 - x2)) return max_margin.mean() class NetVLAD(nn.Module): def __init__(self, cluster_size, feature_size, add_batch_norm=True): super(NetVLAD, self).__init__() self.feature_size = feature_size self.cluster_size = cluster_size self.clusters = nn.Parameter((1 / math.sqrt(feature_size)) * torch.randn(feature_size, cluster_size)) self.clusters2 = nn.Parameter((1 / math.sqrt(feature_size)) * torch.randn(1, feature_size, cluster_size)) self.add_batch_norm = add_batch_norm self.batch_norm = nn.BatchNorm1d(cluster_size) self.out_dim = cluster_size * feature_size def forward(self, x): max_sample = x.size()[1] x = x.view(-1, self.feature_size) assignment = torch.matmul(x, self.clusters) if self.add_batch_norm: assignment = self.batch_norm(assignment) assignment = F.softmax(assignment, dim=1) assignment = assignment.view(-1, max_sample, self.cluster_size) a_sum = torch.sum(assignment, -2, keepdim=True) a = a_sum * self.clusters2 assignment = assignment.transpose(1, 2) x = x.view(-1, max_sample, self.feature_size) vlad = torch.matmul(assignment, x) vlad = vlad.transpose(1, 2) vlad = vlad - a # L2 intra norm vlad = F.normalize(vlad) # flattening + L2 norm vlad = vlad.view(-1, self.cluster_size * self.feature_size) vlad = F.normalize(vlad) return vlad