tnt306 commited on
Commit
09980bd
·
1 Parent(s): 3797309

Reorganized folders.

Browse files

Added training graph.

app.py CHANGED
@@ -14,12 +14,12 @@ examples_path = "examples"
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
 
16
  correct_preds, wrong_preds = {}, {}
17
- condition_lst = pd.read_csv("feature.csv", header = "infer", sep = ",", encoding = "utf-8", dtype=str)
18
- D_LABITEMS = pd.read_csv("D_LABITEMS.csv", header = "infer", sep = ",", encoding = "utf-8", dtype=str)
19
 
20
 
21
  def load_model():
22
- path = r"final_model.pt"
23
  kwargs, state = torch.load(path, weights_only=False, map_location=device)
24
  model = VariationalGNN(**kwargs).to(device)
25
  model.load_state_dict(state)
@@ -236,7 +236,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
236
  query_tbx.change(fn=query_input_change_event, inputs=[query_tbx, query_type], outputs=[query_btn, html])
237
  query_type.change(fn=query_input_change_event, inputs=[query_tbx, query_type], outputs=[query_btn, html])
238
  query_btn.click(fn=do_query, inputs=[query_tbx, query_type], outputs=html)
239
- with gr.Accordion("More on technical details...", open=False):
240
  gr.Markdown(
241
  """
242
  - Paper: [Variationally Regularized Graph-based Representation Learning for Electronic Health Records (Zhu et al, 2021)](https://arxiv.org/abs/1912.03761)
@@ -250,9 +250,13 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
250
  - Trained on NVIDIA A100 with PyTorch 2.4.0
251
  """
252
  )
 
 
 
 
253
 
254
  for partialFunc in resDispPartFuncs:
255
  partialFunc(outputs=[result_pred, result_label])
256
 
257
 
258
- demo.launch(debug=True)
 
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
 
16
  correct_preds, wrong_preds = {}, {}
17
+ condition_lst = pd.read_csv("data/feature.csv", header = "infer", sep = ",", encoding = "utf-8", dtype=str)
18
+ D_LABITEMS = pd.read_csv("data/D_LABITEMS.csv", header = "infer", sep = ",", encoding = "utf-8", dtype=str)
19
 
20
 
21
  def load_model():
22
+ path = r"models/final_model.pt"
23
  kwargs, state = torch.load(path, weights_only=False, map_location=device)
24
  model = VariationalGNN(**kwargs).to(device)
25
  model.load_state_dict(state)
 
236
  query_tbx.change(fn=query_input_change_event, inputs=[query_tbx, query_type], outputs=[query_btn, html])
237
  query_type.change(fn=query_input_change_event, inputs=[query_tbx, query_type], outputs=[query_btn, html])
238
  query_btn.click(fn=do_query, inputs=[query_tbx, query_type], outputs=html)
239
+ with gr.Accordion("More on Technical Details...", open=False):
240
  gr.Markdown(
241
  """
242
  - Paper: [Variationally Regularized Graph-based Representation Learning for Electronic Health Records (Zhu et al, 2021)](https://arxiv.org/abs/1912.03761)
 
250
  - Trained on NVIDIA A100 with PyTorch 2.4.0
251
  """
252
  )
253
+ with gr.Accordion("More on Training...", open=False):
254
+ gr.HTML("""
255
+ <img src="/file=images/AUPRC_Training_Graph.png" alt="">
256
+ """)
257
 
258
  for partialFunc in resDispPartFuncs:
259
  partialFunc(outputs=[result_pred, result_label])
260
 
261
 
262
+ demo.launch(debug=True, allowed_paths=["images/."])
D_LABITEMS.csv → data/D_LABITEMS.csv RENAMED
File without changes
feature.csv → data/feature.csv RENAMED
File without changes
images/AUPRC_Training_Graph.png ADDED
final_model.pt → models/final_model.pt RENAMED
File without changes