BiswajitPadhi99's picture
Add Q3 Visuals link
1c4c601
import streamlit as st
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
# Plot Function Definitions
def create_gender_pie_chart(df):
"""Creates a bar chart for Gender Distribution."""
gender_counts = df['gender'].value_counts().reset_index()
gender_counts.columns = ['Gender', 'Count']
fig_gender = px.pie(
gender_counts,
names='Gender',
values='Count',
hover_data=['Count'],
hole=0.3
)
st.plotly_chart(fig_gender, use_container_width=True)
def create_race_pie_chart(df):
race_counts = df['race'].value_counts().reset_index()
race_counts.columns = ['Race Type', 'Count']
fig_race = px.pie(
race_counts,
names='Race Type',
values='Count',
hover_data=['Count'],
hole=0.3
)
st.plotly_chart(fig_race, use_container_width=True)
def create_insurance_pie_chart(df):
insurance_counts = df['insurance'].value_counts().reset_index()
insurance_counts.columns = ['Insurance Type', 'Count']
fig_insurance = px.pie(
insurance_counts,
names='Insurance Type',
values='Count',
hover_data=['Count'],
hole=0.3
)
st.plotly_chart(fig_insurance, use_container_width=True)
def create_mortality_pie_chart(df):
#plt.figure(figsize=(6,3), facecolor='white')
total_admissions = df.shape[0]
labels = ['Survived', 'Died']
sizes = [total_admissions - df['hospital_expire_flag'].sum(),
df['hospital_expire_flag'].sum()]
colors = ['#66b3ff', '#ff6666']
explode = (0.1, 0)
plt.pie(sizes, explode=explode, labels=labels, colors=colors,
autopct='%1.1f%%', startangle=140, textprops={'fontsize': 14})
plt.axis('equal')
plt.tight_layout()
st.pyplot(plt.gcf())
def create_admission_type_bar_chart(df):
admission_counts = df['admission_type'].value_counts().reset_index()
admission_counts.columns = ['Admission Type', 'Count']
fig_admission = px.bar(
admission_counts,
y='Admission Type',
x='Count',
color='Admission Type',
labels={'Count': 'Number of Admissions', 'Admission Type': 'Admission Type'},
hover_data=['Count']
)
st.plotly_chart(fig_admission, use_container_width=True)
def create_time_series_heatmap(df):
"""Creates an admissions over time heatmap."""
month_order = ['January', 'February', 'March', 'April', 'May', 'June',
'July', 'August', 'September', 'October', 'November', 'December']
df['admission_month'] = pd.Categorical(df['admission_month'], categories=month_order, ordered=True)
heatmap_df = df.groupby(['admission_year', 'admission_month']).size().reset_index(name='counts')
fig = px.density_heatmap(
heatmap_df,
x='admission_month',
y='admission_year',
z='counts',
histfunc='sum',
labels={'counts': 'Number of Admissions', 'admission_month': 'Admission Month', 'admission_year': 'Admission Year'},
color_continuous_scale='rdbu'
)
fig.update_xaxes(categoryorder='array', categoryarray=month_order)
fig.update_layout(yaxis=dict(autorange='reversed'))
fig.update_traces(colorbar=dict(title='Admissions'))
st.plotly_chart(fig, use_container_width=True)
# def create_stacked_bar_admission_race(df):
# """Creates a stacked bar chart for Admission Types by Race."""
# admission_race = df.groupby(['race', 'admission_type']).size().unstack(fill_value=0)
# admission_race_percent = admission_race.div(admission_race.sum(axis=1), axis=0) * 100
# admission_race_percent.plot(kind='bar', stacked=True, figsize=(8, 6), colormap='tab20')
# plt.xlabel("Race")
# plt.ylabel("Percentage of Admission Types")
# plt.legend(title='Admission Type', bbox_to_anchor=(1.05, 1), loc='upper left')
# plt.tight_layout()
# st.pyplot(plt.gcf())
# def create_los_by_race(df):
# """Creates a box plot for Length of Stay by Race."""
# fig, ax = plt.subplots(figsize=(6, 4))
# sns.boxplot(data=df, x='race', y='los', palette='Pastel1', ax=ax)
# ax.set_xlabel("Race")
# ax.set_ylabel("Length of Stay (Days)")
# ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
# plt.tight_layout()
# st.pyplot(fig)
# def create_correlation_heatmap(df):
# """Creates a correlation heatmap for numerical features."""
# numerical_features = df[['anchor_age', 'los']]
# corr_matrix = numerical_features.corr()
# fig, ax = plt.subplots(figsize=(3.5, 3))
# sns.heatmap(corr_matrix, annot=True, cmap='coolwarm', fmt=".2f", ax=ax)
# plt.tight_layout()
# st.pyplot(fig)
def create_age_distribution_by_gender(df):
plt.figure(figsize=(12, 8))
sns.histplot(data=df, x='anchor_age', bins=30,
kde=True, palette='bright', hue='gender')
plt.xlabel('Age', fontsize=16)
plt.ylabel('Number of Admissions', fontsize=16)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.tight_layout()
st.pyplot(plt.gcf())
def create_age_distribution_by_admission_type(df):
plt.figure(figsize=(12, 8))
sns.boxenplot(data=df, x='admission_type',
y='anchor_age', palette='Set3')
plt.xlabel('Admission Type', fontsize=16)
plt.ylabel('Age', fontsize=16)
plt.xticks(fontsize=16, rotation=45)
plt.yticks(fontsize=16)
plt.tight_layout()
st.pyplot(plt.gcf())
def create_mortality_by_race(df):
"""Creates a bar chart for Mortality Rate by Race."""
mortality_race = df.groupby('race')['hospital_expire_flag'].mean().reset_index()
mortality_race['mortality_rate'] = mortality_race['hospital_expire_flag'] * 100
fig, ax = plt.subplots(figsize=(6, 4))
sns.barplot(data=mortality_race, x='race', y='mortality_rate', palette='Set2', ax=ax)
ax.set_xlabel("Race")
ax.set_ylabel("Mortality Rate (%)")
ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
plt.tight_layout()
st.pyplot(fig)
def create_mortality_by_gender(df):
"""Creates a bar chart for Mortality Rate by Gender."""
mortality_gender = df.groupby('gender')['hospital_expire_flag'].mean().reset_index()
mortality_gender['mortality_rate'] = mortality_gender['hospital_expire_flag'] * 100
fig, ax = plt.subplots(figsize=(6, 4))
sns.barplot(data=mortality_gender, x='gender', y='mortality_rate', palette='Set3', ax=ax)
ax.set_xlabel("Gender")
ax.set_ylabel("Mortality Rate (%)")
plt.tight_layout()
st.pyplot(fig)
def create_mortality_by_age_group(df):
"""Creates a bar chart for Mortality Rate by Age Group."""
bins = [0, 30, 50, 70, 90, 120]
labels = ['0-30', '31-50', '51-70', '71-90', '91-120']
df['age_group'] = pd.cut(df['anchor_age'], bins=bins, labels=labels, right=False)
mortality_age = df.groupby('age_group')['hospital_expire_flag'].mean().reset_index()
mortality_age['mortality_rate'] = mortality_age['hospital_expire_flag'] * 100
fig, ax = plt.subplots(figsize=(6, 4))
sns.barplot(data=mortality_age, x='age_group', y='mortality_rate', palette='coolwarm', ax=ax)
ax.set_xlabel("Age Group")
ax.set_ylabel("Mortality Rate (%)")
plt.tight_layout()
st.pyplot(fig)
def create_violin_age_race_mortality(df):
"""Creates a violin plot for Age Distribution by Race and Mortality."""
fig, ax = plt.subplots(figsize=(8, 6))
sns.violinplot(
data=df,
x='race',
y='anchor_age',
hue='hospital_expire_flag',
split=True,
palette='Set2',
ax=ax
)
ax.set_xlabel("Race")
ax.set_ylabel("Age")
ax.legend(title='Mortality', loc='upper right')
plt.tight_layout()
st.pyplot(fig)
def create_heatmap_race_gender_mortality(df):
"""Creates a heatmap for Mortality Rate by Race and Gender."""
pivot_table = df.pivot_table(
index='race',
columns='gender',
values='hospital_expire_flag',
aggfunc='mean'
) * 100
fig, ax = plt.subplots(figsize=(8, 6))
sns.heatmap(pivot_table, annot=True, fmt=".1f", cmap='YlOrRd', ax=ax)
ax.set_xlabel("Gender")
ax.set_ylabel("Race")
plt.tight_layout()
st.pyplot(fig)
def create_treemap_race_mortality(df):
"""Creates a treemap for Race and Mortality."""
treemap_df = df.groupby(['race', 'hospital_expire_flag']).size().reset_index(name='counts')
treemap_df['Mortality'] = treemap_df['hospital_expire_flag'].map({0: 'Survived', 1: 'Died'})
fig = px.treemap(
treemap_df,
path=['race', 'Mortality'],
values='counts',
color='Mortality',
color_discrete_map={'Survived':'#66b3ff','Died':'#ff6666'}
)
fig.update_layout(margin = dict(t=30, l=0, r=0, b=0))
st.plotly_chart(fig, use_container_width=True)
# Streamlit Application
# Set Streamlit page configuration
st.set_page_config(
page_title="MIMIC-IV ICU Patient Data Dashboard",
layout="wide",
initial_sidebar_state="expanded",
)
st.title("MIMIC-IV ICU Patient Data Dashboard")
st.markdown('''
Explore the general feature distribution and demographics related bias in ICU patients from the MIMIC-IV dataset. Utilize the sidebar filters to customize the data view'''
)
# Sidebar Filters
st.sidebar.header("Filter Data")
@st.cache_data
def load_data():
admissions_df = pd.read_feather('data/admissions.feather')
patients_df = pd.read_feather('data/patients.feather')
# diagnoses_icd_df = pd.read_csv('data/diagnoses_icd.csv')
# pharmacy_df = pd.read_csv('data/pharmacy.csv')
# prescriptions_df = pd.read_csv('data/prescriptions.csv')
# d_hcpcs_df = pd.read_csv('data/d_hcpcs.csv')
# poe_detail_df = pd.read_csv('data/poe_detail.csv')
# provider_df = pd.read_csv('data/provider.csv')
race_map = {"WHITE":"WHITE",
"BLACK/AFRICAN AMERICAN":"BLACK",
"OTHER":"OTHER",
"UNKNOWN":"UNKNOWN",
"HISPANIC/LATINO - PUERTO RICAN":"HISPANIC",
"WHITE - OTHER EUROPEAN":"WHITE",
"HISPANIC OR LATINO":"HISPANIC",
"ASIAN":"ASIAN",
"ASIAN - CHINESE":"ASIAN",
"WHITE - RUSSIAN":"WHITE",
"BLACK/CAPE VERDEAN":"BLACK",
"HISPANIC/LATINO - DOMINICAN":"HISPANIC",
"BLACK/CARIBBEAN ISLAND":"BLACK",
"BLACK/AFRICAN":"BLACK",
"PATIENT DECLINED TO ANSWER":"UNKNOWN",
"UNABLE TO OBTAIN":"UNKNOWN",
"PORTUGUESE":"WHITE",
"ASIAN - SOUTH EAST ASIAN":"ASIAN",
"HISPANIC/LATINO - GUATEMALAN":"HISPANIC",
"ASIAN - ASIAN INDIAN":"ASIAN",
"WHITE - EASTERN EUROPEAN":"WHITE",
"WHITE - BRAZILIAN":"WHITE",
"AMERICAN INDIAN/ALASKA NATIVE":"NATIVES",
"HISPANIC/LATINO - SALVADORAN":"HISPANIC",
"HISPANIC/LATINO - MEXICAN":"HISPANIC",
"HISPANIC/LATINO - COLUMBIAN":"HISPANIC",
"MULTIPLE RACE/ETHNICITY":"MULTI-ETHINIC",
"HISPANIC/LATINO - HONDURAN":"HISPANIC",
"ASIAN - KOREAN":"ASIAN",
"SOUTH AMERICAN":"HISPANIC",
"HISPANIC/LATINO - CUBAN":"HISPANIC",
"HISPANIC/LATINO - CENTRAL AMERICAN":"HISPANIC",
"NATIVE HAWAIIAN OR OTHER PACIFIC ISLANDER":"NATIVES"}
admissions_df['race'] = admissions_df['race'].map(race_map)
merged_df = pd.merge(admissions_df, patients_df, on='subject_id', how='left')
merged_df = merged_df.dropna(subset=['anchor_age', 'gender', 'race', 'hospital_expire_flag'])
merged_df['admittime'] = pd.to_datetime(merged_df['admittime'])
merged_df['dischtime'] = pd.to_datetime(merged_df['dischtime'])
merged_df['deathtime'] = pd.to_datetime(merged_df['deathtime'], errors='coerce')
# Create derived features
merged_df['los'] = (merged_df['dischtime'] - merged_df['admittime']).dt.days
merged_df['admission_year'] = merged_df['admittime'].dt.year
merged_df['admission_month'] = merged_df['admittime'].dt.month_name()
merged_df['admittime_date'] = merged_df['admittime'].dt.date
return merged_df
merged_df = load_data()
# Sidebar Filters Function
def add_sidebar_filters(df):
# Admission Types
admission_types = sorted(df['admission_type'].unique())
selected_admission_types = st.sidebar.multiselect(
"Select Admission Type(s):",
options=admission_types,
default=admission_types
)
# Insurance Types
insurance_types = sorted(df['insurance'].unique())
selected_insurance_types = st.sidebar.multiselect(
"Select Insurance Type(s):",
options=insurance_types,
default=insurance_types
)
# Gender
genders = sorted(df['gender'].unique())
selected_genders = st.sidebar.multiselect(
"Select Gender(s):",
options=genders,
default=genders
)
# Race
races = sorted(df['race'].unique())
selected_races = st.sidebar.multiselect(
"Select Race(s):",
options=races,
default=races
)
# Year Range
min_year = int(df['admission_year'].min())
max_year = int(df['admission_year'].max())
selected_years = st.sidebar.slider(
"Select Admission Year Range:",
min_value=min_year,
max_value=max_year,
value=(min_year, max_year)
)
# Apply Filters
filtered_df = df[
(df['admission_type'].isin(selected_admission_types)) &
(df['insurance'].isin(selected_insurance_types)) &
(df['gender'].isin(selected_genders)) &
(df['race'].isin(selected_races)) &
(df['admission_year'] >= selected_years[0]) &
(df['admission_year'] <= selected_years[1])
]
return filtered_df
filtered_df = add_sidebar_filters(merged_df)
# Display Summary Statistics for Q1
st.header("Summary Statistics")
# Create four columns for metrics
col1, col2, col3, col4 = st.columns(4)
with col1:
total_admissions = filtered_df.shape[0]
st.metric("Total Admissions", f"{total_admissions:,}")
with col2:
average_age = filtered_df['anchor_age'].mean()
st.metric("Average Age", f"{average_age:.2f} years")
with col3:
gender_counts = filtered_df['gender'].value_counts()
male_count = gender_counts.get('M', 0)
female_count = gender_counts.get('F', 0)
st.metric("Male Patients", f"{male_count:,}")
st.metric("Female Patients", f"{female_count:,}")
with col4:
mortality_rate = filtered_df['hospital_expire_flag'].mean() * 100 # Percentage
st.metric("Mortality Rate", f"{mortality_rate:.2f}%")
st.markdown("---")
# Create Tabs for Q1 and Q2
tabs = st.tabs(["General Overview", "Potential Biases"])
# Q1: General Overview
with tabs[0]:
st.subheader("General Feature Distribution and Outcome Metrics")
# Define the number of columns per row
num_cols = 2
# Define all Q1 plots in a list with titles and plot-generating functions
q1_plots_2_col = [
{
"title": "Gender Distribution",
"plot": lambda: create_gender_pie_chart(filtered_df)
},
{
"title": "Race Distribution",
"plot": lambda: create_race_pie_chart(filtered_df)
},
{
"title": "Insurance Type Distribution",
"plot": lambda: create_insurance_pie_chart(filtered_df)
},
{
"title": "Mortality Rate of ICU Patients",
"plot": lambda: create_mortality_pie_chart(filtered_df)
}
]
# Arrange Q1 plots in a grid layout
for i in range(0, len(q1_plots_2_col), num_cols):
cols = st.columns(num_cols)
for j in range(num_cols):
if i + j < len(q1_plots_2_col):
with cols[j]:
st.subheader(q1_plots_2_col[i + j]["title"])
q1_plots_2_col[i + j]["plot"]()
num_cols = 1
q1_plots_1_col = [
{
"title": "Admission Type Count",
"plot": lambda: create_admission_type_bar_chart(filtered_df)
},
{
"title": "Admissions Over Time",
"plot": lambda: create_time_series_heatmap(filtered_df)
}
]
# Arrange Q1 plots in a grid layout
for i in range(0, len(q1_plots_1_col), num_cols):
cols = st.columns(num_cols)
for j in range(num_cols):
if i + j < len(q1_plots_1_col):
with cols[j]:
st.subheader(q1_plots_1_col[i + j]["title"])
q1_plots_1_col[i + j]["plot"]()
# Q2: Potential Biases
with tabs[1]:
st.subheader("Analyzing Potential Biases Across Demographics")
# Define the number of columns per row
num_cols = 2
# Define all Q2 plots in a list with titles and plot-generating functions
q2_plots = [
{
"title": "Age Distribution of ICU Patients",
"plot": lambda: create_age_distribution_by_gender(filtered_df)
},
{
"title": "Boxen Plot of Age Distribution by Admission Type",
"plot": lambda: create_age_distribution_by_admission_type(filtered_df)
},
{
"title": "Mortality Rate by Race",
"plot": lambda: create_mortality_by_race(filtered_df)
},
{
"title": "Mortality Rate by Gender",
"plot": lambda: create_mortality_by_gender(filtered_df)
},
{
"title": "Mortality Rate by Age Group",
"plot": lambda: create_mortality_by_age_group(filtered_df)
},
{
"title": "Age Distribution by Race and Mortality",
"plot": lambda: create_violin_age_race_mortality(filtered_df)
},
{
"title": "Heatmap: Race & Gender vs. Mortality",
"plot": lambda: create_heatmap_race_gender_mortality(filtered_df)
},
{
"title": "Treemap of Race and Mortality",
"plot": lambda: create_treemap_race_mortality(filtered_df)
}
]
# Arrange Q2 plots in a grid layout
for i in range(0, len(q2_plots), num_cols):
cols = st.columns(num_cols)
for j in range(num_cols):
if i + j < len(q2_plots):
with cols[j]:
st.subheader(q2_plots[i + j]["title"])
q2_plots[i + j]["plot"]()
# Footer
st.markdown("""
---
**Data Source:** MIMIC-IV Dataset
**Project:** Fairness in EHR Data
**Developed with:** Streamlit, Python
**Q3 Visuals:** https://idyllic-cucurucho-672fc1.netlify.app/
""")