# streamlit_app.py import streamlit as st import pandas as pd import matplotlib.pyplot as plt import seaborn as sns import plotly.express as px import plotly.graph_objects as go # --------------------------- # Function Definitions # --------------------------- def create_histogram(df): """Creates a histogram for Age Distribution.""" fig, ax = plt.subplots(figsize=(5, 3.5)) sns.histplot(df['anchor_age'], bins=30, kde=True, color='skyblue', ax=ax) ax.set_xlabel("Age") ax.set_ylabel("Number of Admissions") ax.set_title("Age Distribution") plt.tight_layout() st.pyplot(fig) def create_gender_bar_chart(df): """Creates a bar chart for Gender Distribution.""" fig, ax = plt.subplots(figsize=(5, 3.5)) sns.countplot(data=df, x='gender', palette='pastel', ax=ax) ax.set_title("Gender Distribution") ax.set_xlabel("Gender") ax.set_ylabel("Number of Admissions") plt.tight_layout() st.pyplot(fig) def create_stacked_bar_admission_race(df): """Creates a stacked bar chart for Admission Types by Race.""" admission_race = df.groupby(['race', 'admission_type']).size().unstack(fill_value=0) admission_race_percent = admission_race.div(admission_race.sum(axis=1), axis=0) * 100 admission_race_percent.plot(kind='bar', stacked=True, figsize=(8, 6), colormap='tab20') plt.title("Admission Types by Race (%)") plt.xlabel("Race") plt.ylabel("Percentage of Admission Types") plt.legend(title='Admission Type', bbox_to_anchor=(1.05, 1), loc='upper left') plt.tight_layout() st.pyplot(plt.gcf()) def create_los_by_race(df): """Creates a box plot for Length of Stay by Race.""" fig, ax = plt.subplots(figsize=(6, 4)) sns.boxplot(data=df, x='race', y='los', palette='Pastel1', ax=ax) ax.set_title("Length of Stay by Race") ax.set_xlabel("Race") ax.set_ylabel("Length of Stay (Days)") ax.set_xticklabels(ax.get_xticklabels(), rotation=45) plt.tight_layout() st.pyplot(fig) def create_correlation_heatmap(df): """Creates a correlation heatmap for numerical features.""" numerical_features = df[['anchor_age', 'los']] corr_matrix = numerical_features.corr() fig, ax = plt.subplots(figsize=(3.5, 3)) sns.heatmap(corr_matrix, annot=True, cmap='coolwarm', fmt=".2f", ax=ax) ax.set_title("Correlation Heatmap") plt.tight_layout() st.pyplot(fig) def create_time_series_heatmap(df): """Creates an admissions over time heatmap.""" month_order = ['January', 'February', 'March', 'April', 'May', 'June', 'July', 'August', 'September', 'October', 'November', 'December'] df['admission_month'] = pd.Categorical(df['admission_month'], categories=month_order, ordered=True) heatmap_df = df.groupby(['admission_year', 'admission_month']).size().reset_index(name='counts') fig = px.density_heatmap( heatmap_df, x='admission_month', y='admission_year', z='counts', histfunc='sum', title='Admissions Over Time', labels={'counts': 'Number of Admissions'}, color_continuous_scale='Blues' ) fig.update_xaxes(categoryorder='array', categoryarray=month_order) fig.update_layout(yaxis=dict(autorange='reversed')) fig.update_traces(colorbar=dict(title='Admissions')) st.plotly_chart(fig, use_container_width=True) def create_mortality_by_race(df): """Creates a bar chart for Mortality Rate by Race.""" mortality_race = df.groupby('race')['hospital_expire_flag'].mean().reset_index() mortality_race['mortality_rate'] = mortality_race['hospital_expire_flag'] * 100 fig, ax = plt.subplots(figsize=(6, 4)) sns.barplot(data=mortality_race, x='race', y='mortality_rate', palette='Set2', ax=ax) ax.set_title("Mortality Rate by Race") ax.set_xlabel("Race") ax.set_ylabel("Mortality Rate (%)") ax.set_xticklabels(ax.get_xticklabels(), rotation=45) plt.tight_layout() st.pyplot(fig) def create_mortality_by_gender(df): """Creates a bar chart for Mortality Rate by Gender.""" mortality_gender = df.groupby('gender')['hospital_expire_flag'].mean().reset_index() mortality_gender['mortality_rate'] = mortality_gender['hospital_expire_flag'] * 100 fig, ax = plt.subplots(figsize=(6, 4)) sns.barplot(data=mortality_gender, x='gender', y='mortality_rate', palette='Set3', ax=ax) ax.set_title("Mortality Rate by Gender") ax.set_xlabel("Gender") ax.set_ylabel("Mortality Rate (%)") plt.tight_layout() st.pyplot(fig) def create_mortality_by_age_group(df): """Creates a bar chart for Mortality Rate by Age Group.""" # Define age bins and labels bins = [0, 30, 50, 70, 90, 120] labels = ['0-30', '31-50', '51-70', '71-90', '91-120'] df['age_group'] = pd.cut(df['anchor_age'], bins=bins, labels=labels, right=False) mortality_age = df.groupby('age_group')['hospital_expire_flag'].mean().reset_index() mortality_age['mortality_rate'] = mortality_age['hospital_expire_flag'] * 100 fig, ax = plt.subplots(figsize=(6, 4)) sns.barplot(data=mortality_age, x='age_group', y='mortality_rate', palette='coolwarm', ax=ax) ax.set_title("Mortality Rate by Age Group") ax.set_xlabel("Age Group") ax.set_ylabel("Mortality Rate (%)") plt.tight_layout() st.pyplot(fig) def create_violin_age_race_mortality(df): """Creates a violin plot for Age Distribution by Race and Mortality.""" fig, ax = plt.subplots(figsize=(8, 6)) sns.violinplot( data=df, x='race', y='anchor_age', hue='hospital_expire_flag', split=True, palette='Set2', ax=ax ) ax.set_title("Age Distribution by Race and Mortality") ax.set_xlabel("Race") ax.set_ylabel("Age") ax.legend(title='Mortality', loc='upper right') plt.tight_layout() st.pyplot(fig) def create_heatmap_race_gender_mortality(df): """Creates a heatmap for Mortality Rate by Race and Gender.""" pivot_table = df.pivot_table( index='race', columns='gender', values='hospital_expire_flag', aggfunc='mean' ) * 100 # Convert to percentage fig, ax = plt.subplots(figsize=(8, 6)) sns.heatmap(pivot_table, annot=True, fmt=".1f", cmap='YlOrRd', ax=ax) ax.set_title("Mortality Rate by Race and Gender (%)") ax.set_xlabel("Gender") ax.set_ylabel("Race") plt.tight_layout() st.pyplot(fig) def create_parallel_coordinates(df): """Creates a parallel coordinates plot for Demographics and Outcomes.""" # Select relevant numerical features parallel_df = df[['anchor_age', 'los', 'hospital_expire_flag']].copy() # Encode categorical variables numerically parallel_df['race_code'] = df['race'].astype('category').cat.codes parallel_df['gender_code'] = df['gender'].astype('category').cat.codes # Create the parallel coordinates plot fig = px.parallel_coordinates( parallel_df, color='hospital_expire_flag', labels={ 'anchor_age': 'Age', 'los': 'Length of Stay', 'hospital_expire_flag': 'Mortality', 'race_code': 'Race', 'gender_code': 'Gender' }, color_continuous_scale=px.colors.diverging.Tealrose, color_continuous_midpoint=0.5 ) fig.update_layout(title='Parallel Coordinates Plot of Demographics and Outcomes') st.plotly_chart(fig, use_container_width=True) def create_treemap_race_mortality(df): """Creates a treemap for Race and Mortality.""" treemap_df = df.groupby(['race', 'hospital_expire_flag']).size().reset_index(name='counts') treemap_df['Mortality'] = treemap_df['hospital_expire_flag'].map({0: 'Survived', 1: 'Died'}) fig = px.treemap( treemap_df, path=['race', 'Mortality'], values='counts', color='Mortality', color_discrete_map={'Survived':'#66b3ff','Died':'#ff6666'}, title='Treemap of Race and Mortality' ) fig.update_layout(margin = dict(t=30, l=0, r=0, b=0)) st.plotly_chart(fig, use_container_width=True) def create_sankey_race_mortality(df): """Creates a Sankey diagram for Race to Mortality Outcomes.""" sankey_df = df.groupby(['race', 'hospital_expire_flag']).size().reset_index(name='counts') # Map 'hospital_expire_flag' to 'Mortality' status sankey_df['Mortality'] = sankey_df['hospital_expire_flag'].map({0: 'Survived', 1: 'Died'}) # Create source and target labels source = sankey_df['race'].tolist() target = sankey_df['Mortality'].tolist() values = sankey_df['counts'].tolist() # Create a list of unique labels ensuring no duplicates unique_races = sankey_df['race'].unique().tolist() unique_mortality = sankey_df['Mortality'].unique().tolist() labels = unique_races + unique_mortality # Create a mapping from label to index for efficient lookup label_to_index = {label: idx for idx, label in enumerate(labels)} # Map source and target labels to their corresponding indices source_indices = [label_to_index[s] for s in source] target_indices = [label_to_index[t] for t in target] # Optionally, define colors for different node types # For example, races could have one color and mortality outcomes another race_color = "#FFA07A" # Light Salmon mortality_color = "#20B2AA" # Light Sea Green node_colors = [race_color] * len(unique_races) + [mortality_color] * len(unique_mortality) # Create the Sankey diagram fig = go.Figure(data=[go.Sankey( node=dict( pad=15, thickness=20, line=dict(color="black", width=0.5), label=labels, color=node_colors ), link=dict( source=source_indices, target=target_indices, value=values ) )]) # Add title to the layout fig.update_layout( title_text="Sankey Diagram of Race and Mortality Outcomes", font_size=10 ) st.plotly_chart(fig, use_container_width=True) # --------------------------- # Streamlit Application # --------------------------- # Set Streamlit page configuration st.set_page_config( page_title="MIMIC-IV ICU Patient Data Dashboard", layout="wide", initial_sidebar_state="expanded", ) # Title and Description st.title("MIMIC-IV ICU Patient Data Dashboard") st.markdown(""" Explore the general feature distribution and outcome metrics of ICU patients from the MIMIC-IV dataset. Utilize the sidebar filters to customize the data view and interact with various visualizations to uncover patterns and insights. """) # Sidebar Filters st.sidebar.header("Filter Data") @st.cache_data def load_data(): admissions_df = pd.read_feather('data/admissions.feather') patients_df = pd.read_feather('data/patients.feather') # diagnoses_icd_df = pd.read_csv('data/diagnoses_icd.csv') pharmacy_df = pd.read_feather('data/pharmacy.feather') # prescriptions_df = pd.read_csv('data/prescriptions.csv') # d_hcpcs_df = pd.read_csv('data/d_hcpcs.csv') # poe_detail_df = pd.read_csv('data/poe_detail.csv') # provider_df = pd.read_csv('data/provider.csv') race_map = {"WHITE":"WHITE", "BLACK/AFRICAN AMERICAN":"BLACK", "OTHER":"OTHER", "UNKNOWN":"UNKNOWN", "HISPANIC/LATINO - PUERTO RICAN":"HISPANIC", "WHITE - OTHER EUROPEAN":"WHITE", "HISPANIC OR LATINO":"HISPANIC", "ASIAN":"ASIAN", "ASIAN - CHINESE":"ASIAN", "WHITE - RUSSIAN":"WHITE", "BLACK/CAPE VERDEAN":"BLACK", "HISPANIC/LATINO - DOMINICAN":"HISPANIC", "BLACK/CARIBBEAN ISLAND":"BLACK", "BLACK/AFRICAN":"BLACK", "PATIENT DECLINED TO ANSWER":"UNKNOWN", "UNABLE TO OBTAIN":"UNKNOWN", "PORTUGUESE":"WHITE", "ASIAN - SOUTH EAST ASIAN":"ASIAN", "HISPANIC/LATINO - GUATEMALAN":"HISPANIC", "ASIAN - ASIAN INDIAN":"ASIAN", "WHITE - EASTERN EUROPEAN":"WHITE", "WHITE - BRAZILIAN":"WHITE", "AMERICAN INDIAN/ALASKA NATIVE":"NATIVES", "HISPANIC/LATINO - SALVADORAN":"HISPANIC", "HISPANIC/LATINO - MEXICAN":"HISPANIC", "HISPANIC/LATINO - COLUMBIAN":"HISPANIC", "MULTIPLE RACE/ETHNICITY":"MULTI-ETHINIC", "HISPANIC/LATINO - HONDURAN":"HISPANIC", "ASIAN - KOREAN":"ASIAN", "SOUTH AMERICAN":"HISPANIC", "HISPANIC/LATINO - CUBAN":"HISPANIC", "HISPANIC/LATINO - CENTRAL AMERICAN":"HISPANIC", "NATIVE HAWAIIAN OR OTHER PACIFIC ISLANDER":"NATIVES"} admissions_df['race'] = admissions_df['race'].map(race_map) # Merge admissions and patients data on 'subject_id' merged_df = pd.merge(admissions_df, patients_df, on='subject_id', how='left') # Handle missing values by dropping rows with critical missing data merged_df = merged_df.dropna(subset=['anchor_age', 'gender', 'race', 'hospital_expire_flag']) # Convert datetime columns merged_df['admittime'] = pd.to_datetime(merged_df['admittime']) merged_df['dischtime'] = pd.to_datetime(merged_df['dischtime']) merged_df['deathtime'] = pd.to_datetime(merged_df['deathtime'], errors='coerce') # Create derived features merged_df['los'] = (merged_df['dischtime'] - merged_df['admittime']).dt.days merged_df['admission_year'] = merged_df['admittime'].dt.year merged_df['admission_month'] = merged_df['admittime'].dt.month_name() merged_df['admittime_date'] = merged_df['admittime'].dt.date return merged_df merged_df = load_data() # Sidebar Filters Function def add_sidebar_filters(df): # Admission Types admission_types = sorted(df['admission_type'].unique()) selected_admission_types = st.sidebar.multiselect( "Select Admission Type(s):", options=admission_types, default=admission_types ) # Insurance Types insurance_types = sorted(df['insurance'].unique()) selected_insurance_types = st.sidebar.multiselect( "Select Insurance Type(s):", options=insurance_types, default=insurance_types ) # Gender genders = sorted(df['gender'].unique()) selected_genders = st.sidebar.multiselect( "Select Gender(s):", options=genders, default=genders ) # Race races = sorted(df['race'].unique()) selected_races = st.sidebar.multiselect( "Select Race(s):", options=races, default=races ) # Year Range min_year = int(df['admission_year'].min()) max_year = int(df['admission_year'].max()) selected_years = st.sidebar.slider( "Select Admission Year Range:", min_value=min_year, max_value=max_year, value=(min_year, max_year) ) # Apply Filters filtered_df = df[ (df['admission_type'].isin(selected_admission_types)) & (df['insurance'].isin(selected_insurance_types)) & (df['gender'].isin(selected_genders)) & (df['race'].isin(selected_races)) & (df['admission_year'] >= selected_years[0]) & (df['admission_year'] <= selected_years[1]) ] return filtered_df filtered_df = add_sidebar_filters(merged_df) # Display Summary Statistics for Q1 st.header("Summary Statistics") col1, col2, col3, col4 = st.columns(4) with col1: total_admissions = filtered_df.shape[0] st.metric("Total Admissions", f"{total_admissions:,}") with col2: average_age = filtered_df['anchor_age'].mean() st.metric("Average Age", f"{average_age:.2f} years") with col3: gender_counts = filtered_df['gender'].value_counts() male_count = gender_counts.get('M', 0) female_count = gender_counts.get('F', 0) st.metric("Male Patients", f"{male_count:,}") st.metric("Female Patients", f"{female_count:,}") with col4: mortality_rate = filtered_df['hospital_expire_flag'].mean() * 100 # Percentage st.metric("Mortality Rate", f"{mortality_rate:.2f}%") st.markdown("---") # Create Tabs for Q1 and Q2 tabs = st.tabs(["General Overview", "Potential Biases"]) # Q1: General Overview with tabs[0]: st.subheader("General Feature Distribution and Outcome Metrics") num_cols = 2 q1_plots = [ { "title": "Age Distribution of ICU Patients", "plot": lambda: create_histogram(filtered_df) }, { "title": "Gender Distribution of ICU Patients", "plot": lambda: create_gender_bar_chart(filtered_df) }, { "title": "Admission Types by Race", "plot": lambda: create_stacked_bar_admission_race(filtered_df) }, { "title": "Length of Stay by Race", "plot": lambda: create_los_by_race(filtered_df) }, { "title": "Correlation Heatmap of Age and LOS", "plot": lambda: create_correlation_heatmap(filtered_df) }, { "title": "Admissions Over Time", "plot": lambda: create_time_series_heatmap(filtered_df) } ] for i in range(0, len(q1_plots), num_cols): cols = st.columns(num_cols) for j in range(num_cols): if i + j < len(q1_plots): with cols[j]: st.subheader(q1_plots[i + j]["title"]) q1_plots[i + j]["plot"]() # Q2: Potential Biases from patient side with tabs[1]: st.subheader("Analyzing Potential Biases Across Demographics") num_cols = 2 q2_plots = [ { "title": "Mortality Rate by Race", "plot": lambda: create_mortality_by_race(filtered_df) }, { "title": "Mortality Rate by Gender", "plot": lambda: create_mortality_by_gender(filtered_df) }, { "title": "Mortality Rate by Age Group", "plot": lambda: create_mortality_by_age_group(filtered_df) }, { "title": "Age Distribution by Race and Mortality", "plot": lambda: create_violin_age_race_mortality(filtered_df) }, { "title": "Heatmap: Race & Gender vs. Mortality", "plot": lambda: create_heatmap_race_gender_mortality(filtered_df) }, { "title": "Parallel Coordinates Plot of Demographics and Outcomes", "plot": lambda: create_parallel_coordinates(filtered_df) }, { "title": "Treemap of Race and Mortality", "plot": lambda: create_treemap_race_mortality(filtered_df) }, { "title": "Sankey Diagram: Race to Mortality Outcomes", "plot": lambda: create_sankey_race_mortality(filtered_df) } ] for i in range(0, len(q2_plots), num_cols): cols = st.columns(num_cols) for j in range(num_cols): if i + j < len(q2_plots): with cols[j]: st.subheader(q2_plots[i + j]["title"]) q2_plots[i + j]["plot"]() # Footer st.markdown(""" --- **Data Source:** MIMIC-IV Dataset **Project:** Investigating Biases in ICU Patient Data **Developed with:** Streamlit, Python """)