import streamlit as st import pandas as pd import matplotlib.pyplot as plt import seaborn as sns import plotly.express as px import plotly.graph_objects as go # Plot Function Definitions def create_gender_pie_chart(df): """Creates a bar chart for Gender Distribution.""" gender_counts = df['gender'].value_counts().reset_index() gender_counts.columns = ['Gender', 'Count'] fig_gender = px.pie( gender_counts, names='Gender', values='Count', hover_data=['Count'], hole=0.3 ) st.plotly_chart(fig_gender, use_container_width=True) def create_race_pie_chart(df): race_counts = df['race'].value_counts().reset_index() race_counts.columns = ['Race Type', 'Count'] fig_race = px.pie( race_counts, names='Race Type', values='Count', hover_data=['Count'], hole=0.3 ) st.plotly_chart(fig_race, use_container_width=True) def create_insurance_pie_chart(df): insurance_counts = df['insurance'].value_counts().reset_index() insurance_counts.columns = ['Insurance Type', 'Count'] fig_insurance = px.pie( insurance_counts, names='Insurance Type', values='Count', hover_data=['Count'], hole=0.3 ) st.plotly_chart(fig_insurance, use_container_width=True) def create_mortality_pie_chart(df): #plt.figure(figsize=(6,3), facecolor='white') total_admissions = df.shape[0] labels = ['Survived', 'Died'] sizes = [total_admissions - df['hospital_expire_flag'].sum(), df['hospital_expire_flag'].sum()] colors = ['#66b3ff', '#ff6666'] explode = (0.1, 0) plt.pie(sizes, explode=explode, labels=labels, colors=colors, autopct='%1.1f%%', startangle=140, textprops={'fontsize': 14}) plt.axis('equal') plt.tight_layout() st.pyplot(plt.gcf()) def create_admission_type_bar_chart(df): admission_counts = df['admission_type'].value_counts().reset_index() admission_counts.columns = ['Admission Type', 'Count'] fig_admission = px.bar( admission_counts, y='Admission Type', x='Count', color='Admission Type', labels={'Count': 'Number of Admissions', 'Admission Type': 'Admission Type'}, hover_data=['Count'] ) st.plotly_chart(fig_admission, use_container_width=True) def create_time_series_heatmap(df): """Creates an admissions over time heatmap.""" month_order = ['January', 'February', 'March', 'April', 'May', 'June', 'July', 'August', 'September', 'October', 'November', 'December'] df['admission_month'] = pd.Categorical(df['admission_month'], categories=month_order, ordered=True) heatmap_df = df.groupby(['admission_year', 'admission_month']).size().reset_index(name='counts') fig = px.density_heatmap( heatmap_df, x='admission_month', y='admission_year', z='counts', histfunc='sum', labels={'counts': 'Number of Admissions', 'admission_month': 'Admission Month', 'admission_year': 'Admission Year'}, color_continuous_scale='rdbu' ) fig.update_xaxes(categoryorder='array', categoryarray=month_order) fig.update_layout(yaxis=dict(autorange='reversed')) fig.update_traces(colorbar=dict(title='Admissions')) st.plotly_chart(fig, use_container_width=True) # def create_stacked_bar_admission_race(df): # """Creates a stacked bar chart for Admission Types by Race.""" # admission_race = df.groupby(['race', 'admission_type']).size().unstack(fill_value=0) # admission_race_percent = admission_race.div(admission_race.sum(axis=1), axis=0) * 100 # admission_race_percent.plot(kind='bar', stacked=True, figsize=(8, 6), colormap='tab20') # plt.xlabel("Race") # plt.ylabel("Percentage of Admission Types") # plt.legend(title='Admission Type', bbox_to_anchor=(1.05, 1), loc='upper left') # plt.tight_layout() # st.pyplot(plt.gcf()) # def create_los_by_race(df): # """Creates a box plot for Length of Stay by Race.""" # fig, ax = plt.subplots(figsize=(6, 4)) # sns.boxplot(data=df, x='race', y='los', palette='Pastel1', ax=ax) # ax.set_xlabel("Race") # ax.set_ylabel("Length of Stay (Days)") # ax.set_xticklabels(ax.get_xticklabels(), rotation=45) # plt.tight_layout() # st.pyplot(fig) # def create_correlation_heatmap(df): # """Creates a correlation heatmap for numerical features.""" # numerical_features = df[['anchor_age', 'los']] # corr_matrix = numerical_features.corr() # fig, ax = plt.subplots(figsize=(3.5, 3)) # sns.heatmap(corr_matrix, annot=True, cmap='coolwarm', fmt=".2f", ax=ax) # plt.tight_layout() # st.pyplot(fig) def create_age_distribution_by_gender(df): plt.figure(figsize=(12, 8)) sns.histplot(data=df, x='anchor_age', bins=30, kde=True, palette='bright', hue='gender') plt.xlabel('Age', fontsize=16) plt.ylabel('Number of Admissions', fontsize=16) plt.xticks(fontsize=16) plt.yticks(fontsize=16) plt.tight_layout() st.pyplot(plt.gcf()) def create_age_distribution_by_admission_type(df): plt.figure(figsize=(12, 8)) sns.boxenplot(data=df, x='admission_type', y='anchor_age', palette='Set3') plt.xlabel('Admission Type', fontsize=16) plt.ylabel('Age', fontsize=16) plt.xticks(fontsize=16, rotation=45) plt.yticks(fontsize=16) plt.tight_layout() st.pyplot(plt.gcf()) def create_mortality_by_race(df): """Creates a bar chart for Mortality Rate by Race.""" mortality_race = df.groupby('race')['hospital_expire_flag'].mean().reset_index() mortality_race['mortality_rate'] = mortality_race['hospital_expire_flag'] * 100 fig, ax = plt.subplots(figsize=(6, 4)) sns.barplot(data=mortality_race, x='race', y='mortality_rate', palette='Set2', ax=ax) ax.set_xlabel("Race") ax.set_ylabel("Mortality Rate (%)") ax.set_xticklabels(ax.get_xticklabels(), rotation=45) plt.tight_layout() st.pyplot(fig) def create_mortality_by_gender(df): """Creates a bar chart for Mortality Rate by Gender.""" mortality_gender = df.groupby('gender')['hospital_expire_flag'].mean().reset_index() mortality_gender['mortality_rate'] = mortality_gender['hospital_expire_flag'] * 100 fig, ax = plt.subplots(figsize=(6, 4)) sns.barplot(data=mortality_gender, x='gender', y='mortality_rate', palette='Set3', ax=ax) ax.set_xlabel("Gender") ax.set_ylabel("Mortality Rate (%)") plt.tight_layout() st.pyplot(fig) def create_mortality_by_age_group(df): """Creates a bar chart for Mortality Rate by Age Group.""" bins = [0, 30, 50, 70, 90, 120] labels = ['0-30', '31-50', '51-70', '71-90', '91-120'] df['age_group'] = pd.cut(df['anchor_age'], bins=bins, labels=labels, right=False) mortality_age = df.groupby('age_group')['hospital_expire_flag'].mean().reset_index() mortality_age['mortality_rate'] = mortality_age['hospital_expire_flag'] * 100 fig, ax = plt.subplots(figsize=(6, 4)) sns.barplot(data=mortality_age, x='age_group', y='mortality_rate', palette='coolwarm', ax=ax) ax.set_xlabel("Age Group") ax.set_ylabel("Mortality Rate (%)") plt.tight_layout() st.pyplot(fig) def create_violin_age_race_mortality(df): """Creates a violin plot for Age Distribution by Race and Mortality.""" fig, ax = plt.subplots(figsize=(8, 6)) sns.violinplot( data=df, x='race', y='anchor_age', hue='hospital_expire_flag', split=True, palette='Set2', ax=ax ) ax.set_xlabel("Race") ax.set_ylabel("Age") ax.legend(title='Mortality', loc='upper right') plt.tight_layout() st.pyplot(fig) def create_heatmap_race_gender_mortality(df): """Creates a heatmap for Mortality Rate by Race and Gender.""" pivot_table = df.pivot_table( index='race', columns='gender', values='hospital_expire_flag', aggfunc='mean' ) * 100 fig, ax = plt.subplots(figsize=(8, 6)) sns.heatmap(pivot_table, annot=True, fmt=".1f", cmap='YlOrRd', ax=ax) ax.set_xlabel("Gender") ax.set_ylabel("Race") plt.tight_layout() st.pyplot(fig) def create_treemap_race_mortality(df): """Creates a treemap for Race and Mortality.""" treemap_df = df.groupby(['race', 'hospital_expire_flag']).size().reset_index(name='counts') treemap_df['Mortality'] = treemap_df['hospital_expire_flag'].map({0: 'Survived', 1: 'Died'}) fig = px.treemap( treemap_df, path=['race', 'Mortality'], values='counts', color='Mortality', color_discrete_map={'Survived':'#66b3ff','Died':'#ff6666'} ) fig.update_layout(margin = dict(t=30, l=0, r=0, b=0)) st.plotly_chart(fig, use_container_width=True) # Streamlit Application # Set Streamlit page configuration st.set_page_config( page_title="MIMIC-IV ICU Patient Data Dashboard", layout="wide", initial_sidebar_state="expanded", ) st.title("MIMIC-IV ICU Patient Data Dashboard") st.markdown(''' Explore the general feature distribution and demographics related bias in ICU patients from the MIMIC-IV dataset. Utilize the sidebar filters to customize the data view''' ) # Sidebar Filters st.sidebar.header("Filter Data") @st.cache_data def load_data(): admissions_df = pd.read_feather('data/admissions.feather') patients_df = pd.read_feather('data/patients.feather') # diagnoses_icd_df = pd.read_csv('data/diagnoses_icd.csv') # pharmacy_df = pd.read_csv('data/pharmacy.csv') # prescriptions_df = pd.read_csv('data/prescriptions.csv') # d_hcpcs_df = pd.read_csv('data/d_hcpcs.csv') # poe_detail_df = pd.read_csv('data/poe_detail.csv') # provider_df = pd.read_csv('data/provider.csv') race_map = {"WHITE":"WHITE", "BLACK/AFRICAN AMERICAN":"BLACK", "OTHER":"OTHER", "UNKNOWN":"UNKNOWN", "HISPANIC/LATINO - PUERTO RICAN":"HISPANIC", "WHITE - OTHER EUROPEAN":"WHITE", "HISPANIC OR LATINO":"HISPANIC", "ASIAN":"ASIAN", "ASIAN - CHINESE":"ASIAN", "WHITE - RUSSIAN":"WHITE", "BLACK/CAPE VERDEAN":"BLACK", "HISPANIC/LATINO - DOMINICAN":"HISPANIC", "BLACK/CARIBBEAN ISLAND":"BLACK", "BLACK/AFRICAN":"BLACK", "PATIENT DECLINED TO ANSWER":"UNKNOWN", "UNABLE TO OBTAIN":"UNKNOWN", "PORTUGUESE":"WHITE", "ASIAN - SOUTH EAST ASIAN":"ASIAN", "HISPANIC/LATINO - GUATEMALAN":"HISPANIC", "ASIAN - ASIAN INDIAN":"ASIAN", "WHITE - EASTERN EUROPEAN":"WHITE", "WHITE - BRAZILIAN":"WHITE", "AMERICAN INDIAN/ALASKA NATIVE":"NATIVES", "HISPANIC/LATINO - SALVADORAN":"HISPANIC", "HISPANIC/LATINO - MEXICAN":"HISPANIC", "HISPANIC/LATINO - COLUMBIAN":"HISPANIC", "MULTIPLE RACE/ETHNICITY":"MULTI-ETHINIC", "HISPANIC/LATINO - HONDURAN":"HISPANIC", "ASIAN - KOREAN":"ASIAN", "SOUTH AMERICAN":"HISPANIC", "HISPANIC/LATINO - CUBAN":"HISPANIC", "HISPANIC/LATINO - CENTRAL AMERICAN":"HISPANIC", "NATIVE HAWAIIAN OR OTHER PACIFIC ISLANDER":"NATIVES"} admissions_df['race'] = admissions_df['race'].map(race_map) merged_df = pd.merge(admissions_df, patients_df, on='subject_id', how='left') merged_df = merged_df.dropna(subset=['anchor_age', 'gender', 'race', 'hospital_expire_flag']) merged_df['admittime'] = pd.to_datetime(merged_df['admittime']) merged_df['dischtime'] = pd.to_datetime(merged_df['dischtime']) merged_df['deathtime'] = pd.to_datetime(merged_df['deathtime'], errors='coerce') # Create derived features merged_df['los'] = (merged_df['dischtime'] - merged_df['admittime']).dt.days merged_df['admission_year'] = merged_df['admittime'].dt.year merged_df['admission_month'] = merged_df['admittime'].dt.month_name() merged_df['admittime_date'] = merged_df['admittime'].dt.date return merged_df merged_df = load_data() # Sidebar Filters Function def add_sidebar_filters(df): # Admission Types admission_types = sorted(df['admission_type'].unique()) selected_admission_types = st.sidebar.multiselect( "Select Admission Type(s):", options=admission_types, default=admission_types ) # Insurance Types insurance_types = sorted(df['insurance'].unique()) selected_insurance_types = st.sidebar.multiselect( "Select Insurance Type(s):", options=insurance_types, default=insurance_types ) # Gender genders = sorted(df['gender'].unique()) selected_genders = st.sidebar.multiselect( "Select Gender(s):", options=genders, default=genders ) # Race races = sorted(df['race'].unique()) selected_races = st.sidebar.multiselect( "Select Race(s):", options=races, default=races ) # Year Range min_year = int(df['admission_year'].min()) max_year = int(df['admission_year'].max()) selected_years = st.sidebar.slider( "Select Admission Year Range:", min_value=min_year, max_value=max_year, value=(min_year, max_year) ) # Apply Filters filtered_df = df[ (df['admission_type'].isin(selected_admission_types)) & (df['insurance'].isin(selected_insurance_types)) & (df['gender'].isin(selected_genders)) & (df['race'].isin(selected_races)) & (df['admission_year'] >= selected_years[0]) & (df['admission_year'] <= selected_years[1]) ] return filtered_df filtered_df = add_sidebar_filters(merged_df) # Display Summary Statistics for Q1 st.header("Summary Statistics") # Create four columns for metrics col1, col2, col3, col4 = st.columns(4) with col1: total_admissions = filtered_df.shape[0] st.metric("Total Admissions", f"{total_admissions:,}") with col2: average_age = filtered_df['anchor_age'].mean() st.metric("Average Age", f"{average_age:.2f} years") with col3: gender_counts = filtered_df['gender'].value_counts() male_count = gender_counts.get('M', 0) female_count = gender_counts.get('F', 0) st.metric("Male Patients", f"{male_count:,}") st.metric("Female Patients", f"{female_count:,}") with col4: mortality_rate = filtered_df['hospital_expire_flag'].mean() * 100 # Percentage st.metric("Mortality Rate", f"{mortality_rate:.2f}%") st.markdown("---") # Create Tabs for Q1 and Q2 tabs = st.tabs(["General Overview", "Potential Biases"]) # Q1: General Overview with tabs[0]: st.subheader("General Feature Distribution and Outcome Metrics") # Define the number of columns per row num_cols = 2 # Define all Q1 plots in a list with titles and plot-generating functions q1_plots_2_col = [ { "title": "Gender Distribution", "plot": lambda: create_gender_pie_chart(filtered_df) }, { "title": "Race Distribution", "plot": lambda: create_race_pie_chart(filtered_df) }, { "title": "Insurance Type Distribution", "plot": lambda: create_insurance_pie_chart(filtered_df) }, { "title": "Mortality Rate of ICU Patients", "plot": lambda: create_mortality_pie_chart(filtered_df) } ] # Arrange Q1 plots in a grid layout for i in range(0, len(q1_plots_2_col), num_cols): cols = st.columns(num_cols) for j in range(num_cols): if i + j < len(q1_plots_2_col): with cols[j]: st.subheader(q1_plots_2_col[i + j]["title"]) q1_plots_2_col[i + j]["plot"]() num_cols = 1 q1_plots_1_col = [ { "title": "Admission Type Count", "plot": lambda: create_admission_type_bar_chart(filtered_df) }, { "title": "Admissions Over Time", "plot": lambda: create_time_series_heatmap(filtered_df) } ] # Arrange Q1 plots in a grid layout for i in range(0, len(q1_plots_1_col), num_cols): cols = st.columns(num_cols) for j in range(num_cols): if i + j < len(q1_plots_1_col): with cols[j]: st.subheader(q1_plots_1_col[i + j]["title"]) q1_plots_1_col[i + j]["plot"]() # Q2: Potential Biases with tabs[1]: st.subheader("Analyzing Potential Biases Across Demographics") # Define the number of columns per row num_cols = 2 # Define all Q2 plots in a list with titles and plot-generating functions q2_plots = [ { "title": "Age Distribution of ICU Patients", "plot": lambda: create_age_distribution_by_gender(filtered_df) }, { "title": "Boxen Plot of Age Distribution by Admission Type", "plot": lambda: create_age_distribution_by_admission_type(filtered_df) }, { "title": "Mortality Rate by Race", "plot": lambda: create_mortality_by_race(filtered_df) }, { "title": "Mortality Rate by Gender", "plot": lambda: create_mortality_by_gender(filtered_df) }, { "title": "Mortality Rate by Age Group", "plot": lambda: create_mortality_by_age_group(filtered_df) }, { "title": "Age Distribution by Race and Mortality", "plot": lambda: create_violin_age_race_mortality(filtered_df) }, { "title": "Heatmap: Race & Gender vs. Mortality", "plot": lambda: create_heatmap_race_gender_mortality(filtered_df) }, { "title": "Treemap of Race and Mortality", "plot": lambda: create_treemap_race_mortality(filtered_df) } ] # Arrange Q2 plots in a grid layout for i in range(0, len(q2_plots), num_cols): cols = st.columns(num_cols) for j in range(num_cols): if i + j < len(q2_plots): with cols[j]: st.subheader(q2_plots[i + j]["title"]) q2_plots[i + j]["plot"]() # Footer st.markdown(""" --- **Data Source:** MIMIC-IV Dataset **Project:** Fairness in EHR Data **Developed with:** Streamlit, Python **Q3 Visuals:** https://idyllic-cucurucho-672fc1.netlify.app/ """)