File size: 3,044 Bytes
eaf2e33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
from math import sqrt
from abc import abstractmethod
from itertools import combinations
from src.gan.gankits import nz


class ExclusionReg:
    # NOTE: To be maximised
    def __init__(self, lbd):
        self.lbd = lbd

    @abstractmethod
    def forward(self, muss, stdss, betas):
        pass


class WassersteinExclusion(ExclusionReg):
    def forward(self, muss, stdss, betas):
        b, m, d = muss.shape
        rho = torch.zeros([b], device=muss.device)
        for i, j in combinations(range(m), 2):
            x = torch.square((muss[:, i, :] - muss[:, j, :])).sum(dim=-1)
            y = torch.sum((stdss[:, i, :] + stdss[:, j, :]), dim=-1)
            z = torch.sqrt((stdss[:, i, :] * stdss[:, j, :]).sum(dim=-1))
            w = (x + y - 2 * z).sqrt()
            rho += betas[:, i] * betas[:, j] * w
        return self.lbd * rho


class LogWassersteinExclusion(ExclusionReg):
    def forward(self, muss, stdss, betas):
        b, m, d = muss.shape
        rho = torch.zeros([b], device=muss.device)

        for i, j in combinations(range(m), 2):
            x = torch.square((muss[:, i, :] - muss[:, j, :])).sum(dim=-1)
            y = torch.sum((stdss[:, i, :] + stdss[:, j, :]), dim=-1)
            z = torch.sqrt((stdss[:, i, :] * stdss[:, j, :]).sum(dim=-1))
            w = (x + y - 2 * z).sqrt()
            rho += betas[:, i] * betas[:, j] * torch.log(w + 1)
        return self.lbd * rho


class ClipExclusion(ExclusionReg):
    def __init__(self, lbd, wbar=0.6 * sqrt(nz)):
        super(ClipExclusion, self).__init__(lbd)
        self.wbar = wbar

    def forward(self, muss, stdss, betas):
        b, m, d = muss.shape
        rho = torch.zeros([b], device=muss.device)
        for i, j in combinations(range(m), 2):
            x = torch.square((muss[:, i, :] - muss[:, j, :])).sum(dim=-1)
            y = torch.sum((stdss[:, i, :] + stdss[:, j, :]), dim=-1)
            z = torch.sqrt((stdss[:, i, :] * stdss[:, j, :]).sum(dim=-1))
            w = (x + y - 2 * z).sqrt()
            rho += betas[:, i] * betas[:, j] * torch.clip(w, max=self.wbar)
        return self.lbd * rho


class LogClipExclusion(ExclusionReg):
    def __init__(self, lbd, wbar=0.6 * sqrt(nz)):
        super(LogClipExclusion, self).__init__(lbd)
        self.wbar = wbar

    def forward(self, muss, stdss, betas):
        b, m, d = muss.shape
        rho = torch.zeros([b], device=muss.device)

        for i, j in combinations(range(m), 2):
            x = torch.square((muss[:, i, :] - muss[:, j, :])).sum(dim=-1)
            y = torch.sum((stdss[:, i, :] + stdss[:, j, :]), dim=-1)
            z = torch.sqrt((stdss[:, i, :] * stdss[:, j, :]).sum(dim=-1))
            w = (x + y - 2 * z).sqrt()
            rho += betas[:, i] * betas[:, j] * torch.log(torch.clip(w, max=self.wbar) + 1)
        return self.lbd * rho


# class SurrogateDistReg:
#     def __init__(self, lbd, clip=30.):
#         self.lbd = lbd
#         self.clip = clip
#
#     def forward(self, muss, stdss, betas):


if __name__ == '__main__':
    pass