#This is the basic app with generations 5 & 6 and middle management
#This is also pre-filtered for chain scale 4: upper upscale

import pickle
import pandas as pd
import shap
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
import gradio.themes as gt #Here I imported a pre-set gradio theme. The one that I used is called "soft"


# This loads the model in which the data is trained on
loaded_model = pickle.load(open("h47_xgb.pkl", 'rb'))

# Setup SHAP
explainer = shap.Explainer(loaded_model)

#Creating employee profiles for Mr.Bean and Tom Hanks
employee_selection = {
    "Mr.Bean - At Risk/Medium Risk 🟥⚠️": [4.2, 3.6, 3.4, 3.5, 3.7, 3.9],
    "Tom Hanks - Happy 🟢": [5.0, 4.8, 4.7, 4.8, 4.9, 4.9],
    "Default": [3, 3, 3, 3, 3, 3]
}

# Create the main function
def main_func(Engage2, Voice, Merit, Workload, WellBeing, SupportiveGM,
              ChainScale=4, ManagementLevel=2):
    new_row = pd.DataFrame.from_dict({
        'ManagementLevel': ManagementLevel, 'Engage2': Engage2, 'Voice': Voice,
        'Merit': Merit,
        'Workload': Workload, 'ChainScale': ChainScale, 'WellBeing': WellBeing,
        'SupportiveGM': SupportiveGM
    }, orient='index').transpose()

    prob = loaded_model.predict_proba(new_row)

    shap_values = explainer(new_row)
    selected_features = ["Engage2", "Voice", "Merit", "Workload", "WellBeing", "SupportiveGM"]
    shap_values_filtered = shap_values[:, selected_features]

    # Generate SHAP bar plot
    plt.figure(figsize=(6, 4))
    shap.plots.bar(shap_values[0], max_display=6, show=False)

    plt.tight_layout()
    local_plot = plt.gcf()
    plt.close()

    return {"Leave ❌": float(prob[0][0]), "Stay ✅ ": 1 - float(prob[0][0])}, local_plot

# Updates the sliders so that they show the values of each of the profiles
def update_sliders(profile):
    if profile in employee_selection:
        return employee_selection [profile]
    return [3,3,3,3,3,3]

# Create the UI
title = "Hilton Employee Turnover Predictor & Interpreter 🏨"
description1 = """
This app predicts whether a Millennial/Generation Z employee in upper upscale hotels will stay or leave based on the top six important variables impacting intent to stay.
"""
description2 = """
Choose from the pre-set employee categories, or adjust the values to identify who will stay ✅ or leave ❌!
"""

with gr.Blocks(theme = gt.Soft()) as demo:
    gr.Markdown(f"## {title}")
    gr.Markdown(description1)
    gr.Markdown("""---""")
    gr.Markdown(description2)
    gr.Markdown("""---""")

    with gr.Row():
        with gr.Column():
            profile_dropdown = gr.Dropdown(choices=["Default", "Tom Hanks - Happy 🟢", "Mr.Bean - At Risk/Medium Risk 🟥⚠️"], label="Select a profile to learn more!")

            Engage2 = gr.Slider(label="Engagement (Engage2)", minimum=1, maximum=5, value=4, step=0.1)
            Voice = gr.Slider(label="Voice", minimum=1, maximum=5, value=4, step=0.1)
            Merit = gr.Slider(label="Merit", minimum=1, maximum=5, value=4, step=0.1)
            Workload = gr.Slider(label="Workload", minimum=1, maximum=5, value=4, step=0.1)
            WellBeing = gr.Slider(label="Well-being", minimum=1, maximum=5, value=4, step=0.1)
            SupportiveGM = gr.Slider(label="Supportive GM", minimum=1, maximum=5, value=4, step=0.1)

            submit_btn = gr.Button("Predict 🔍")

        with gr.Column(visible=True, scale=1, min_width=600) as output_col:
            label = gr.Label(label="Predicted Label")
            local_plot = gr.Plot(label="SHAP Analysis")

            submit_btn.click(
                main_func,
                [Engage2, Voice, Merit, Workload, WellBeing, SupportiveGM],
                [label, local_plot]
            )

    profile_dropdown.change(update_sliders, inputs=[profile_dropdown], outputs=[Engage2, Voice, Merit, Workload, WellBeing, SupportiveGM])

demo.launch()