from torch import nn from models.bin import BiN import torch import constants as cst class TABL_layer(nn.Module): def __init__(self, d2, d1, t1, t2): super().__init__() self.t1 = t1 weight = torch.Tensor(d2, d1) self.W1 = nn.Parameter(weight) nn.init.kaiming_uniform_(self.W1, nonlinearity='relu') weight2 = torch.Tensor(t1, t1) self.W = nn.Parameter(weight2) nn.init.constant_(self.W, 1/t1) weight3 = torch.Tensor(t1, t2) self.W2 = nn.Parameter(weight3) nn.init.kaiming_uniform_(self.W2, nonlinearity='relu') bias1 = torch.Tensor(d2, t2) self.B = nn.Parameter(bias1) nn.init.constant_(self.B, 0) l = torch.Tensor(1,) self.l = nn.Parameter(l) nn.init.constant_(self.l, 0.5) self.activation = nn.ReLU() def forward(self, X): #maintaining the weight parameter between 0 and 1. if (self.l[0] < 0): l = torch.Tensor(1,).to(cst.DEVICE) self.l = nn.Parameter(l) nn.init.constant_(self.l, 0.0) if (self.l[0] > 1): l = torch.Tensor(1,).to(cst.DEVICE) self.l = nn.Parameter(l) nn.init.constant_(self.l, 1.0) #modelling the dependence along the first mode of X while keeping the temporal order intact (7) X = self.W1 @ X #enforcing constant (1) on the diagonal W = self.W -self.W *torch.eye(self.t1,dtype=torch.float32).to(cst.DEVICE)+torch.eye(self.t1,dtype=torch.float32).to(cst.DEVICE)/self.t1 #attention, the aim of the second step is to learn how important the temporal instances are to each other (8) E = X @ W #computing the attention mask (9) A = torch.softmax(E, dim=-1) #applying a soft attention mechanism (10) #he attention mask A obtained from the third step is used to zero out the effect of unimportant elements X = self.l[0] * (X) + (1.0 - self.l[0])*X*A #the final step of the proposed layer estimates the temporal mapping W2, after the bias shift (11) y = X @ self.W2 + self.B return y class BL_layer(nn.Module): def __init__(self, d2, d1, t1, t2): super().__init__() weight1 = torch.Tensor(d2, d1) self.W1 = nn.Parameter(weight1) nn.init.kaiming_uniform_(self.W1, nonlinearity='relu') weight2 = torch.Tensor(t1, t2) self.W2 = nn.Parameter(weight2) nn.init.kaiming_uniform_(self.W2, nonlinearity='relu') bias1 = torch.zeros((d2, t2)) self.B = nn.Parameter(bias1) nn.init.constant_(self.B, 0) self.activation = nn.ReLU() def forward(self, x): x = self.activation(self.W1 @ x @ self.W2 + self.B) return x class BiN_CTABL(nn.Module): def __init__(self, d2, d1, t1, t2, d3, t3, d4, t4): super().__init__() self.BiN = BiN(d1, t1) self.BL = BL_layer(d2, d1, t1, t2) self.BL2 = BL_layer(d3, d2, t2, t3) self.TABL = TABL_layer(d4, d3, t3, t4) self.dropout = nn.Dropout(0.1) def forward(self, x): x = x.permute(0, 2, 1) #first of all we pass the input to the BiN layer, then we use the C(TABL) architecture x = self.BiN(x) self.max_norm_(self.BL.W1.data) self.max_norm_(self.BL.W2.data) x = self.BL(x) x = self.dropout(x) self.max_norm_(self.BL2.W1.data) self.max_norm_(self.BL2.W2.data) x = self.BL2(x) x = self.dropout(x) self.max_norm_(self.TABL.W1.data) self.max_norm_(self.TABL.W.data) self.max_norm_(self.TABL.W2.data) x = self.TABL(x) x = torch.squeeze(x) x = torch.softmax(x, 1) return x def max_norm_(self, w): with torch.no_grad(): if (torch.linalg.matrix_norm(w) > 10.0): norm = torch.linalg.matrix_norm(w) desired = torch.clamp(norm, min=0.0, max=10.0) w *= (desired / (1e-8 + norm))