Spaces:
Sleeping
Sleeping
File size: 14,779 Bytes
5f1cd98 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 |
import copy
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
def clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
def clone_params(param, N):
return nn.ParameterList([copy.deepcopy(param) for _ in range(N)])
# TODO: replaced with https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html?
class LayerNorm(nn.Module):
def __init__(self, features, eps=1e-6):
super(LayerNorm, self).__init__()
self.a_2 = nn.Parameter(torch.ones(features))
self.b_2 = nn.Parameter(torch.zeros(features))
self.eps = eps
def forward(self, x):
mean = x.mean(-1, keepdim=True)
std = x.std(-1, keepdim=True)
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
class GraphLayer(nn.Module):
def __init__(self, in_features, hidden_features, out_features, num_of_nodes,
num_of_heads, dropout, alpha, concat=True):
super(GraphLayer, self).__init__()
self.in_features = in_features # MyNote: Embedding size
self.hidden_features = hidden_features # MyNote: Embedding size
self.out_features = out_features # MyNote: Embedding size (ngoại trừ Decoder Graph, khác chỗ này)
self.alpha = alpha # MyNote: hardcoded 0.1
self.concat = concat # MyNote: Encoder graph ->True; Decoder Graph -> False.
self.num_of_nodes = num_of_nodes # MyNote: Số node trong Graph.
self.num_of_heads = num_of_heads # MyNote: Số attention head. -> là 1 (VGNN/Mimic)
# MyNote: gọi clones() nhưng List chỉ có 1 phần tử vì num_of_heads=1 (ghi trong paper).
self.W = clones(nn.Linear(in_features, hidden_features), num_of_heads)
self.a = clone_params(nn.Parameter(torch.rand(size=(1, 2 * hidden_features)), requires_grad=True), num_of_heads)
self.ffn = nn.Sequential(
nn.Linear(out_features, out_features),
nn.ReLU()
)
if not concat:
self.V = nn.Linear(hidden_features, out_features)
else:
self.V = nn.Linear(num_of_heads * hidden_features, out_features)
self.dropout = nn.Dropout(dropout)
self.leakyrelu = nn.LeakyReLU(self.alpha)
if concat: # MyNote: Ko hiểu khác nhau chỗ nào?
self.norm = LayerNorm(hidden_features)
else:
self.norm = LayerNorm(hidden_features)
def initialize(self):
for i in range(len(self.W)):
nn.init.xavier_normal_(self.W[i].weight.data)
for i in range(len(self.a)):
nn.init.xavier_normal_(self.a[i].data)
if not self.concat:
nn.init.xavier_normal_(self.V.weight.data)
nn.init.xavier_normal_(self.out_layer.weight.data)
def attention(self, linear, a, N, data, edge):
"""MyNote: _summary_
Args:
linear (_type_): weights (R^(dxd))
a (_type_): bias (R^(1x(2*d)))
N (_type_): number of nodes
data (_type_): h_prime = Toàn bộ Nodes & Embedding của nó.
edge (_type_): Vd: edge -> input_edges = 2x11664
108x108=11664 -> 108 lab-value/procedure... (one-hot encoding)
Returns:
_type_: _description_
"""
data = linear(data).unsqueeze(0)
assert not torch.isnan(data).any()
# edge: 2*D x E
h = torch.cat((data[:, edge[0, :], :], data[:, edge[1, :], :]),
dim=0)
data = data.squeeze(0)
# h: N x out
assert not torch.isnan(h).any()
# edge_h: 2*D x E
edge_h = torch.cat((h[0, :, :], h[1, :, :]), dim=1).transpose(0, 1)
# edge: 2*D x E
edge_e = torch.exp(self.leakyrelu(a.mm(edge_h).squeeze()) / np.sqrt(self.hidden_features * self.num_of_heads))
assert not torch.isnan(edge_e).any()
# edge_e: E
edge_e = torch.sparse_coo_tensor(edge, edge_e, torch.Size([N, N]))
e_rowsum = torch.sparse.mm(edge_e, torch.ones(size=(N, 1)).to(device))
# e_rowsum: N x 1
row_check = (e_rowsum == 0)
e_rowsum[row_check] = 1
zero_idx = row_check.nonzero()[:, 0]
edge_e = edge_e.add(
torch.sparse.FloatTensor(zero_idx.repeat(2, 1), torch.ones(len(zero_idx)).to(device), torch.Size([N, N]))) # type: ignore
# edge_e: E
h_prime = torch.sparse.mm(edge_e, data)
assert not torch.isnan(h_prime).any()
# h_prime: N x out
h_prime.div_(e_rowsum)
# h_prime: N x out
assert not torch.isnan(h_prime).any()
return h_prime
def forward(self, edge, data=None):
# MyNote: input: (input_edges, h_prime)
# MyNote: Vd: edge -> input_edges = 2x11881
# MyNote: data -> h_prime = Toàn bộ Nodes & Embedding của nó.
N = self.num_of_nodes
if self.concat: # MyNote: hardcoded True
# MyNote: Zip nhưng thực ra chỉ có 1 element vì Attention head là 1 (ghi trong paper).
h_prime = torch.cat([self.attention(l, a, N, data, edge) for l, a in zip(self.W, self.a)], dim=1)
else:
h_prime = torch.stack([self.attention(l, a, N, data, edge) for l, a in zip(self.W, self.a)], dim=0).mean(
dim=0)
h_prime = self.dropout(h_prime)
if self.concat:
return F.elu(self.norm(h_prime))
else:
return self.V(F.relu(self.norm(h_prime)))
class VariationalGNN(nn.Module):
def __init__(self,
in_features,
out_features,
num_of_nodes,
n_heads,
n_layers,
dropout,
alpha, # MyNote: hardcoded 0.1
variational=True,
none_graph_features=0,
concat=True):
# Save input parameters for later convenient restoration of the object for inference.
self.kwargs = {'in_features': in_features,
'out_features': out_features,
'num_of_nodes': num_of_nodes,
'n_heads': n_heads,
'n_layers': n_layers,
'dropout': dropout,
'alpha': alpha,
'variational': variational,
'none_graph_features': none_graph_features,
'concat': concat}
super(VariationalGNN, self).__init__()
self.variational = variational
# Add two more nodes: the 1st indicates the patient is normal; the last node is used to absorb features from specific nodes of specific patients, to make prediction.
self.num_of_nodes = num_of_nodes + 2 - none_graph_features
# MyNote: this is the lookup embedding in paper. (Patient)
self.embed = nn.Embedding(self.num_of_nodes, in_features, padding_idx=0)
self.in_att = clones(
GraphLayer(in_features, in_features, in_features, self.num_of_nodes,
n_heads, dropout, alpha, concat=True), n_layers)
self.out_features = out_features
self.out_att = GraphLayer(in_features, in_features, out_features, self.num_of_nodes,
n_heads, dropout, alpha, concat=False)
self.n_heads = n_heads
self.dropout = nn.Dropout(dropout)
self.parameterize = nn.Linear(out_features, out_features * 2)
self.out_layer = nn.Sequential(
nn.Linear(out_features, out_features),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(out_features, 1))
self.none_graph_features = none_graph_features
#region none_graph_features > 0
if none_graph_features > 0:
self.features_ffn = nn.Sequential(
nn.Linear(none_graph_features, out_features//2),
nn.ReLU(),
nn.Dropout(dropout))
self.out_layer = nn.Sequential(
nn.Linear(out_features + out_features//2, out_features),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(out_features, 1))
#endregion
for i in range(n_layers):
self.in_att[i].initialize()
"""MyNote: Hàm này để chi? -> data là 1 patient sample với multihot encoding (chỉ bệnh).
Cần trả về các Edges nối các bệnh này với nhau. Nhớ rằng: mặc định tất cả các bệnh Connect với nhau.
"""
def data_to_edges(self, data):
"""MyNote: Must return (input_edges, output_edges)"""
length = data.size()[0]
nonzero = data.nonzero() # MyNote: return indices indicating non-zero values.
if nonzero.size()[0] == 0: # MyNote: case mà Patient bình thường! (ko có chẩn đoán, xét nghiệm gì!)
# MyNote: Why return so? shape(2, 1), shape(2, 1) Why length + 1? -> Khi bệnh nhân bình thường, vector bệnh của họ toàn là 0 -> cũng phải trả
# ra cái gì đó (vậy là chọn Node đầu và node cuối)
# MyNote: Right side: should include also torch.LongTensor([[0], [0]]) -> ám chỉ là "bình thường" (ko bệnh tật)???
return torch.LongTensor([[0], [0]]), torch.LongTensor([[length + 1], [length + 1]])
if self.training:
mask = torch.rand(nonzero.size()[0])
mask = mask > 0.05
nonzero = nonzero[mask]
if nonzero.size()[0] == 0:
# MyNote: có phải ý là ngay cả khi Patient có issue, 5% trong số đó ta sẽ đối xử như là ko có issue???
return torch.LongTensor([[0], [0]]), torch.LongTensor([[length + 1], [length + 1]])
# MyNote: case: when (testing/validating/infering) OR 95% probability bệnh nhân có ít nhất 1 issue nào đó.
nonzero = nonzero.transpose(0, 1) + 1 # MyNote: Why +1? -> Cộng để tăng Index vì có 2 Node giả đầu (là node chỉ bình thường) và cuối (là node absorb các node khác cho predict)
lengths = nonzero.size()[1]
input_edges = torch.cat((nonzero.repeat(1, lengths),
nonzero.repeat(lengths, 1).transpose(0, 1)
.contiguous().view((1, lengths ** 2))), dim=0)
nonzero = torch.cat((nonzero, torch.LongTensor([[length + 1]]).to(device)), dim=1)
lengths = nonzero.size()[1]
output_edges = torch.cat((nonzero.repeat(1, lengths),
nonzero.repeat(lengths, 1).transpose(0, 1)
.contiguous().view((1, lengths ** 2))), dim=0)
return input_edges.to(device), output_edges.to(device)
def reparameterise(self, mu, logvar):
if self.training:
# Assume log_variation (NOT log_standard_deviation!)
std = logvar.mul(0.5).exp_()
# MyNote: tensor.new() -> Constructs a new tensor of the same data type as self tensor.
eps = std.data.new(std.size()).normal_()
return eps.mul(std).add_(mu)
else:
return mu
def encoder_decoder(self, data):
"""Given a patient data, encode it into the total graph, then decode to the last node.
Args:
data ([N]): multi-hot encoding (of diagnose codes). E.g. shape = [1309]
Returns:
Tuple[Tensor, Tensor]: The last node's features, plus KL Divergence
"""
N = self.num_of_nodes
input_edges, output_edges = self.data_to_edges(data)
h_prime = self.embed(torch.arange(N).long().to(device))
# Encoder:
for attn in self.in_att:
h_prime = attn(input_edges, h_prime)
if self.variational:
# Even given only a patient's data, this parameterization affects the total graph.
h_prime = self.parameterize(h_prime).view(-1, 2, self.out_features)
h_prime = self.dropout(h_prime)
mu = h_prime[:, 0, :]
logvar = h_prime[:, 1, :]
h_prime = self.reparameterise(mu, logvar) # h_prime.shape = [N, z_dim] e.g. (1311x256)
# Essential variables (mu, ,logvar) for computing DL Divergence later.
# Note: only consider the patient's graph (NOT the total graph).
split = int(math.sqrt(len(input_edges[0])))
pat_diag_code_idx = input_edges[0][0:split]
mu = mu[pat_diag_code_idx, :]
logvar = logvar[pat_diag_code_idx, :]
# Decoder:
h_prime = self.out_att(output_edges, h_prime)
if self.variational:
"""
Need to divide with mu.size()[0] because the original formula sums over all latent dimensions.
"""
return (h_prime[-1], # The last node's features.
0.5 * torch.sum(logvar.exp() - logvar - 1 + mu.pow(2)) / mu.size()[0]
)
else:
return (h_prime[-1], \
torch.tensor(0.0).to(device)
)
def forward(self, data):
# Concate batches
batch_size = data.size()[0]
# In eicu data the first feature whether have be admitted before is not included in the graph
if self.none_graph_features == 0: # MyNote: self.none_graph_features hardcoded = 0!!! -> cái này ko phải ám chỉ là ko dùng features cho nodes!
# MyNote: for each Patient-Encounter, encode the graph specifically for that.
outputs = [self.encoder_decoder(data[i, :]) for i in range(batch_size)]
# MyNote: return logits (output of out_layer()) -> later use BCEWithLogitsLoss
return self.out_layer(F.relu(torch.stack([out[0] for out in outputs]))), \
torch.sum(torch.stack([out[1] for out in outputs]))
else:
outputs = [(data[i, :self.none_graph_features],
self.encoder_decoder(data[i, self.none_graph_features:])) for i in range(batch_size)]
return self.out_layer(F.relu(
torch.stack([torch.cat((self.features_ffn(torch.FloatTensor([out[0]]).to(device)), out[1][0]))
for out in outputs]))), \
torch.sum(torch.stack([out[1][1] for out in outputs]), dim=-1)
|