Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import plotly.express as px | |
import plotly.graph_objects as go | |
# Plot Function Definitions | |
def create_gender_pie_chart(df): | |
"""Creates a bar chart for Gender Distribution.""" | |
gender_counts = df['gender'].value_counts().reset_index() | |
gender_counts.columns = ['Gender', 'Count'] | |
fig_gender = px.pie( | |
gender_counts, | |
names='Gender', | |
values='Count', | |
hover_data=['Count'], | |
hole=0.3 | |
) | |
st.plotly_chart(fig_gender, use_container_width=True) | |
def create_race_pie_chart(df): | |
race_counts = df['race'].value_counts().reset_index() | |
race_counts.columns = ['Race Type', 'Count'] | |
fig_race = px.pie( | |
race_counts, | |
names='Race Type', | |
values='Count', | |
hover_data=['Count'], | |
hole=0.3 | |
) | |
st.plotly_chart(fig_race, use_container_width=True) | |
def create_insurance_pie_chart(df): | |
insurance_counts = df['insurance'].value_counts().reset_index() | |
insurance_counts.columns = ['Insurance Type', 'Count'] | |
fig_insurance = px.pie( | |
insurance_counts, | |
names='Insurance Type', | |
values='Count', | |
hover_data=['Count'], | |
hole=0.3 | |
) | |
st.plotly_chart(fig_insurance, use_container_width=True) | |
def create_mortality_pie_chart(df): | |
#plt.figure(figsize=(6,3), facecolor='white') | |
total_admissions = df.shape[0] | |
labels = ['Survived', 'Died'] | |
sizes = [total_admissions - df['hospital_expire_flag'].sum(), | |
df['hospital_expire_flag'].sum()] | |
colors = ['#66b3ff', '#ff6666'] | |
explode = (0.1, 0) | |
plt.pie(sizes, explode=explode, labels=labels, colors=colors, | |
autopct='%1.1f%%', startangle=140, textprops={'fontsize': 14}) | |
plt.axis('equal') | |
plt.tight_layout() | |
st.pyplot(plt.gcf()) | |
def create_admission_type_bar_chart(df): | |
admission_counts = df['admission_type'].value_counts().reset_index() | |
admission_counts.columns = ['Admission Type', 'Count'] | |
fig_admission = px.bar( | |
admission_counts, | |
y='Admission Type', | |
x='Count', | |
color='Admission Type', | |
labels={'Count': 'Number of Admissions', 'Admission Type': 'Admission Type'}, | |
hover_data=['Count'] | |
) | |
st.plotly_chart(fig_admission, use_container_width=True) | |
def create_time_series_heatmap(df): | |
"""Creates an admissions over time heatmap.""" | |
month_order = ['January', 'February', 'March', 'April', 'May', 'June', | |
'July', 'August', 'September', 'October', 'November', 'December'] | |
df['admission_month'] = pd.Categorical(df['admission_month'], categories=month_order, ordered=True) | |
heatmap_df = df.groupby(['admission_year', 'admission_month']).size().reset_index(name='counts') | |
fig = px.density_heatmap( | |
heatmap_df, | |
x='admission_month', | |
y='admission_year', | |
z='counts', | |
histfunc='sum', | |
labels={'counts': 'Number of Admissions', 'admission_month': 'Admission Month', 'admission_year': 'Admission Year'}, | |
color_continuous_scale='rdbu' | |
) | |
fig.update_xaxes(categoryorder='array', categoryarray=month_order) | |
fig.update_layout(yaxis=dict(autorange='reversed')) | |
fig.update_traces(colorbar=dict(title='Admissions')) | |
st.plotly_chart(fig, use_container_width=True) | |
# def create_stacked_bar_admission_race(df): | |
# """Creates a stacked bar chart for Admission Types by Race.""" | |
# admission_race = df.groupby(['race', 'admission_type']).size().unstack(fill_value=0) | |
# admission_race_percent = admission_race.div(admission_race.sum(axis=1), axis=0) * 100 | |
# admission_race_percent.plot(kind='bar', stacked=True, figsize=(8, 6), colormap='tab20') | |
# plt.xlabel("Race") | |
# plt.ylabel("Percentage of Admission Types") | |
# plt.legend(title='Admission Type', bbox_to_anchor=(1.05, 1), loc='upper left') | |
# plt.tight_layout() | |
# st.pyplot(plt.gcf()) | |
# def create_los_by_race(df): | |
# """Creates a box plot for Length of Stay by Race.""" | |
# fig, ax = plt.subplots(figsize=(6, 4)) | |
# sns.boxplot(data=df, x='race', y='los', palette='Pastel1', ax=ax) | |
# ax.set_xlabel("Race") | |
# ax.set_ylabel("Length of Stay (Days)") | |
# ax.set_xticklabels(ax.get_xticklabels(), rotation=45) | |
# plt.tight_layout() | |
# st.pyplot(fig) | |
# def create_correlation_heatmap(df): | |
# """Creates a correlation heatmap for numerical features.""" | |
# numerical_features = df[['anchor_age', 'los']] | |
# corr_matrix = numerical_features.corr() | |
# fig, ax = plt.subplots(figsize=(3.5, 3)) | |
# sns.heatmap(corr_matrix, annot=True, cmap='coolwarm', fmt=".2f", ax=ax) | |
# plt.tight_layout() | |
# st.pyplot(fig) | |
def create_age_distribution_by_gender(df): | |
plt.figure(figsize=(12, 8)) | |
sns.histplot(data=df, x='anchor_age', bins=30, | |
kde=True, palette='bright', hue='gender') | |
plt.xlabel('Age', fontsize=16) | |
plt.ylabel('Number of Admissions', fontsize=16) | |
plt.xticks(fontsize=16) | |
plt.yticks(fontsize=16) | |
plt.tight_layout() | |
st.pyplot(plt.gcf()) | |
def create_age_distribution_by_admission_type(df): | |
plt.figure(figsize=(12, 8)) | |
sns.boxenplot(data=df, x='admission_type', | |
y='anchor_age', palette='Set3') | |
plt.xlabel('Admission Type', fontsize=16) | |
plt.ylabel('Age', fontsize=16) | |
plt.xticks(fontsize=16, rotation=45) | |
plt.yticks(fontsize=16) | |
plt.tight_layout() | |
st.pyplot(plt.gcf()) | |
def create_mortality_by_race(df): | |
"""Creates a bar chart for Mortality Rate by Race.""" | |
mortality_race = df.groupby('race')['hospital_expire_flag'].mean().reset_index() | |
mortality_race['mortality_rate'] = mortality_race['hospital_expire_flag'] * 100 | |
fig, ax = plt.subplots(figsize=(6, 4)) | |
sns.barplot(data=mortality_race, x='race', y='mortality_rate', palette='Set2', ax=ax) | |
ax.set_xlabel("Race") | |
ax.set_ylabel("Mortality Rate (%)") | |
ax.set_xticklabels(ax.get_xticklabels(), rotation=45) | |
plt.tight_layout() | |
st.pyplot(fig) | |
def create_mortality_by_gender(df): | |
"""Creates a bar chart for Mortality Rate by Gender.""" | |
mortality_gender = df.groupby('gender')['hospital_expire_flag'].mean().reset_index() | |
mortality_gender['mortality_rate'] = mortality_gender['hospital_expire_flag'] * 100 | |
fig, ax = plt.subplots(figsize=(6, 4)) | |
sns.barplot(data=mortality_gender, x='gender', y='mortality_rate', palette='Set3', ax=ax) | |
ax.set_xlabel("Gender") | |
ax.set_ylabel("Mortality Rate (%)") | |
plt.tight_layout() | |
st.pyplot(fig) | |
def create_mortality_by_age_group(df): | |
"""Creates a bar chart for Mortality Rate by Age Group.""" | |
bins = [0, 30, 50, 70, 90, 120] | |
labels = ['0-30', '31-50', '51-70', '71-90', '91-120'] | |
df['age_group'] = pd.cut(df['anchor_age'], bins=bins, labels=labels, right=False) | |
mortality_age = df.groupby('age_group')['hospital_expire_flag'].mean().reset_index() | |
mortality_age['mortality_rate'] = mortality_age['hospital_expire_flag'] * 100 | |
fig, ax = plt.subplots(figsize=(6, 4)) | |
sns.barplot(data=mortality_age, x='age_group', y='mortality_rate', palette='coolwarm', ax=ax) | |
ax.set_xlabel("Age Group") | |
ax.set_ylabel("Mortality Rate (%)") | |
plt.tight_layout() | |
st.pyplot(fig) | |
def create_violin_age_race_mortality(df): | |
"""Creates a violin plot for Age Distribution by Race and Mortality.""" | |
fig, ax = plt.subplots(figsize=(8, 6)) | |
sns.violinplot( | |
data=df, | |
x='race', | |
y='anchor_age', | |
hue='hospital_expire_flag', | |
split=True, | |
palette='Set2', | |
ax=ax | |
) | |
ax.set_xlabel("Race") | |
ax.set_ylabel("Age") | |
ax.legend(title='Mortality', loc='upper right') | |
plt.tight_layout() | |
st.pyplot(fig) | |
def create_heatmap_race_gender_mortality(df): | |
"""Creates a heatmap for Mortality Rate by Race and Gender.""" | |
pivot_table = df.pivot_table( | |
index='race', | |
columns='gender', | |
values='hospital_expire_flag', | |
aggfunc='mean' | |
) * 100 | |
fig, ax = plt.subplots(figsize=(8, 6)) | |
sns.heatmap(pivot_table, annot=True, fmt=".1f", cmap='YlOrRd', ax=ax) | |
ax.set_xlabel("Gender") | |
ax.set_ylabel("Race") | |
plt.tight_layout() | |
st.pyplot(fig) | |
def create_treemap_race_mortality(df): | |
"""Creates a treemap for Race and Mortality.""" | |
treemap_df = df.groupby(['race', 'hospital_expire_flag']).size().reset_index(name='counts') | |
treemap_df['Mortality'] = treemap_df['hospital_expire_flag'].map({0: 'Survived', 1: 'Died'}) | |
fig = px.treemap( | |
treemap_df, | |
path=['race', 'Mortality'], | |
values='counts', | |
color='Mortality', | |
color_discrete_map={'Survived':'#66b3ff','Died':'#ff6666'} | |
) | |
fig.update_layout(margin = dict(t=30, l=0, r=0, b=0)) | |
st.plotly_chart(fig, use_container_width=True) | |
# Streamlit Application | |
# Set Streamlit page configuration | |
st.set_page_config( | |
page_title="MIMIC-IV ICU Patient Data Dashboard", | |
layout="wide", | |
initial_sidebar_state="expanded", | |
) | |
st.title("MIMIC-IV ICU Patient Data Dashboard") | |
st.markdown(''' | |
Explore the general feature distribution and demographics related bias in ICU patients from the MIMIC-IV dataset. Utilize the sidebar filters to customize the data view''' | |
) | |
# Sidebar Filters | |
st.sidebar.header("Filter Data") | |
def load_data(): | |
admissions_df = pd.read_feather('data/admissions.feather') | |
patients_df = pd.read_feather('data/patients.feather') | |
# diagnoses_icd_df = pd.read_csv('data/diagnoses_icd.csv') | |
# pharmacy_df = pd.read_csv('data/pharmacy.csv') | |
# prescriptions_df = pd.read_csv('data/prescriptions.csv') | |
# d_hcpcs_df = pd.read_csv('data/d_hcpcs.csv') | |
# poe_detail_df = pd.read_csv('data/poe_detail.csv') | |
# provider_df = pd.read_csv('data/provider.csv') | |
race_map = {"WHITE":"WHITE", | |
"BLACK/AFRICAN AMERICAN":"BLACK", | |
"OTHER":"OTHER", | |
"UNKNOWN":"UNKNOWN", | |
"HISPANIC/LATINO - PUERTO RICAN":"HISPANIC", | |
"WHITE - OTHER EUROPEAN":"WHITE", | |
"HISPANIC OR LATINO":"HISPANIC", | |
"ASIAN":"ASIAN", | |
"ASIAN - CHINESE":"ASIAN", | |
"WHITE - RUSSIAN":"WHITE", | |
"BLACK/CAPE VERDEAN":"BLACK", | |
"HISPANIC/LATINO - DOMINICAN":"HISPANIC", | |
"BLACK/CARIBBEAN ISLAND":"BLACK", | |
"BLACK/AFRICAN":"BLACK", | |
"PATIENT DECLINED TO ANSWER":"UNKNOWN", | |
"UNABLE TO OBTAIN":"UNKNOWN", | |
"PORTUGUESE":"WHITE", | |
"ASIAN - SOUTH EAST ASIAN":"ASIAN", | |
"HISPANIC/LATINO - GUATEMALAN":"HISPANIC", | |
"ASIAN - ASIAN INDIAN":"ASIAN", | |
"WHITE - EASTERN EUROPEAN":"WHITE", | |
"WHITE - BRAZILIAN":"WHITE", | |
"AMERICAN INDIAN/ALASKA NATIVE":"NATIVES", | |
"HISPANIC/LATINO - SALVADORAN":"HISPANIC", | |
"HISPANIC/LATINO - MEXICAN":"HISPANIC", | |
"HISPANIC/LATINO - COLUMBIAN":"HISPANIC", | |
"MULTIPLE RACE/ETHNICITY":"MULTI-ETHINIC", | |
"HISPANIC/LATINO - HONDURAN":"HISPANIC", | |
"ASIAN - KOREAN":"ASIAN", | |
"SOUTH AMERICAN":"HISPANIC", | |
"HISPANIC/LATINO - CUBAN":"HISPANIC", | |
"HISPANIC/LATINO - CENTRAL AMERICAN":"HISPANIC", | |
"NATIVE HAWAIIAN OR OTHER PACIFIC ISLANDER":"NATIVES"} | |
admissions_df['race'] = admissions_df['race'].map(race_map) | |
merged_df = pd.merge(admissions_df, patients_df, on='subject_id', how='left') | |
merged_df = merged_df.dropna(subset=['anchor_age', 'gender', 'race', 'hospital_expire_flag']) | |
merged_df['admittime'] = pd.to_datetime(merged_df['admittime']) | |
merged_df['dischtime'] = pd.to_datetime(merged_df['dischtime']) | |
merged_df['deathtime'] = pd.to_datetime(merged_df['deathtime'], errors='coerce') | |
# Create derived features | |
merged_df['los'] = (merged_df['dischtime'] - merged_df['admittime']).dt.days | |
merged_df['admission_year'] = merged_df['admittime'].dt.year | |
merged_df['admission_month'] = merged_df['admittime'].dt.month_name() | |
merged_df['admittime_date'] = merged_df['admittime'].dt.date | |
return merged_df | |
merged_df = load_data() | |
# Sidebar Filters Function | |
def add_sidebar_filters(df): | |
# Admission Types | |
admission_types = sorted(df['admission_type'].unique()) | |
selected_admission_types = st.sidebar.multiselect( | |
"Select Admission Type(s):", | |
options=admission_types, | |
default=admission_types | |
) | |
# Insurance Types | |
insurance_types = sorted(df['insurance'].unique()) | |
selected_insurance_types = st.sidebar.multiselect( | |
"Select Insurance Type(s):", | |
options=insurance_types, | |
default=insurance_types | |
) | |
# Gender | |
genders = sorted(df['gender'].unique()) | |
selected_genders = st.sidebar.multiselect( | |
"Select Gender(s):", | |
options=genders, | |
default=genders | |
) | |
# Race | |
races = sorted(df['race'].unique()) | |
selected_races = st.sidebar.multiselect( | |
"Select Race(s):", | |
options=races, | |
default=races | |
) | |
# Year Range | |
min_year = int(df['admission_year'].min()) | |
max_year = int(df['admission_year'].max()) | |
selected_years = st.sidebar.slider( | |
"Select Admission Year Range:", | |
min_value=min_year, | |
max_value=max_year, | |
value=(min_year, max_year) | |
) | |
# Apply Filters | |
filtered_df = df[ | |
(df['admission_type'].isin(selected_admission_types)) & | |
(df['insurance'].isin(selected_insurance_types)) & | |
(df['gender'].isin(selected_genders)) & | |
(df['race'].isin(selected_races)) & | |
(df['admission_year'] >= selected_years[0]) & | |
(df['admission_year'] <= selected_years[1]) | |
] | |
return filtered_df | |
filtered_df = add_sidebar_filters(merged_df) | |
# Display Summary Statistics for Q1 | |
st.header("Summary Statistics") | |
# Create four columns for metrics | |
col1, col2, col3, col4 = st.columns(4) | |
with col1: | |
total_admissions = filtered_df.shape[0] | |
st.metric("Total Admissions", f"{total_admissions:,}") | |
with col2: | |
average_age = filtered_df['anchor_age'].mean() | |
st.metric("Average Age", f"{average_age:.2f} years") | |
with col3: | |
gender_counts = filtered_df['gender'].value_counts() | |
male_count = gender_counts.get('M', 0) | |
female_count = gender_counts.get('F', 0) | |
st.metric("Male Patients", f"{male_count:,}") | |
st.metric("Female Patients", f"{female_count:,}") | |
with col4: | |
mortality_rate = filtered_df['hospital_expire_flag'].mean() * 100 # Percentage | |
st.metric("Mortality Rate", f"{mortality_rate:.2f}%") | |
st.markdown("---") | |
# Create Tabs for Q1 and Q2 | |
tabs = st.tabs(["General Overview", "Potential Biases"]) | |
# Q1: General Overview | |
with tabs[0]: | |
st.subheader("General Feature Distribution and Outcome Metrics") | |
# Define the number of columns per row | |
num_cols = 2 | |
# Define all Q1 plots in a list with titles and plot-generating functions | |
q1_plots_2_col = [ | |
{ | |
"title": "Gender Distribution", | |
"plot": lambda: create_gender_pie_chart(filtered_df) | |
}, | |
{ | |
"title": "Race Distribution", | |
"plot": lambda: create_race_pie_chart(filtered_df) | |
}, | |
{ | |
"title": "Insurance Type Distribution", | |
"plot": lambda: create_insurance_pie_chart(filtered_df) | |
}, | |
{ | |
"title": "Mortality Rate of ICU Patients", | |
"plot": lambda: create_mortality_pie_chart(filtered_df) | |
} | |
] | |
# Arrange Q1 plots in a grid layout | |
for i in range(0, len(q1_plots_2_col), num_cols): | |
cols = st.columns(num_cols) | |
for j in range(num_cols): | |
if i + j < len(q1_plots_2_col): | |
with cols[j]: | |
st.subheader(q1_plots_2_col[i + j]["title"]) | |
q1_plots_2_col[i + j]["plot"]() | |
num_cols = 1 | |
q1_plots_1_col = [ | |
{ | |
"title": "Admission Type Count", | |
"plot": lambda: create_admission_type_bar_chart(filtered_df) | |
}, | |
{ | |
"title": "Admissions Over Time", | |
"plot": lambda: create_time_series_heatmap(filtered_df) | |
} | |
] | |
# Arrange Q1 plots in a grid layout | |
for i in range(0, len(q1_plots_1_col), num_cols): | |
cols = st.columns(num_cols) | |
for j in range(num_cols): | |
if i + j < len(q1_plots_1_col): | |
with cols[j]: | |
st.subheader(q1_plots_1_col[i + j]["title"]) | |
q1_plots_1_col[i + j]["plot"]() | |
# Q2: Potential Biases | |
with tabs[1]: | |
st.subheader("Analyzing Potential Biases Across Demographics") | |
# Define the number of columns per row | |
num_cols = 2 | |
# Define all Q2 plots in a list with titles and plot-generating functions | |
q2_plots = [ | |
{ | |
"title": "Age Distribution of ICU Patients", | |
"plot": lambda: create_age_distribution_by_gender(filtered_df) | |
}, | |
{ | |
"title": "Boxen Plot of Age Distribution by Admission Type", | |
"plot": lambda: create_age_distribution_by_admission_type(filtered_df) | |
}, | |
{ | |
"title": "Mortality Rate by Race", | |
"plot": lambda: create_mortality_by_race(filtered_df) | |
}, | |
{ | |
"title": "Mortality Rate by Gender", | |
"plot": lambda: create_mortality_by_gender(filtered_df) | |
}, | |
{ | |
"title": "Mortality Rate by Age Group", | |
"plot": lambda: create_mortality_by_age_group(filtered_df) | |
}, | |
{ | |
"title": "Age Distribution by Race and Mortality", | |
"plot": lambda: create_violin_age_race_mortality(filtered_df) | |
}, | |
{ | |
"title": "Heatmap: Race & Gender vs. Mortality", | |
"plot": lambda: create_heatmap_race_gender_mortality(filtered_df) | |
}, | |
{ | |
"title": "Treemap of Race and Mortality", | |
"plot": lambda: create_treemap_race_mortality(filtered_df) | |
} | |
] | |
# Arrange Q2 plots in a grid layout | |
for i in range(0, len(q2_plots), num_cols): | |
cols = st.columns(num_cols) | |
for j in range(num_cols): | |
if i + j < len(q2_plots): | |
with cols[j]: | |
st.subheader(q2_plots[i + j]["title"]) | |
q2_plots[i + j]["plot"]() | |
# Footer | |
st.markdown(""" | |
--- | |
**Data Source:** MIMIC-IV Dataset | |
**Project:** Fairness in EHR Data | |
**Developed with:** Streamlit, Python | |
**Q3 Visuals:** https://idyllic-cucurucho-672fc1.netlify.app/ | |
""") | |