joaopimenta commited on
Commit
6e396aa
·
verified ·
1 Parent(s): 2d75828

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +879 -0
app.py ADDED
@@ -0,0 +1,879 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import pandas as pd
4
+ import joblib
5
+ import torch
6
+ from transformers import AutoTokenizer, AutoModel
7
+ from xgboost import XGBClassifier
8
+ from sklearn.preprocessing import StandardScaler
9
+ from sklearn.decomposition import PCA
10
+ from sklearn.metrics import precision_recall_curve, roc_curve, confusion_matrix, classification_report
11
+ import matplotlib.pyplot as plt
12
+ import shap
13
+ import plotly.express as px
14
+ import streamlit as st
15
+ import pandas as pd
16
+ import datetime
17
+ import json
18
+ import requests
19
+ from streamlit_lottie import st_lottie
20
+ import streamlit.components.v1 as components
21
+ from streamlit_navigation_bar import st_navbar
22
+ from transformers import AutoTokenizer, AutoModel
23
+ import re
24
+ from tqdm import tqdm
25
+ import torch
26
+ import os
27
+ from hugchat.login import Login
28
+ from hugchat import hugchat
29
+ from transformers import pipeline
30
+ from transformers import AutoTokenizer, AutoModelForTokenClassification
31
+ import torch.nn as nn
32
+ import time
33
+
34
+
35
+
36
+ pages = ["Home", "Tabular data", "Clinical text notes", "Ensemble prediction"]
37
+
38
+ styles = {
39
+ "nav": {
40
+ "background-color": "rgba(0, 0, 0, 0.5)",
41
+ # Add 50% transparency
42
+ },
43
+ "div": {
44
+ "max-width": "32rem",
45
+ },
46
+ "span": {
47
+ "border-radius": "0.26rem",
48
+ "color": "rgb(255 ,255, 255)",
49
+ "margin": "0 0.225rem",
50
+ "padding": "0.375rem 0.625rem",
51
+ },
52
+ "active": {
53
+ "background-color": "rgba(0 ,0, 200, 0.95)",
54
+ },
55
+ "hover": {
56
+ "background-color": "rgba(255, 255, 255, 0.95)",
57
+ },
58
+ }
59
+
60
+ page = st_navbar(pages, styles=styles)
61
+
62
+ if page=="Home":
63
+
64
+ st.markdown("""
65
+ <style>
66
+ .title {
67
+ text-align: center;
68
+ font-size: 36px;
69
+ font-weight: bold;
70
+ color: #2C3E50;
71
+ }
72
+ .subtitle {
73
+ text-align: center;
74
+ font-size: 22px;
75
+ color: #7F8C8D;
76
+ }
77
+ .box {
78
+ background-color: #ECF0F1;
79
+ padding: 15px;
80
+ border-radius: 10px;
81
+ text-align: center;
82
+ margin-bottom: 10px;
83
+ font-size: 18px;
84
+ }
85
+ </style>
86
+ """, unsafe_allow_html=True)
87
+
88
+ # Header
89
+ st.markdown("<h1 class='title'>📊 AI Clinical Readmission Predictor</h1>", unsafe_allow_html=True)
90
+ st.markdown("<h2 class='subtitle'>Using Machine Learning for Better Patient Outcomes</h2>", unsafe_allow_html=True)
91
+ image_1 ='https://content.presspage.com/uploads/2110/4970f578-5f20-4675-acc2-3b2cda25fa96/1920_ai-machine-learning-cedars-sinai.jpg?10000'
92
+ image_2 = 'https://med-tech.world/app/uploads/2024/10/AI-Hospitals.jpg.webp'
93
+
94
+
95
+ st.image(image_2, width=1350) # Hospital Icon
96
+
97
+ st.write("This app helps predict patient readmission risk using machine learning models. "
98
+ "Upload data, analyze clinical notes, and see predictions from our ensemble model.")
99
+
100
+ # Navigation Buttons
101
+ st.markdown("---")
102
+ st.markdown("<h3 style='text-align: center;'>🚀 Explore the App</h3>", unsafe_allow_html=True)
103
+
104
+ elif page== "Tabular data":
105
+
106
+ # Function to load Lottie animation
107
+ def load_lottie(url):
108
+ response = requests.get(url)
109
+ if response.status_code != 200:
110
+ return None
111
+ return response.json()
112
+
113
+ # Load Lottie Animation
114
+ lottie_hello = load_lottie("https://assets7.lottiefiles.com/packages/lf20_jcikwtux.json")
115
+ if lottie_hello:
116
+ st_lottie(lottie_hello, speed=1, loop=True, height=200)
117
+
118
+ # Load dataset
119
+ df = pd.read_csv('/Users/joaopimenta/Downloads/ensemble_test.csv')
120
+
121
+ # Streamlit App Header
122
+ st.title('🏥 Hospital Readmission Prediction')
123
+ st.markdown("""
124
+ <h3 style='text-align: center; color: gray;'>Predict ICU hospital readmission using Artificial Intelligence</h3>
125
+ """, unsafe_allow_html=True)
126
+ st.markdown("---")
127
+
128
+ # Helper Functions
129
+ def get_age_group(age):
130
+ """Classify age into predefined groups with correct column names."""
131
+ if 36 <= age <= 50:
132
+ return "age_group_36-50 (Middle-Aged Adults)"
133
+ elif 51 <= age <= 65:
134
+ return "age_group_51-65 (Older Middle-Aged Adults)"
135
+ elif 66 <= age <= 80:
136
+ return "age_group_66-80 (Senior Adults)"
137
+ elif age >= 81:
138
+ return "age_group_81+ (Elderly)"
139
+ return "age_group_Below_36"
140
+
141
+
142
+ def get_period(hour):
143
+ """Determine admission/discharge period."""
144
+ return "Morning" if 6 <= hour < 18 else "Night"
145
+
146
+ # **User Inputs**
147
+ st.subheader("📌 Select the admission's Characteristics")
148
+
149
+ admission_type = st.selectbox("🛑 Type of Admission", df.columns[df.columns.str.startswith('admission_type_')])
150
+ admission_location = st.selectbox("📍 Admission Location", df.columns[df.columns.str.startswith('admission_location_')])
151
+ discharge_location = st.selectbox("🏥 Discharge Location", df.columns[df.columns.str.startswith('discharge_location_')])
152
+ insurance = st.selectbox("💰 Insurance Type", df.columns[df.columns.str.startswith('insurance_')])
153
+
154
+ st.sidebar.subheader("📊 Patient Information")
155
+ language = st.sidebar.selectbox("🗣 Language", df.columns[df.columns.str.startswith('language_')])
156
+ marital_status = st.sidebar.selectbox("💍 Marital Status", df.columns[df.columns.str.startswith('marital_status_')])
157
+ race = st.sidebar.selectbox("🧑 Race", df.columns[df.columns.str.startswith('race_')])
158
+ sex = st.sidebar.selectbox("⚧ Sex", ['gender_M', 'gender_F'])
159
+ age = st.sidebar.slider("📅 Age", 18, 100, 50)
160
+
161
+ admission_time = st.time_input("⏳ Admission Time", value=datetime.time(12, 0))
162
+ discharge_time = st.time_input("⏳ Discharge Time", value=datetime.time(12, 0))
163
+
164
+ # Laboratory & Clinical Values
165
+ st.subheader("📈 Clinical Values")
166
+ numerical_features = ['los_days', 'previous_stays', 'n_meds', 'drg_severity', 'drg_mortality', 'time_since_last_stay',
167
+ 'blood_cells', 'hemoglobin', 'glucose', 'creatine', 'plaquete']
168
+ numeric_inputs = {}
169
+ cols = st.columns(len(numerical_features))
170
+
171
+ # General Numerical Values
172
+ st.subheader("📊 General Hosptal Information")
173
+ general_numerical_features = ['los_days', 'previous_stays', 'n_meds', 'drg_severity',
174
+ 'drg_mortality', 'time_since_last_stay']
175
+
176
+ general_inputs = {}
177
+ cols = st.columns(3) # Three columns for general values
178
+
179
+ for i, feature in enumerate(general_numerical_features):
180
+ col_index = i % 3 # Distribute across columns
181
+ min_val, max_val = df[feature].min(), df[feature].max()
182
+
183
+ with cols[col_index]:
184
+ general_inputs[feature] = st.slider(
185
+ f"📌 {feature.replace('_', ' ').title()}",
186
+ float(min_val),
187
+ float(max_val),
188
+ float((min_val + max_val) / 2)
189
+ )
190
+
191
+ # Laboratory Values
192
+ st.subheader("🧪 Laboratory Test Results")
193
+ lab_numerical_features = ['blood_cells', 'hemoglobin', 'glucose',
194
+ 'creatine', 'plaquete']
195
+
196
+ lab_inputs = {}
197
+ lab_cols = st.columns(3) # Three columns for lab values
198
+
199
+ for i, feature in enumerate(lab_numerical_features):
200
+ col_index = i % 3 # Distribute across columns
201
+ min_val, max_val = df[feature].min(), df[feature].max()
202
+
203
+ with lab_cols[col_index]:
204
+ lab_inputs[feature] = st.slider(
205
+ f"🩸 {feature.replace('_', ' ').title()}",
206
+ float(min_val),
207
+ float(max_val),
208
+ float((min_val + max_val) / 2)
209
+ )
210
+ min_val, max_val = df["cci_score"].min(), df["cci_score"].max()
211
+ lab_inputs["cci_score"] = st.sidebar.slider(
212
+ f"📌 CCI Score",
213
+ float(min_val),
214
+ float(max_val),
215
+ float((min_val + max_val) / 2)
216
+ )
217
+
218
+ # Process Inputs into Features
219
+ feature_vector = {col: 0 for col in df.columns}
220
+ feature_vector.update({
221
+ admission_type: 1,
222
+ admission_location: 1,
223
+ discharge_location: 1,
224
+ insurance: 1,
225
+ language: 1,
226
+ marital_status: 1,
227
+ race: 1,
228
+ "gender_M": 1 if sex == "gender_M" else 0,
229
+ f"admit_period_{get_period(admission_time.hour)}": 1,
230
+ f"discharge_period_{get_period(discharge_time.hour)}": 1
231
+ })
232
+ age_group = get_age_group(age) # This function now returns correct dataset column names
233
+
234
+ # Use the exact column names from the dataset
235
+ for group in [
236
+ "age_group_36-50 (Middle-Aged Adults)",
237
+ "age_group_51-65 (Older Middle-Aged Adults)",
238
+ "age_group_66-80 (Senior Adults)",
239
+ "age_group_81+ (Elderly)"
240
+ ]:
241
+ feature_vector[group] = 1 if group == age_group else 0 # Set selected group to 1, others to 0
242
+
243
+ feature_vector.update(numeric_inputs)
244
+ # Display Processed Data
245
+ st.markdown("---")
246
+
247
+ # Load XGBoost model
248
+ tabular_model_path = "/Users/joaopimenta/Downloads/final_xgboost_model.pkl"
249
+ tabular_model = joblib.load(tabular_model_path)
250
+ print("✅ XGBoost Tabular Model loaded successfully!")
251
+
252
+ # Load dataset columns (use the same order as training)
253
+ expected_columns = [
254
+ col for col in df.columns if col not in ["Unnamed: 0", "subject_id", "hadm_id", "probs"]
255
+ ]
256
+
257
+ # Define correct dataset column names for age groups
258
+ age_group_mapping = {
259
+ "age_group_36-50": "age_group_36-50 (Middle-Aged Adults)",
260
+ "age_group_51-65": "age_group_51-65 (Older Middle-Aged Adults)",
261
+ "age_group_66-80": "age_group_66-80 (Senior Adults)",
262
+ "age_group_81+": "age_group_81+ (Elderly)",
263
+ }
264
+
265
+ # Process Inputs into Features
266
+ feature_vector = {col: 0 for col in df.columns}
267
+
268
+ # Set selected categorical features to 1
269
+ feature_vector.update({
270
+ admission_type: 1,
271
+ admission_location: 1,
272
+ discharge_location: 1,
273
+ insurance: 1,
274
+ language: 1,
275
+ marital_status: 1,
276
+ race: 1,
277
+ "gender_M": 1 if sex == "gender_M" else 0,
278
+ f"admit_period_{get_period(admission_time.hour)}": 1,
279
+ f"discharge_period_{get_period(discharge_time.hour)}": 1
280
+ })
281
+
282
+ # Set correct age group
283
+ age_group = get_age_group(age)
284
+ for group in [
285
+ "age_group_36-50 (Middle-Aged Adults)",
286
+ "age_group_51-65 (Older Middle-Aged Adults)",
287
+ "age_group_66-80 (Senior Adults)",
288
+ "age_group_81+ (Elderly)"
289
+ ]:
290
+ feature_vector[group] = 1 if group == age_group else 0
291
+
292
+ # Update with numerical inputs
293
+ feature_vector.update(general_inputs)
294
+ feature_vector.update(lab_inputs)
295
+
296
+ # Ensure feature order matches expected model input
297
+ fixed_feature_vector = {age_group_mapping.get(k, k): v for k, v in feature_vector.items()}
298
+ feature_df = pd.DataFrame([fixed_feature_vector]).reindex(columns=expected_columns, fill_value=0)
299
+
300
+ st.write(feature_df)
301
+ # Predict probability of readmission
302
+ prediction_proba = tabular_model.predict_proba(feature_df)[:, 1]
303
+ probability = float(prediction_proba[0]) # Convert NumPy array to scalar
304
+ st.session_state["XGBoost probability"] = probability
305
+ prediction = (prediction_proba >= 0.5).astype(int)
306
+
307
+ import shap
308
+ import matplotlib.pyplot as plt
309
+ import streamlit.components.v1 as components # Required for displaying SHAP force plot
310
+
311
+ st.write(f"Raw Prediction Probability: {probability:.4f}")
312
+
313
+ # Prediction Button
314
+ if st.button("🚀 Predict Readmission"):
315
+ with st.spinner("🔍 Processing Prediction..."):
316
+ st.subheader("🎯 Prediction Results")
317
+ col1, col2 = st.columns(2)
318
+
319
+ with col1:
320
+ st.metric(label="🧮 Readmission Probability", value=f"{probability:.2%}")
321
+
322
+ with col2:
323
+ if prediction == 1:
324
+ st.error("⚠️ High Risk of Readmission")
325
+ else:
326
+ st.success("✅ Low Risk of Readmission")
327
+
328
+ # Feature Importance Button
329
+ if st.button("🔍 Feature Importance for Prediction"):
330
+ st.metric(label="🧮 Readmission Probability", value=f"{probability:.2%}")
331
+ # ✅ Initialize SHAP Explainer for XGBoost
332
+ explainer = shap.TreeExplainer(tabular_model)
333
+ shap_values = explainer.shap_values(feature_df) # SHAP values for all samples
334
+
335
+ # ✅ Convert SHAP values into a DataFrame (Sorting First)
336
+ shap_df = pd.DataFrame({
337
+ "Feature": feature_df.columns,
338
+ "SHAP Value": shap_values[0] # SHAP values for the first instance
339
+ })
340
+
341
+ # ✅ Select **Top 10 Most Important Features** (Sorted by Absolute SHAP Value)
342
+ shap_df["abs_SHAP"] = shap_df["SHAP Value"].abs() # Add column with absolute values
343
+ shap_df = shap_df.sort_values(by="abs_SHAP", ascending=False).head(10) # Top 10
344
+
345
+ # Get top features and their SHAP impact values (shap_df assumed to be available)
346
+ top_features = sorted(zip(shap_df['Feature'], shap_df['SHAP Value']), key=lambda x: abs(x[1]), reverse=True)
347
+
348
+ # Create a formatted string for `top_factors` to be shown in the UI
349
+ top_factors = "\n".join([f"- {feat}: {round(value, 2)} impact" for feat, value in top_features])
350
+
351
+ # ✅ Login to HuggingChat (credentials hard-coded here)
352
+ EMAIL = st.secrets["email"]
353
+ PASSWD = st.secrets["passwd"]
354
+ cookie_path_dir = "./cookies_snapshot"
355
+
356
+ # Log in and save cookies
357
+ try:
358
+ sign = Login(EMAIL, PASSWD)
359
+ cookies = sign.login(cookie_dir_path=cookie_path_dir, save_cookies=True)
360
+ sign.saveCookiesToDir(cookie_path_dir)
361
+ except Exception as e:
362
+ st.error(f"❌ Login to HuggingChat failed. Error: {e}")
363
+ st.stop()
364
+
365
+ # ✅ Create HuggingChat bot instance
366
+ chatbot = hugchat.ChatBot(cookies=cookies.get_dict())
367
+
368
+ # 🎭 **Streamlit UI**
369
+ st.title("🩺 AI-Powered Patient Readmission Analysis")
370
+
371
+ # ✅ Construct the AI query with real SHAP feature impacts
372
+ # ✅ Construct the AI query with real SHAP feature impacts
373
+ hugging_prompt = f"""
374
+ A hospital AI model predicts patient readmission based on the following feature impacts:
375
+ {top_factors}
376
+
377
+ Can you explain why the model made this decision? Specifically, what were the key characteristics of the patient or their admission that influenced the model’s prediction the most?
378
+ """
379
+
380
+ # ✅ Query HuggingChat
381
+ with st.spinner("🤖 Analyzing..."):
382
+ try:
383
+ response = chatbot.chat(hugging_prompt) # Corrected method to 'chat' instead of 'query'
384
+ ai_output = response # Extract AI response
385
+ # 🎭 **Show AI Response in a Stylish Chat Format**
386
+ with st.chat_message("assistant"):
387
+ st.markdown(f"**💡 AI Explanation:**\n\n{ai_output}")
388
+ except Exception as e:
389
+ st.error(f"⚠️ Error retrieving response: {e}")
390
+ st.stop()
391
+
392
+ # ✅ **Expand for SHAP Feature Details**
393
+ with st.expander("📜 Click to see detailed feature impacts"):
394
+ st.markdown(f"```{top_factors}```")
395
+
396
+ # Show Top 10 Features
397
+ #st.write(shap_df[["Feature", "SHAP Value"]]) # Display only relevant columns
398
+
399
+ # ✅ SHAP Bar Plot (Corrected for Top 10 Selection)
400
+ fig, ax = plt.subplots(figsize=(8, 6))
401
+ shap.bar_plot(shap_df["SHAP Value"].values, shap_df["Feature"].values) # Correct Top 10
402
+ st.pyplot(fig)
403
+
404
+ # 🎯 SHAP Force Plot (How Features Affected the Prediction)
405
+ st.subheader("🎯 SHAP Force Plot (How Features Affected the Prediction)")
406
+
407
+ # ✅ Fix: Use explainer.expected_value (single scalar)
408
+ force_plot = shap.force_plot(
409
+ explainer.expected_value, shap_values[0], feature_df.iloc[0], matplotlib=False
410
+ )
411
+
412
+ # ✅ Convert SHAP force plot to HTML
413
+ shap_html = f"<head>{shap.getjs()}</head><body>{force_plot.html()}</body>"
414
+
415
+ # ✅ Render SHAP force plot in Streamlit
416
+ components.html(shap_html, height=400)
417
+
418
+ elif page == "Clinical text notes":
419
+ # Set Streamlit Page Title
420
+ st.subheader("📝 Clinical Text Note")
421
+
422
+ # Utility Functions
423
+
424
+ def clean_text(text):
425
+ """Cleans input text by removing non-ASCII characters, extra spaces, and unwanted symbols."""
426
+ text = re.sub(r"[^\x20-\x7E]", " ", text)
427
+ text = re.sub(r"_{2,}", "", text)
428
+ text = re.sub(r"\s+", " ", text)
429
+ text = re.sub(r"[^\w\s.,:;*%()\[\]-]", "", text)
430
+ return text.lower().strip()
431
+
432
+
433
+ import re
434
+
435
+ def extract_fields(text):
436
+ """Extracts key fields from clinical notes using regex patterns."""
437
+ patterns = {
438
+ "Discharge Medications": r"Discharge Medications[:\-]?\s*(.+?)\s+(?:Discharge Disposition|Discharge Condition|Discharge Instructions|Followup Instructions|$)",
439
+ "Discharge Diagnosis": r"Discharge Diagnosis[:\-]?\s*(.+?)\s+(?:Discharge Condition|Discharge Medications|Discharge Instructions|Followup Instructions|$)",
440
+ "Discharge Instructions": r"Discharge Instructions[:\-]?\s*(.*?)\s+(?:Followup Instructions|Discharge Disposition|Discharge Condition|$)",
441
+ "History of Present Illness": r"History of Present Illness[:\-]?\s*(.+?)\s+(?:Past Medical History|Social History|Family History|Physical Exam|$)",
442
+ "Past Medical History": r"Past Medical History[:\-]?\s*(.+?)\s+(?:Social History|Family History|Physical Exam|$)"
443
+ }
444
+
445
+ extracted_data = {}
446
+
447
+ for field, pattern in patterns.items():
448
+ match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
449
+ if match:
450
+ extracted_data[field] = match.group(1).strip()
451
+
452
+ return extracted_data
453
+
454
+ def extract_features(texts, model, tokenizer, device, batch_size=8):
455
+ """Extracts CLS token embeddings from the Clinical-Longformer model."""
456
+ all_features = []
457
+ for i in range(0, len(texts), batch_size):
458
+ batch_texts = texts[i:i+batch_size]
459
+ inputs = tokenizer(batch_texts, return_tensors="pt", truncation=True, padding=True, max_length=4096).to(device)
460
+ global_attention_mask = torch.zeros_like(inputs["input_ids"]).to(device)
461
+ global_attention_mask[:, 0] = 1 # Set global attention for CLS token
462
+
463
+ with torch.no_grad():
464
+ outputs = model(**inputs, global_attention_mask=global_attention_mask)
465
+
466
+ all_features.append(outputs.last_hidden_state[:, 0, :])
467
+
468
+ return torch.cat(all_features, dim=0)
469
+
470
+
471
+ def extract_entities(text, pipe, entity_group):
472
+ """Extracts specific entities from the clinical note using a NER pipeline."""
473
+ entities = pipe(text)
474
+ return [ent['word'] for ent in entities if ent['entity_group'] == entity_group] or ["No relevant entities found"]
475
+
476
+ # Load Model and Tokenizer
477
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
478
+
479
+ @st.cache_resource()
480
+ def load_models():
481
+ """Loads transformer models for text processing and NER."""
482
+ longformer_tokenizer = AutoTokenizer.from_pretrained("yikuan8/Clinical-Longformer")
483
+ longformer_model = AutoModel.from_pretrained("yikuan8/Clinical-Longformer").to(device).eval()
484
+
485
+ ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
486
+ ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
487
+ ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple")
488
+
489
+ return longformer_tokenizer, longformer_model, ner_pipe
490
+
491
+ longformer_tokenizer, longformer_model, ner_pipe = load_models()
492
+
493
+ # Text Input
494
+ clinical_note = st.text_area("✍️ Enter Clinical Note", placeholder="Write the clinical note here...")
495
+
496
+ if clinical_note:
497
+ cleaned_note = clean_text(clinical_note)
498
+ #st.write("### 📝 Cleaned Clinical Note:")
499
+ #st.write(cleaned_note)
500
+
501
+ # Extract Fields
502
+ extracted_data = extract_fields(cleaned_note)
503
+ st.write("### Extracted Fields")
504
+ st.write(extracted_data)
505
+
506
+ # Extract Embeddings
507
+ with st.spinner("🔍 Extracting embeddings..."):
508
+ embeddings = extract_features([cleaned_note], longformer_model, longformer_tokenizer, device)
509
+ #st.write("### Extracted Embeddings")
510
+ #st.write(embeddings)
511
+ # Definir a classe RobustMLPClassifier
512
+ class RobustMLPClassifier(nn.Module):
513
+ def __init__(self, input_dim, hidden_dims=[256, 128, 64], dropout=0.3, activation=nn.ReLU()):
514
+ super(RobustMLPClassifier, self).__init__()
515
+ layers = []
516
+ current_dim = input_dim
517
+
518
+ for h in hidden_dims:
519
+ layers.append(nn.Linear(current_dim, h))
520
+ layers.append(nn.BatchNorm1d(h))
521
+ layers.append(activation)
522
+ layers.append(nn.Dropout(dropout))
523
+ current_dim = h
524
+
525
+ layers.append(nn.Linear(current_dim, 1))
526
+ self.net = nn.Sequential(*layers)
527
+
528
+ def forward(self, x):
529
+ return self.net(x)
530
+
531
+ # --- Load MLP Model and PCA ---
532
+ mlp_model_path = "/Users/joaopimenta/Downloads/Capstone/best_mlp_model_full.pth"
533
+ pca_path = "/Users/joaopimenta/Downloads/Capstone/best_pca_model.pkl"
534
+
535
+ best_mlp_model = torch.load(mlp_model_path)
536
+ best_mlp_model.to(device)
537
+ best_mlp_model.eval()
538
+
539
+ pca = joblib.load(pca_path)
540
+
541
+ def predict_readmission(texts):
542
+ """Predicts hospital readmission probability using Clinical-Longformer embeddings and MLP."""
543
+ embeddings = extract_features(texts, longformer_model, longformer_tokenizer, device)
544
+ embeddings_pca = pca.transform(embeddings.cpu().numpy()) # Apply PCA
545
+
546
+ inputs = torch.FloatTensor(embeddings_pca).to(device)
547
+
548
+ with torch.no_grad():
549
+ logits = best_mlp_model(inputs)
550
+ probabilities = torch.sigmoid(logits).cpu().numpy()
551
+
552
+ return probabilities
553
+
554
+ # Extract Medical Entities
555
+ with st.spinner("🔍 Identifying medical entities..."):
556
+ extracted_data["Extracted Medications"] = extract_entities(
557
+ extracted_data.get("Discharge Medications", ""), ner_pipe, "Medication"
558
+ )
559
+
560
+ extracted_data["Extracted Diseases"] = extract_entities(
561
+ extracted_data.get("Discharge Diagnosis", ""), ner_pipe, "Disease_disorder"
562
+ )
563
+
564
+ extracted_data["Extracted Diseases (Past Medical History)"] = extract_entities(
565
+ extracted_data.get("Past Medical History", ""), ner_pipe, "Disease_disorder"
566
+ )
567
+
568
+ extracted_data["Extracted Diseases (History of Present Illness)"] = extract_entities(
569
+ extracted_data.get("History of Present Illness", ""), ner_pipe, "Disease_disorder"
570
+ )
571
+
572
+ # Extração de sintomas agora inclui "History of Present Illness"
573
+ extracted_data["Extracted Symptoms"] = extract_entities(
574
+ extracted_data.get("Review of Systems", "") + " " + extracted_data.get("History of Present Illness", ""),
575
+ ner_pipe, "Sign_symptom"
576
+ )
577
+
578
+
579
+ def clean_entities(entities):
580
+ """Reconstruct fragmented tokens and remove duplicates."""
581
+ cleaned = []
582
+ temp = ""
583
+
584
+ for entity in entities:
585
+ if entity.startswith("##"): # Fragmented token
586
+ temp += entity.replace("##", "")
587
+ else:
588
+ if temp:
589
+ cleaned.append(temp) # Add the reconstructed token
590
+ temp = entity
591
+ if temp:
592
+ cleaned.append(temp) # Add the last processed token
593
+
594
+ # Filter out irrelevant short words and special characters
595
+ cleaned = [word for word in cleaned if len(word) > 2 and not re.match(r"^[\W_]+$", word)]
596
+
597
+ return sorted(set(cleaned)) # Remove duplicates and sort
598
+
599
+ # Clean extracted diseases and symptoms
600
+ diseases_cleaned = clean_entities(
601
+ extracted_data.get("Extracted Diseases", []) +
602
+ extracted_data.get("Extracted Diseases (Past Medical History)", []) +
603
+ extracted_data.get("Extracted Diseases (History of Present Illness)", [])
604
+ )
605
+ # Clean and reconstruct medication names
606
+ medications_cleaned = clean_entities(extracted_data.get("Extracted Medications", []))
607
+
608
+ # Store cleaned data in the main dictionary
609
+ extracted_data["Extracted Medications Cleaned"] = medications_cleaned
610
+
611
+ symptoms_cleaned = clean_entities(extracted_data.get("Extracted Symptoms", []))
612
+
613
+ # Display extracted entities
614
+ def display_list(title, items, icon="📌"):
615
+ """Display extracted medical entities in an expandable list."""
616
+ with st.expander(f"**{title} ({len(items)})**"):
617
+ if items:
618
+ for item in items:
619
+ st.markdown(f"- {icon} **{item}**")
620
+ else:
621
+ st.markdown("_No information available._")
622
+
623
+
624
+ # Layout Header
625
+ st.markdown("## 🏥 **Patient Medical Analysis**")
626
+ st.markdown("---")
627
+
628
+ # Creating columns for metrics
629
+ col1, col2, col3 = st.columns(3)
630
+
631
+ # Medications Metrics
632
+ num_medications = len(medications_cleaned )
633
+ col1.metric(label="💊 Total Medications", value=num_medications)
634
+
635
+ # Diseases Metrics
636
+ num_diseases = len(diseases_cleaned)
637
+ col2.metric(label="🦠 Total Diseases", value=num_diseases)
638
+
639
+ # Symptoms Metrics
640
+ num_symptoms = len(symptoms_cleaned)
641
+ col3.metric(label="🤒 Total Symptoms", value=num_symptoms)
642
+
643
+ st.markdown("---")
644
+
645
+ # Organizing lists in two columns
646
+ col1, col2 = st.columns(2)
647
+
648
+ # Display Medications List
649
+ with col1:
650
+ st.markdown("### 💊 **Medications**")
651
+ display_list("Medication List", medications_cleaned , icon="💊")
652
+
653
+ # Display Diseases List
654
+ with col2:
655
+ st.markdown("### 🦠 **Diseases**")
656
+ display_list("Disease List", diseases_cleaned, icon="🦠")
657
+
658
+ # Symptoms Section
659
+ st.markdown("### 🤒 **Symptoms**")
660
+ display_list("Symptoms List", symptoms_cleaned, icon="🤒")
661
+
662
+ st.markdown("---")
663
+
664
+ # Load tokenizer
665
+ tokenizer = AutoTokenizer.from_pretrained("yikuan8/Clinical-Longformer")
666
+
667
+ # Functions for token count and truncation
668
+ def count_tokens(text):
669
+ tokens = tokenizer.tokenize(text)
670
+ return len(tokens)
671
+
672
+ def trunced_text(nr):
673
+ return 1 if nr > 4096 else 0
674
+
675
+ # Dictionary of diseases with synonyms (matching capitalization in the image)
676
+ disease_synonyms = {
677
+ "Pneumonia": ["pneumonia", "pneumonitis"],
678
+ "Diabetes": ["diabetes", "diabetes mellitus", "dm"],
679
+ "CHF": ["CHF", "congestive heart failure", "heart failure"],
680
+ "Septicemia": ["septicemia", "sepsis", "blood infection"],
681
+ "Cirrhosis": ["cirrhosis", "liver cirrhosis", "hepatic cirrhosis"],
682
+ "COPD": ["COPD", "chronic obstructive pulmonary disease"],
683
+ "Renal_Failure": ["renal failure", "kidney failure", "chronic kidney disease", "CKD"]
684
+ }
685
+
686
+ # Extract relevant fields (assuming extract_fields is defined elsewhere)
687
+ extracted_data = extract_fields(clinical_note)
688
+
689
+ # Compute token counts
690
+ number_of_tokens = count_tokens(clinical_note)
691
+ number_of_tokens_med = count_tokens(extracted_data.get("Discharge Medications", ""))
692
+ number_of_tokens_dis = count_tokens(extracted_data.get("Discharge Diagnosis", ""))
693
+ trunced = trunced_text(number_of_tokens)
694
+
695
+ # Convert diagnosis text to lowercase for case-insensitive matching
696
+ full_diagnosis_text = extracted_data.get("Discharge Diagnosis", "").lower()
697
+
698
+ # Function to check for any synonym in the diagnosis text
699
+ def check_disease_presence(disease_list, text):
700
+ return int(any(re.search(rf"\b{synonym}\b", text, re.IGNORECASE) for synonym in disease_list))
701
+
702
+ # Create binary columns for each disease based on synonyms
703
+ disease_flags = {disease: check_disease_presence(synonyms, full_diagnosis_text)
704
+ for disease, synonyms in disease_synonyms.items()}
705
+
706
+ # Count total diseases found
707
+ disease_flags["total_conditions"] = sum(disease_flags.values())
708
+
709
+ # Create DataFrame with a single row (matching column names from the image)
710
+ df = pd.DataFrame([{
711
+ 'number_of_tokens_dis': number_of_tokens_dis,
712
+ 'number_of_tokens': number_of_tokens,
713
+ 'number_of_tokens_med': number_of_tokens_med,
714
+ 'diagnostic_count': num_diseases, # Ensuring column name matches
715
+ 'total_conditions': disease_flags["total_conditions"], # Matching name
716
+ 'trunced': trunced,
717
+ **{disease: disease_flags[disease] for disease in disease_synonyms.keys()} # Disease presence flags
718
+ }])
719
+
720
+ # Display DataFrame
721
+ #st.write(df)
722
+
723
+ #load lighGBoost model
724
+ light_path = '/Users/joaopimenta/Downloads/best_lgbm_model.pkl'
725
+ light_model = joblib.load(light_path)
726
+ #st.write("LightGBoost Model loaded sucessfully!")
727
+
728
+ # Ensure df is already created from previous steps
729
+ # Select only the columns that match the model input
730
+ model_features = light_model.feature_name_
731
+
732
+ # Check if all required features are in df
733
+ missing_features = [feat for feat in model_features if feat not in df.columns]
734
+ if missing_features:
735
+ st.write(f"⚠️ Warning: Missing features in df: {missing_features}")
736
+
737
+ # Fill missing columns with 0 (if needed, assuming binary features)
738
+ for feat in missing_features:
739
+ df[feat] = 0 # Default to 0 for missing binary disease indicators
740
+
741
+ # Reorder df to match model features exactly
742
+ df = df[model_features]
743
+
744
+ # Convert df to NumPy array for LightGBM prediction
745
+ X = df.to_numpy()
746
+
747
+ # Make prediction
748
+ # Get probability of readmission
749
+ light_probability = light_model.predict_proba(X)[:, 1] # Get probability for class 1 (readmission)
750
+ # Armazenar no session_state
751
+ st.session_state["lightgbm probability"] = light_probability
752
+
753
+ # Output results
754
+ #st.write(f"🔹 Readmission Prediction: {probability}")
755
+
756
+ # Prediction Button
757
+ if st.button("🚀 Predict Readmission"):
758
+ with st.spinner("🔍 Extracting embeddings and predicting readmission..."):
759
+ readmission_prob = predict_readmission([cleaned_note])[0][0] # Compute only once
760
+ st.session_state["MLP probability"] = readmission_prob
761
+ prediction = 1 if readmission_prob > 0.5 else 0 # Define prediction value
762
+
763
+ # Display Results
764
+ st.subheader("🎯 Prediction Results")
765
+ col1, col2 = st.columns(2)
766
+
767
+ with col1:
768
+ st.metric(label="🧮 Readmission Probability", value=f"{readmission_prob:.2%}")
769
+
770
+ with col2:
771
+ if prediction == 1:
772
+ st.error("⚠️ High Risk of Readmission")
773
+ else:
774
+ st.success("✅ Low Risk of Readmission")
775
+
776
+ # Display Readmission Probability with Centered Styling
777
+ st.markdown(f"""
778
+ <div style="text-align:center; padding: 20px; background-color: #f8f9fa; border-radius: 10px;">
779
+ <h3>📊 Readmission Probability</h3>
780
+ <h2 style="color: {'red' if readmission_prob > 0.5 else 'green'};">{readmission_prob:.2%}</h2>
781
+ </div>
782
+ """, unsafe_allow_html=True)
783
+
784
+ elif page == "Ensemble prediction":
785
+
786
+ # Load the ensemble model
787
+ ensemble_model = joblib.load("/Users/joaopimenta/Downloads/best_ensemble_model.pkl")
788
+ #st.write("✅ Ensemble Model loaded successfully!")
789
+
790
+ # Define models
791
+ models = ["XGBoost", "lightgbm", "MLP"]
792
+
793
+ # Retrieve stored probabilities from session state and ensure they are numeric
794
+ probabilities = []
795
+ for model in models:
796
+ key = f"{model} probability"
797
+ if key in st.session_state:
798
+ try:
799
+ prob = float(st.session_state[key])
800
+ probabilities.append(prob)
801
+ except ValueError:
802
+ st.error(f"⚠️ Invalid probability value for {model}: {st.session_state[key]}")
803
+ probabilities.append(None)
804
+ else:
805
+ probabilities.append(None)
806
+
807
+ # Ensure all probabilities are valid before proceeding
808
+ if None not in probabilities:
809
+ st.write("### 🗳️ Voting Process in Progress...")
810
+
811
+ progress_bar = st.progress(0) # Progress bar
812
+ voting_display = st.empty() # Placeholder for voting animation
813
+
814
+ votes = []
815
+ for i, (model, prob) in enumerate(zip(models, probabilities)):
816
+ time.sleep(1) # Simulate suspense
817
+
818
+ # Simulated blinking effect
819
+ for _ in range(3):
820
+ voting_display.markdown(f"⏳ {model} is deciding...")
821
+ time.sleep(0.5)
822
+ voting_display.markdown("")
823
+ time.sleep(0.5)
824
+
825
+ # Convert probability to label
826
+ if prob < 0.33:
827
+ vote = "🟢 Low"
828
+ elif prob < 0.46:
829
+ vote = "🟡 Medium"
830
+ else:
831
+ vote = "🔴 High"
832
+
833
+ votes.append(vote)
834
+ voting_display.markdown(f"✅ **{model} voted: {vote}**")
835
+ progress_bar.progress((i + 1) / len(models))
836
+
837
+ time.sleep(1)
838
+ progress_bar.empty()
839
+
840
+ # Create a DataFrame with numeric probabilities
841
+ final_df = pd.DataFrame([probabilities], columns=['probs', 'probs_lgb', 'probs_mlp'])
842
+ final_df = final_df.astype(float) # Ensure all values are float
843
+
844
+ # Fazer a predição final com o ensemble
845
+ final_probability = ensemble_model.predict_proba(final_df)[:, 1][0] # Probabilidade de classe 1
846
+ final_prediction = 1 if final_probability >= 0.25 else 0 # Aplicando threshold de 0.25
847
+
848
+ # Estilização do resultado final
849
+ st.markdown("---")
850
+ if final_prediction == 1:
851
+ st.markdown(f"""
852
+ <div style="text-align: center; background-color: #ffdddd; padding: 15px; border-radius: 10px;">
853
+ <h2>🚨 <b>Final Prediction: 1</b> (Readmission Likely) </h2>
854
+ <h3>🔍 Probability: {final_probability:.2f} (Threshold: 0.25)</h3>
855
+ </div>
856
+ """, unsafe_allow_html=True)
857
+ else:
858
+ st.markdown(f"""
859
+ <div style="text-align: center; background-color: #ddffdd; padding: 15px; border-radius: 10px;">
860
+ <h2>✅ <b>Final Prediction: 0</b> (No Readmission Risk) </h2>
861
+ <h3>🔍 Probability: {final_probability:.2f} (Threshold: 0.25)</h3>
862
+ </div>
863
+ """, unsafe_allow_html=True)
864
+
865
+ # 🎨 **Weight Visualization: How Much Each Model Contributed**
866
+ st.write("### ⚖️ Model Contribution to Final Decision")
867
+ fig, ax = plt.subplots()
868
+ ax.bar(models, probabilities, color=["blue", "green", "red"])
869
+ ax.set_ylabel("Probability")
870
+ ax.set_title("Model Prediction Probabilities")
871
+ st.pyplot(fig)
872
+
873
+ # Show detailed voting breakdown
874
+ st.write("### 📊 Voting Breakdown:")
875
+ for model, vote in zip(models, votes):
876
+ st.write(f"🔹 {model}: **{vote}** (Prob: {probabilities[models.index(model)]:.2f})")
877
+
878
+ else:
879
+ st.warning("⚠️ Some model predictions are missing. Please run all models before voting.")