File size: 4,079 Bytes
69524d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
from torch import nn
from models.bin import BiN
import torch
import constants as cst
class TABL_layer(nn.Module):
def __init__(self, d2, d1, t1, t2):
super().__init__()
self.t1 = t1
weight = torch.Tensor(d2, d1)
self.W1 = nn.Parameter(weight)
nn.init.kaiming_uniform_(self.W1, nonlinearity='relu')
weight2 = torch.Tensor(t1, t1)
self.W = nn.Parameter(weight2)
nn.init.constant_(self.W, 1/t1)
weight3 = torch.Tensor(t1, t2)
self.W2 = nn.Parameter(weight3)
nn.init.kaiming_uniform_(self.W2, nonlinearity='relu')
bias1 = torch.Tensor(d2, t2)
self.B = nn.Parameter(bias1)
nn.init.constant_(self.B, 0)
l = torch.Tensor(1,)
self.l = nn.Parameter(l)
nn.init.constant_(self.l, 0.5)
self.activation = nn.ReLU()
def forward(self, X):
#maintaining the weight parameter between 0 and 1.
if (self.l[0] < 0):
l = torch.Tensor(1,).to(cst.DEVICE)
self.l = nn.Parameter(l)
nn.init.constant_(self.l, 0.0)
if (self.l[0] > 1):
l = torch.Tensor(1,).to(cst.DEVICE)
self.l = nn.Parameter(l)
nn.init.constant_(self.l, 1.0)
#modelling the dependence along the first mode of X while keeping the temporal order intact (7)
X = self.W1 @ X
#enforcing constant (1) on the diagonal
W = self.W -self.W *torch.eye(self.t1,dtype=torch.float32).to(cst.DEVICE)+torch.eye(self.t1,dtype=torch.float32).to(cst.DEVICE)/self.t1
#attention, the aim of the second step is to learn how important the temporal instances are to each other (8)
E = X @ W
#computing the attention mask (9)
A = torch.softmax(E, dim=-1)
#applying a soft attention mechanism (10)
#he attention mask A obtained from the third step is used to zero out the effect of unimportant elements
X = self.l[0] * (X) + (1.0 - self.l[0])*X*A
#the final step of the proposed layer estimates the temporal mapping W2, after the bias shift (11)
y = X @ self.W2 + self.B
return y
class BL_layer(nn.Module):
def __init__(self, d2, d1, t1, t2):
super().__init__()
weight1 = torch.Tensor(d2, d1)
self.W1 = nn.Parameter(weight1)
nn.init.kaiming_uniform_(self.W1, nonlinearity='relu')
weight2 = torch.Tensor(t1, t2)
self.W2 = nn.Parameter(weight2)
nn.init.kaiming_uniform_(self.W2, nonlinearity='relu')
bias1 = torch.zeros((d2, t2))
self.B = nn.Parameter(bias1)
nn.init.constant_(self.B, 0)
self.activation = nn.ReLU()
def forward(self, x):
x = self.activation(self.W1 @ x @ self.W2 + self.B)
return x
class BiN_CTABL(nn.Module):
def __init__(self, d2, d1, t1, t2, d3, t3, d4, t4):
super().__init__()
self.BiN = BiN(d1, t1)
self.BL = BL_layer(d2, d1, t1, t2)
self.BL2 = BL_layer(d3, d2, t2, t3)
self.TABL = TABL_layer(d4, d3, t3, t4)
self.dropout = nn.Dropout(0.1)
def forward(self, x):
x = x.permute(0, 2, 1)
#first of all we pass the input to the BiN layer, then we use the C(TABL) architecture
x = self.BiN(x)
self.max_norm_(self.BL.W1.data)
self.max_norm_(self.BL.W2.data)
x = self.BL(x)
x = self.dropout(x)
self.max_norm_(self.BL2.W1.data)
self.max_norm_(self.BL2.W2.data)
x = self.BL2(x)
x = self.dropout(x)
self.max_norm_(self.TABL.W1.data)
self.max_norm_(self.TABL.W.data)
self.max_norm_(self.TABL.W2.data)
x = self.TABL(x)
x = torch.squeeze(x)
x = torch.softmax(x, 1)
return x
def max_norm_(self, w):
with torch.no_grad():
if (torch.linalg.matrix_norm(w) > 10.0):
norm = torch.linalg.matrix_norm(w)
desired = torch.clamp(norm, min=0.0, max=10.0)
w *= (desired / (1e-8 + norm))
|