tnt306 commited on
Commit
759caee
·
1 Parent(s): 6cbef49
Files changed (3) hide show
  1. .vscode/launch.json +15 -0
  2. app.py +72 -56
  3. final_model.pth +3 -0
.vscode/launch.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ // Use IntelliSense to learn about possible attributes.
3
+ // Hover to view descriptions of existing attributes.
4
+ // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5
+ "version": "0.2.0",
6
+ "configurations": [
7
+ {
8
+ "name": "Python Debugger: Current File",
9
+ "type": "debugpy",
10
+ "request": "launch",
11
+ "program": "app.py",
12
+ "console": "integratedTerminal"
13
+ }
14
+ ]
15
+ }
app.py CHANGED
@@ -1,78 +1,94 @@
1
  from typing import Dict
2
  from pathlib import Path
3
- import pandas as pd
4
- from io import StringIO
5
 
 
 
 
 
6
  import gradio as gr
 
7
 
8
- # def predict(text: str) -> Dict:
9
- # return {"alive": 0.9, "death": 0.1}
10
-
11
- # example_list = [[1.0, 1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 0.0]]
12
-
13
- # # Create title, description and article strings
14
- # title = "This is title."
15
- # description = "This is description."
16
- # article = "This is article."
17
-
18
- # default_csv = "Phase,Activity,Start date,End date\n\"Mapping the Field\",\"Literature review\",2024-01-01,2024-01-31"
19
-
20
- # def process_csv_text(temp_file):
21
- # if isinstance(temp_file, str):
22
- # print("1")
23
- # df = pd.read_csv(temp_file, header = "infer", sep = ",", encoding = "utf-8")
24
- # else:
25
- # print("2")
26
- # df = pd.read_csv(temp_file.name)
27
- # print("***")
28
- # print(df)
29
- # print("***")
30
- # return df
31
-
32
- # with gr.Blocks() as demo:
33
- # upload_button = gr.UploadButton(label="Upload Timetable", file_types = ['.csv'], file_count = "single")
34
- # table = gr.Dataframe(headers=["Phase", "Activity", "Start date", "End date"], type="pandas", col_count=4)
35
- # upload_button.upload(fn=process_csv_text, inputs=upload_button, outputs=table, api_name="upload_csv")
36
-
37
- # demo.launch(debug=True)
38
 
39
- def predict():
 
 
 
 
 
 
 
 
40
  return {"Death": 0.9, "Alive": 0.1}
41
 
42
- def download_patient(patient_id: str) -> str:
43
- my_file = Path(f"Patient{patient_id}.csv")
44
- print(f"File to download [{str(my_file)}].")
45
- if not my_file.is_file():
46
- raise Exception(f"[{my_file}] not found.")
47
- print(f"Downloading file [{str(my_file)}].")
48
- return str(my_file)
49
 
50
- with gr.Blocks() as demo:
51
  with gr.Row():
52
  with gr.Column():
 
 
 
 
 
 
53
  patient_upload_file = gr.File(label="Upload A Patient",
54
  file_types = ['.csv'],
55
- file_count = "single")
 
 
 
 
 
 
56
  with gr.Row():
57
- with gr.Column(min_width=100):
58
- patient_1_input_btn = gr.Button("Patient 1")
59
- patient_1_download_btn = gr.DownloadButton(label="Download 1", value="Patient1.csv")
60
- with gr.Column(min_width=100):
61
- patient_2_input_btn = gr.Button("Patient 2")
62
- patient_2_download_btn = gr.DownloadButton(label="Download 2", value="Patient2.csv")
63
- with gr.Column(min_width=100):
64
- patient_3_input_btn = gr.Button("Patient 3")
65
- patient_3_download_btn = gr.DownloadButton(label="Download 3", value="Patient3.csv")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  with gr.Column():
 
 
 
 
 
 
67
  result = gr.Label(num_top_classes=2, label="Predictions")
68
 
69
  # Choose a patient to predict.
70
- patient_1_input_btn.click(fn=predict, inputs=None, outputs=result, api_name="predict")
 
 
71
 
72
- # Download a patient ehr profile.
73
- # patient_1_download_btn.click(fn=download_patient, inputs=patient_1_download_btn, outputs=patient_1_download_btn)
74
- # patient_2_download_btn.click(fn=download_patient, inputs=patient_2_download_btn, outputs=patient_2_download_btn)
75
- # patient_3_download_btn.click(fn=download_patient, inputs=patient_3_download_btn, outputs=patient_3_download_btn)
76
 
77
 
78
  demo.launch(debug=True)
 
 
1
  from typing import Dict
2
  from pathlib import Path
3
+ import pickle, logging, sys
4
+ from typing import Tuple, List, Dict
5
 
6
+ import numpy as np
7
+ import numpy.typing as npt
8
+ from scipy.sparse import csr_matrix
9
+ import torch
10
  import gradio as gr
11
+ import pandas as pd
12
 
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ model = torch.jit.load("final_model.pth")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ def predict(patient_id: int):
17
+ print(f"predict patient {patient_id}")
18
+ df = pd.read_csv(f"Patient{patient_id}.csv",
19
+ header="infer",
20
+ sep=",",
21
+ encoding="utf-8",
22
+ dtype={'condition': 'str', 'user_key': 'float32'},
23
+ keep_default_na=False)
24
+
25
  return {"Death": 0.9, "Alive": 0.1}
26
 
 
 
 
 
 
 
 
27
 
28
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
29
  with gr.Row():
30
  with gr.Column():
31
+ gr.Markdown(
32
+ """
33
+ ## Input:
34
+ (See examples for file structure)
35
+ """
36
+ )
37
  patient_upload_file = gr.File(label="Upload A Patient",
38
  file_types = ['.csv'],
39
+ file_count = "single",
40
+ height=100)
41
+ gr.Markdown(
42
+ """
43
+ ## Examples - Correct Prediction:
44
+ """
45
+ )
46
  with gr.Row():
47
+ with gr.Column(variant='panel', min_width=100):
48
+ patient_1_input_btn = gr.Button("Patient 1", size="sm")
49
+ patient_1_download_btn = gr.DownloadButton(label="Download", value="Patient1.csv", size="sm")
50
+ patient_id_1 = gr.Number(value=1, visible=False)
51
+ with gr.Column(variant='panel', min_width=100):
52
+ patient_2_input_btn = gr.Button("Patient 2", size="sm")
53
+ patient_2_download_btn = gr.DownloadButton(label="Download", value="Patient2.csv", size="sm")
54
+ patient_id_2 = gr.Number(value=2, visible=False)
55
+ with gr.Column(variant='panel', min_width=100):
56
+ patient_3_input_btn = gr.Button("Patient 3", size="sm")
57
+ patient_3_download_btn = gr.DownloadButton(label="Download", value="Patient3.csv", size="sm")
58
+ patient_id_3 = gr.Number(value=3, visible=False)
59
+ gr.Markdown(
60
+ """
61
+ ## Examples - Wrong Prediction:
62
+ """
63
+ )
64
+ with gr.Row():
65
+ with gr.Column(variant='panel', min_width=100):
66
+ patient_4_input_btn = gr.Button("Patient 4", size="sm")
67
+ patient_4_download_btn = gr.DownloadButton(label="Download", value="Patient4.csv", size="sm")
68
+ patient_id_4 = gr.Number(value=4, visible=False)
69
+ with gr.Column(variant='panel', min_width=100):
70
+ patient_5_input_btn = gr.Button("Patient 5", size="sm")
71
+ patient_5_download_btn = gr.DownloadButton(label="Download", value="Patient5.csv", size="sm")
72
+ patient_id_5 = gr.Number(value=5, visible=False)
73
+ with gr.Column(variant='panel', min_width=100):
74
+ patient_6_input_btn = gr.Button("Patient 6", size="sm")
75
+ patient_6_download_btn = gr.DownloadButton(label="Download", value="Patient6.csv", size="sm")
76
+ patient_id_6 = gr.Number(value=6, visible=False)
77
  with gr.Column():
78
+ gr.Markdown(
79
+ """
80
+ ## Mortality Prediction:
81
+ In 24 hours after ICU admission.
82
+ """
83
+ )
84
  result = gr.Label(num_top_classes=2, label="Predictions")
85
 
86
  # Choose a patient to predict.
87
+ patient_1_input_btn.click(fn=predict, inputs=patient_id_1, outputs=result, api_name="predict")
88
+ patient_2_input_btn.click(fn=predict, inputs=patient_id_2, outputs=result, api_name="predict")
89
+ patient_3_input_btn.click(fn=predict, inputs=patient_id_3, outputs=result, api_name="predict")
90
 
 
 
 
 
91
 
92
 
93
  demo.launch(debug=True)
94
+
final_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:731461d450bb0ec5f011c3751c98621171fb9abd017331da48d1f6fb3dec184d
3
+ size 61002953