import streamlit as st
from utils import validate_sequence, predict, plot_prediction_graphs
from model import models
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

def main():
    st.set_page_config(layout="wide")  # Keep the wide layout for overall flexibility
    st.title("AA Property Inference Demo", anchor=None)

    # Instructional text below title
    st.markdown("""
        <style>
        .reportview-container {
            font-family: 'Courier New', monospace;
        }
        </style>
        <p style='font-size:16px;'><span style='font-size:24px;'>&larr;</span> Don't know where to start? Open tab to input a sequence.</p>
        """, unsafe_allow_html=True)

    # Input section in the sidebar
    sequence = st.sidebar.text_input("Enter your amino acid sequence:")
    uploaded_file = st.sidebar.file_uploader("Or upload a CSV file with amino acid sequences", type="csv")
    analyze_pressed = st.sidebar.button("Analyze Sequence")
    show_graphs = st.sidebar.checkbox("Show Prediction Graphs")

    sequences = [sequence] if sequence else []
    if uploaded_file:
        df = pd.read_csv(uploaded_file)
        sequences.extend(df['sequence'].tolist())
        names = df['name'].tolist()  # Store names from the CSV file
    else:
        names = [f"Seq {i+1}" for i in range(len(sequences))]  # Default names if no file

    results = []
    all_data = {}
    if analyze_pressed:
        for name, seq in zip(names, sequences):
            if validate_sequence(seq):
                model_results = {}
                graph_data = {}
                for model_name, model in models.items():
                    prediction, confidence = predict(model, seq)
                    model_results[f"{model_name}_prediction"] = prediction
                    model_results[f"{model_name}_confidence"] = round(confidence, 3)
                    graph_data[model_name] = (prediction, confidence)
                results.append({"Name": name, "Sequence": seq, **model_results})
                all_data[name] = graph_data  # Use name as key
            else:
                st.sidebar.error(f"Invalid sequence for {name}: {seq}")

        if results:
            results_df = pd.DataFrame(results)
            st.write("### Results")
            st.dataframe(results_df.style.format(precision=3), width=None, height=None)
            
            if show_graphs and all_data:
                st.write("## Graphs")
                plot_prediction_graphs(all_data,models.keys())


if __name__ == "__main__":
    main()