from pathlib import Path from glob import glob from functools import partial import numpy as np import torch import gradio as gr import pandas as pd import re examples_path = "examples" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = torch.jit.load("final_model.pth").to(device) correct_preds, wrong_preds = {}, {} condition_lst = pd.read_csv("feature.csv", header = "infer", sep = ",", encoding = "utf-8", dtype=str) D_LABITEMS = pd.read_csv("D_LABITEMS.csv", header = "infer", sep = ",", encoding = "utf-8", dtype=str) def _check_patient_csv_format(df: pd.DataFrame): if not (list(df.columns)[0:2] == ["condition", "value"]): raise gr.Error(f"Column set [{list(df.columns)}]: not expected.", duration=None) if condition_lst["condition"].to_list() != df["condition"].to_list(): raise gr.Error(f"Condition set: not expected.", duration=None) vals = np.sort(df["value"].unique()) if not (vals.ndim == 1 and len(vals) == 2 and all(vals == np.array([0.0, 1.0]))): raise gr.Error(f"Column 'value': contain invalid values.", duration=None) def _extract_patient_data_from_name(csv_file_name: str): patient_file_pat = r"^Patient_(\d+)_\(Label-(alive|dead)\)_\(Predicted-(dead|alive)\).csv$" csv_name = Path(csv_file_name).name matches = re.search(patient_file_pat, csv_name) if matches is None: return None else: return (matches.group(1), matches.group(2), matches.group(3)) def _find_example_csv_files() -> None: all_csv_files = glob(f'{examples_path}/*.csv', recursive=True) if len(all_csv_files) == 0: print("*** No csv files found.") else: for one_csv_file in all_csv_files: matches = _extract_patient_data_from_name(one_csv_file) if matches: pat_id, pat_label, pat_predicted = matches if pat_id in correct_preds or pat_id in wrong_preds: print(f"*** File [{one_csv_file}]: already processed! How come?") else: if pat_label == pat_predicted: correct_preds[pat_id] = {"label": pat_label, "predicted": pat_predicted, "file_name": one_csv_file} else: wrong_preds[pat_id] = {"label": pat_label, "predicted": pat_predicted, "file_name": one_csv_file} else: print(f"*** File [{one_csv_file}]: wrong name.") _find_example_csv_files() def _predict(file_path: str): df = pd.read_csv(f"{file_path}", header="infer", sep=",", encoding="utf-8", dtype={'condition': 'str', 'value': 'float32'}, keep_default_na=False) _check_patient_csv_format(df) patient_data = torch.from_numpy(df["value"].to_numpy()).unsqueeze(dim=0).to(device) probability, _ = model(patient_data) probability = probability.detach().cpu()[0].item() return probability def example_csv_click(patient_id: int): print(f"*** Predict patient {patient_id} (Example CSV)") patient = correct_preds[patient_id] if patient_id in correct_preds else wrong_preds[patient_id] probability = _predict(patient['file_name']) return [{"Death": probability, "Alive": 1-probability}, patient['label']] def user_csv_upload(temp_csv_file_path): print(f"*** Predict patient (User CSV Upload)") matches = _extract_patient_data_from_name(temp_csv_file_path) probability = _predict(temp_csv_file_path) return [{"Death": probability, "Alive": 1-probability}, "(Not Available)" if matches is None else matches[1]] def do_query(query_str, query_type): if query_type in ["Diagnosis", "Procedure"]: str_to_search = f"ICD-9 {query_type} Code " + query_str return gr.HTML(value=f'Google', visible=True) else: # Lab Code query_str = query_str.strip() if (index := query_str.rfind("_")) >= 0: query_str = query_str[0:index] res = D_LABITEMS[D_LABITEMS["ITEMID"] == query_str] if res.shape[0] == 0: answer = "(Something wrong. No definition found.)" elif res.shape[0] == 1: answer = f"{res["LABEL"].values[0]}-{res["FLUID"].values[0]}-{res["CATEGORY"].values[0]}" else: answer=f"(Something wrong. Too many definitions, given code [{query_str}].)" return gr.HTML(value=answer, visible=True) def query_input_change_event(query_str, query_type): if (query_str is not None and len(query_str.strip())>0 and\ query_type is not None): return [gr.Button(interactive=True), gr.HTML(visible=False)] else: return [gr.Button(interactive=False), gr.HTML(visible=False)] resDispPartFuncs = [] css = \ """ #selectFileToUpload {max-height: 180px} .gradio-container { background: url(https://www.kindpng.com/picc/m/207-2075829_transparent-healthcare-clipart-medical-report-icon-hd-png.png); background-position: 80% 80%; background-repeat: no-repeat; background-size: 200px; } """ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: with gr.Row(): with gr.Column(): gr.Markdown( """ ## Input: (See examples for file structure) """ ) patient_upload_file = gr.File(label="Upload A Patient", file_types = ['.csv'], file_count = "single", elem_id="selectFileToUpload") patient_upload_file.upload(fn=user_csv_upload, inputs=patient_upload_file, outputs=None) gr.Markdown( """ ## Examples - Correct Prediction: """ ) with gr.Row(): for patient_id in correct_preds.keys(): with gr.Column(variant='panel', min_width=100): patient_input_btn = gr.Button(f"Patient {patient_id}", size="sm") patient_download_btn = gr.DownloadButton(label="Download", value=f"{correct_preds[patient_id]["file_name"]}", size="sm") patient_id_num = gr.Number(value=patient_id, visible=False) partFunc = partial(patient_input_btn.click, fn=example_csv_click, inputs=patient_id_num, api_name="predict") resDispPartFuncs.append(partFunc) gr.Markdown( """ ## Examples - Wrong Prediction: """ ) with gr.Row(): for patient_id in wrong_preds.keys(): with gr.Column(variant='panel', min_width=100): patient_input_btn = gr.Button(f"Patient {patient_id}", size="sm") patient_download_btn = gr.DownloadButton(label="Download", value=f"{wrong_preds[patient_id]["file_name"]}", size="sm") patient_id_num = gr.Number(value=patient_id, visible=False) partFunc = partial(patient_input_btn.click, fn=example_csv_click, inputs=patient_id_num, api_name="predict") resDispPartFuncs.append(partFunc) with gr.Column(): gr.Markdown( """ ## Mortality Prediction: In 24 hours after ICU admission. """ ) result_pred = gr.Label(num_top_classes=2, label="Predicted") result_label = gr.Label(label="Label") with gr.Accordion("More on Patient Conditions...", open=False): query_tbx = gr.Textbox(label="Enter one ICD-9 Diagnosis/Procedure Code or Lab Value:", lines=1, max_lines=1, placeholder="00869 for 'Other viral intes infec' (Diagnosis)") query_type = gr.Radio(["Diagnosis", "Procedure", "Lab Value"], show_label=False) query_btn = gr.Button(value="Query", size="sm", interactive=False) html = gr.HTML("", visible=False) query_tbx.change(fn=query_input_change_event, inputs=[query_tbx, query_type], outputs=[query_btn, html]) query_type.change(fn=query_input_change_event, inputs=[query_tbx, query_type], outputs=[query_btn, html]) query_btn.click(fn=do_query, inputs=[query_tbx, query_type], outputs=html) for partialFunc in resDispPartFuncs: partialFunc(outputs=[result_pred, result_label]) demo.launch(debug=True)