from pathlib import Path from glob import glob from functools import partial import numpy as np import torch import gradio as gr import pandas as pd import re from model import VariationalGNN examples_path = "examples" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") correct_preds, wrong_preds = {}, {} condition_lst = pd.read_csv("data/feature.csv", header = "infer", sep = ",", encoding = "utf-8", dtype=str) D_LABITEMS = pd.read_csv("data/D_LABITEMS.csv", header = "infer", sep = ",", encoding = "utf-8", dtype=str) def load_model(): path = r"models/final_model.pt" kwargs, state = torch.load(path, weights_only=False, map_location=device) model = VariationalGNN(**kwargs).to(device) model.load_state_dict(state) return model model = load_model() def _check_patient_csv_format(df: pd.DataFrame): if not (list(df.columns)[0:2] == ["condition", "value"]): raise gr.Error(f"Column set [{list(df.columns)}]: not expected.", duration=None) if condition_lst["condition"].to_list() != df["condition"].to_list(): raise gr.Error(f"Condition set: not expected.", duration=None) vals = np.sort(df["value"].unique()) if not (vals.ndim == 1 and len(vals) == 2 and all(vals == np.array([0.0, 1.0]))): raise gr.Error(f"Column 'value': contain invalid values.", duration=None) def _extract_patient_data_from_name(csv_file_name: str): patient_file_pat = r"^Patient_(\d+)_\(Label-(alive|dead)\)_\(Predicted-(dead|alive)\).csv$" csv_name = Path(csv_file_name).name matches = re.search(patient_file_pat, csv_name) if matches is None: return None else: return (matches.group(1), matches.group(2), matches.group(3)) def _find_example_csv_files() -> None: all_csv_files = glob(f'{examples_path}/*.csv', recursive=True) if len(all_csv_files) == 0: print("*** No csv files found.") else: for one_csv_file in all_csv_files: matches = _extract_patient_data_from_name(one_csv_file) if matches: pat_id, pat_label, pat_predicted = matches if pat_id in correct_preds or pat_id in wrong_preds: print(f"*** File [{one_csv_file}]: already processed! How come?") else: if pat_label == pat_predicted: correct_preds[pat_id] = {"label": pat_label, "predicted": pat_predicted, "file_name": one_csv_file} else: wrong_preds[pat_id] = {"label": pat_label, "predicted": pat_predicted, "file_name": one_csv_file} else: print(f"*** File [{one_csv_file}]: wrong name.") _find_example_csv_files() def _predict(file_path: str): df = pd.read_csv(f"{file_path}", header="infer", sep=",", encoding="utf-8", dtype={'condition': 'str', 'value': 'float32'}, keep_default_na=False) _check_patient_csv_format(df) patient_data = torch.from_numpy(df["value"].to_numpy()).unsqueeze(dim=0).to(device) model.eval() with torch.inference_mode(): probability, _ = model(patient_data) probability = torch.sigmoid(probability.detach().cpu()[0]).item() return probability def example_csv_click(patient_id: int): print(f"*** Predict patient {patient_id} (Example CSV)") patient = correct_preds[patient_id] if patient_id in correct_preds else wrong_preds[patient_id] probability = _predict(patient['file_name']) return [{"dead": probability, "alive": 1-probability}, patient['label']] def user_csv_upload(temp_csv_file_path): print(f"*** Predict patient (User CSV Upload)") matches = _extract_patient_data_from_name(temp_csv_file_path) probability = _predict(temp_csv_file_path) return [{"dead": probability, "alive": 1-probability}, "(Not Available)" if matches is None else matches[1]] def do_query(query_str, query_type): if query_type in ["Diagnosis", "Procedure"]: str_to_search = f"ICD-9 {query_type} Code " + query_str return gr.HTML(value=f'Google', visible=True) else: # Lab Code query_str = query_str.strip() if (index := query_str.rfind("_")) >= 0: query_str = query_str[0:index] res = D_LABITEMS[D_LABITEMS["ITEMID"] == query_str] if res.shape[0] == 0: answer = "(Something wrong. No definition found.)" elif res.shape[0] == 1: answer = f"{res['LABEL'].values[0]}-{res['FLUID'].values[0]}-{res['CATEGORY'].values[0]}" else: answer=f"(Something wrong. Too many definitions, given code [{query_str}].)" return gr.HTML(value=answer, visible=True) def query_input_change_event(query_str, query_type): if (query_str is not None and len(query_str.strip())>0 and\ query_type is not None): return [gr.Button(interactive=True), gr.HTML(visible=False)] else: return [gr.Button(interactive=False), gr.HTML(visible=False)] resDispPartFuncs = [] css = \ """ #selectFileToUpload {max-height: 180px} .gradio-container { background: url(https://www.kindpng.com/picc/m/207-2075829_transparent-healthcare-clipart-medical-report-icon-hd-png.png); background-position: 80% 85%; background-repeat: no-repeat; background-size: 200px; } #label-label { height: 50px !important; } #label-label > .container { height: 50px !important; } #label-label > .container > h2 { //height: 50px !important; padding: 0 !important; } """ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: with gr.Row(): with gr.Column(): gr.Markdown( """ ## Input: (See examples for file structure) """ ) patient_upload_file = gr.File(label="Upload A Patient", file_types = ['.csv'], file_count = "single", elem_id="selectFileToUpload") patient_upload_file.upload(fn=user_csv_upload, inputs=patient_upload_file, outputs=None) gr.Markdown( """ ## Examples - Correct Prediction: """ ) with gr.Row(): for patient_id in correct_preds.keys(): with gr.Column(variant='panel', min_width=100): patient_input_btn = gr.Button(f"Patient {patient_id}", size="sm") patient_download_btn = gr.DownloadButton(label="Download", value=f"{correct_preds[patient_id]['file_name']}", size="sm") patient_id_num = gr.Number(value=patient_id, visible=False) partFunc = partial(patient_input_btn.click, fn=example_csv_click, inputs=patient_id_num, api_name="predict") resDispPartFuncs.append(partFunc) gr.Markdown( """ ## Examples - Wrong Prediction: """ ) with gr.Row(): for patient_id in wrong_preds.keys(): with gr.Column(variant='panel', min_width=100): patient_input_btn = gr.Button(f"Patient {patient_id}", size="sm") patient_download_btn = gr.DownloadButton(label="Download", value=f"{wrong_preds[patient_id]['file_name']}", size="sm") patient_id_num = gr.Number(value=patient_id, visible=False) partFunc = partial(patient_input_btn.click, fn=example_csv_click, inputs=patient_id_num, api_name="predict") resDispPartFuncs.append(partFunc) with gr.Column(): gr.Markdown( """ ## Mortality Prediction: In 24 hours after ICU admission. """ ) result_pred = gr.Label(num_top_classes=2, label="Predicted") result_label = gr.Label(label="Label", elem_id="label-label") with gr.Accordion("More on Patient Conditions...", open=False): query_tbx = gr.Textbox(label="Enter one ICD-9 Diagnosis/Procedure Code or Lab Value:", lines=1, max_lines=1, placeholder="00869 for 'Other viral intes infec' (Diagnosis)") query_type = gr.Radio(["Diagnosis", "Procedure", "Lab Value"], show_label=False) query_btn = gr.Button(value="Query", size="sm", interactive=False) html = gr.HTML("", visible=False) query_tbx.change(fn=query_input_change_event, inputs=[query_tbx, query_type], outputs=[query_btn, html]) query_type.change(fn=query_input_change_event, inputs=[query_tbx, query_type], outputs=[query_btn, html]) query_btn.click(fn=do_query, inputs=[query_tbx, query_type], outputs=html) with gr.Accordion("More on Technical Details...", open=False): gr.Markdown( """ - Paper: [Variationally Regularized Graph-based Representation Learning for Electronic Health Records (Zhu et al, 2021)](https://arxiv.org/abs/1912.03761) - Dataset: [MIMIC-III](https://physionet.org/content/mimiciii/1.4/) - 50,314 records, 10,591 features - 5,315 positive, 44,999 negative (11.8%) - Split: 80% training, 10% validation, 10% testing - Notable points: - Result: AUPRC 0.7027 (Baseline: 0.118) on Val split - Variational Regularization, inspired by [Kipf et al., 2016](https://arxiv.org/abs/1611.07308) - Trained on NVIDIA A100 with PyTorch 2.4.0 - Code on GitHub: [pytorch-variational-gcn-ehr-public](https://github.com/ThachNgocTran/pytorch-variational-gcn-ehr-public) """ ) with gr.Accordion("More on Training...", open=False): gr.HTML(""" """) for partialFunc in resDispPartFuncs: partialFunc(outputs=[result_pred, result_label]) demo.launch(debug=True, allowed_paths=["images/."])