Liangrj5
init
ebf5d87
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class GatedEmbeddingUnit(nn.Module):
def __init__(self, input_dimension, output_dimension):
super(GatedEmbeddingUnit, self).__init__()
self.fc = nn.Linear(input_dimension, output_dimension)
self.cg = ContextGating(output_dimension)
def forward(self, x):
x = self.fc(x)
x = self.cg(x)
x = F.normalize(x)
return x
class ContextGating(nn.Module):
def __init__(self, dimension, add_batch_norm=True):
super(ContextGating, self).__init__()
self.fc = nn.Linear(dimension, dimension)
self.add_batch_norm = add_batch_norm
self.batch_norm = nn.BatchNorm1d(dimension)
def forward(self, x):
x1 = self.fc(x)
if self.add_batch_norm:
x1 = self.batch_norm(x1)
x = torch.cat((x, x1), 1)
return F.glu(x, 1)
class MaxMarginRankingLoss(nn.Module):
def __init__(self, margin=1):
super(MaxMarginRankingLoss, self).__init__()
self.margin = margin
def forward(self, x):
n = x.size()[0]
x1 = torch.diag(x)
x1 = x1.unsqueeze(1)
x1 = x1.expand(n, n)
x1 = x1.contiguous().view(-1, 1)
x1 = torch.cat((x1, x1), 0)
x2 = x.view(-1, 1)
x3 = x.transpose(0, 1).contiguous().view(-1, 1)
x2 = torch.cat((x2, x3), 0)
max_margin = F.relu(self.margin - (x1 - x2))
return max_margin.mean()
class NetVLAD(nn.Module):
def __init__(self, cluster_size, feature_size, add_batch_norm=True):
super(NetVLAD, self).__init__()
self.feature_size = feature_size
self.cluster_size = cluster_size
self.clusters = nn.Parameter((1 / math.sqrt(feature_size))
* torch.randn(feature_size, cluster_size))
self.clusters2 = nn.Parameter((1 / math.sqrt(feature_size))
* torch.randn(1, feature_size, cluster_size))
self.add_batch_norm = add_batch_norm
self.batch_norm = nn.BatchNorm1d(cluster_size)
self.out_dim = cluster_size * feature_size
def forward(self, x):
max_sample = x.size()[1]
x = x.view(-1, self.feature_size)
assignment = torch.matmul(x, self.clusters)
if self.add_batch_norm:
assignment = self.batch_norm(assignment)
assignment = F.softmax(assignment, dim=1)
assignment = assignment.view(-1, max_sample, self.cluster_size)
a_sum = torch.sum(assignment, -2, keepdim=True)
a = a_sum * self.clusters2
assignment = assignment.transpose(1, 2)
x = x.view(-1, max_sample, self.feature_size)
vlad = torch.matmul(assignment, x)
vlad = vlad.transpose(1, 2)
vlad = vlad - a
# L2 intra norm
vlad = F.normalize(vlad)
# flattening + L2 norm
vlad = vlad.view(-1, self.cluster_size * self.feature_size)
vlad = F.normalize(vlad)
return vlad