import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from scipy.stats import gaussian_kde

import gradio as gr
from pathlib import Path
import gradio as gr
import plotly.graph_objects as go


import re
import ast

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns


def convert_google_sheet_url(url):
    # Regular expression to match and capture the necessary part of the URL
    pattern = r'https://docs\.google\.com/spreadsheets/d/([a-zA-Z0-9-_]+)(/edit#gid=(\d+)|/edit.*)?'

    # Replace function to construct the new URL for CSV export
    # If gid is present in the URL, it includes it in the export URL, otherwise, it's omitted
    replacement = lambda m: f'https://docs.google.com/spreadsheets/d/{m.group(1)}/export?' + (f'gid={m.group(3)}&' if m.group(3) else '') + 'format=csv'

    # Replace using regex
    new_url = re.sub(pattern, replacement, url)

    return new_url

# Replace with your modified URL
# url = "https://docs.google.com/spreadsheets/d/1dlTjKJrGVwRDU8m-hT53IdSluRAsWXftnx5uRqnq4yE/edit?gid=0#gid=0"
url = "https://docs.google.com/spreadsheets/d/1MY0-DOitMZGnib73BAaSKg0TI7i5V1CXP8dF6jAgKWc/edit?gid=293606167#gid=293606167"

new_url = convert_google_sheet_url(url)


df = pd.read_csv(new_url)

# Set 'Categories' column as index
df1 = df.copy()
df1.set_index('Categories', inplace=True)

transposed_df = df.transpose()
transposed_df.columns = transposed_df.iloc[0]
df = transposed_df.drop(["Categories"])


df = df.fillna("[]")
df1 = df1.fillna("[]")


# Convert the string representation of lists into actual lists for all relevant columns
for col in df.columns:  # Skip the first column which is 'Categories'
    df[col] = df[col].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)


# Convert the string representation of lists into actual lists for all relevant columns
for col in df1.columns:  # Skip the first column which is 'Categories'
    df1[col] = df1[col].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)


cols = df.columns

# Get the specific column while filtering out empty cells
column_data = df[cols[0]]

# Filter out the empty lists ([])
filtered_column_data = column_data[column_data.apply(lambda x: x != [])]



def get_score(avg_kl_div,kl_div,missing,extra,common):
    Wc=1
    Wm=1.5
    We=1.5
    WeE=(We*extra)**2
    WeM=(Wm*missing)**2
    WeC=(We*common)**2
    if kl_div==-1:
        kl_div=avg_kl_div
    kl_div_factor=kl_div/avg_kl_div
    ans=kl_div_factor*(((WeE+WeM)/WeC)-2)#  (e**2 -c**2)/c**2 +(m**2-c**2)/c**2 => (0-1)*[((e**2+m**2)/c**2 -2)] => ((rank*y/a)m(m+1)/2))
    return ans
def get_individual_score(avg_kl_div,kl_div,e_or_m,common):
    if kl_div==-1:
        kl_div=avg_kl_div
    kl_div_factor=kl_div/avg_kl_div
    weight=1.5
    ans=avg_kl_div + ((1+(e_or_m/common))*(((e_or_m)*(e_or_m+1)))/2)**0.5 # X +- [(1+b/a)*n**2*y]
    # ans = kl_div_factor*((((weight*e_or_m)**2)/(common**2))-1)
    return ans


def get_entity_scores(ans4):
    # Calculate average KL divergence
    tt = 0
    avg_kl_div = 0
    for t in ans4:
        if t[0] != -1:
            avg_kl_div += t[0]
            tt += 1

    # Avoid division by zero
    if tt > 0:
        avg_kl_div /= tt
    else:
        avg_kl_div = 0

    extra_entity_score = []
    missing_entity_score = []

    for t in ans4:
        extra_entity_score.append(get_individual_score(avg_kl_div, t[0], t[2], t[3]))
        missing_entity_score.append(get_individual_score(avg_kl_div, t[0], t[1], t[3]))

    extra_entity_score.sort()
    missing_entity_score.sort()

    return (
        missing_entity_score[:int(0.950 * len(missing_entity_score))],
        extra_entity_score[:int(0.95 * len(extra_entity_score))]
    )


compare = df.columns[0]
column_data = df[compare]

# Filter out the empty lists ([])
filtered_column_data = column_data[column_data.apply(lambda x: x != [])]

# Display the filtered column data
variables = filtered_column_data.to_list()
models = filtered_column_data.index.to_list()

color_schemes = [
    '#d60000',  # Red
    '#2f5282',  # Navy Blue
    '#f15cd8',  # Pink
    '#66abb7',  # Light Teal
    '#ce7391',  # Rose
    '#6bdb7a',  # Light Green
    '#ea8569',  # Coral
    '#b36cc9',  # Lavender
    '#ffd700',  # Gold
    '#ff7f0e',  # Orange
    '#1f77b4',  # Blue
    '#2ca02c',  # Green
]


colors = color_schemes[:len(models)]

values_dict = {model: var for var, model in zip(variables, models)}
color_dict = {model: color for model, color in zip(models, colors)}


# plot_grouped_3d_kde(values_dict, models, color_dict, compare)


import numpy as np
import plotly.graph_objects as go
from scipy.stats import gaussian_kde
import plotly.express as px



def adjust_kde_range(data, increment=25, threshold=0.00005):
    kde = gaussian_kde(data)
    min_x, max_x = min(data) - increment, max(data) + increment

    # Keep expanding the range until both tails get close to zero
    while True:
        x_values = np.linspace(min_x, max_x, 1000)
        y_values = kde(x_values)

        # # Check the values at the tails
        # print(y_values[0], y_values[-1])
        # print(x_values[0], x_values[-1], "\n")

        if y_values[0] < threshold and y_values[-1] < threshold:
            break  # Stop if both tails are below the threshold

        # Extend the range
        min_x -= increment
        max_x += increment

    return x_values, y_values


def compute_kde_ranges(missing_scores, extra_scores):
    data1 = np.array(missing_scores)
    data2 = -np.array(extra_scores)  # Negate extra scores for alignment

    # Compute KDE for missing scores with extended range
    x_missing, y_missing = adjust_kde_range(data1)

    # Compute KDE for extra scores with extended range
    x_extra, y_extra = adjust_kde_range(data2)

    # Calculate axis limits
    Val_x_extra = [max(x_extra)]
    Val_x_miss = [x_missing[np.argmax(y_missing)]]

    peak_extra = max(y_extra)
    peak_miss = max(y_missing)

    # Calculate the x and y axis ranges
    min_x = min(min(x_missing), min(x_extra))
    max_x = max(max(x_missing), max(x_extra))
    x_range = [min_x, max_x]

    y_range = [-peak_extra, peak_miss * 1.25]

    return x_missing, y_missing, x_extra, y_extra, x_range, y_range


def calculate_ticks(x_min, x_max, num_ticks=20):
    # Calculate the total range
    total_range = x_max - x_min

    # Determine the interval between ticks
    interval = total_range / (num_ticks - 1)  # We need num_ticks - 1 intervals

    # Generate tick values
    ticks = np.arange(x_min, x_max + interval, interval)

    return ticks




def plot_filled_surface(x, z, y_level, color):
    """
    Create a 3D mesh to fill the surface between the KDE curve and the 0-axis.
    """
    x_full = np.concatenate([x, x[::-1]])  # X-axis values, with reverse for baseline
    z_full = np.concatenate([z, np.zeros_like(z)])  # Z-axis (KDE and baseline at 0)
    y_full = np.full_like(x_full, y_level)  # Flat Y plane (constant for each model)

    num_pts = len(x)
    i = np.arange(num_pts - 1)
    j = i + 1
    k = i + num_pts

    i = np.concatenate([i, i + num_pts])
    j = np.concatenate([j, j + num_pts])
    k = np.concatenate([k, i[:len(i)//2]])

    return go.Mesh3d(
        x=x_full, y=y_full, z=z_full,
        i=i, j=j, k=k,
        opacity=0.5,
        color=color,
        showscale=False,
        legendgroup='filling'
    )



def plot_kde_3d(values_dict, models, color_dict, compare):

    # values_dict, models, color_dict, compare = (values_dict, models, color_dict, 'Comparison Title')
    fig = go.Figure()

    model_y_positions = {model: i for i, model in enumerate(models)}

    x_ranges = []
    y_ranges = []

    for model in models:
        missing_scores, extra_scores = get_entity_scores(values_dict[model])

        # Compute KDE and ranges for missing and extra scores
        x_m, y_m, x_e, y_e, x_range, y_range = compute_kde_ranges(missing_scores, extra_scores)

        # Append ranges for global limits
        x_ranges.append(x_range)
        y_ranges.append(y_range)

        # Get color for this model
        color = color_dict.get(model, 'rgba(0, 0, 0, 0.5)')  # Default color if not found

        # Create filled surfaces between KDE curves and zero line
        fig.add_trace(plot_filled_surface(x_m, y_m, model_y_positions[model], color))
        fig.add_trace(plot_filled_surface(x_e, -y_e, model_y_positions[model], color))

        # Plot the KDE lines (for visualization of the curves)
        fig.add_trace(go.Scatter3d(
            x=x_m,
            y=[model_y_positions[model]] * len(x_m),
            z=y_m,
            mode='lines',
            line=dict(color='blue'),
            showlegend=False
        ))

        fig.add_trace(go.Scatter3d(
            x=x_e,
            y=[model_y_positions[model]] * len(x_e),
            z=-y_e,
            mode='lines',
            line=dict(color='red'),
            showlegend=False  # Hide legend for extra scores to combine with missing scores
        ))

    # Compute global x and y limits
    x_min = min(r[0] for r in x_ranges)
    x_max = max(r[1] for r in x_ranges)
    y_min = min(r[0] for r in y_ranges)
    y_max = max(r[1] for r in y_ranges)

    # Define x, y, z axis tick intervals
    x_ticks = calculate_ticks(np.floor(x_min), np.ceil(x_max))
    y_ticks = list(model_y_positions.values())
    z_ticks = calculate_ticks(y_min, y_max)

    # Add a line through the 0-axis of density for each model
    for model in models:
        color = color_dict.get(model, 'rgba(0, 0, 0, 0.5)')
        fig.add_trace(go.Scatter3d(
            x=[x_min, x_max],
            y=[model_y_positions[model], model_y_positions[model]],
            z=[0, 0],
            mode='lines',
            # line=dict(color=color, width=2, dash='dash'),
            line=dict(color=color),
            name=model,

            # showlegend=False
        ))

    # Update layout for 3D plot
    fig.update_layout(
        title=f'3D KDE Plots for {compare}',
        scene=dict(
            xaxis_title='Score',
            yaxis_title='Model',
            zaxis_title='Density',
            xaxis=dict(
                range=[x_min, x_max],
                tickvals=x_ticks,
                ticktext=[f'{tick:.2f}' for tick in x_ticks]
            ),
            yaxis=dict(
                tickvals=y_ticks,
                ticktext=[list(model_y_positions.keys())[list(model_y_positions.values()).index(tick)] for tick in y_ticks]
            ),
            zaxis=dict(
                range=[y_min, y_max],
                tickvals=z_ticks,
                ticktext=[f'{tick:.4f}' for tick in z_ticks]
            ),
            camera=dict(
                eye=dict(x=1.25, y=1.25, z=1.25)
            )
        ),
        autosize=True,
        width=1200*.75,
        height=800*.75
    )

    # Save the plot as an HTML file
    # plot = px.scatter(x=range(10), y=range(10))
    filename = f"{compare}.html"
    fig.write_html(filename)

    # fig.show()

    return fig



# Path to your saved HTML file
html_file_path = '3d_plot.html'
title = 'My 3D Plot'

def display_plot():
    fig = plot_kde_3d(values_dict, models, color_dict, compare)
    return fig


# Define the Gradio interface
interface = gr.Interface(
    fn=display_plot,
    inputs=[],
    outputs=gr.Plot(),
    title='Plotly 3D Plot in Gradio',
    description='This app displays a 3D Plotly plot directly in the Gradio interface.',
    live=False
)

# Launch the Gradio app
if __name__ == "__main__":
    interface.launch()