Spaces:
Sleeping
Sleeping
# streamlit_app.py | |
import streamlit as st | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import plotly.express as px | |
import plotly.graph_objects as go | |
# --------------------------- | |
# Function Definitions | |
# --------------------------- | |
def create_histogram(df): | |
"""Creates a histogram for Age Distribution.""" | |
fig, ax = plt.subplots(figsize=(5, 3.5)) | |
sns.histplot(df['anchor_age'], bins=30, kde=True, color='skyblue', ax=ax) | |
ax.set_xlabel("Age") | |
ax.set_ylabel("Number of Admissions") | |
ax.set_title("Age Distribution") | |
plt.tight_layout() | |
st.pyplot(fig) | |
def create_gender_bar_chart(df): | |
"""Creates a bar chart for Gender Distribution.""" | |
fig, ax = plt.subplots(figsize=(5, 3.5)) | |
sns.countplot(data=df, x='gender', palette='pastel', ax=ax) | |
ax.set_title("Gender Distribution") | |
ax.set_xlabel("Gender") | |
ax.set_ylabel("Number of Admissions") | |
plt.tight_layout() | |
st.pyplot(fig) | |
def create_stacked_bar_admission_race(df): | |
"""Creates a stacked bar chart for Admission Types by Race.""" | |
admission_race = df.groupby(['race', 'admission_type']).size().unstack(fill_value=0) | |
admission_race_percent = admission_race.div(admission_race.sum(axis=1), axis=0) * 100 | |
admission_race_percent.plot(kind='bar', stacked=True, figsize=(8, 6), colormap='tab20') | |
plt.title("Admission Types by Race (%)") | |
plt.xlabel("Race") | |
plt.ylabel("Percentage of Admission Types") | |
plt.legend(title='Admission Type', bbox_to_anchor=(1.05, 1), loc='upper left') | |
plt.tight_layout() | |
st.pyplot(plt.gcf()) | |
def create_los_by_race(df): | |
"""Creates a box plot for Length of Stay by Race.""" | |
fig, ax = plt.subplots(figsize=(6, 4)) | |
sns.boxplot(data=df, x='race', y='los', palette='Pastel1', ax=ax) | |
ax.set_title("Length of Stay by Race") | |
ax.set_xlabel("Race") | |
ax.set_ylabel("Length of Stay (Days)") | |
ax.set_xticklabels(ax.get_xticklabels(), rotation=45) | |
plt.tight_layout() | |
st.pyplot(fig) | |
def create_correlation_heatmap(df): | |
"""Creates a correlation heatmap for numerical features.""" | |
numerical_features = df[['anchor_age', 'los']] | |
corr_matrix = numerical_features.corr() | |
fig, ax = plt.subplots(figsize=(3.5, 3)) | |
sns.heatmap(corr_matrix, annot=True, cmap='coolwarm', fmt=".2f", ax=ax) | |
ax.set_title("Correlation Heatmap") | |
plt.tight_layout() | |
st.pyplot(fig) | |
def create_time_series_heatmap(df): | |
"""Creates an admissions over time heatmap.""" | |
month_order = ['January', 'February', 'March', 'April', 'May', 'June', | |
'July', 'August', 'September', 'October', 'November', 'December'] | |
df['admission_month'] = pd.Categorical(df['admission_month'], categories=month_order, ordered=True) | |
heatmap_df = df.groupby(['admission_year', 'admission_month']).size().reset_index(name='counts') | |
fig = px.density_heatmap( | |
heatmap_df, | |
x='admission_month', | |
y='admission_year', | |
z='counts', | |
histfunc='sum', | |
title='Admissions Over Time', | |
labels={'counts': 'Number of Admissions'}, | |
color_continuous_scale='Blues' | |
) | |
fig.update_xaxes(categoryorder='array', categoryarray=month_order) | |
fig.update_layout(yaxis=dict(autorange='reversed')) | |
fig.update_traces(colorbar=dict(title='Admissions')) | |
st.plotly_chart(fig, use_container_width=True) | |
def create_mortality_by_race(df): | |
"""Creates a bar chart for Mortality Rate by Race.""" | |
mortality_race = df.groupby('race')['hospital_expire_flag'].mean().reset_index() | |
mortality_race['mortality_rate'] = mortality_race['hospital_expire_flag'] * 100 | |
fig, ax = plt.subplots(figsize=(6, 4)) | |
sns.barplot(data=mortality_race, x='race', y='mortality_rate', palette='Set2', ax=ax) | |
ax.set_title("Mortality Rate by Race") | |
ax.set_xlabel("Race") | |
ax.set_ylabel("Mortality Rate (%)") | |
ax.set_xticklabels(ax.get_xticklabels(), rotation=45) | |
plt.tight_layout() | |
st.pyplot(fig) | |
def create_mortality_by_gender(df): | |
"""Creates a bar chart for Mortality Rate by Gender.""" | |
mortality_gender = df.groupby('gender')['hospital_expire_flag'].mean().reset_index() | |
mortality_gender['mortality_rate'] = mortality_gender['hospital_expire_flag'] * 100 | |
fig, ax = plt.subplots(figsize=(6, 4)) | |
sns.barplot(data=mortality_gender, x='gender', y='mortality_rate', palette='Set3', ax=ax) | |
ax.set_title("Mortality Rate by Gender") | |
ax.set_xlabel("Gender") | |
ax.set_ylabel("Mortality Rate (%)") | |
plt.tight_layout() | |
st.pyplot(fig) | |
def create_mortality_by_age_group(df): | |
"""Creates a bar chart for Mortality Rate by Age Group.""" | |
# Define age bins and labels | |
bins = [0, 30, 50, 70, 90, 120] | |
labels = ['0-30', '31-50', '51-70', '71-90', '91-120'] | |
df['age_group'] = pd.cut(df['anchor_age'], bins=bins, labels=labels, right=False) | |
mortality_age = df.groupby('age_group')['hospital_expire_flag'].mean().reset_index() | |
mortality_age['mortality_rate'] = mortality_age['hospital_expire_flag'] * 100 | |
fig, ax = plt.subplots(figsize=(6, 4)) | |
sns.barplot(data=mortality_age, x='age_group', y='mortality_rate', palette='coolwarm', ax=ax) | |
ax.set_title("Mortality Rate by Age Group") | |
ax.set_xlabel("Age Group") | |
ax.set_ylabel("Mortality Rate (%)") | |
plt.tight_layout() | |
st.pyplot(fig) | |
def create_violin_age_race_mortality(df): | |
"""Creates a violin plot for Age Distribution by Race and Mortality.""" | |
fig, ax = plt.subplots(figsize=(8, 6)) | |
sns.violinplot( | |
data=df, | |
x='race', | |
y='anchor_age', | |
hue='hospital_expire_flag', | |
split=True, | |
palette='Set2', | |
ax=ax | |
) | |
ax.set_title("Age Distribution by Race and Mortality") | |
ax.set_xlabel("Race") | |
ax.set_ylabel("Age") | |
ax.legend(title='Mortality', loc='upper right') | |
plt.tight_layout() | |
st.pyplot(fig) | |
def create_heatmap_race_gender_mortality(df): | |
"""Creates a heatmap for Mortality Rate by Race and Gender.""" | |
pivot_table = df.pivot_table( | |
index='race', | |
columns='gender', | |
values='hospital_expire_flag', | |
aggfunc='mean' | |
) * 100 # Convert to percentage | |
fig, ax = plt.subplots(figsize=(8, 6)) | |
sns.heatmap(pivot_table, annot=True, fmt=".1f", cmap='YlOrRd', ax=ax) | |
ax.set_title("Mortality Rate by Race and Gender (%)") | |
ax.set_xlabel("Gender") | |
ax.set_ylabel("Race") | |
plt.tight_layout() | |
st.pyplot(fig) | |
def create_parallel_coordinates(df): | |
"""Creates a parallel coordinates plot for Demographics and Outcomes.""" | |
# Select relevant numerical features | |
parallel_df = df[['anchor_age', 'los', 'hospital_expire_flag']].copy() | |
# Encode categorical variables numerically | |
parallel_df['race_code'] = df['race'].astype('category').cat.codes | |
parallel_df['gender_code'] = df['gender'].astype('category').cat.codes | |
# Create the parallel coordinates plot | |
fig = px.parallel_coordinates( | |
parallel_df, | |
color='hospital_expire_flag', | |
labels={ | |
'anchor_age': 'Age', | |
'los': 'Length of Stay', | |
'hospital_expire_flag': 'Mortality', | |
'race_code': 'Race', | |
'gender_code': 'Gender' | |
}, | |
color_continuous_scale=px.colors.diverging.Tealrose, | |
color_continuous_midpoint=0.5 | |
) | |
fig.update_layout(title='Parallel Coordinates Plot of Demographics and Outcomes') | |
st.plotly_chart(fig, use_container_width=True) | |
def create_treemap_race_mortality(df): | |
"""Creates a treemap for Race and Mortality.""" | |
treemap_df = df.groupby(['race', 'hospital_expire_flag']).size().reset_index(name='counts') | |
treemap_df['Mortality'] = treemap_df['hospital_expire_flag'].map({0: 'Survived', 1: 'Died'}) | |
fig = px.treemap( | |
treemap_df, | |
path=['race', 'Mortality'], | |
values='counts', | |
color='Mortality', | |
color_discrete_map={'Survived':'#66b3ff','Died':'#ff6666'}, | |
title='Treemap of Race and Mortality' | |
) | |
fig.update_layout(margin = dict(t=30, l=0, r=0, b=0)) | |
st.plotly_chart(fig, use_container_width=True) | |
def create_sankey_race_mortality(df): | |
"""Creates a Sankey diagram for Race to Mortality Outcomes.""" | |
sankey_df = df.groupby(['race', 'hospital_expire_flag']).size().reset_index(name='counts') | |
# Map 'hospital_expire_flag' to 'Mortality' status | |
sankey_df['Mortality'] = sankey_df['hospital_expire_flag'].map({0: 'Survived', 1: 'Died'}) | |
# Create source and target labels | |
source = sankey_df['race'].tolist() | |
target = sankey_df['Mortality'].tolist() | |
values = sankey_df['counts'].tolist() | |
# Create a list of unique labels ensuring no duplicates | |
unique_races = sankey_df['race'].unique().tolist() | |
unique_mortality = sankey_df['Mortality'].unique().tolist() | |
labels = unique_races + unique_mortality | |
# Create a mapping from label to index for efficient lookup | |
label_to_index = {label: idx for idx, label in enumerate(labels)} | |
# Map source and target labels to their corresponding indices | |
source_indices = [label_to_index[s] for s in source] | |
target_indices = [label_to_index[t] for t in target] | |
# Optionally, define colors for different node types | |
# For example, races could have one color and mortality outcomes another | |
race_color = "#FFA07A" # Light Salmon | |
mortality_color = "#20B2AA" # Light Sea Green | |
node_colors = [race_color] * len(unique_races) + [mortality_color] * len(unique_mortality) | |
# Create the Sankey diagram | |
fig = go.Figure(data=[go.Sankey( | |
node=dict( | |
pad=15, | |
thickness=20, | |
line=dict(color="black", width=0.5), | |
label=labels, | |
color=node_colors | |
), | |
link=dict( | |
source=source_indices, | |
target=target_indices, | |
value=values | |
) | |
)]) | |
# Add title to the layout | |
fig.update_layout( | |
title_text="Sankey Diagram of Race and Mortality Outcomes", | |
font_size=10 | |
) | |
st.plotly_chart(fig, use_container_width=True) | |
# --------------------------- | |
# Streamlit Application | |
# --------------------------- | |
# Set Streamlit page configuration | |
st.set_page_config( | |
page_title="MIMIC-IV ICU Patient Data Dashboard", | |
layout="wide", | |
initial_sidebar_state="expanded", | |
) | |
# Title and Description | |
st.title("MIMIC-IV ICU Patient Data Dashboard") | |
st.markdown(""" | |
Explore the general feature distribution and outcome metrics of ICU patients from the MIMIC-IV dataset. Utilize the sidebar filters to customize the data view and interact with various visualizations to uncover patterns and insights. | |
""") | |
# Sidebar Filters | |
st.sidebar.header("Filter Data") | |
def load_data(): | |
# Load the dataframes (update the paths as necessary) | |
admissions_df = pd.read_csv('data/admissions.csv') | |
patients_df = pd.read_csv('data/patients.csv') | |
# diagnoses_icd_df = pd.read_csv('data/diagnoses_icd.csv') | |
# pharmacy_df = pd.read_csv('data/pharmacy.csv') | |
# prescriptions_df = pd.read_csv('data/prescriptions.csv') | |
# d_hcpcs_df = pd.read_csv('data/d_hcpcs.csv') | |
# poe_detail_df = pd.read_csv('data/poe_detail.csv') | |
# provider_df = pd.read_csv('data/provider.csv') | |
race_map = {"WHITE":"WHITE", | |
"BLACK/AFRICAN AMERICAN":"BLACK", | |
"OTHER":"OTHER", | |
"UNKNOWN":"UNKNOWN", | |
"HISPANIC/LATINO - PUERTO RICAN":"HISPANIC", | |
"WHITE - OTHER EUROPEAN":"WHITE", | |
"HISPANIC OR LATINO":"HISPANIC", | |
"ASIAN":"ASIAN", | |
"ASIAN - CHINESE":"ASIAN", | |
"WHITE - RUSSIAN":"WHITE", | |
"BLACK/CAPE VERDEAN":"BLACK", | |
"HISPANIC/LATINO - DOMINICAN":"HISPANIC", | |
"BLACK/CARIBBEAN ISLAND":"BLACK", | |
"BLACK/AFRICAN":"BLACK", | |
"PATIENT DECLINED TO ANSWER":"UNKNOWN", | |
"UNABLE TO OBTAIN":"UNKNOWN", | |
"PORTUGUESE":"WHITE", | |
"ASIAN - SOUTH EAST ASIAN":"ASIAN", | |
"HISPANIC/LATINO - GUATEMALAN":"HISPANIC", | |
"ASIAN - ASIAN INDIAN":"ASIAN", | |
"WHITE - EASTERN EUROPEAN":"WHITE", | |
"WHITE - BRAZILIAN":"WHITE", | |
"AMERICAN INDIAN/ALASKA NATIVE":"NATIVES", | |
"HISPANIC/LATINO - SALVADORAN":"HISPANIC", | |
"HISPANIC/LATINO - MEXICAN":"HISPANIC", | |
"HISPANIC/LATINO - COLUMBIAN":"HISPANIC", | |
"MULTIPLE RACE/ETHNICITY":"MULTI-ETHINIC", | |
"HISPANIC/LATINO - HONDURAN":"HISPANIC", | |
"ASIAN - KOREAN":"ASIAN", | |
"SOUTH AMERICAN":"HISPANIC", | |
"HISPANIC/LATINO - CUBAN":"HISPANIC", | |
"HISPANIC/LATINO - CENTRAL AMERICAN":"HISPANIC", | |
"NATIVE HAWAIIAN OR OTHER PACIFIC ISLANDER":"NATIVES"} | |
admissions_df['race'] = admissions_df['race'].map(race_map) | |
# Merge admissions and patients data on 'subject_id' | |
merged_df = pd.merge(admissions_df, patients_df, on='subject_id', how='left') | |
# Handle missing values by dropping rows with critical missing data | |
merged_df = merged_df.dropna(subset=['anchor_age', 'gender', 'race', 'hospital_expire_flag']) | |
# Convert datetime columns | |
merged_df['admittime'] = pd.to_datetime(merged_df['admittime']) | |
merged_df['dischtime'] = pd.to_datetime(merged_df['dischtime']) | |
merged_df['deathtime'] = pd.to_datetime(merged_df['deathtime'], errors='coerce') # Some may not have deathtime | |
# Create derived features | |
merged_df['los'] = (merged_df['dischtime'] - merged_df['admittime']).dt.days | |
merged_df['admission_year'] = merged_df['admittime'].dt.year | |
merged_df['admission_month'] = merged_df['admittime'].dt.month_name() | |
merged_df['admittime_date'] = merged_df['admittime'].dt.date | |
return merged_df | |
merged_df = load_data() | |
# Sidebar Filters Function | |
def add_sidebar_filters(df): | |
# Admission Types | |
admission_types = sorted(df['admission_type'].unique()) | |
selected_admission_types = st.sidebar.multiselect( | |
"Select Admission Type(s):", | |
options=admission_types, | |
default=admission_types | |
) | |
# Insurance Types | |
insurance_types = sorted(df['insurance'].unique()) | |
selected_insurance_types = st.sidebar.multiselect( | |
"Select Insurance Type(s):", | |
options=insurance_types, | |
default=insurance_types | |
) | |
# Gender | |
genders = sorted(df['gender'].unique()) | |
selected_genders = st.sidebar.multiselect( | |
"Select Gender(s):", | |
options=genders, | |
default=genders | |
) | |
# Race | |
races = sorted(df['race'].unique()) | |
selected_races = st.sidebar.multiselect( | |
"Select Race(s):", | |
options=races, | |
default=races | |
) | |
# Year Range | |
min_year = int(df['admission_year'].min()) | |
max_year = int(df['admission_year'].max()) | |
selected_years = st.sidebar.slider( | |
"Select Admission Year Range:", | |
min_value=min_year, | |
max_value=max_year, | |
value=(min_year, max_year) | |
) | |
# Apply Filters | |
filtered_df = df[ | |
(df['admission_type'].isin(selected_admission_types)) & | |
(df['insurance'].isin(selected_insurance_types)) & | |
(df['gender'].isin(selected_genders)) & | |
(df['race'].isin(selected_races)) & | |
(df['admission_year'] >= selected_years[0]) & | |
(df['admission_year'] <= selected_years[1]) | |
] | |
return filtered_df | |
filtered_df = add_sidebar_filters(merged_df) | |
# Display Summary Statistics for Q1 | |
st.header("Summary Statistics") | |
# Create four columns for metrics | |
col1, col2, col3, col4 = st.columns(4) | |
with col1: | |
total_admissions = filtered_df.shape[0] | |
st.metric("Total Admissions", f"{total_admissions:,}") | |
with col2: | |
average_age = filtered_df['anchor_age'].mean() | |
st.metric("Average Age", f"{average_age:.2f} years") | |
with col3: | |
gender_counts = filtered_df['gender'].value_counts() | |
male_count = gender_counts.get('M', 0) | |
female_count = gender_counts.get('F', 0) | |
st.metric("Male Patients", f"{male_count:,}") | |
st.metric("Female Patients", f"{female_count:,}") | |
with col4: | |
mortality_rate = filtered_df['hospital_expire_flag'].mean() * 100 # Percentage | |
st.metric("Mortality Rate", f"{mortality_rate:.2f}%") | |
st.markdown("---") | |
# Create Tabs for Q1 and Q2 | |
tabs = st.tabs(["General Overview", "Potential Biases"]) | |
# --------------------------- | |
# Q1: General Overview | |
# --------------------------- | |
with tabs[0]: | |
st.subheader("General Feature Distribution and Outcome Metrics") | |
# Define the number of columns per row | |
num_cols = 2 | |
# Define all Q1 plots in a list with titles and plot-generating functions | |
q1_plots = [ | |
{ | |
"title": "Age Distribution of ICU Patients", | |
"plot": lambda: create_histogram(filtered_df) | |
}, | |
{ | |
"title": "Gender Distribution of ICU Patients", | |
"plot": lambda: create_gender_bar_chart(filtered_df) | |
}, | |
{ | |
"title": "Admission Types by Race", | |
"plot": lambda: create_stacked_bar_admission_race(filtered_df) | |
}, | |
{ | |
"title": "Length of Stay by Race", | |
"plot": lambda: create_los_by_race(filtered_df) | |
}, | |
{ | |
"title": "Correlation Heatmap of Age and LOS", | |
"plot": lambda: create_correlation_heatmap(filtered_df) | |
}, | |
{ | |
"title": "Admissions Over Time", | |
"plot": lambda: create_time_series_heatmap(filtered_df) | |
} | |
] | |
# Arrange Q1 plots in a grid layout | |
for i in range(0, len(q1_plots), num_cols): | |
cols = st.columns(num_cols) | |
for j in range(num_cols): | |
if i + j < len(q1_plots): | |
with cols[j]: | |
st.subheader(q1_plots[i + j]["title"]) | |
q1_plots[i + j]["plot"]() | |
# --------------------------- | |
# Q2: Potential Biases | |
# --------------------------- | |
with tabs[1]: | |
st.subheader("Analyzing Potential Biases Across Demographics") | |
# Define the number of columns per row | |
num_cols = 2 | |
# Define all Q2 plots in a list with titles and plot-generating functions | |
q2_plots = [ | |
{ | |
"title": "Mortality Rate by Race", | |
"plot": lambda: create_mortality_by_race(filtered_df) | |
}, | |
{ | |
"title": "Mortality Rate by Gender", | |
"plot": lambda: create_mortality_by_gender(filtered_df) | |
}, | |
{ | |
"title": "Mortality Rate by Age Group", | |
"plot": lambda: create_mortality_by_age_group(filtered_df) | |
}, | |
{ | |
"title": "Age Distribution by Race and Mortality", | |
"plot": lambda: create_violin_age_race_mortality(filtered_df) | |
}, | |
{ | |
"title": "Heatmap: Race & Gender vs. Mortality", | |
"plot": lambda: create_heatmap_race_gender_mortality(filtered_df) | |
}, | |
{ | |
"title": "Parallel Coordinates Plot of Demographics and Outcomes", | |
"plot": lambda: create_parallel_coordinates(filtered_df) | |
}, | |
{ | |
"title": "Treemap of Race and Mortality", | |
"plot": lambda: create_treemap_race_mortality(filtered_df) | |
}, | |
{ | |
"title": "Sankey Diagram: Race to Mortality Outcomes", | |
"plot": lambda: create_sankey_race_mortality(filtered_df) | |
} | |
] | |
# Arrange Q2 plots in a grid layout | |
for i in range(0, len(q2_plots), num_cols): | |
cols = st.columns(num_cols) | |
for j in range(num_cols): | |
if i + j < len(q2_plots): | |
with cols[j]: | |
st.subheader(q2_plots[i + j]["title"]) | |
q2_plots[i + j]["plot"]() | |
# Footer | |
st.markdown(""" | |
--- | |
**Data Source:** MIMIC-IV Dataset | |
**Project:** Investigating Biases in ICU Patient Data | |
**Developed with:** Streamlit, Python | |
""") | |