Spaces:
Sleeping
Sleeping
Run ok offline
Browse files- __pycache__/model.cpython-311.pyc +0 -0
- app.py +23 -8
- final_model.pt +2 -2
- model.py +314 -0
__pycache__/model.cpython-311.pyc
ADDED
Binary file (21.7 kB). View file
|
|
app.py
CHANGED
@@ -7,15 +7,26 @@ import gradio as gr
|
|
7 |
import pandas as pd
|
8 |
import re
|
9 |
|
|
|
|
|
10 |
|
11 |
examples_path = "examples"
|
12 |
-
|
13 |
-
|
14 |
-
model = torch.jit.load("final_model.pt").to(device)
|
15 |
correct_preds, wrong_preds = {}, {}
|
16 |
condition_lst = pd.read_csv("feature.csv", header = "infer", sep = ",", encoding = "utf-8", dtype=str)
|
17 |
D_LABITEMS = pd.read_csv("D_LABITEMS.csv", header = "infer", sep = ",", encoding = "utf-8", dtype=str)
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
def _check_patient_csv_format(df: pd.DataFrame):
|
20 |
if not (list(df.columns)[0:2] == ["condition", "value"]):
|
21 |
raise gr.Error(f"Column set [{list(df.columns)}]: not expected.", duration=None)
|
@@ -71,8 +82,12 @@ def _predict(file_path: str):
|
|
71 |
keep_default_na=False)
|
72 |
_check_patient_csv_format(df)
|
73 |
patient_data = torch.from_numpy(df["value"].to_numpy()).unsqueeze(dim=0).to(device)
|
74 |
-
|
75 |
-
|
|
|
|
|
|
|
|
|
76 |
return probability
|
77 |
|
78 |
|
@@ -81,7 +96,7 @@ def example_csv_click(patient_id: int):
|
|
81 |
|
82 |
patient = correct_preds[patient_id] if patient_id in correct_preds else wrong_preds[patient_id]
|
83 |
probability = _predict(patient['file_name'])
|
84 |
-
return [{"
|
85 |
patient['label']]
|
86 |
|
87 |
|
@@ -90,7 +105,7 @@ def user_csv_upload(temp_csv_file_path):
|
|
90 |
|
91 |
matches = _extract_patient_data_from_name(temp_csv_file_path)
|
92 |
probability = _predict(temp_csv_file_path)
|
93 |
-
return [{"
|
94 |
"(Not Available)" if matches is None else matches[1]]
|
95 |
|
96 |
|
@@ -128,7 +143,7 @@ css = \
|
|
128 |
#selectFileToUpload {max-height: 180px}
|
129 |
.gradio-container {
|
130 |
background: url(https://www.kindpng.com/picc/m/207-2075829_transparent-healthcare-clipart-medical-report-icon-hd-png.png);
|
131 |
-
background-position: 80%
|
132 |
background-repeat: no-repeat;
|
133 |
background-size: 200px;
|
134 |
}
|
|
|
7 |
import pandas as pd
|
8 |
import re
|
9 |
|
10 |
+
from model import VariationalGNN
|
11 |
+
|
12 |
|
13 |
examples_path = "examples"
|
14 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
15 |
+
|
|
|
16 |
correct_preds, wrong_preds = {}, {}
|
17 |
condition_lst = pd.read_csv("feature.csv", header = "infer", sep = ",", encoding = "utf-8", dtype=str)
|
18 |
D_LABITEMS = pd.read_csv("D_LABITEMS.csv", header = "infer", sep = ",", encoding = "utf-8", dtype=str)
|
19 |
|
20 |
+
|
21 |
+
def load_model():
|
22 |
+
path = r"final_model.pt"
|
23 |
+
kwargs, state = torch.load(path, weights_only=False)
|
24 |
+
model = VariationalGNN(**kwargs).to(device)
|
25 |
+
model.load_state_dict(state)
|
26 |
+
return model
|
27 |
+
|
28 |
+
model = load_model()
|
29 |
+
|
30 |
def _check_patient_csv_format(df: pd.DataFrame):
|
31 |
if not (list(df.columns)[0:2] == ["condition", "value"]):
|
32 |
raise gr.Error(f"Column set [{list(df.columns)}]: not expected.", duration=None)
|
|
|
82 |
keep_default_na=False)
|
83 |
_check_patient_csv_format(df)
|
84 |
patient_data = torch.from_numpy(df["value"].to_numpy()).unsqueeze(dim=0).to(device)
|
85 |
+
|
86 |
+
model.eval()
|
87 |
+
with torch.inference_mode():
|
88 |
+
probability, _ = model(patient_data)
|
89 |
+
probability = torch.sigmoid(probability.detach().cpu()[0]).item()
|
90 |
+
|
91 |
return probability
|
92 |
|
93 |
|
|
|
96 |
|
97 |
patient = correct_preds[patient_id] if patient_id in correct_preds else wrong_preds[patient_id]
|
98 |
probability = _predict(patient['file_name'])
|
99 |
+
return [{"dead": probability, "alive": 1-probability},
|
100 |
patient['label']]
|
101 |
|
102 |
|
|
|
105 |
|
106 |
matches = _extract_patient_data_from_name(temp_csv_file_path)
|
107 |
probability = _predict(temp_csv_file_path)
|
108 |
+
return [{"dead": probability, "alive": 1-probability},
|
109 |
"(Not Available)" if matches is None else matches[1]]
|
110 |
|
111 |
|
|
|
143 |
#selectFileToUpload {max-height: 180px}
|
144 |
.gradio-container {
|
145 |
background: url(https://www.kindpng.com/picc/m/207-2075829_transparent-healthcare-clipart-medical-report-icon-hd-png.png);
|
146 |
+
background-position: 80% 85%;
|
147 |
background-repeat: no-repeat;
|
148 |
background-size: 200px;
|
149 |
}
|
final_model.pt
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fc8ce76b804f59492a8898b568978aaf73a64bcd8ab7e97cbf83626113ffde39
|
3 |
+
size 60944806
|
model.py
ADDED
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import math
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
11 |
+
# device = torch.device("cpu")
|
12 |
+
|
13 |
+
def clones(module, N):
|
14 |
+
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
|
15 |
+
|
16 |
+
|
17 |
+
def clone_params(param, N):
|
18 |
+
return nn.ParameterList([copy.deepcopy(param) for _ in range(N)])
|
19 |
+
|
20 |
+
|
21 |
+
# TODO: replaced with https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html?
|
22 |
+
class LayerNorm(nn.Module):
|
23 |
+
def __init__(self, features, eps=1e-6):
|
24 |
+
super(LayerNorm, self).__init__()
|
25 |
+
self.a_2 = nn.Parameter(torch.ones(features))
|
26 |
+
self.b_2 = nn.Parameter(torch.zeros(features))
|
27 |
+
self.eps = eps
|
28 |
+
|
29 |
+
def forward(self, x):
|
30 |
+
mean = x.mean(-1, keepdim=True)
|
31 |
+
std = x.std(-1, keepdim=True)
|
32 |
+
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
|
33 |
+
|
34 |
+
|
35 |
+
class GraphLayer(nn.Module):
|
36 |
+
|
37 |
+
def __init__(self, in_features, hidden_features, out_features, num_of_nodes,
|
38 |
+
num_of_heads, dropout, alpha, concat=True):
|
39 |
+
super(GraphLayer, self).__init__()
|
40 |
+
self.in_features = in_features # MyNote: Embedding size
|
41 |
+
self.hidden_features = hidden_features # MyNote: Embedding size
|
42 |
+
self.out_features = out_features # MyNote: Embedding size (ngoại trừ Decoder Graph, khác chỗ này)
|
43 |
+
self.alpha = alpha # MyNote: hardcoded 0.1
|
44 |
+
self.concat = concat # MyNote: Encoder graph ->True; Decoder Graph -> False.
|
45 |
+
self.num_of_nodes = num_of_nodes # MyNote: Số node trong Graph.
|
46 |
+
self.num_of_heads = num_of_heads # MyNote: Số attention head. -> là 1 (VGNN/Mimic)
|
47 |
+
|
48 |
+
# MyNote: gọi clones() nhưng List chỉ có 1 phần tử vì num_of_heads=1 (ghi trong paper).
|
49 |
+
self.W = clones(nn.Linear(in_features, hidden_features), num_of_heads)
|
50 |
+
self.a = clone_params(nn.Parameter(torch.rand(size=(1, 2 * hidden_features)), requires_grad=True), num_of_heads)
|
51 |
+
self.ffn = nn.Sequential(
|
52 |
+
nn.Linear(out_features, out_features),
|
53 |
+
nn.ReLU()
|
54 |
+
)
|
55 |
+
|
56 |
+
if not concat:
|
57 |
+
self.V = nn.Linear(hidden_features, out_features)
|
58 |
+
else:
|
59 |
+
self.V = nn.Linear(num_of_heads * hidden_features, out_features)
|
60 |
+
|
61 |
+
self.dropout = nn.Dropout(dropout)
|
62 |
+
self.leakyrelu = nn.LeakyReLU(self.alpha)
|
63 |
+
|
64 |
+
if concat: # MyNote: Ko hiểu khác nhau chỗ nào?
|
65 |
+
self.norm = LayerNorm(hidden_features)
|
66 |
+
else:
|
67 |
+
self.norm = LayerNorm(hidden_features)
|
68 |
+
|
69 |
+
def initialize(self):
|
70 |
+
for i in range(len(self.W)):
|
71 |
+
nn.init.xavier_normal_(self.W[i].weight.data)
|
72 |
+
for i in range(len(self.a)):
|
73 |
+
nn.init.xavier_normal_(self.a[i].data)
|
74 |
+
if not self.concat:
|
75 |
+
nn.init.xavier_normal_(self.V.weight.data)
|
76 |
+
nn.init.xavier_normal_(self.out_layer.weight.data)
|
77 |
+
|
78 |
+
def attention(self, linear, a, N, data, edge):
|
79 |
+
"""MyNote: _summary_
|
80 |
+
|
81 |
+
Args:
|
82 |
+
linear (_type_): weights (R^(dxd))
|
83 |
+
a (_type_): bias (R^(1x(2*d)))
|
84 |
+
N (_type_): number of nodes
|
85 |
+
data (_type_): h_prime = Toàn bộ Nodes & Embedding của nó.
|
86 |
+
edge (_type_): Vd: edge -> input_edges = 2x11664
|
87 |
+
108x108=11664 -> 108 lab-value/procedure... (one-hot encoding)
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
_type_: _description_
|
91 |
+
"""
|
92 |
+
data = linear(data).unsqueeze(0)
|
93 |
+
assert not torch.isnan(data).any()
|
94 |
+
# edge: 2*D x E
|
95 |
+
h = torch.cat((data[:, edge[0, :], :], data[:, edge[1, :], :]),
|
96 |
+
dim=0)
|
97 |
+
data = data.squeeze(0)
|
98 |
+
# h: N x out
|
99 |
+
assert not torch.isnan(h).any()
|
100 |
+
# edge_h: 2*D x E
|
101 |
+
edge_h = torch.cat((h[0, :, :], h[1, :, :]), dim=1).transpose(0, 1)
|
102 |
+
# edge: 2*D x E
|
103 |
+
edge_e = torch.exp(self.leakyrelu(a.mm(edge_h).squeeze()) / np.sqrt(self.hidden_features * self.num_of_heads))
|
104 |
+
assert not torch.isnan(edge_e).any()
|
105 |
+
# edge_e: E
|
106 |
+
edge_e = torch.sparse_coo_tensor(edge, edge_e, torch.Size([N, N]))
|
107 |
+
e_rowsum = torch.sparse.mm(edge_e, torch.ones(size=(N, 1)).to(device))
|
108 |
+
# e_rowsum: N x 1
|
109 |
+
row_check = (e_rowsum == 0)
|
110 |
+
e_rowsum[row_check] = 1
|
111 |
+
zero_idx = row_check.nonzero()[:, 0]
|
112 |
+
edge_e = edge_e.add(
|
113 |
+
torch.sparse.FloatTensor(zero_idx.repeat(2, 1), torch.ones(len(zero_idx)).to(device), torch.Size([N, N]))) # type: ignore
|
114 |
+
# edge_e: E
|
115 |
+
h_prime = torch.sparse.mm(edge_e, data)
|
116 |
+
assert not torch.isnan(h_prime).any()
|
117 |
+
# h_prime: N x out
|
118 |
+
h_prime.div_(e_rowsum)
|
119 |
+
# h_prime: N x out
|
120 |
+
assert not torch.isnan(h_prime).any()
|
121 |
+
return h_prime
|
122 |
+
|
123 |
+
def forward(self, edge, data=None):
|
124 |
+
# MyNote: input: (input_edges, h_prime)
|
125 |
+
# MyNote: Vd: edge -> input_edges = 2x11881
|
126 |
+
# MyNote: data -> h_prime = Toàn bộ Nodes & Embedding của nó.
|
127 |
+
N = self.num_of_nodes
|
128 |
+
|
129 |
+
if self.concat: # MyNote: hardcoded True
|
130 |
+
# MyNote: Zip nhưng thực ra chỉ có 1 element vì Attention head là 1 (ghi trong paper).
|
131 |
+
h_prime = torch.cat([self.attention(l, a, N, data, edge) for l, a in zip(self.W, self.a)], dim=1)
|
132 |
+
else:
|
133 |
+
h_prime = torch.stack([self.attention(l, a, N, data, edge) for l, a in zip(self.W, self.a)], dim=0).mean(
|
134 |
+
dim=0)
|
135 |
+
|
136 |
+
h_prime = self.dropout(h_prime)
|
137 |
+
|
138 |
+
if self.concat:
|
139 |
+
return F.elu(self.norm(h_prime))
|
140 |
+
else:
|
141 |
+
return self.V(F.relu(self.norm(h_prime)))
|
142 |
+
|
143 |
+
|
144 |
+
class VariationalGNN(nn.Module):
|
145 |
+
|
146 |
+
def __init__(self,
|
147 |
+
in_features,
|
148 |
+
out_features,
|
149 |
+
num_of_nodes,
|
150 |
+
n_heads,
|
151 |
+
n_layers,
|
152 |
+
dropout,
|
153 |
+
alpha, # MyNote: hardcoded 0.1
|
154 |
+
variational=True,
|
155 |
+
none_graph_features=0,
|
156 |
+
concat=True):
|
157 |
+
|
158 |
+
# Save input parameters for later convenient restoration of the object for inference.
|
159 |
+
self.kwargs = {'in_features': in_features,
|
160 |
+
'out_features': out_features,
|
161 |
+
'num_of_nodes': num_of_nodes,
|
162 |
+
'n_heads': n_heads,
|
163 |
+
'n_layers': n_layers,
|
164 |
+
'dropout': dropout,
|
165 |
+
'alpha': alpha,
|
166 |
+
'variational': variational,
|
167 |
+
'none_graph_features': none_graph_features,
|
168 |
+
'concat': concat}
|
169 |
+
|
170 |
+
super(VariationalGNN, self).__init__()
|
171 |
+
self.variational = variational
|
172 |
+
# Add two more nodes: the 1st indicates the patient is normal; the last node is used to absorb features from specific nodes of specific patients, to make prediction.
|
173 |
+
self.num_of_nodes = num_of_nodes + 2 - none_graph_features
|
174 |
+
# MyNote: this is the lookup embedding in paper. (Patient)
|
175 |
+
self.embed = nn.Embedding(self.num_of_nodes, in_features, padding_idx=0)
|
176 |
+
|
177 |
+
self.in_att = clones(
|
178 |
+
GraphLayer(in_features, in_features, in_features, self.num_of_nodes,
|
179 |
+
n_heads, dropout, alpha, concat=True), n_layers)
|
180 |
+
self.out_features = out_features
|
181 |
+
self.out_att = GraphLayer(in_features, in_features, out_features, self.num_of_nodes,
|
182 |
+
n_heads, dropout, alpha, concat=False)
|
183 |
+
self.n_heads = n_heads
|
184 |
+
self.dropout = nn.Dropout(dropout)
|
185 |
+
self.parameterize = nn.Linear(out_features, out_features * 2)
|
186 |
+
self.out_layer = nn.Sequential(
|
187 |
+
nn.Linear(out_features, out_features),
|
188 |
+
nn.ReLU(),
|
189 |
+
nn.Dropout(dropout),
|
190 |
+
nn.Linear(out_features, 1))
|
191 |
+
self.none_graph_features = none_graph_features
|
192 |
+
#region none_graph_features > 0
|
193 |
+
if none_graph_features > 0:
|
194 |
+
self.features_ffn = nn.Sequential(
|
195 |
+
nn.Linear(none_graph_features, out_features//2),
|
196 |
+
nn.ReLU(),
|
197 |
+
nn.Dropout(dropout))
|
198 |
+
self.out_layer = nn.Sequential(
|
199 |
+
nn.Linear(out_features + out_features//2, out_features),
|
200 |
+
nn.ReLU(),
|
201 |
+
nn.Dropout(dropout),
|
202 |
+
nn.Linear(out_features, 1))
|
203 |
+
#endregion
|
204 |
+
for i in range(n_layers):
|
205 |
+
self.in_att[i].initialize()
|
206 |
+
|
207 |
+
"""MyNote: Hàm này để chi? -> data là 1 patient sample với multihot encoding (chỉ bệnh).
|
208 |
+
Cần trả về các Edges nối các bệnh này với nhau. Nhớ rằng: mặc định tất cả các bệnh Connect với nhau.
|
209 |
+
"""
|
210 |
+
def data_to_edges(self, data):
|
211 |
+
"""MyNote: Must return (input_edges, output_edges)"""
|
212 |
+
length = data.size()[0]
|
213 |
+
nonzero = data.nonzero() # MyNote: return indices indicating non-zero values.
|
214 |
+
if nonzero.size()[0] == 0: # MyNote: case mà Patient bình thường! (ko có chẩn đoán, xét nghiệm gì!)
|
215 |
+
# MyNote: Why return so? shape(2, 1), shape(2, 1) Why length + 1? -> Khi bệnh nhân bình thường, vector bệnh của họ toàn là 0 -> cũng phải trả
|
216 |
+
# ra cái gì đó (vậy là chọn Node đầu và node cuối)
|
217 |
+
# MyNote: Right side: should include also torch.LongTensor([[0], [0]]) -> ám chỉ là "bình thường" (ko bệnh tật)???
|
218 |
+
return torch.LongTensor([[0], [0]]), torch.LongTensor([[length + 1], [length + 1]])
|
219 |
+
if self.training:
|
220 |
+
mask = torch.rand(nonzero.size()[0])
|
221 |
+
mask = mask > 0.05
|
222 |
+
nonzero = nonzero[mask]
|
223 |
+
if nonzero.size()[0] == 0:
|
224 |
+
# MyNote: có phải ý là ngay cả khi Patient có issue, 5% trong số đó ta sẽ đối x��� như là ko có issue???
|
225 |
+
return torch.LongTensor([[0], [0]]), torch.LongTensor([[length + 1], [length + 1]])
|
226 |
+
|
227 |
+
# MyNote: case: when (testing/validating/infering) OR 95% probability bệnh nhân có ít nhất 1 issue nào đó.
|
228 |
+
nonzero = nonzero.transpose(0, 1) + 1 # MyNote: Why +1? -> Cộng để tăng Index vì có 2 Node giả đầu (là node chỉ bình thường) và cuối (là node absorb các node khác cho predict)
|
229 |
+
lengths = nonzero.size()[1]
|
230 |
+
input_edges = torch.cat((nonzero.repeat(1, lengths),
|
231 |
+
nonzero.repeat(lengths, 1).transpose(0, 1)
|
232 |
+
.contiguous().view((1, lengths ** 2))), dim=0)
|
233 |
+
|
234 |
+
nonzero = torch.cat((nonzero, torch.LongTensor([[length + 1]]).to(device)), dim=1)
|
235 |
+
lengths = nonzero.size()[1]
|
236 |
+
output_edges = torch.cat((nonzero.repeat(1, lengths),
|
237 |
+
nonzero.repeat(lengths, 1).transpose(0, 1)
|
238 |
+
.contiguous().view((1, lengths ** 2))), dim=0)
|
239 |
+
return input_edges.to(device), output_edges.to(device)
|
240 |
+
|
241 |
+
def reparameterise(self, mu, logvar):
|
242 |
+
if self.training:
|
243 |
+
# Assume log_variation (NOT log_standard_deviation!)
|
244 |
+
std = logvar.mul(0.5).exp_()
|
245 |
+
# MyNote: tensor.new() -> Constructs a new tensor of the same data type as self tensor.
|
246 |
+
eps = std.data.new(std.size()).normal_()
|
247 |
+
return eps.mul(std).add_(mu)
|
248 |
+
else:
|
249 |
+
return mu
|
250 |
+
|
251 |
+
def encoder_decoder(self, data):
|
252 |
+
"""Given a patient data, encode it into the total graph, then decode to the last node.
|
253 |
+
|
254 |
+
Args:
|
255 |
+
data ([N]): multi-hot encoding (of diagnose codes). E.g. shape = [1309]
|
256 |
+
|
257 |
+
Returns:
|
258 |
+
Tuple[Tensor, Tensor]: The last node's features, plus KL Divergence
|
259 |
+
"""
|
260 |
+
N = self.num_of_nodes
|
261 |
+
input_edges, output_edges = self.data_to_edges(data)
|
262 |
+
h_prime = self.embed(torch.arange(N).long().to(device))
|
263 |
+
|
264 |
+
# Encoder:
|
265 |
+
for attn in self.in_att:
|
266 |
+
h_prime = attn(input_edges, h_prime)
|
267 |
+
|
268 |
+
if self.variational:
|
269 |
+
# Even given only a patient's data, this parameterization affects the total graph.
|
270 |
+
h_prime = self.parameterize(h_prime).view(-1, 2, self.out_features)
|
271 |
+
h_prime = self.dropout(h_prime)
|
272 |
+
mu = h_prime[:, 0, :]
|
273 |
+
logvar = h_prime[:, 1, :]
|
274 |
+
h_prime = self.reparameterise(mu, logvar) # h_prime.shape = [N, z_dim] e.g. (1311x256)
|
275 |
+
|
276 |
+
# Essential variables (mu, ,logvar) for computing DL Divergence later.
|
277 |
+
# Note: only consider the patient's graph (NOT the total graph).
|
278 |
+
split = int(math.sqrt(len(input_edges[0])))
|
279 |
+
pat_diag_code_idx = input_edges[0][0:split]
|
280 |
+
mu = mu[pat_diag_code_idx, :]
|
281 |
+
logvar = logvar[pat_diag_code_idx, :]
|
282 |
+
|
283 |
+
# Decoder:
|
284 |
+
h_prime = self.out_att(output_edges, h_prime)
|
285 |
+
|
286 |
+
if self.variational:
|
287 |
+
"""
|
288 |
+
Need to divide with mu.size()[0] because the original formula sums over all latent dimensions.
|
289 |
+
"""
|
290 |
+
return (h_prime[-1], # The last node's features.
|
291 |
+
0.5 * torch.sum(logvar.exp() - logvar - 1 + mu.pow(2)) / mu.size()[0]
|
292 |
+
)
|
293 |
+
else:
|
294 |
+
return (h_prime[-1], \
|
295 |
+
torch.tensor(0.0).to(device)
|
296 |
+
)
|
297 |
+
|
298 |
+
def forward(self, data):
|
299 |
+
# Concate batches
|
300 |
+
batch_size = data.size()[0]
|
301 |
+
# In eicu data the first feature whether have be admitted before is not included in the graph
|
302 |
+
if self.none_graph_features == 0: # MyNote: self.none_graph_features hardcoded = 0!!! -> cái này ko phải ám chỉ là ko dùng features cho nodes!
|
303 |
+
# MyNote: for each Patient-Encounter, encode the graph specifically for that.
|
304 |
+
outputs = [self.encoder_decoder(data[i, :]) for i in range(batch_size)]
|
305 |
+
# MyNote: return logits (output of out_layer()) -> later use BCEWithLogitsLoss
|
306 |
+
return self.out_layer(F.relu(torch.stack([out[0] for out in outputs]))), \
|
307 |
+
torch.sum(torch.stack([out[1] for out in outputs]))
|
308 |
+
else:
|
309 |
+
outputs = [(data[i, :self.none_graph_features],
|
310 |
+
self.encoder_decoder(data[i, self.none_graph_features:])) for i in range(batch_size)]
|
311 |
+
return self.out_layer(F.relu(
|
312 |
+
torch.stack([torch.cat((self.features_ffn(torch.FloatTensor([out[0]]).to(device)), out[1][0]))
|
313 |
+
for out in outputs]))), \
|
314 |
+
torch.sum(torch.stack([out[1][1] for out in outputs]), dim=-1)
|