BiswajitPadhi99's picture
Add app.py
7c3768c
raw
history blame
19.7 kB
# streamlit_app.py
import streamlit as st
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
# ---------------------------
# Function Definitions
# ---------------------------
def create_histogram(df):
"""Creates a histogram for Age Distribution."""
fig, ax = plt.subplots(figsize=(5, 3.5))
sns.histplot(df['anchor_age'], bins=30, kde=True, color='skyblue', ax=ax)
ax.set_xlabel("Age")
ax.set_ylabel("Number of Admissions")
ax.set_title("Age Distribution")
plt.tight_layout()
st.pyplot(fig)
def create_gender_bar_chart(df):
"""Creates a bar chart for Gender Distribution."""
fig, ax = plt.subplots(figsize=(5, 3.5))
sns.countplot(data=df, x='gender', palette='pastel', ax=ax)
ax.set_title("Gender Distribution")
ax.set_xlabel("Gender")
ax.set_ylabel("Number of Admissions")
plt.tight_layout()
st.pyplot(fig)
def create_stacked_bar_admission_race(df):
"""Creates a stacked bar chart for Admission Types by Race."""
admission_race = df.groupby(['race', 'admission_type']).size().unstack(fill_value=0)
admission_race_percent = admission_race.div(admission_race.sum(axis=1), axis=0) * 100
admission_race_percent.plot(kind='bar', stacked=True, figsize=(8, 6), colormap='tab20')
plt.title("Admission Types by Race (%)")
plt.xlabel("Race")
plt.ylabel("Percentage of Admission Types")
plt.legend(title='Admission Type', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
st.pyplot(plt.gcf())
def create_los_by_race(df):
"""Creates a box plot for Length of Stay by Race."""
fig, ax = plt.subplots(figsize=(6, 4))
sns.boxplot(data=df, x='race', y='los', palette='Pastel1', ax=ax)
ax.set_title("Length of Stay by Race")
ax.set_xlabel("Race")
ax.set_ylabel("Length of Stay (Days)")
ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
plt.tight_layout()
st.pyplot(fig)
def create_correlation_heatmap(df):
"""Creates a correlation heatmap for numerical features."""
numerical_features = df[['anchor_age', 'los']]
corr_matrix = numerical_features.corr()
fig, ax = plt.subplots(figsize=(3.5, 3))
sns.heatmap(corr_matrix, annot=True, cmap='coolwarm', fmt=".2f", ax=ax)
ax.set_title("Correlation Heatmap")
plt.tight_layout()
st.pyplot(fig)
def create_time_series_heatmap(df):
"""Creates an admissions over time heatmap."""
month_order = ['January', 'February', 'March', 'April', 'May', 'June',
'July', 'August', 'September', 'October', 'November', 'December']
df['admission_month'] = pd.Categorical(df['admission_month'], categories=month_order, ordered=True)
heatmap_df = df.groupby(['admission_year', 'admission_month']).size().reset_index(name='counts')
fig = px.density_heatmap(
heatmap_df,
x='admission_month',
y='admission_year',
z='counts',
histfunc='sum',
title='Admissions Over Time',
labels={'counts': 'Number of Admissions'},
color_continuous_scale='Blues'
)
fig.update_xaxes(categoryorder='array', categoryarray=month_order)
fig.update_layout(yaxis=dict(autorange='reversed'))
fig.update_traces(colorbar=dict(title='Admissions'))
st.plotly_chart(fig, use_container_width=True)
def create_mortality_by_race(df):
"""Creates a bar chart for Mortality Rate by Race."""
mortality_race = df.groupby('race')['hospital_expire_flag'].mean().reset_index()
mortality_race['mortality_rate'] = mortality_race['hospital_expire_flag'] * 100
fig, ax = plt.subplots(figsize=(6, 4))
sns.barplot(data=mortality_race, x='race', y='mortality_rate', palette='Set2', ax=ax)
ax.set_title("Mortality Rate by Race")
ax.set_xlabel("Race")
ax.set_ylabel("Mortality Rate (%)")
ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
plt.tight_layout()
st.pyplot(fig)
def create_mortality_by_gender(df):
"""Creates a bar chart for Mortality Rate by Gender."""
mortality_gender = df.groupby('gender')['hospital_expire_flag'].mean().reset_index()
mortality_gender['mortality_rate'] = mortality_gender['hospital_expire_flag'] * 100
fig, ax = plt.subplots(figsize=(6, 4))
sns.barplot(data=mortality_gender, x='gender', y='mortality_rate', palette='Set3', ax=ax)
ax.set_title("Mortality Rate by Gender")
ax.set_xlabel("Gender")
ax.set_ylabel("Mortality Rate (%)")
plt.tight_layout()
st.pyplot(fig)
def create_mortality_by_age_group(df):
"""Creates a bar chart for Mortality Rate by Age Group."""
# Define age bins and labels
bins = [0, 30, 50, 70, 90, 120]
labels = ['0-30', '31-50', '51-70', '71-90', '91-120']
df['age_group'] = pd.cut(df['anchor_age'], bins=bins, labels=labels, right=False)
mortality_age = df.groupby('age_group')['hospital_expire_flag'].mean().reset_index()
mortality_age['mortality_rate'] = mortality_age['hospital_expire_flag'] * 100
fig, ax = plt.subplots(figsize=(6, 4))
sns.barplot(data=mortality_age, x='age_group', y='mortality_rate', palette='coolwarm', ax=ax)
ax.set_title("Mortality Rate by Age Group")
ax.set_xlabel("Age Group")
ax.set_ylabel("Mortality Rate (%)")
plt.tight_layout()
st.pyplot(fig)
def create_violin_age_race_mortality(df):
"""Creates a violin plot for Age Distribution by Race and Mortality."""
fig, ax = plt.subplots(figsize=(8, 6))
sns.violinplot(
data=df,
x='race',
y='anchor_age',
hue='hospital_expire_flag',
split=True,
palette='Set2',
ax=ax
)
ax.set_title("Age Distribution by Race and Mortality")
ax.set_xlabel("Race")
ax.set_ylabel("Age")
ax.legend(title='Mortality', loc='upper right')
plt.tight_layout()
st.pyplot(fig)
def create_heatmap_race_gender_mortality(df):
"""Creates a heatmap for Mortality Rate by Race and Gender."""
pivot_table = df.pivot_table(
index='race',
columns='gender',
values='hospital_expire_flag',
aggfunc='mean'
) * 100 # Convert to percentage
fig, ax = plt.subplots(figsize=(8, 6))
sns.heatmap(pivot_table, annot=True, fmt=".1f", cmap='YlOrRd', ax=ax)
ax.set_title("Mortality Rate by Race and Gender (%)")
ax.set_xlabel("Gender")
ax.set_ylabel("Race")
plt.tight_layout()
st.pyplot(fig)
def create_parallel_coordinates(df):
"""Creates a parallel coordinates plot for Demographics and Outcomes."""
# Select relevant numerical features
parallel_df = df[['anchor_age', 'los', 'hospital_expire_flag']].copy()
# Encode categorical variables numerically
parallel_df['race_code'] = df['race'].astype('category').cat.codes
parallel_df['gender_code'] = df['gender'].astype('category').cat.codes
# Create the parallel coordinates plot
fig = px.parallel_coordinates(
parallel_df,
color='hospital_expire_flag',
labels={
'anchor_age': 'Age',
'los': 'Length of Stay',
'hospital_expire_flag': 'Mortality',
'race_code': 'Race',
'gender_code': 'Gender'
},
color_continuous_scale=px.colors.diverging.Tealrose,
color_continuous_midpoint=0.5
)
fig.update_layout(title='Parallel Coordinates Plot of Demographics and Outcomes')
st.plotly_chart(fig, use_container_width=True)
def create_treemap_race_mortality(df):
"""Creates a treemap for Race and Mortality."""
treemap_df = df.groupby(['race', 'hospital_expire_flag']).size().reset_index(name='counts')
treemap_df['Mortality'] = treemap_df['hospital_expire_flag'].map({0: 'Survived', 1: 'Died'})
fig = px.treemap(
treemap_df,
path=['race', 'Mortality'],
values='counts',
color='Mortality',
color_discrete_map={'Survived':'#66b3ff','Died':'#ff6666'},
title='Treemap of Race and Mortality'
)
fig.update_layout(margin = dict(t=30, l=0, r=0, b=0))
st.plotly_chart(fig, use_container_width=True)
def create_sankey_race_mortality(df):
"""Creates a Sankey diagram for Race to Mortality Outcomes."""
sankey_df = df.groupby(['race', 'hospital_expire_flag']).size().reset_index(name='counts')
# Map 'hospital_expire_flag' to 'Mortality' status
sankey_df['Mortality'] = sankey_df['hospital_expire_flag'].map({0: 'Survived', 1: 'Died'})
# Create source and target labels
source = sankey_df['race'].tolist()
target = sankey_df['Mortality'].tolist()
values = sankey_df['counts'].tolist()
# Create a list of unique labels ensuring no duplicates
unique_races = sankey_df['race'].unique().tolist()
unique_mortality = sankey_df['Mortality'].unique().tolist()
labels = unique_races + unique_mortality
# Create a mapping from label to index for efficient lookup
label_to_index = {label: idx for idx, label in enumerate(labels)}
# Map source and target labels to their corresponding indices
source_indices = [label_to_index[s] for s in source]
target_indices = [label_to_index[t] for t in target]
# Optionally, define colors for different node types
# For example, races could have one color and mortality outcomes another
race_color = "#FFA07A" # Light Salmon
mortality_color = "#20B2AA" # Light Sea Green
node_colors = [race_color] * len(unique_races) + [mortality_color] * len(unique_mortality)
# Create the Sankey diagram
fig = go.Figure(data=[go.Sankey(
node=dict(
pad=15,
thickness=20,
line=dict(color="black", width=0.5),
label=labels,
color=node_colors
),
link=dict(
source=source_indices,
target=target_indices,
value=values
)
)])
# Add title to the layout
fig.update_layout(
title_text="Sankey Diagram of Race and Mortality Outcomes",
font_size=10
)
st.plotly_chart(fig, use_container_width=True)
# ---------------------------
# Streamlit Application
# ---------------------------
# Set Streamlit page configuration
st.set_page_config(
page_title="MIMIC-IV ICU Patient Data Dashboard",
layout="wide",
initial_sidebar_state="expanded",
)
# Title and Description
st.title("MIMIC-IV ICU Patient Data Dashboard")
st.markdown("""
Explore the general feature distribution and outcome metrics of ICU patients from the MIMIC-IV dataset. Utilize the sidebar filters to customize the data view and interact with various visualizations to uncover patterns and insights.
""")
# Sidebar Filters
st.sidebar.header("Filter Data")
@st.cache_data
def load_data():
# Load the dataframes (update the paths as necessary)
admissions_df = pd.read_csv('data/admissions.csv')
patients_df = pd.read_csv('data/patients.csv')
# diagnoses_icd_df = pd.read_csv('data/diagnoses_icd.csv')
# pharmacy_df = pd.read_csv('data/pharmacy.csv')
# prescriptions_df = pd.read_csv('data/prescriptions.csv')
# d_hcpcs_df = pd.read_csv('data/d_hcpcs.csv')
# poe_detail_df = pd.read_csv('data/poe_detail.csv')
# provider_df = pd.read_csv('data/provider.csv')
race_map = {"WHITE":"WHITE",
"BLACK/AFRICAN AMERICAN":"BLACK",
"OTHER":"OTHER",
"UNKNOWN":"UNKNOWN",
"HISPANIC/LATINO - PUERTO RICAN":"HISPANIC",
"WHITE - OTHER EUROPEAN":"WHITE",
"HISPANIC OR LATINO":"HISPANIC",
"ASIAN":"ASIAN",
"ASIAN - CHINESE":"ASIAN",
"WHITE - RUSSIAN":"WHITE",
"BLACK/CAPE VERDEAN":"BLACK",
"HISPANIC/LATINO - DOMINICAN":"HISPANIC",
"BLACK/CARIBBEAN ISLAND":"BLACK",
"BLACK/AFRICAN":"BLACK",
"PATIENT DECLINED TO ANSWER":"UNKNOWN",
"UNABLE TO OBTAIN":"UNKNOWN",
"PORTUGUESE":"WHITE",
"ASIAN - SOUTH EAST ASIAN":"ASIAN",
"HISPANIC/LATINO - GUATEMALAN":"HISPANIC",
"ASIAN - ASIAN INDIAN":"ASIAN",
"WHITE - EASTERN EUROPEAN":"WHITE",
"WHITE - BRAZILIAN":"WHITE",
"AMERICAN INDIAN/ALASKA NATIVE":"NATIVES",
"HISPANIC/LATINO - SALVADORAN":"HISPANIC",
"HISPANIC/LATINO - MEXICAN":"HISPANIC",
"HISPANIC/LATINO - COLUMBIAN":"HISPANIC",
"MULTIPLE RACE/ETHNICITY":"MULTI-ETHINIC",
"HISPANIC/LATINO - HONDURAN":"HISPANIC",
"ASIAN - KOREAN":"ASIAN",
"SOUTH AMERICAN":"HISPANIC",
"HISPANIC/LATINO - CUBAN":"HISPANIC",
"HISPANIC/LATINO - CENTRAL AMERICAN":"HISPANIC",
"NATIVE HAWAIIAN OR OTHER PACIFIC ISLANDER":"NATIVES"}
admissions_df['race'] = admissions_df['race'].map(race_map)
# Merge admissions and patients data on 'subject_id'
merged_df = pd.merge(admissions_df, patients_df, on='subject_id', how='left')
# Handle missing values by dropping rows with critical missing data
merged_df = merged_df.dropna(subset=['anchor_age', 'gender', 'race', 'hospital_expire_flag'])
# Convert datetime columns
merged_df['admittime'] = pd.to_datetime(merged_df['admittime'])
merged_df['dischtime'] = pd.to_datetime(merged_df['dischtime'])
merged_df['deathtime'] = pd.to_datetime(merged_df['deathtime'], errors='coerce') # Some may not have deathtime
# Create derived features
merged_df['los'] = (merged_df['dischtime'] - merged_df['admittime']).dt.days
merged_df['admission_year'] = merged_df['admittime'].dt.year
merged_df['admission_month'] = merged_df['admittime'].dt.month_name()
merged_df['admittime_date'] = merged_df['admittime'].dt.date
return merged_df
merged_df = load_data()
# Sidebar Filters Function
def add_sidebar_filters(df):
# Admission Types
admission_types = sorted(df['admission_type'].unique())
selected_admission_types = st.sidebar.multiselect(
"Select Admission Type(s):",
options=admission_types,
default=admission_types
)
# Insurance Types
insurance_types = sorted(df['insurance'].unique())
selected_insurance_types = st.sidebar.multiselect(
"Select Insurance Type(s):",
options=insurance_types,
default=insurance_types
)
# Gender
genders = sorted(df['gender'].unique())
selected_genders = st.sidebar.multiselect(
"Select Gender(s):",
options=genders,
default=genders
)
# Race
races = sorted(df['race'].unique())
selected_races = st.sidebar.multiselect(
"Select Race(s):",
options=races,
default=races
)
# Year Range
min_year = int(df['admission_year'].min())
max_year = int(df['admission_year'].max())
selected_years = st.sidebar.slider(
"Select Admission Year Range:",
min_value=min_year,
max_value=max_year,
value=(min_year, max_year)
)
# Apply Filters
filtered_df = df[
(df['admission_type'].isin(selected_admission_types)) &
(df['insurance'].isin(selected_insurance_types)) &
(df['gender'].isin(selected_genders)) &
(df['race'].isin(selected_races)) &
(df['admission_year'] >= selected_years[0]) &
(df['admission_year'] <= selected_years[1])
]
return filtered_df
filtered_df = add_sidebar_filters(merged_df)
# Display Summary Statistics for Q1
st.header("Summary Statistics")
# Create four columns for metrics
col1, col2, col3, col4 = st.columns(4)
with col1:
total_admissions = filtered_df.shape[0]
st.metric("Total Admissions", f"{total_admissions:,}")
with col2:
average_age = filtered_df['anchor_age'].mean()
st.metric("Average Age", f"{average_age:.2f} years")
with col3:
gender_counts = filtered_df['gender'].value_counts()
male_count = gender_counts.get('M', 0)
female_count = gender_counts.get('F', 0)
st.metric("Male Patients", f"{male_count:,}")
st.metric("Female Patients", f"{female_count:,}")
with col4:
mortality_rate = filtered_df['hospital_expire_flag'].mean() * 100 # Percentage
st.metric("Mortality Rate", f"{mortality_rate:.2f}%")
st.markdown("---")
# Create Tabs for Q1 and Q2
tabs = st.tabs(["General Overview", "Potential Biases"])
# ---------------------------
# Q1: General Overview
# ---------------------------
with tabs[0]:
st.subheader("General Feature Distribution and Outcome Metrics")
# Define the number of columns per row
num_cols = 2
# Define all Q1 plots in a list with titles and plot-generating functions
q1_plots = [
{
"title": "Age Distribution of ICU Patients",
"plot": lambda: create_histogram(filtered_df)
},
{
"title": "Gender Distribution of ICU Patients",
"plot": lambda: create_gender_bar_chart(filtered_df)
},
{
"title": "Admission Types by Race",
"plot": lambda: create_stacked_bar_admission_race(filtered_df)
},
{
"title": "Length of Stay by Race",
"plot": lambda: create_los_by_race(filtered_df)
},
{
"title": "Correlation Heatmap of Age and LOS",
"plot": lambda: create_correlation_heatmap(filtered_df)
},
{
"title": "Admissions Over Time",
"plot": lambda: create_time_series_heatmap(filtered_df)
}
]
# Arrange Q1 plots in a grid layout
for i in range(0, len(q1_plots), num_cols):
cols = st.columns(num_cols)
for j in range(num_cols):
if i + j < len(q1_plots):
with cols[j]:
st.subheader(q1_plots[i + j]["title"])
q1_plots[i + j]["plot"]()
# ---------------------------
# Q2: Potential Biases
# ---------------------------
with tabs[1]:
st.subheader("Analyzing Potential Biases Across Demographics")
# Define the number of columns per row
num_cols = 2
# Define all Q2 plots in a list with titles and plot-generating functions
q2_plots = [
{
"title": "Mortality Rate by Race",
"plot": lambda: create_mortality_by_race(filtered_df)
},
{
"title": "Mortality Rate by Gender",
"plot": lambda: create_mortality_by_gender(filtered_df)
},
{
"title": "Mortality Rate by Age Group",
"plot": lambda: create_mortality_by_age_group(filtered_df)
},
{
"title": "Age Distribution by Race and Mortality",
"plot": lambda: create_violin_age_race_mortality(filtered_df)
},
{
"title": "Heatmap: Race & Gender vs. Mortality",
"plot": lambda: create_heatmap_race_gender_mortality(filtered_df)
},
{
"title": "Parallel Coordinates Plot of Demographics and Outcomes",
"plot": lambda: create_parallel_coordinates(filtered_df)
},
{
"title": "Treemap of Race and Mortality",
"plot": lambda: create_treemap_race_mortality(filtered_df)
},
{
"title": "Sankey Diagram: Race to Mortality Outcomes",
"plot": lambda: create_sankey_race_mortality(filtered_df)
}
]
# Arrange Q2 plots in a grid layout
for i in range(0, len(q2_plots), num_cols):
cols = st.columns(num_cols)
for j in range(num_cols):
if i + j < len(q2_plots):
with cols[j]:
st.subheader(q2_plots[i + j]["title"])
q2_plots[i + j]["plot"]()
# Footer
st.markdown("""
---
**Data Source:** MIMIC-IV Dataset
**Project:** Investigating Biases in ICU Patient Data
**Developed with:** Streamlit, Python
""")