Spaces:
Sleeping
Sleeping
File size: 3,044 Bytes
eaf2e33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import torch
from math import sqrt
from abc import abstractmethod
from itertools import combinations
from src.gan.gankits import nz
class ExclusionReg:
# NOTE: To be maximised
def __init__(self, lbd):
self.lbd = lbd
@abstractmethod
def forward(self, muss, stdss, betas):
pass
class WassersteinExclusion(ExclusionReg):
def forward(self, muss, stdss, betas):
b, m, d = muss.shape
rho = torch.zeros([b], device=muss.device)
for i, j in combinations(range(m), 2):
x = torch.square((muss[:, i, :] - muss[:, j, :])).sum(dim=-1)
y = torch.sum((stdss[:, i, :] + stdss[:, j, :]), dim=-1)
z = torch.sqrt((stdss[:, i, :] * stdss[:, j, :]).sum(dim=-1))
w = (x + y - 2 * z).sqrt()
rho += betas[:, i] * betas[:, j] * w
return self.lbd * rho
class LogWassersteinExclusion(ExclusionReg):
def forward(self, muss, stdss, betas):
b, m, d = muss.shape
rho = torch.zeros([b], device=muss.device)
for i, j in combinations(range(m), 2):
x = torch.square((muss[:, i, :] - muss[:, j, :])).sum(dim=-1)
y = torch.sum((stdss[:, i, :] + stdss[:, j, :]), dim=-1)
z = torch.sqrt((stdss[:, i, :] * stdss[:, j, :]).sum(dim=-1))
w = (x + y - 2 * z).sqrt()
rho += betas[:, i] * betas[:, j] * torch.log(w + 1)
return self.lbd * rho
class ClipExclusion(ExclusionReg):
def __init__(self, lbd, wbar=0.6 * sqrt(nz)):
super(ClipExclusion, self).__init__(lbd)
self.wbar = wbar
def forward(self, muss, stdss, betas):
b, m, d = muss.shape
rho = torch.zeros([b], device=muss.device)
for i, j in combinations(range(m), 2):
x = torch.square((muss[:, i, :] - muss[:, j, :])).sum(dim=-1)
y = torch.sum((stdss[:, i, :] + stdss[:, j, :]), dim=-1)
z = torch.sqrt((stdss[:, i, :] * stdss[:, j, :]).sum(dim=-1))
w = (x + y - 2 * z).sqrt()
rho += betas[:, i] * betas[:, j] * torch.clip(w, max=self.wbar)
return self.lbd * rho
class LogClipExclusion(ExclusionReg):
def __init__(self, lbd, wbar=0.6 * sqrt(nz)):
super(LogClipExclusion, self).__init__(lbd)
self.wbar = wbar
def forward(self, muss, stdss, betas):
b, m, d = muss.shape
rho = torch.zeros([b], device=muss.device)
for i, j in combinations(range(m), 2):
x = torch.square((muss[:, i, :] - muss[:, j, :])).sum(dim=-1)
y = torch.sum((stdss[:, i, :] + stdss[:, j, :]), dim=-1)
z = torch.sqrt((stdss[:, i, :] * stdss[:, j, :]).sum(dim=-1))
w = (x + y - 2 * z).sqrt()
rho += betas[:, i] * betas[:, j] * torch.log(torch.clip(w, max=self.wbar) + 1)
return self.lbd * rho
# class SurrogateDistReg:
# def __init__(self, lbd, clip=30.):
# self.lbd = lbd
# self.clip = clip
#
# def forward(self, muss, stdss, betas):
if __name__ == '__main__':
pass
|