import torch from torch import nn import constants as cst class BiN(nn.Module): def __init__(self, d1, t1): super().__init__() self.t1 = t1 self.d1 = d1 bias1 = torch.Tensor(t1, 1) self.B1 = nn.Parameter(bias1) nn.init.constant_(self.B1, 0) l1 = torch.Tensor(t1, 1) self.l1 = nn.Parameter(l1) nn.init.xavier_normal_(self.l1) bias2 = torch.Tensor(d1, 1) self.B2 = nn.Parameter(bias2) nn.init.constant_(self.B2, 0) l2 = torch.Tensor(d1, 1) self.l2 = nn.Parameter(l2) nn.init.xavier_normal_(self.l2) y1 = torch.Tensor(1, ) self.y1 = nn.Parameter(y1) nn.init.constant_(self.y1, 0.5) y2 = torch.Tensor(1, ) self.y2 = nn.Parameter(y2) nn.init.constant_(self.y2, 0.5) def forward(self, x): # if the two scalars are negative then we setting them to 0 if (self.y1[0] < 0): y1 = torch.cuda.FloatTensor(1, ) self.y1 = nn.Parameter(y1) nn.init.constant_(self.y1, 0.01) if (self.y2[0] < 0): y2 = torch.cuda.FloatTensor(1, ) self.y2 = nn.Parameter(y2) nn.init.constant_(self.y2, 0.01) # normalization along the temporal dimensione T2 = torch.ones([self.t1, 1], device=cst.DEVICE) x2 = torch.mean(x, dim=2) x2 = torch.reshape(x2, (x2.shape[0], x2.shape[1], 1)) std = torch.std(x, dim=2) std = torch.reshape(std, (std.shape[0], std.shape[1], 1)) # it can be possible that the std of some temporal slices is 0, and this produces inf values, so we have to set them to one std[std < 1e-4] = 1 diff = x - (x2 @ (T2.T)) Z2 = diff / (std @ (T2.T)) X2 = self.l2 @ T2.T X2 = X2 * Z2 X2 = X2 + (self.B2 @ T2.T) # normalization along the feature dimension T1 = torch.ones([self.d1, 1], device=cst.DEVICE) x1 = torch.mean(x, dim=1) x1 = torch.reshape(x1, (x1.shape[0], x1.shape[1], 1)) std = torch.std(x, dim=1) std = torch.reshape(std, (std.shape[0], std.shape[1], 1)) op1 = x1 @ T1.T op1 = torch.permute(op1, (0, 2, 1)) op2 = std @ T1.T op2 = torch.permute(op2, (0, 2, 1)) z1 = (x - op1) / (op2) X1 = (T1 @ self.l1.T) X1 = X1 * z1 X1 = X1 + (T1 @ self.B1.T) # weighing the imporance of temporal and feature normalization x = self.y1 * X1 + self.y2 * X2 return x