Spaces:
Sleeping
Sleeping
import copy | |
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import numpy as np | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# device = torch.device("cpu") | |
def clones(module, N): | |
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) | |
def clone_params(param, N): | |
return nn.ParameterList([copy.deepcopy(param) for _ in range(N)]) | |
# TODO: replaced with https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html? | |
class LayerNorm(nn.Module): | |
def __init__(self, features, eps=1e-6): | |
super(LayerNorm, self).__init__() | |
self.a_2 = nn.Parameter(torch.ones(features)) | |
self.b_2 = nn.Parameter(torch.zeros(features)) | |
self.eps = eps | |
def forward(self, x): | |
mean = x.mean(-1, keepdim=True) | |
std = x.std(-1, keepdim=True) | |
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 | |
class GraphLayer(nn.Module): | |
def __init__(self, in_features, hidden_features, out_features, num_of_nodes, | |
num_of_heads, dropout, alpha, concat=True): | |
super(GraphLayer, self).__init__() | |
self.in_features = in_features # MyNote: Embedding size | |
self.hidden_features = hidden_features # MyNote: Embedding size | |
self.out_features = out_features # MyNote: Embedding size (ngoại trừ Decoder Graph, khác chỗ này) | |
self.alpha = alpha # MyNote: hardcoded 0.1 | |
self.concat = concat # MyNote: Encoder graph ->True; Decoder Graph -> False. | |
self.num_of_nodes = num_of_nodes # MyNote: Số node trong Graph. | |
self.num_of_heads = num_of_heads # MyNote: Số attention head. -> là 1 (VGNN/Mimic) | |
# MyNote: gọi clones() nhưng List chỉ có 1 phần tử vì num_of_heads=1 (ghi trong paper). | |
self.W = clones(nn.Linear(in_features, hidden_features), num_of_heads) | |
self.a = clone_params(nn.Parameter(torch.rand(size=(1, 2 * hidden_features)), requires_grad=True), num_of_heads) | |
self.ffn = nn.Sequential( | |
nn.Linear(out_features, out_features), | |
nn.ReLU() | |
) | |
if not concat: | |
self.V = nn.Linear(hidden_features, out_features) | |
else: | |
self.V = nn.Linear(num_of_heads * hidden_features, out_features) | |
self.dropout = nn.Dropout(dropout) | |
self.leakyrelu = nn.LeakyReLU(self.alpha) | |
if concat: # MyNote: Ko hiểu khác nhau chỗ nào? | |
self.norm = LayerNorm(hidden_features) | |
else: | |
self.norm = LayerNorm(hidden_features) | |
def initialize(self): | |
for i in range(len(self.W)): | |
nn.init.xavier_normal_(self.W[i].weight.data) | |
for i in range(len(self.a)): | |
nn.init.xavier_normal_(self.a[i].data) | |
if not self.concat: | |
nn.init.xavier_normal_(self.V.weight.data) | |
nn.init.xavier_normal_(self.out_layer.weight.data) | |
def attention(self, linear, a, N, data, edge): | |
"""MyNote: _summary_ | |
Args: | |
linear (_type_): weights (R^(dxd)) | |
a (_type_): bias (R^(1x(2*d))) | |
N (_type_): number of nodes | |
data (_type_): h_prime = Toàn bộ Nodes & Embedding của nó. | |
edge (_type_): Vd: edge -> input_edges = 2x11664 | |
108x108=11664 -> 108 lab-value/procedure... (one-hot encoding) | |
Returns: | |
_type_: _description_ | |
""" | |
data = linear(data).unsqueeze(0) | |
assert not torch.isnan(data).any() | |
# edge: 2*D x E | |
h = torch.cat((data[:, edge[0, :], :], data[:, edge[1, :], :]), | |
dim=0) | |
data = data.squeeze(0) | |
# h: N x out | |
assert not torch.isnan(h).any() | |
# edge_h: 2*D x E | |
edge_h = torch.cat((h[0, :, :], h[1, :, :]), dim=1).transpose(0, 1) | |
# edge: 2*D x E | |
edge_e = torch.exp(self.leakyrelu(a.mm(edge_h).squeeze()) / np.sqrt(self.hidden_features * self.num_of_heads)) | |
assert not torch.isnan(edge_e).any() | |
# edge_e: E | |
edge_e = torch.sparse_coo_tensor(edge, edge_e, torch.Size([N, N])) | |
e_rowsum = torch.sparse.mm(edge_e, torch.ones(size=(N, 1)).to(device)) | |
# e_rowsum: N x 1 | |
row_check = (e_rowsum == 0) | |
e_rowsum[row_check] = 1 | |
zero_idx = row_check.nonzero()[:, 0] | |
edge_e = edge_e.add( | |
torch.sparse.FloatTensor(zero_idx.repeat(2, 1), torch.ones(len(zero_idx)).to(device), torch.Size([N, N]))) # type: ignore | |
# edge_e: E | |
h_prime = torch.sparse.mm(edge_e, data) | |
assert not torch.isnan(h_prime).any() | |
# h_prime: N x out | |
h_prime.div_(e_rowsum) | |
# h_prime: N x out | |
assert not torch.isnan(h_prime).any() | |
return h_prime | |
def forward(self, edge, data=None): | |
# MyNote: input: (input_edges, h_prime) | |
# MyNote: Vd: edge -> input_edges = 2x11881 | |
# MyNote: data -> h_prime = Toàn bộ Nodes & Embedding của nó. | |
N = self.num_of_nodes | |
if self.concat: # MyNote: hardcoded True | |
# MyNote: Zip nhưng thực ra chỉ có 1 element vì Attention head là 1 (ghi trong paper). | |
h_prime = torch.cat([self.attention(l, a, N, data, edge) for l, a in zip(self.W, self.a)], dim=1) | |
else: | |
h_prime = torch.stack([self.attention(l, a, N, data, edge) for l, a in zip(self.W, self.a)], dim=0).mean( | |
dim=0) | |
h_prime = self.dropout(h_prime) | |
if self.concat: | |
return F.elu(self.norm(h_prime)) | |
else: | |
return self.V(F.relu(self.norm(h_prime))) | |
class VariationalGNN(nn.Module): | |
def __init__(self, | |
in_features, | |
out_features, | |
num_of_nodes, | |
n_heads, | |
n_layers, | |
dropout, | |
alpha, # MyNote: hardcoded 0.1 | |
variational=True, | |
none_graph_features=0, | |
concat=True): | |
# Save input parameters for later convenient restoration of the object for inference. | |
self.kwargs = {'in_features': in_features, | |
'out_features': out_features, | |
'num_of_nodes': num_of_nodes, | |
'n_heads': n_heads, | |
'n_layers': n_layers, | |
'dropout': dropout, | |
'alpha': alpha, | |
'variational': variational, | |
'none_graph_features': none_graph_features, | |
'concat': concat} | |
super(VariationalGNN, self).__init__() | |
self.variational = variational | |
# Add two more nodes: the 1st indicates the patient is normal; the last node is used to absorb features from specific nodes of specific patients, to make prediction. | |
self.num_of_nodes = num_of_nodes + 2 - none_graph_features | |
# MyNote: this is the lookup embedding in paper. (Patient) | |
self.embed = nn.Embedding(self.num_of_nodes, in_features, padding_idx=0) | |
self.in_att = clones( | |
GraphLayer(in_features, in_features, in_features, self.num_of_nodes, | |
n_heads, dropout, alpha, concat=True), n_layers) | |
self.out_features = out_features | |
self.out_att = GraphLayer(in_features, in_features, out_features, self.num_of_nodes, | |
n_heads, dropout, alpha, concat=False) | |
self.n_heads = n_heads | |
self.dropout = nn.Dropout(dropout) | |
self.parameterize = nn.Linear(out_features, out_features * 2) | |
self.out_layer = nn.Sequential( | |
nn.Linear(out_features, out_features), | |
nn.ReLU(), | |
nn.Dropout(dropout), | |
nn.Linear(out_features, 1)) | |
self.none_graph_features = none_graph_features | |
#region none_graph_features > 0 | |
if none_graph_features > 0: | |
self.features_ffn = nn.Sequential( | |
nn.Linear(none_graph_features, out_features//2), | |
nn.ReLU(), | |
nn.Dropout(dropout)) | |
self.out_layer = nn.Sequential( | |
nn.Linear(out_features + out_features//2, out_features), | |
nn.ReLU(), | |
nn.Dropout(dropout), | |
nn.Linear(out_features, 1)) | |
#endregion | |
for i in range(n_layers): | |
self.in_att[i].initialize() | |
"""MyNote: Hàm này để chi? -> data là 1 patient sample với multihot encoding (chỉ bệnh). | |
Cần trả về các Edges nối các bệnh này với nhau. Nhớ rằng: mặc định tất cả các bệnh Connect với nhau. | |
""" | |
def data_to_edges(self, data): | |
"""MyNote: Must return (input_edges, output_edges)""" | |
length = data.size()[0] | |
nonzero = data.nonzero() # MyNote: return indices indicating non-zero values. | |
if nonzero.size()[0] == 0: # MyNote: case mà Patient bình thường! (ko có chẩn đoán, xét nghiệm gì!) | |
# MyNote: Why return so? shape(2, 1), shape(2, 1) Why length + 1? -> Khi bệnh nhân bình thường, vector bệnh của họ toàn là 0 -> cũng phải trả | |
# ra cái gì đó (vậy là chọn Node đầu và node cuối) | |
# MyNote: Right side: should include also torch.LongTensor([[0], [0]]) -> ám chỉ là "bình thường" (ko bệnh tật)??? | |
return torch.LongTensor([[0], [0]]), torch.LongTensor([[length + 1], [length + 1]]) | |
if self.training: | |
mask = torch.rand(nonzero.size()[0]) | |
mask = mask > 0.05 | |
nonzero = nonzero[mask] | |
if nonzero.size()[0] == 0: | |
# MyNote: có phải ý là ngay cả khi Patient có issue, 5% trong số đó ta sẽ đối xử như là ko có issue??? | |
return torch.LongTensor([[0], [0]]), torch.LongTensor([[length + 1], [length + 1]]) | |
# MyNote: case: when (testing/validating/infering) OR 95% probability bệnh nhân có ít nhất 1 issue nào đó. | |
nonzero = nonzero.transpose(0, 1) + 1 # MyNote: Why +1? -> Cộng để tăng Index vì có 2 Node giả đầu (là node chỉ bình thường) và cuối (là node absorb các node khác cho predict) | |
lengths = nonzero.size()[1] | |
input_edges = torch.cat((nonzero.repeat(1, lengths), | |
nonzero.repeat(lengths, 1).transpose(0, 1) | |
.contiguous().view((1, lengths ** 2))), dim=0) | |
nonzero = torch.cat((nonzero, torch.LongTensor([[length + 1]]).to(device)), dim=1) | |
lengths = nonzero.size()[1] | |
output_edges = torch.cat((nonzero.repeat(1, lengths), | |
nonzero.repeat(lengths, 1).transpose(0, 1) | |
.contiguous().view((1, lengths ** 2))), dim=0) | |
return input_edges.to(device), output_edges.to(device) | |
def reparameterise(self, mu, logvar): | |
if self.training: | |
# Assume log_variation (NOT log_standard_deviation!) | |
std = logvar.mul(0.5).exp_() | |
# MyNote: tensor.new() -> Constructs a new tensor of the same data type as self tensor. | |
eps = std.data.new(std.size()).normal_() | |
return eps.mul(std).add_(mu) | |
else: | |
return mu | |
def encoder_decoder(self, data): | |
"""Given a patient data, encode it into the total graph, then decode to the last node. | |
Args: | |
data ([N]): multi-hot encoding (of diagnose codes). E.g. shape = [1309] | |
Returns: | |
Tuple[Tensor, Tensor]: The last node's features, plus KL Divergence | |
""" | |
N = self.num_of_nodes | |
input_edges, output_edges = self.data_to_edges(data) | |
h_prime = self.embed(torch.arange(N).long().to(device)) | |
# Encoder: | |
for attn in self.in_att: | |
h_prime = attn(input_edges, h_prime) | |
if self.variational: | |
# Even given only a patient's data, this parameterization affects the total graph. | |
h_prime = self.parameterize(h_prime).view(-1, 2, self.out_features) | |
h_prime = self.dropout(h_prime) | |
mu = h_prime[:, 0, :] | |
logvar = h_prime[:, 1, :] | |
h_prime = self.reparameterise(mu, logvar) # h_prime.shape = [N, z_dim] e.g. (1311x256) | |
# Essential variables (mu, ,logvar) for computing DL Divergence later. | |
# Note: only consider the patient's graph (NOT the total graph). | |
split = int(math.sqrt(len(input_edges[0]))) | |
pat_diag_code_idx = input_edges[0][0:split] | |
mu = mu[pat_diag_code_idx, :] | |
logvar = logvar[pat_diag_code_idx, :] | |
# Decoder: | |
h_prime = self.out_att(output_edges, h_prime) | |
if self.variational: | |
""" | |
Need to divide with mu.size()[0] because the original formula sums over all latent dimensions. | |
""" | |
return (h_prime[-1], # The last node's features. | |
0.5 * torch.sum(logvar.exp() - logvar - 1 + mu.pow(2)) / mu.size()[0] | |
) | |
else: | |
return (h_prime[-1], \ | |
torch.tensor(0.0).to(device) | |
) | |
def forward(self, data): | |
# Concate batches | |
batch_size = data.size()[0] | |
# In eicu data the first feature whether have be admitted before is not included in the graph | |
if self.none_graph_features == 0: # MyNote: self.none_graph_features hardcoded = 0!!! -> cái này ko phải ám chỉ là ko dùng features cho nodes! | |
# MyNote: for each Patient-Encounter, encode the graph specifically for that. | |
outputs = [self.encoder_decoder(data[i, :]) for i in range(batch_size)] | |
# MyNote: return logits (output of out_layer()) -> later use BCEWithLogitsLoss | |
return self.out_layer(F.relu(torch.stack([out[0] for out in outputs]))), \ | |
torch.sum(torch.stack([out[1] for out in outputs])) | |
else: | |
outputs = [(data[i, :self.none_graph_features], | |
self.encoder_decoder(data[i, self.none_graph_features:])) for i in range(batch_size)] | |
return self.out_layer(F.relu( | |
torch.stack([torch.cat((self.features_ffn(torch.FloatTensor([out[0]]).to(device)), out[1][0])) | |
for out in outputs]))), \ | |
torch.sum(torch.stack([out[1][1] for out in outputs]), dim=-1) | |