import re
import streamlit as st
import requests
import pandas as pd
from io import StringIO
import plotly.graph_objs as go
from huggingface_hub import HfApi
from huggingface_hub.utils import RepositoryNotFoundError, RevisionNotFoundError

#from yall import create_yall

def place_holder_dataframe():
    list_dict = [
        {"gist_id":"mistralai/Mistral-7B-Instruct-v0.3",
        "filename":"https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3/blob/main/README.md",
        "url":"https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3",
        "model_name":"Mistral-7B-Instruct-v0.3",
        "model_id":"mistralai/Mistral-7B-Instruct-v0.3",
        "Model":"Mistral-7B-Instruct-v0.3",
        "Elo":1200,
        "Undetected rate":0.27
        },
        {
        "gist_id":"mistralai/Mixtral-8x22B-Instruct-v0.1",
        "filename":"https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1/blob/main/README.md",
        "url":"https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1",
        "model_name":"Mixtral-8x22B-Instruct-v0.1",
        "model_id":"mistralai/Mixtral-8x22B-Instruct-v0.1",
        "Model":"Mixtral-8x22B-Instruct-v0.1",
        "Elo":1950,
        "Undetected rate":0.63
        },
        {
        "gist_id":"mistralai/Mixtral-8x7B-Instruct-v0.1",
        "filename":"https://huggingface.co/mistralai/Mixtral-8x7B-v0.1/blob/main/README.md",
        "url":"https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1",
        "model_name":"Mixtral-8x7B-Instruct-v0.1",
        "model_id":"mistralai/Mixtral-8x7B-Instruct-v0.1",
        "Model":"Mixtral-8x7B-Instruct-v0.1",
        "Elo":1467,
        "Undetected rate":0.41
        }
    ]
    df = pd.DataFrame(list_dict)
    return df


def convert_markdown_table_to_dataframe(md_content):
    """
    Converts markdown table to Pandas DataFrame, handling special characters and links,
    extracts Hugging Face URLs, and adds them to a new column.
    """
    # Remove leading and trailing | characters
    cleaned_content = re.sub(r'\|\s*$', '', re.sub(r'^\|\s*', '', md_content, flags=re.MULTILINE), flags=re.MULTILINE)

    # Create DataFrame from cleaned content
    df = pd.read_csv(StringIO(cleaned_content), sep="\|", engine='python')

    # Remove the first row after the header
    df = df.drop(0, axis=0)

    # Strip whitespace from column names
    df.columns = df.columns.str.strip()

    # Extract Hugging Face URLs and add them to a new column
    model_link_pattern = r'\[(.*?)\]\((.*?)\)\s*\[.*?\]\(.*?\)'
    df['URL'] = df['Model'].apply(lambda x: re.search(model_link_pattern, x).group(2) if re.search(model_link_pattern, x) else None)

    # Clean Model column to have only the model link text
    df['Model'] = df['Model'].apply(lambda x: re.sub(model_link_pattern, r'\1', x))

    return df

@st.cache_data
def get_model_info(df):
    api = HfApi()

    # Initialize new columns for likes and tags
    df['Likes'] = None
    df['Tags'] = None

    # Iterate through DataFrame rows
    for index, row in df.iterrows():
        model = row['Model'].strip()
        try:
            model_info = api.model_info(repo_id=str(model))
            df.loc[index, 'Likes'] = model_info.likes
            df.loc[index, 'Tags'] = ', '.join(model_info.tags)

        except (RepositoryNotFoundError, RevisionNotFoundError):
            df.loc[index, 'Likes'] = -1
            df.loc[index, 'Tags'] = ''

    return df



def create_bar_chart(df, category):
    """Create and display a bar chart for a given category."""
    st.write(f"### {category} Scores")

    # Sort the DataFrame based on the category score
    sorted_df = df[['Model', category]].sort_values(by=category, ascending=True)

    # Create the bar chart with a color gradient (using 'Viridis' color scale as an example)
    fig = go.Figure(go.Bar(
        x=sorted_df[category],
        y=sorted_df['Model'],
        orientation='h',
        marker=dict(color=sorted_df[category], colorscale='Inferno')
    ))

    # Update layout for better readability
    fig.update_layout(
        margin=dict(l=20, r=20, t=20, b=20)
    )

    # Adjust the height of the chart based on the number of rows in the DataFrame
    st.plotly_chart(fig, use_container_width=True, height=35)

# Example usage:
# create_bar_chart(your_dataframe, 'Your_Category')


def main():
    st.set_page_config(page_title="LLM Roleplay Leaderboard", layout="wide")

    st.title("🏆🎭 LLM Roleplay Leaderboard")
    st.markdown("LLM Roleplay Leaderboard that uses scores from the matou garou roleplay game 🏠🐈‍.")
    #content = create_yall()
    tab1, tab2 = st.tabs(["🏆🎭 Leaderboard", "📝 About"])

    df = place_holder_dataframe()
    with tab1:
        if len(df)>0:
            try:
                df = df.sort_values(by='Elo', ascending=False)
                # Add a search bar
                search_query = st.text_input("Search models", "")
                # Display the filtered DataFrame or the entire leaderboard
                st.dataframe(
                    df[['Model', 'Elo', 'url', 'Undetected rate']],
                    use_container_width=True,
                    column_config={
                        "url": st.column_config.LinkColumn("url"),
                    },
                    hide_index=True,
                )

                # Filter the DataFrame based on the search query
                if search_query:
                    df = df[df['Model'].str.contains(search_query, case=False)]

                # Comparison between models
                selected_models = st.multiselect('Select models to compare', df['Model'].unique())
                comparison_df = df[df['Model'].isin(selected_models)]
                st.dataframe(
                    comparison_df,
                    use_container_width=True,
                    column_config={
                        "url": st.column_config.LinkColumn("url"),
                    },
                    hide_index=True,
                )
                
                # Add a button to export data to CSV
                if st.button("Export to CSV"):
                    # Export the DataFrame to CSV
                    csv_data = df.to_csv(index=False)

                    # Create a link to download the CSV file
                    st.download_button(
                        label="Download CSV",
                        data=csv_data,
                        file_name="leaderboard.csv",
                        key="download-csv",
                        help="Click to download the CSV file",
                    )

                # Full-width plot for the first category
                create_bar_chart(df, "Elo")

                # Next two plots in two columns
                col1, col2 = st.columns(2)
                with col1:
                    create_bar_chart(df, "Undetected rate")


            except Exception as e:
                st.error("An error occurred while processing the markdown table.")
                st.error(str(e))
        else:
            st.error("Failed to download the content from the URL provided.")

     # About tab
    with tab2:
        st.markdown('''
            ### Roleplay Leaderboard

        This space is here to present the results from the Matou-Garou space, where human and AI play a game of werewolf.
        
        It is meant as a social experience to see if you would be able to detect if talking to an AI.
        We also hope that this leaderboard can be used by video game creator in the future to select what model to select for LLM based NPCs
                
           Popularized by [Teknium](https://huggingface.co/teknium) and [NousResearch](https://huggingface.co/NousResearch), this benchmark suite aggregates four benchmarks
           Leaderboard copied from [Maxime Labonne](https://huggingface.co/mlabonne)
        ''')
        
if __name__ == "__main__":
    main()