|
import math |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
class GatedEmbeddingUnit(nn.Module): |
|
def __init__(self, input_dimension, output_dimension): |
|
super(GatedEmbeddingUnit, self).__init__() |
|
|
|
self.fc = nn.Linear(input_dimension, output_dimension) |
|
self.cg = ContextGating(output_dimension) |
|
|
|
def forward(self, x): |
|
x = self.fc(x) |
|
x = self.cg(x) |
|
x = F.normalize(x) |
|
return x |
|
|
|
|
|
class ContextGating(nn.Module): |
|
def __init__(self, dimension, add_batch_norm=True): |
|
super(ContextGating, self).__init__() |
|
self.fc = nn.Linear(dimension, dimension) |
|
self.add_batch_norm = add_batch_norm |
|
self.batch_norm = nn.BatchNorm1d(dimension) |
|
|
|
def forward(self, x): |
|
x1 = self.fc(x) |
|
|
|
if self.add_batch_norm: |
|
x1 = self.batch_norm(x1) |
|
|
|
x = torch.cat((x, x1), 1) |
|
return F.glu(x, 1) |
|
|
|
|
|
class MaxMarginRankingLoss(nn.Module): |
|
def __init__(self, margin=1): |
|
super(MaxMarginRankingLoss, self).__init__() |
|
self.margin = margin |
|
|
|
def forward(self, x): |
|
n = x.size()[0] |
|
|
|
x1 = torch.diag(x) |
|
x1 = x1.unsqueeze(1) |
|
x1 = x1.expand(n, n) |
|
x1 = x1.contiguous().view(-1, 1) |
|
x1 = torch.cat((x1, x1), 0) |
|
|
|
x2 = x.view(-1, 1) |
|
x3 = x.transpose(0, 1).contiguous().view(-1, 1) |
|
|
|
x2 = torch.cat((x2, x3), 0) |
|
|
|
max_margin = F.relu(self.margin - (x1 - x2)) |
|
return max_margin.mean() |
|
|
|
|
|
class NetVLAD(nn.Module): |
|
def __init__(self, cluster_size, feature_size, add_batch_norm=True): |
|
super(NetVLAD, self).__init__() |
|
self.feature_size = feature_size |
|
self.cluster_size = cluster_size |
|
self.clusters = nn.Parameter((1 / math.sqrt(feature_size)) |
|
* torch.randn(feature_size, cluster_size)) |
|
self.clusters2 = nn.Parameter((1 / math.sqrt(feature_size)) |
|
* torch.randn(1, feature_size, cluster_size)) |
|
|
|
self.add_batch_norm = add_batch_norm |
|
self.batch_norm = nn.BatchNorm1d(cluster_size) |
|
self.out_dim = cluster_size * feature_size |
|
|
|
def forward(self, x): |
|
max_sample = x.size()[1] |
|
x = x.view(-1, self.feature_size) |
|
assignment = torch.matmul(x, self.clusters) |
|
|
|
if self.add_batch_norm: |
|
assignment = self.batch_norm(assignment) |
|
|
|
assignment = F.softmax(assignment, dim=1) |
|
assignment = assignment.view(-1, max_sample, self.cluster_size) |
|
|
|
a_sum = torch.sum(assignment, -2, keepdim=True) |
|
a = a_sum * self.clusters2 |
|
|
|
assignment = assignment.transpose(1, 2) |
|
|
|
x = x.view(-1, max_sample, self.feature_size) |
|
vlad = torch.matmul(assignment, x) |
|
vlad = vlad.transpose(1, 2) |
|
vlad = vlad - a |
|
|
|
|
|
vlad = F.normalize(vlad) |
|
|
|
|
|
vlad = vlad.view(-1, self.cluster_size * self.feature_size) |
|
vlad = F.normalize(vlad) |
|
|
|
return vlad |
|
|