import pandas as pd
import streamlit as st
from streamlit_calendar import calendar
from streamlit_timeline import st_timeline
import numpy as np
from sklearn.cluster import KMeans
import altair as alt

st.set_page_config(layout="wide")

# load data
df = pd.read_csv("data/colon.csv")
df = df.dropna(subset=["DESCRIPTION", "START"])
df["BIRTHDATE"] = pd.to_datetime(df["BIRTHDATE"], errors="coerce").dt.date
df["START"] = pd.to_datetime(df["START"], errors="coerce").dt.date
df["STOP"] = pd.to_datetime(df["STOP"], errors="coerce").dt.date
df = df.sort_values(by=["ID", "START", "DESCRIPTION"], ascending=[True, False, True])
unique_ids = df["ID"].unique()

# inject custom CSS to set the width of the sidebar
st.markdown(
    """
    <style>
        section[data-testid="stSidebar"] {
            width: 600px !important; # Set the width to your desired value
        }
    </style>
    """,
    unsafe_allow_html=True,
)

# pick id
st.sidebar.title("Patient information")
st.session_state.id = st.sidebar.selectbox(
    "Select patient ID:",
    unique_ids,
    index=0,
    placeholder="Type or select ID...",
)

# sidebar
name = (
    df.loc[df["ID"] == st.session_state.id, "NAME"].iloc[0]
    if not df.loc[df["ID"] == st.session_state.id, "NAME"].empty
    else None
)

gender = (
    df.loc[df["ID"] == st.session_state.id, "GENDER"].iloc[0]
    if not df.loc[df["ID"] == st.session_state.id, "GENDER"].empty
    else None
)
st.sidebar.write("Name:", name, f" ({gender})")

bd = (
    df.loc[df["ID"] == st.session_state.id, "BIRTHDATE"].iloc[0]
    if not df.loc[df["ID"] == st.session_state.id, "BIRTHDATE"].empty
    else None
)
st.sidebar.write("Birthdate:", bd)

race = (
    df.loc[df["ID"] == st.session_state.id, "RACE"].iloc[0]
    if not df.loc[df["ID"] == st.session_state.id, "RACE"].empty
    else None
)

etn = (
    df.loc[df["ID"] == st.session_state.id, "ETHNICITY"].iloc[0]
    if not df.loc[df["ID"] == st.session_state.id, "ETHNICITY"].empty
    else None
)
st.sidebar.write("Race/Ethnicity:", race, " /", etn)

mar = (
    df.loc[df["ID"] == st.session_state.id, "MARITAL"].iloc[0]
    if not df.loc[df["ID"] == st.session_state.id, "MARITAL"].empty
    else None
)
st.sidebar.write("Marital status:", mar)

adr = (
    df.loc[df["ID"] == st.session_state.id, "ADDRESS"].iloc[0]
    if not df.loc[df["ID"] == st.session_state.id, "ADDRESS"].empty
    else None
)
st.sidebar.write("Address:", adr)

# filter data
st.session_state.filtered_df = df[df["ID"] == st.session_state.id]
try:
    st.session_state.initial_date = (
        st.session_state.filtered_df["START"].max().strftime("%Y-%m-%d")
    )
except:
    pass

if not st.session_state.filtered_df.empty:
    st.session_state.events = [
        {
            "title": row["DESCRIPTION"],
            "color": "#3a6ad6",
            "start": row["START"].strftime("%Y-%m-%d"),
            "end": row["START"].strftime("%Y-%m-%d"),
        }
        for _, row in st.session_state.filtered_df.iterrows()
    ]

# calendar
mode = st.sidebar.selectbox(
    "Calendar Mode:",
    (
        "daygrid",
        "list",
    ),
)

calendar_options = {
    "editable": "true",
    "navLinks": "true",
    "selectable": "true",
}

if mode == "daygrid":
    calendar_options = {
        **calendar_options,
        "headerToolbar": {
            "left": "today prev,next",
            "center": "title",
            "right": "dayGridDay,dayGridWeek,dayGridMonth",
        },
        "initialDate": st.session_state.initial_date,
        "initialView": "dayGridMonth",
    }

elif mode == "list":
    calendar_options = {
        **calendar_options,
        "initialDate": st.session_state.initial_date,
        "initialView": "listMonth",
    }

with st.sidebar:
    st.session_state.state = calendar(
        events=st.session_state.get("events", st.session_state.events),
        options=calendar_options,
        custom_css="""
        .fc-event-past {
            opacity: 0.8;
        }
        .fc-event-time {
            font-style: italic;
        }
        .fc-event-title {
            font-weight: 700;
        }
        .fc-toolbar-title {
            font-size: 2rem;
        }
        .fc-button {
            background-color: #4CAF50;
            color: #ffffff;
            border: none;
            cursor: pointer;
        }
        .fc-button:hover {
            background-color: #45a049;
        }
        .fc-button-primary {
            background-color: #3a6ad6;
        }
        .fc-button-primary:hover {
            background-color: #3a6ad6;
        }
        .fc-button-secondary {
            background-color: #e7e7e7;
            color: black;
        }
        .fc-button-secondary:hover {
            background-color: #ddd;
        }
        """,
        key=mode,
    )


if st.session_state.state.get("eventsSet") is not None:
    st.session_state["events"] = st.session_state.state["eventsSet"]

# clustering
col1, col2 = st.columns([1, 2])

with col1:
    # clustering
    st.markdown(
        """ 
    <style>
    div.stSlider > div[data-baseweb="slider"] > div > div > div[role="slider"] {
        background-color: #3a6ad6; 
        box-shadow: rgba(58, 106, 214, 0.2) 0px 0px 0px 0.2rem;
    }
    div.stSlider > div[data-baseweb="slider"] > div > div > div > div {
        color: #3a6ad6;
    }
    div.stSlider > div[data-baseweb = "slider"] > div > div {{
    background: linear-gradient(to right, #3a6ad6 0%, 
                                #3a6ad6 {NB}%, 
                                #3a6ad6 {NB}%, 
                                #3a6ad6 100%); }}
    </style>
    """,
        unsafe_allow_html=True,
    )
    st.session_state.n_clusters = st.slider("Select number of clusters", 2, 5, 5)
    if st.button("Show cluster"):
        df = df[["ID", "START", "STOP", "DESCRIPTION"]]
        st.session_state.df = df.groupby("ID").agg({"DESCRIPTION": list}).reset_index()
        st.session_state.df["DESCRIPTION"] = st.session_state.df["DESCRIPTION"].apply(
            np.array
        )
        training_data = st.session_state.df["DESCRIPTION"].tolist()

        transformed_data = []
        for array in training_data:
            unique_values = np.unique(array)
            value_to_int = {value: idx + 1 for idx, value in enumerate(unique_values)}
            transformed_array = np.vectorize(value_to_int.get)(array)
            transformed_data.append(transformed_array)

        max_length = max(len(array) for array in transformed_data)
        padded_data = [
            np.pad(array, (0, max_length - len(array)), "constant")
            for array in transformed_data
        ]
        padded_data_array = np.vstack(padded_data)

        st.session_state.kmeans = KMeans(
            n_clusters=st.session_state.n_clusters, random_state=42
        )
        st.session_state.cluster_labels = st.session_state.kmeans.fit_predict(
            padded_data_array
        )
        st.session_state.idx = st.session_state.df.index[
            st.session_state.df["ID"] == st.session_state.id
        ]
        st.write(
            "This patient belonngs to cluster:",
            st.session_state.cluster_labels[st.session_state.idx][0],
        )

    try:
        st.session_state.label_counts = (
            pd.Series(st.session_state.cluster_labels).value_counts().sort_index()
        )
        st.session_state.cluster_df = pd.DataFrame(
            {
                "Cluster Label": st.session_state.label_counts.index,
                "Count": st.session_state.label_counts.values,
            }
        )
        # st.bar_chart(st.session_state.cluster_df)
        chart = (
            alt.Chart(st.session_state.cluster_df)
            .mark_bar()
            .encode(x="Cluster Label:O", y="Count:Q")
            .properties(title="Number of people per cluster")
            .configure_legend(disable=True)  # Disable the legend
        )
        st.altair_chart(chart, use_container_width=True)
    except:
        pass

with col2:
    try:
        st.session_state.selected_cluster = st.selectbox(
            "Select cluster to view descriptions",
            np.unique(st.session_state.cluster_labels),
            0,
        )
        st.session_state.indices = np.where(
            st.session_state.cluster_labels == st.session_state.selected_cluster
        )[0]
        st.session_state.seq_df = st.session_state.df.loc[st.session_state.indices]
        st.write(f"Descriptions for cluster {st.session_state.selected_cluster}:")
        st.dataframe(
            st.session_state.seq_df["DESCRIPTION"],
            use_container_width=True,
        )
    except:
        pass

# timeline
if not st.session_state.filtered_df.empty:
    st.session_state.item = [
        {
            "id": id,
            "content": row["DESCRIPTION"],
            "start": row["START"].strftime("%Y-%m-%d"),
        }
        for id, (_, row) in enumerate(st.session_state.filtered_df.iterrows())
    ]

st.session_state.timeline = st_timeline(
    st.session_state.item, groups=[], options={}, height="300px", width="100%"
)