|
import torch
|
|
from torch import nn
|
|
import constants as cst
|
|
|
|
class BiN(nn.Module):
|
|
def __init__(self, d1, t1):
|
|
super().__init__()
|
|
self.t1 = t1
|
|
self.d1 = d1
|
|
|
|
bias1 = torch.Tensor(t1, 1)
|
|
self.B1 = nn.Parameter(bias1)
|
|
nn.init.constant_(self.B1, 0)
|
|
|
|
l1 = torch.Tensor(t1, 1)
|
|
self.l1 = nn.Parameter(l1)
|
|
nn.init.xavier_normal_(self.l1)
|
|
|
|
bias2 = torch.Tensor(d1, 1)
|
|
self.B2 = nn.Parameter(bias2)
|
|
nn.init.constant_(self.B2, 0)
|
|
|
|
l2 = torch.Tensor(d1, 1)
|
|
self.l2 = nn.Parameter(l2)
|
|
nn.init.xavier_normal_(self.l2)
|
|
|
|
y1 = torch.Tensor(1, )
|
|
self.y1 = nn.Parameter(y1)
|
|
nn.init.constant_(self.y1, 0.5)
|
|
|
|
y2 = torch.Tensor(1, )
|
|
self.y2 = nn.Parameter(y2)
|
|
nn.init.constant_(self.y2, 0.5)
|
|
|
|
def forward(self, x):
|
|
|
|
|
|
if (self.y1[0] < 0):
|
|
y1 = torch.cuda.FloatTensor(1, )
|
|
self.y1 = nn.Parameter(y1)
|
|
nn.init.constant_(self.y1, 0.01)
|
|
|
|
if (self.y2[0] < 0):
|
|
y2 = torch.cuda.FloatTensor(1, )
|
|
self.y2 = nn.Parameter(y2)
|
|
nn.init.constant_(self.y2, 0.01)
|
|
|
|
|
|
T2 = torch.ones([self.t1, 1], device=cst.DEVICE)
|
|
x2 = torch.mean(x, dim=2)
|
|
x2 = torch.reshape(x2, (x2.shape[0], x2.shape[1], 1))
|
|
|
|
std = torch.std(x, dim=2)
|
|
std = torch.reshape(std, (std.shape[0], std.shape[1], 1))
|
|
|
|
std[std < 1e-4] = 1
|
|
|
|
diff = x - (x2 @ (T2.T))
|
|
Z2 = diff / (std @ (T2.T))
|
|
|
|
X2 = self.l2 @ T2.T
|
|
X2 = X2 * Z2
|
|
X2 = X2 + (self.B2 @ T2.T)
|
|
|
|
|
|
T1 = torch.ones([self.d1, 1], device=cst.DEVICE)
|
|
x1 = torch.mean(x, dim=1)
|
|
x1 = torch.reshape(x1, (x1.shape[0], x1.shape[1], 1))
|
|
|
|
std = torch.std(x, dim=1)
|
|
std = torch.reshape(std, (std.shape[0], std.shape[1], 1))
|
|
|
|
op1 = x1 @ T1.T
|
|
op1 = torch.permute(op1, (0, 2, 1))
|
|
|
|
op2 = std @ T1.T
|
|
op2 = torch.permute(op2, (0, 2, 1))
|
|
|
|
z1 = (x - op1) / (op2)
|
|
X1 = (T1 @ self.l1.T)
|
|
X1 = X1 * z1
|
|
X1 = X1 + (T1 @ self.B1.T)
|
|
|
|
|
|
x = self.y1 * X1 + self.y2 * X2
|
|
|
|
return x |