import torch
import streamlit as st
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.feature_extraction.text import CountVectorizer
import seaborn as sns
import plotly.express as px
import plotly.io as pio
import plotly.graph_objects as go
from run import run_pipeline

# Set page configuration
st.set_page_config(layout="wide")

# Function to load and clean data
def load_and_clean_data():
    df1 = pd.read_csv("data/reviewed_social_media_english.csv")
    df2 = pd.read_csv("data/reviewed_news_english.csv")
    df3 = pd.read_csv("data/tamil_social_media.csv")  
    df4 = pd.read_csv("data/tamil_news.csv")       

    # Concatenate dataframes and clean data
    df_combined = pd.concat([df1, df2, df3, df4])
    
    # Replace 'nan' and 'None' with numpy NaN for removal
    df_combined['Domain'] = df_combined['Domain'].replace({"MUSLIM": "Muslim", "nan": pd.NA, "None": pd.NA, "Other-Ethnic": "Other-Ethnicity"})
    
    # Specific replacements for 'Sentiment' and 'Discrimination'
    df_combined['Sentiment'] = df_combined['Sentiment'].replace({"nan": pd.NA, "None": pd.NA, "No": pd.NA})
    df_combined['Discrimination'] = df_combined['Discrimination'].replace({"nan": pd.NA, "None": pd.NA, "No": pd.NA})
    
    # Drop rows with NA values in 'Domain', 'Sentiment', and 'Discrimination'
    df_combined.dropna(subset=['Domain', 'Sentiment', 'Discrimination'], inplace=True)

    return df_combined

df = load_and_clean_data()

# Page navigation setup
page_names = ["Dashboard for GESI Conversation in Sri Lanka", "GESI Overview", "Sentiment Analysis", "Discrimination Analysis", "Channel Analysis"]
page = st.sidebar.selectbox("Choose a page", page_names)

# Sidebar Filters
domain_options = df['Domain'].dropna().unique()
channel_options = df['Channel'].dropna().unique()
sentiment_options = df['Sentiment'].dropna().unique()
discrimination_options = df['Discrimination'].dropna().unique()

domain_filter = st.sidebar.multiselect('Select Domain', options=domain_options, default=domain_options)
channel_filter = st.sidebar.multiselect('Select Channel', options=channel_options, default=channel_options)
sentiment_filter = st.sidebar.multiselect('Select Sentiment', options=sentiment_options, default=sentiment_options)
discrimination_filter = st.sidebar.multiselect('Select Discrimination', options=discrimination_options, default=discrimination_options)

# Apply filters
df_filtered = df[(df['Domain'].isin(domain_filter)) & 
                 (df['Channel'].isin(channel_filter)) & 
                 (df['Sentiment'].isin(sentiment_filter)) & 
                 (df['Discrimination'].isin(discrimination_filter))]

# Define a color palette for consistent visualization styles
color_palette = px.colors.sequential.Viridis

# Function to render the model prediction visualization page
def render_prediction_page():
    st.title("Dashboard for GESI Conversations in Sri Lanka")
    st.write("""
    Instant Analysis: Enter any text snippet and get immediate predictions from our model trained on English, Sinhala, and Tamil languages.\n\n
    Domain Identification: Discover the subject matter of your text with a quantifiable domain score.
    """)
    
    # User input text area
    user_input = st.text_area("Enter Text/Content here to analyze", height=150)
    
    if st.button("Perform Contextual Analysis"):
        # Use run_pipeline to get predictions
        predictions = run_pipeline(user_input)
        
        # Extract prediction details
        domain_label = predictions.get("domain_label", "Unknown")
        domain_score = predictions.get("domain_score", 0)
        discrimination_label = predictions.get("discrimination_label", "Unknown")
        discrimination_score = predictions.get("discrimination_score", 0)
        
        # Visualization layout
        col1, col2 = st.columns(2)
        
        with col1:
            st.markdown("#### Domain Label")
            st.markdown(f"## {domain_label}")
            st.progress(domain_score)
            
        with col2:
            st.markdown("#### Discrimination Label")
            st.markdown(f"## {discrimination_label}")
            st.progress(discrimination_score)
            
        col3, col4 = st.columns(2)
        
        with col3:
            # Display Domain Score in Bold
            st.markdown(f'**Domain Score: {domain_score:.2f}**', unsafe_allow_html=True)
            
        with col4:
            # Display Discrimination Score in Bold
            st.markdown(f'**Discrimination Score: {discrimination_score:.2f}**', unsafe_allow_html=True)

# Visualization for Domain Distribution
def create_pie_chart(df, column, title):
    fig = px.pie(df, names=column, title=title, hole=0.35)
    fig.update_layout(margin=dict(l=20, r=20, t=30, b=20), legend=dict(x=0.1, y=1), font=dict(size=12))
    fig.update_traces(marker=dict(colors=color_palette))
    return fig

# Visualization for Distribution of Gender versus Ethnicity
def create_gender_ethnicity_distribution_chart(df):
    df['GenderOrEthnicity'] = df['Domain'].apply(lambda x: "Gender: Women & LGBTQIA+" if x in ["Women", "LGBTQIA+"] else "Ethnicity")
    fig = px.pie(df, names='GenderOrEthnicity', title='Distribution of Gender versus Ethnicity', hole=0.35)
    fig.update_layout(margin=dict(l=20, r=20, t=30, b=20), legend=dict(x=0.1, y=1), font=dict(size=12))
    return fig

# Visualization for Sentiment Distribution Across Domains
def create_sentiment_distribution_chart(df):
    domain_counts = df.groupby(['Domain', 'Sentiment']).size().reset_index(name='counts')
    domain_counts = domain_counts.sort_values('counts')

    # color scheme
    color_map = {'Negative': 'red', 'Positive': 'blue', 'Neutral': 'lightblue'}

    fig = px.bar(domain_counts, x='Domain', y='counts', color='Sentiment', color_discrete_map=color_map,
                 title="Sentiment Distribution Across Domains", barmode='stack')
    fig.update_layout(margin=dict(l=20, r=20, t=50, b=20), xaxis_title="Domain", yaxis_title="Counts", font=dict(size=10))
    return fig

# Visualization for Correlation between Sentiment and Discrimination
def create_sentiment_discrimination_grouped_chart(df):
    # Creating a crosstab of 'Sentiment' and 'Discrimination'
    crosstab_df = pd.crosstab(df['Sentiment'], df['Discrimination'])

    # Check if 'Discriminative' and 'Non Discriminative' are in the columns after the crosstab operation
    value_vars = crosstab_df.columns.intersection(['Discriminative', 'Non Discriminative']).tolist()

    # If 'Non Discriminative' is not in columns, it will not be included in melting
    melted_df = pd.melt(crosstab_df.reset_index(), id_vars='Sentiment', value_vars=value_vars, var_name='Discrimination', value_name='Count')

    # Proceeding to plot only if we have data to plot
    if not melted_df.empty:
        fig = px.bar(melted_df, x='Sentiment', y='Count', color='Discrimination', barmode='group', title="Sentiment vs. Discrimination")
        fig.update_layout(margin=dict(l=20, r=20, t=50, b=20), xaxis_title="Sentiment", yaxis_title="Count", font=dict(size=10))
        return fig
    else:
        return "No data to display for the selected filters."

# Function for Top Domains with Negative Sentiment Chart
def create_top_negative_sentiment_domains_chart(df):
    domain_counts = df.groupby(['Domain', 'Sentiment']).size().unstack(fill_value=0)
    domain_counts.sort_values(by='Negative', ascending=False, inplace=True)
    domain_counts_subset = domain_counts.iloc[:3, [0]]
    domain_counts_subset = domain_counts_subset.rename(columns={domain_counts_subset.columns[0]: 'Count'})
    domain_counts_subset = domain_counts_subset.reset_index()
    colors = ['limegreen', 'crimson', 'darkcyan']
    fig = px.bar(domain_counts_subset, x='Count', y='Domain', title='Top Domains with Negative Sentiment', color='Domain',
                 orientation='h', color_discrete_sequence=colors)
    fig.update_layout(margin=dict(l=20, r=20, t=50, b=20), xaxis_title="Negative Sentiment Content Count", yaxis_title="Domain", font=dict(size=10))
    return fig

# Function for Key Phrases in Negative Sentiment Content Chart
def create_key_phrases_negative_sentiment_chart(df):
    cv = CountVectorizer(ngram_range=(3,3), stop_words='english')
    trigrams = cv.fit_transform(df['Content'][df['Sentiment'] == 'Negative'])
    count_values = trigrams.toarray().sum(axis=0)
    ngram_freq = pd.DataFrame(sorted([(count_values[i], k) for k, i in cv.vocabulary_.items()], reverse=True))
    ngram_freq.columns = ['frequency', 'ngram']
    fig = px.bar(ngram_freq.head(10), x='frequency', y='ngram', orientation='h', title='Key Phrases in Negative Sentiment Content')
    fig.update_layout(margin=dict(l=20, r=20, t=50, b=20), xaxis_title="Frequency", yaxis_title="Trigram", font=dict(size=10))
    return fig

# Function for Key Phrases in Positive Sentiment Content Chart
def create_key_phrases_positive_sentiment_chart(df):
    # Filter the DataFrame for positive sentiments and drop any rows with NaN in 'Content'
    positive_df = df[df['Sentiment'] == 'Positive'].dropna(subset=['Content'])

    # Create a CountVectorizer instance
    cv = CountVectorizer(ngram_range=(3, 3), stop_words='english')

    # Apply CountVectorizer only on non-null content
    trigrams = cv.fit_transform(positive_df['Content'])

    # Sum the frequency of each n-gram and create a DataFrame
    count_values = trigrams.toarray().sum(axis=0)
    ngram_freq = pd.DataFrame(sorted([(count_values[i], k) for k, i in cv.vocabulary_.items()], reverse=True))
    ngram_freq.columns = ['frequency', 'ngram']

    # Create the bar chart
    fig = px.bar(ngram_freq.head(10), x='frequency', y='ngram', orientation='h', title='Key Phrases in Positive Sentiment Content')

    # Update layout settings
    fig.update_layout(margin=dict(l=20, r=20, t=50, b=20), xaxis_title="Frequency", yaxis_title="Trigram", font=dict(size=10))

    return fig

# Function for Prevalence of Discriminatory Content Chart
def create_prevalence_discriminatory_content_chart(df):
    domain_counts = df.groupby(['Domain', 'Discrimination']).size().unstack(fill_value=0)
    fig = px.bar(domain_counts, x=domain_counts.index, y=['Discriminative', 'Non-Discriminative'], barmode='group',
                 title='Prevalence of Discriminatory Content')
    fig.update_layout(margin=dict(l=20, r=20, t=50, b=20), xaxis_title="Domain", yaxis_title="Count", font=dict(size=10))
    return fig

# Function for Top Domains with Discriminatory Content Chart
def create_top_discriminatory_domains_chart(df):
    domain_counts = df.groupby(['Domain', 'Discrimination']).size().unstack(fill_value=0)
    domain_counts.sort_values(by='Discriminative', ascending=False, inplace=True)
    domain_counts_subset = domain_counts.iloc[:3]
    domain_counts_subset = domain_counts_subset.rename(columns={'Discriminative': 'Count'})
    fig = px.bar(domain_counts_subset, x='Count', y=domain_counts_subset.index, orientation='h',
                 title='Top Domains with Discriminatory Content')
    fig.update_layout(margin=dict(l=20, r=20, t=50, b=20), xaxis_title="Discriminatory Content Count", yaxis_title="Domain", font=dict(size=10))
    return fig

# Function for Channel-wise Sentiment Over Time Chart
def create_sentiment_distribution_by_channel_chart(df):
    sentiment_by_channel = df.groupby(['Channel', 'Sentiment']).size().reset_index(name='counts')
    color_map = {'Positive': 'blue', 'Neutral': 'lightblue', 'Negative': 'red'}
    fig = px.bar(sentiment_by_channel, x='Channel', y='counts', color='Sentiment', title="Sentiment Distribution by Channel", barmode='group', color_discrete_map=color_map)
    fig.update_layout(margin=dict(l=20, r=20, t=50, b=20), xaxis_title="Channel", yaxis_title="Counts", font=dict(size=10), title_x=0.5)
    return fig

# Function for Channel-wise Distribution of Discriminative Content Chart
def create_channel_discrimination_chart(df):
    channel_discrimination = df.groupby(['Channel', 'Discrimination']).size().unstack(fill_value=0)
    fig = px.bar(channel_discrimination, x=channel_discrimination.index, y=['Discriminative', 'Non-Discriminative'], barmode='group')
    fig.update_layout(title='Channel-wise Distribution of Discriminative Content', margin=dict(l=20, r=20, t=50, b=20), font=dict(size=10), title_x=0.5)
    return fig

# Function for rendering dashboard
def render_dashboard(page, df_filtered):
    if page == "Dashboard for GESI Conversations in Sri Lanka":
        render_prediction_page()
    elif page == "GESI Overview":
        st.title("GESI Overview Dashboard")
        col1, col2 = st.columns(2)
        with col1:
            st.plotly_chart(create_pie_chart(df_filtered, 'Domain', 'Distribution of Domains'))
        with col2:
            st.plotly_chart(create_gender_ethnicity_distribution_chart(df_filtered))

        col3, col4 = st.columns(2)
        with col3:
            st.plotly_chart(create_sentiment_distribution_chart(df_filtered))
        with col4:
            chart = create_sentiment_discrimination_grouped_chart(df_filtered)
            if isinstance(chart, str):
                st.write(chart)
            else:
                st.plotly_chart(chart)

    elif page == "Sentiment Analysis":
        st.title("Sentiment Analysis Dashboard")
        col1, col2 = st.columns(2)
        with col1:
            st.plotly_chart(create_sentiment_distribution_chart(df_filtered))
        with col2:
            st.plotly_chart(create_top_negative_sentiment_domains_chart(df_filtered))

        col3, col4 = st.columns(2)
        with col3:
            st.plotly_chart(create_key_phrases_negative_sentiment_chart(df_filtered))
        with col4:
            st.plotly_chart(create_key_phrases_positive_sentiment_chart(df_filtered))

    elif page == "Discrimination Analysis":
        st.title("Discrimination Analysis Dashboard")
        col1, col2 = st.columns(2)
        with col1:
            st.plotly_chart(create_prevalence_discriminatory_content_chart(df_filtered))
        with col2:
            st.plotly_chart(create_top_discriminatory_domains_chart(df_filtered))

    elif page == "Channel Analysis":
        st.title("Channel Analysis Dashboard")
        col1, col2 = st.columns(2)
        with col1:
            st.plotly_chart(create_sentiment_distribution_by_channel_chart(df_filtered))
        with col2:
            st.plotly_chart(create_channel_discrimination_chart(df_filtered))

# Render the selected dashboard page
render_dashboard(page, df_filtered)