tnt306 commited on
Commit
5f1cd98
·
1 Parent(s): 8938f9b

Run ok offline

Browse files
Files changed (4) hide show
  1. __pycache__/model.cpython-311.pyc +0 -0
  2. app.py +23 -8
  3. final_model.pt +2 -2
  4. model.py +314 -0
__pycache__/model.cpython-311.pyc ADDED
Binary file (21.7 kB). View file
 
app.py CHANGED
@@ -7,15 +7,26 @@ import gradio as gr
7
  import pandas as pd
8
  import re
9
 
 
 
10
 
11
  examples_path = "examples"
12
- #device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
- device = torch.device("cpu")
14
- model = torch.jit.load("final_model.pt").to(device)
15
  correct_preds, wrong_preds = {}, {}
16
  condition_lst = pd.read_csv("feature.csv", header = "infer", sep = ",", encoding = "utf-8", dtype=str)
17
  D_LABITEMS = pd.read_csv("D_LABITEMS.csv", header = "infer", sep = ",", encoding = "utf-8", dtype=str)
18
 
 
 
 
 
 
 
 
 
 
 
19
  def _check_patient_csv_format(df: pd.DataFrame):
20
  if not (list(df.columns)[0:2] == ["condition", "value"]):
21
  raise gr.Error(f"Column set [{list(df.columns)}]: not expected.", duration=None)
@@ -71,8 +82,12 @@ def _predict(file_path: str):
71
  keep_default_na=False)
72
  _check_patient_csv_format(df)
73
  patient_data = torch.from_numpy(df["value"].to_numpy()).unsqueeze(dim=0).to(device)
74
- probability, _ = model(patient_data)
75
- probability = probability.detach().cpu()[0].item()
 
 
 
 
76
  return probability
77
 
78
 
@@ -81,7 +96,7 @@ def example_csv_click(patient_id: int):
81
 
82
  patient = correct_preds[patient_id] if patient_id in correct_preds else wrong_preds[patient_id]
83
  probability = _predict(patient['file_name'])
84
- return [{"Death": probability, "Alive": 1-probability},
85
  patient['label']]
86
 
87
 
@@ -90,7 +105,7 @@ def user_csv_upload(temp_csv_file_path):
90
 
91
  matches = _extract_patient_data_from_name(temp_csv_file_path)
92
  probability = _predict(temp_csv_file_path)
93
- return [{"Death": probability, "Alive": 1-probability},
94
  "(Not Available)" if matches is None else matches[1]]
95
 
96
 
@@ -128,7 +143,7 @@ css = \
128
  #selectFileToUpload {max-height: 180px}
129
  .gradio-container {
130
  background: url(https://www.kindpng.com/picc/m/207-2075829_transparent-healthcare-clipart-medical-report-icon-hd-png.png);
131
- background-position: 80% 80%;
132
  background-repeat: no-repeat;
133
  background-size: 200px;
134
  }
 
7
  import pandas as pd
8
  import re
9
 
10
+ from model import VariationalGNN
11
+
12
 
13
  examples_path = "examples"
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+
 
16
  correct_preds, wrong_preds = {}, {}
17
  condition_lst = pd.read_csv("feature.csv", header = "infer", sep = ",", encoding = "utf-8", dtype=str)
18
  D_LABITEMS = pd.read_csv("D_LABITEMS.csv", header = "infer", sep = ",", encoding = "utf-8", dtype=str)
19
 
20
+
21
+ def load_model():
22
+ path = r"final_model.pt"
23
+ kwargs, state = torch.load(path, weights_only=False)
24
+ model = VariationalGNN(**kwargs).to(device)
25
+ model.load_state_dict(state)
26
+ return model
27
+
28
+ model = load_model()
29
+
30
  def _check_patient_csv_format(df: pd.DataFrame):
31
  if not (list(df.columns)[0:2] == ["condition", "value"]):
32
  raise gr.Error(f"Column set [{list(df.columns)}]: not expected.", duration=None)
 
82
  keep_default_na=False)
83
  _check_patient_csv_format(df)
84
  patient_data = torch.from_numpy(df["value"].to_numpy()).unsqueeze(dim=0).to(device)
85
+
86
+ model.eval()
87
+ with torch.inference_mode():
88
+ probability, _ = model(patient_data)
89
+ probability = torch.sigmoid(probability.detach().cpu()[0]).item()
90
+
91
  return probability
92
 
93
 
 
96
 
97
  patient = correct_preds[patient_id] if patient_id in correct_preds else wrong_preds[patient_id]
98
  probability = _predict(patient['file_name'])
99
+ return [{"dead": probability, "alive": 1-probability},
100
  patient['label']]
101
 
102
 
 
105
 
106
  matches = _extract_patient_data_from_name(temp_csv_file_path)
107
  probability = _predict(temp_csv_file_path)
108
+ return [{"dead": probability, "alive": 1-probability},
109
  "(Not Available)" if matches is None else matches[1]]
110
 
111
 
 
143
  #selectFileToUpload {max-height: 180px}
144
  .gradio-container {
145
  background: url(https://www.kindpng.com/picc/m/207-2075829_transparent-healthcare-clipart-medical-report-icon-hd-png.png);
146
+ background-position: 80% 85%;
147
  background-repeat: no-repeat;
148
  background-size: 200px;
149
  }
final_model.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ca894ef68b6e4e9af3a425e22ab9378688b2d199d39aa07cd11b8307acc45967
3
- size 60999527
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fc8ce76b804f59492a8898b568978aaf73a64bcd8ab7e97cbf83626113ffde39
3
+ size 60944806
model.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ import numpy as np
9
+
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ # device = torch.device("cpu")
12
+
13
+ def clones(module, N):
14
+ return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
15
+
16
+
17
+ def clone_params(param, N):
18
+ return nn.ParameterList([copy.deepcopy(param) for _ in range(N)])
19
+
20
+
21
+ # TODO: replaced with https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html?
22
+ class LayerNorm(nn.Module):
23
+ def __init__(self, features, eps=1e-6):
24
+ super(LayerNorm, self).__init__()
25
+ self.a_2 = nn.Parameter(torch.ones(features))
26
+ self.b_2 = nn.Parameter(torch.zeros(features))
27
+ self.eps = eps
28
+
29
+ def forward(self, x):
30
+ mean = x.mean(-1, keepdim=True)
31
+ std = x.std(-1, keepdim=True)
32
+ return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
33
+
34
+
35
+ class GraphLayer(nn.Module):
36
+
37
+ def __init__(self, in_features, hidden_features, out_features, num_of_nodes,
38
+ num_of_heads, dropout, alpha, concat=True):
39
+ super(GraphLayer, self).__init__()
40
+ self.in_features = in_features # MyNote: Embedding size
41
+ self.hidden_features = hidden_features # MyNote: Embedding size
42
+ self.out_features = out_features # MyNote: Embedding size (ngoại trừ Decoder Graph, khác chỗ này)
43
+ self.alpha = alpha # MyNote: hardcoded 0.1
44
+ self.concat = concat # MyNote: Encoder graph ->True; Decoder Graph -> False.
45
+ self.num_of_nodes = num_of_nodes # MyNote: Số node trong Graph.
46
+ self.num_of_heads = num_of_heads # MyNote: Số attention head. -> là 1 (VGNN/Mimic)
47
+
48
+ # MyNote: gọi clones() nhưng List chỉ có 1 phần tử vì num_of_heads=1 (ghi trong paper).
49
+ self.W = clones(nn.Linear(in_features, hidden_features), num_of_heads)
50
+ self.a = clone_params(nn.Parameter(torch.rand(size=(1, 2 * hidden_features)), requires_grad=True), num_of_heads)
51
+ self.ffn = nn.Sequential(
52
+ nn.Linear(out_features, out_features),
53
+ nn.ReLU()
54
+ )
55
+
56
+ if not concat:
57
+ self.V = nn.Linear(hidden_features, out_features)
58
+ else:
59
+ self.V = nn.Linear(num_of_heads * hidden_features, out_features)
60
+
61
+ self.dropout = nn.Dropout(dropout)
62
+ self.leakyrelu = nn.LeakyReLU(self.alpha)
63
+
64
+ if concat: # MyNote: Ko hiểu khác nhau chỗ nào?
65
+ self.norm = LayerNorm(hidden_features)
66
+ else:
67
+ self.norm = LayerNorm(hidden_features)
68
+
69
+ def initialize(self):
70
+ for i in range(len(self.W)):
71
+ nn.init.xavier_normal_(self.W[i].weight.data)
72
+ for i in range(len(self.a)):
73
+ nn.init.xavier_normal_(self.a[i].data)
74
+ if not self.concat:
75
+ nn.init.xavier_normal_(self.V.weight.data)
76
+ nn.init.xavier_normal_(self.out_layer.weight.data)
77
+
78
+ def attention(self, linear, a, N, data, edge):
79
+ """MyNote: _summary_
80
+
81
+ Args:
82
+ linear (_type_): weights (R^(dxd))
83
+ a (_type_): bias (R^(1x(2*d)))
84
+ N (_type_): number of nodes
85
+ data (_type_): h_prime = Toàn bộ Nodes & Embedding của nó.
86
+ edge (_type_): Vd: edge -> input_edges = 2x11664
87
+ 108x108=11664 -> 108 lab-value/procedure... (one-hot encoding)
88
+
89
+ Returns:
90
+ _type_: _description_
91
+ """
92
+ data = linear(data).unsqueeze(0)
93
+ assert not torch.isnan(data).any()
94
+ # edge: 2*D x E
95
+ h = torch.cat((data[:, edge[0, :], :], data[:, edge[1, :], :]),
96
+ dim=0)
97
+ data = data.squeeze(0)
98
+ # h: N x out
99
+ assert not torch.isnan(h).any()
100
+ # edge_h: 2*D x E
101
+ edge_h = torch.cat((h[0, :, :], h[1, :, :]), dim=1).transpose(0, 1)
102
+ # edge: 2*D x E
103
+ edge_e = torch.exp(self.leakyrelu(a.mm(edge_h).squeeze()) / np.sqrt(self.hidden_features * self.num_of_heads))
104
+ assert not torch.isnan(edge_e).any()
105
+ # edge_e: E
106
+ edge_e = torch.sparse_coo_tensor(edge, edge_e, torch.Size([N, N]))
107
+ e_rowsum = torch.sparse.mm(edge_e, torch.ones(size=(N, 1)).to(device))
108
+ # e_rowsum: N x 1
109
+ row_check = (e_rowsum == 0)
110
+ e_rowsum[row_check] = 1
111
+ zero_idx = row_check.nonzero()[:, 0]
112
+ edge_e = edge_e.add(
113
+ torch.sparse.FloatTensor(zero_idx.repeat(2, 1), torch.ones(len(zero_idx)).to(device), torch.Size([N, N]))) # type: ignore
114
+ # edge_e: E
115
+ h_prime = torch.sparse.mm(edge_e, data)
116
+ assert not torch.isnan(h_prime).any()
117
+ # h_prime: N x out
118
+ h_prime.div_(e_rowsum)
119
+ # h_prime: N x out
120
+ assert not torch.isnan(h_prime).any()
121
+ return h_prime
122
+
123
+ def forward(self, edge, data=None):
124
+ # MyNote: input: (input_edges, h_prime)
125
+ # MyNote: Vd: edge -> input_edges = 2x11881
126
+ # MyNote: data -> h_prime = Toàn bộ Nodes & Embedding của nó.
127
+ N = self.num_of_nodes
128
+
129
+ if self.concat: # MyNote: hardcoded True
130
+ # MyNote: Zip nhưng thực ra chỉ có 1 element vì Attention head là 1 (ghi trong paper).
131
+ h_prime = torch.cat([self.attention(l, a, N, data, edge) for l, a in zip(self.W, self.a)], dim=1)
132
+ else:
133
+ h_prime = torch.stack([self.attention(l, a, N, data, edge) for l, a in zip(self.W, self.a)], dim=0).mean(
134
+ dim=0)
135
+
136
+ h_prime = self.dropout(h_prime)
137
+
138
+ if self.concat:
139
+ return F.elu(self.norm(h_prime))
140
+ else:
141
+ return self.V(F.relu(self.norm(h_prime)))
142
+
143
+
144
+ class VariationalGNN(nn.Module):
145
+
146
+ def __init__(self,
147
+ in_features,
148
+ out_features,
149
+ num_of_nodes,
150
+ n_heads,
151
+ n_layers,
152
+ dropout,
153
+ alpha, # MyNote: hardcoded 0.1
154
+ variational=True,
155
+ none_graph_features=0,
156
+ concat=True):
157
+
158
+ # Save input parameters for later convenient restoration of the object for inference.
159
+ self.kwargs = {'in_features': in_features,
160
+ 'out_features': out_features,
161
+ 'num_of_nodes': num_of_nodes,
162
+ 'n_heads': n_heads,
163
+ 'n_layers': n_layers,
164
+ 'dropout': dropout,
165
+ 'alpha': alpha,
166
+ 'variational': variational,
167
+ 'none_graph_features': none_graph_features,
168
+ 'concat': concat}
169
+
170
+ super(VariationalGNN, self).__init__()
171
+ self.variational = variational
172
+ # Add two more nodes: the 1st indicates the patient is normal; the last node is used to absorb features from specific nodes of specific patients, to make prediction.
173
+ self.num_of_nodes = num_of_nodes + 2 - none_graph_features
174
+ # MyNote: this is the lookup embedding in paper. (Patient)
175
+ self.embed = nn.Embedding(self.num_of_nodes, in_features, padding_idx=0)
176
+
177
+ self.in_att = clones(
178
+ GraphLayer(in_features, in_features, in_features, self.num_of_nodes,
179
+ n_heads, dropout, alpha, concat=True), n_layers)
180
+ self.out_features = out_features
181
+ self.out_att = GraphLayer(in_features, in_features, out_features, self.num_of_nodes,
182
+ n_heads, dropout, alpha, concat=False)
183
+ self.n_heads = n_heads
184
+ self.dropout = nn.Dropout(dropout)
185
+ self.parameterize = nn.Linear(out_features, out_features * 2)
186
+ self.out_layer = nn.Sequential(
187
+ nn.Linear(out_features, out_features),
188
+ nn.ReLU(),
189
+ nn.Dropout(dropout),
190
+ nn.Linear(out_features, 1))
191
+ self.none_graph_features = none_graph_features
192
+ #region none_graph_features > 0
193
+ if none_graph_features > 0:
194
+ self.features_ffn = nn.Sequential(
195
+ nn.Linear(none_graph_features, out_features//2),
196
+ nn.ReLU(),
197
+ nn.Dropout(dropout))
198
+ self.out_layer = nn.Sequential(
199
+ nn.Linear(out_features + out_features//2, out_features),
200
+ nn.ReLU(),
201
+ nn.Dropout(dropout),
202
+ nn.Linear(out_features, 1))
203
+ #endregion
204
+ for i in range(n_layers):
205
+ self.in_att[i].initialize()
206
+
207
+ """MyNote: Hàm này để chi? -> data là 1 patient sample với multihot encoding (chỉ bệnh).
208
+ Cần trả về các Edges nối các bệnh này với nhau. Nhớ rằng: mặc định tất cả các bệnh Connect với nhau.
209
+ """
210
+ def data_to_edges(self, data):
211
+ """MyNote: Must return (input_edges, output_edges)"""
212
+ length = data.size()[0]
213
+ nonzero = data.nonzero() # MyNote: return indices indicating non-zero values.
214
+ if nonzero.size()[0] == 0: # MyNote: case mà Patient bình thường! (ko có chẩn đoán, xét nghiệm gì!)
215
+ # MyNote: Why return so? shape(2, 1), shape(2, 1) Why length + 1? -> Khi bệnh nhân bình thường, vector bệnh của họ toàn là 0 -> cũng phải trả
216
+ # ra cái gì đó (vậy là chọn Node đầu và node cuối)
217
+ # MyNote: Right side: should include also torch.LongTensor([[0], [0]]) -> ám chỉ là "bình thường" (ko bệnh tật)???
218
+ return torch.LongTensor([[0], [0]]), torch.LongTensor([[length + 1], [length + 1]])
219
+ if self.training:
220
+ mask = torch.rand(nonzero.size()[0])
221
+ mask = mask > 0.05
222
+ nonzero = nonzero[mask]
223
+ if nonzero.size()[0] == 0:
224
+ # MyNote: có phải ý là ngay cả khi Patient có issue, 5% trong số đó ta sẽ đối x��� như là ko có issue???
225
+ return torch.LongTensor([[0], [0]]), torch.LongTensor([[length + 1], [length + 1]])
226
+
227
+ # MyNote: case: when (testing/validating/infering) OR 95% probability bệnh nhân có ít nhất 1 issue nào đó.
228
+ nonzero = nonzero.transpose(0, 1) + 1 # MyNote: Why +1? -> Cộng để tăng Index vì có 2 Node giả đầu (là node chỉ bình thường) và cuối (là node absorb các node khác cho predict)
229
+ lengths = nonzero.size()[1]
230
+ input_edges = torch.cat((nonzero.repeat(1, lengths),
231
+ nonzero.repeat(lengths, 1).transpose(0, 1)
232
+ .contiguous().view((1, lengths ** 2))), dim=0)
233
+
234
+ nonzero = torch.cat((nonzero, torch.LongTensor([[length + 1]]).to(device)), dim=1)
235
+ lengths = nonzero.size()[1]
236
+ output_edges = torch.cat((nonzero.repeat(1, lengths),
237
+ nonzero.repeat(lengths, 1).transpose(0, 1)
238
+ .contiguous().view((1, lengths ** 2))), dim=0)
239
+ return input_edges.to(device), output_edges.to(device)
240
+
241
+ def reparameterise(self, mu, logvar):
242
+ if self.training:
243
+ # Assume log_variation (NOT log_standard_deviation!)
244
+ std = logvar.mul(0.5).exp_()
245
+ # MyNote: tensor.new() -> Constructs a new tensor of the same data type as self tensor.
246
+ eps = std.data.new(std.size()).normal_()
247
+ return eps.mul(std).add_(mu)
248
+ else:
249
+ return mu
250
+
251
+ def encoder_decoder(self, data):
252
+ """Given a patient data, encode it into the total graph, then decode to the last node.
253
+
254
+ Args:
255
+ data ([N]): multi-hot encoding (of diagnose codes). E.g. shape = [1309]
256
+
257
+ Returns:
258
+ Tuple[Tensor, Tensor]: The last node's features, plus KL Divergence
259
+ """
260
+ N = self.num_of_nodes
261
+ input_edges, output_edges = self.data_to_edges(data)
262
+ h_prime = self.embed(torch.arange(N).long().to(device))
263
+
264
+ # Encoder:
265
+ for attn in self.in_att:
266
+ h_prime = attn(input_edges, h_prime)
267
+
268
+ if self.variational:
269
+ # Even given only a patient's data, this parameterization affects the total graph.
270
+ h_prime = self.parameterize(h_prime).view(-1, 2, self.out_features)
271
+ h_prime = self.dropout(h_prime)
272
+ mu = h_prime[:, 0, :]
273
+ logvar = h_prime[:, 1, :]
274
+ h_prime = self.reparameterise(mu, logvar) # h_prime.shape = [N, z_dim] e.g. (1311x256)
275
+
276
+ # Essential variables (mu, ,logvar) for computing DL Divergence later.
277
+ # Note: only consider the patient's graph (NOT the total graph).
278
+ split = int(math.sqrt(len(input_edges[0])))
279
+ pat_diag_code_idx = input_edges[0][0:split]
280
+ mu = mu[pat_diag_code_idx, :]
281
+ logvar = logvar[pat_diag_code_idx, :]
282
+
283
+ # Decoder:
284
+ h_prime = self.out_att(output_edges, h_prime)
285
+
286
+ if self.variational:
287
+ """
288
+ Need to divide with mu.size()[0] because the original formula sums over all latent dimensions.
289
+ """
290
+ return (h_prime[-1], # The last node's features.
291
+ 0.5 * torch.sum(logvar.exp() - logvar - 1 + mu.pow(2)) / mu.size()[0]
292
+ )
293
+ else:
294
+ return (h_prime[-1], \
295
+ torch.tensor(0.0).to(device)
296
+ )
297
+
298
+ def forward(self, data):
299
+ # Concate batches
300
+ batch_size = data.size()[0]
301
+ # In eicu data the first feature whether have be admitted before is not included in the graph
302
+ if self.none_graph_features == 0: # MyNote: self.none_graph_features hardcoded = 0!!! -> cái này ko phải ám chỉ là ko dùng features cho nodes!
303
+ # MyNote: for each Patient-Encounter, encode the graph specifically for that.
304
+ outputs = [self.encoder_decoder(data[i, :]) for i in range(batch_size)]
305
+ # MyNote: return logits (output of out_layer()) -> later use BCEWithLogitsLoss
306
+ return self.out_layer(F.relu(torch.stack([out[0] for out in outputs]))), \
307
+ torch.sum(torch.stack([out[1] for out in outputs]))
308
+ else:
309
+ outputs = [(data[i, :self.none_graph_features],
310
+ self.encoder_decoder(data[i, self.none_graph_features:])) for i in range(batch_size)]
311
+ return self.out_layer(F.relu(
312
+ torch.stack([torch.cat((self.features_ffn(torch.FloatTensor([out[0]]).to(device)), out[1][0]))
313
+ for out in outputs]))), \
314
+ torch.sum(torch.stack([out[1][1] for out in outputs]), dim=-1)