|
from torch import nn
|
|
from models.bin import BiN
|
|
import torch
|
|
import constants as cst
|
|
|
|
class TABL_layer(nn.Module):
|
|
def __init__(self, d2, d1, t1, t2):
|
|
super().__init__()
|
|
self.t1 = t1
|
|
|
|
weight = torch.Tensor(d2, d1)
|
|
self.W1 = nn.Parameter(weight)
|
|
nn.init.kaiming_uniform_(self.W1, nonlinearity='relu')
|
|
|
|
weight2 = torch.Tensor(t1, t1)
|
|
self.W = nn.Parameter(weight2)
|
|
nn.init.constant_(self.W, 1/t1)
|
|
|
|
weight3 = torch.Tensor(t1, t2)
|
|
self.W2 = nn.Parameter(weight3)
|
|
nn.init.kaiming_uniform_(self.W2, nonlinearity='relu')
|
|
|
|
bias1 = torch.Tensor(d2, t2)
|
|
self.B = nn.Parameter(bias1)
|
|
nn.init.constant_(self.B, 0)
|
|
|
|
l = torch.Tensor(1,)
|
|
self.l = nn.Parameter(l)
|
|
nn.init.constant_(self.l, 0.5)
|
|
|
|
self.activation = nn.ReLU()
|
|
|
|
def forward(self, X):
|
|
|
|
|
|
if (self.l[0] < 0):
|
|
l = torch.Tensor(1,).to(cst.DEVICE)
|
|
self.l = nn.Parameter(l)
|
|
nn.init.constant_(self.l, 0.0)
|
|
|
|
if (self.l[0] > 1):
|
|
l = torch.Tensor(1,).to(cst.DEVICE)
|
|
self.l = nn.Parameter(l)
|
|
nn.init.constant_(self.l, 1.0)
|
|
|
|
|
|
X = self.W1 @ X
|
|
|
|
|
|
W = self.W -self.W *torch.eye(self.t1,dtype=torch.float32).to(cst.DEVICE)+torch.eye(self.t1,dtype=torch.float32).to(cst.DEVICE)/self.t1
|
|
|
|
|
|
E = X @ W
|
|
|
|
|
|
A = torch.softmax(E, dim=-1)
|
|
|
|
|
|
|
|
X = self.l[0] * (X) + (1.0 - self.l[0])*X*A
|
|
|
|
|
|
y = X @ self.W2 + self.B
|
|
return y
|
|
|
|
class BL_layer(nn.Module):
|
|
def __init__(self, d2, d1, t1, t2):
|
|
super().__init__()
|
|
weight1 = torch.Tensor(d2, d1)
|
|
self.W1 = nn.Parameter(weight1)
|
|
nn.init.kaiming_uniform_(self.W1, nonlinearity='relu')
|
|
|
|
weight2 = torch.Tensor(t1, t2)
|
|
self.W2 = nn.Parameter(weight2)
|
|
nn.init.kaiming_uniform_(self.W2, nonlinearity='relu')
|
|
|
|
bias1 = torch.zeros((d2, t2))
|
|
self.B = nn.Parameter(bias1)
|
|
nn.init.constant_(self.B, 0)
|
|
|
|
self.activation = nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
|
|
x = self.activation(self.W1 @ x @ self.W2 + self.B)
|
|
|
|
return x
|
|
|
|
class BiN_CTABL(nn.Module):
|
|
def __init__(self, d2, d1, t1, t2, d3, t3, d4, t4):
|
|
super().__init__()
|
|
|
|
self.BiN = BiN(d1, t1)
|
|
self.BL = BL_layer(d2, d1, t1, t2)
|
|
self.BL2 = BL_layer(d3, d2, t2, t3)
|
|
self.TABL = TABL_layer(d4, d3, t3, t4)
|
|
self.dropout = nn.Dropout(0.1)
|
|
|
|
def forward(self, x):
|
|
x = x.permute(0, 2, 1)
|
|
|
|
x = self.BiN(x)
|
|
|
|
self.max_norm_(self.BL.W1.data)
|
|
self.max_norm_(self.BL.W2.data)
|
|
x = self.BL(x)
|
|
x = self.dropout(x)
|
|
|
|
self.max_norm_(self.BL2.W1.data)
|
|
self.max_norm_(self.BL2.W2.data)
|
|
x = self.BL2(x)
|
|
x = self.dropout(x)
|
|
|
|
self.max_norm_(self.TABL.W1.data)
|
|
self.max_norm_(self.TABL.W.data)
|
|
self.max_norm_(self.TABL.W2.data)
|
|
x = self.TABL(x)
|
|
x = torch.squeeze(x)
|
|
x = torch.softmax(x, 1)
|
|
|
|
return x
|
|
|
|
def max_norm_(self, w):
|
|
with torch.no_grad():
|
|
if (torch.linalg.matrix_norm(w) > 10.0):
|
|
norm = torch.linalg.matrix_norm(w)
|
|
desired = torch.clamp(norm, min=0.0, max=10.0)
|
|
w *= (desired / (1e-8 + norm))
|
|
|