import gradio as gr
import requests
import json
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification, AutoModelForQuestionAnswering

from datasets import load_dataset
import datasets
import plotly.io as pio
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import pandas as pd
from sklearn.metrics import confusion_matrix
import importlib
import torch
from dash import Dash, html, dcc
import numpy as np
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score


def load_model(model_type: str, model_name_or_path: str, dataset_name: str, config_name: str):
  tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

  if model_type == "text_classification":
      dataset = load_dataset(dataset_name, config_name)
      num_labels = len(dataset["train"].features["label"].names)

      if "roberta" in model_name_or_path.lower():
        from transformers import RobertaForSequenceClassification
        model = RobertaForSequenceClassification.from_pretrained(
        model_name_or_path, num_labels=num_labels)
      else:
        model = AutoModelForSequenceClassification.from_pretrained(
            model_name_or_path, num_labels=num_labels)
  elif model_type == "token_classification":
      dataset = load_dataset(dataset_name, config_name)
      num_labels = len(
          dataset["train"].features["ner_tags"].feature.names)
      model = AutoModelForTokenClassification.from_pretrained(
          model_name_or_path, num_labels=num_labels)
  elif model_type == "question_answering":
      model = AutoModelForQuestionAnswering.from_pretrained(model_name_or_path)
  else:
      raise ValueError(f"Invalid model type: {model_type}")

  return tokenizer, model


def test_model(tokenizer, model, test_data: list, label_map: dict):
  results = []
  for text, _, true_label in test_data:
      inputs = tokenizer(text, return_tensors="pt",
                        truncation=True, padding=True)
      outputs = model(**inputs)
      pred_label = label_map[int(outputs.logits.argmax(dim=-1))]
      results.append((text, true_label, pred_label))
  return results


def generate_label_map(dataset):
  if "label" not in dataset.features or dataset.features["label"] is None:
      return {}
      
  if isinstance(dataset.features["label"], datasets.ClassLabel):
      num_labels = dataset.features["label"].num_classes
      label_map = {i: label for i, label in enumerate(dataset.features["label"].names)}
  else:
      num_labels = len(set(dataset["label"]))
      label_map = {i: label for i, label in enumerate(set(dataset["label"]))}
  return label_map

# Explain fairness score: https://arxiv.org/pdf/1908.09635.pdf
def calculate_fairness_score(results, label_map):
  true_labels = [r[1] for r in results]
  pred_labels = [r[2] for r in results]

  # Overall accuracy
  # accuracy = (true_labels == pred_labels).mean()
  accuracy = accuracy_score(true_labels, pred_labels)
  # Calculate confusion matrix for each group
  group_names = label_map.values()
  group_cms = {}
  for group in group_names:
      true_group_indices = [i for i, label in enumerate(true_labels) if label == group]
      pred_group_labels = [pred_labels[i] for i in true_group_indices]
      true_group_labels = [true_labels[i] for i in true_group_indices]

      cm = confusion_matrix(true_group_labels, pred_group_labels, labels=list(group_names))
      group_cms[group] = cm

  # Calculate fairness score which means the average difference between confusion matrices
  score = 0
  for i, group1 in enumerate(group_names):
      for j, group2 in enumerate(group_names):
          if i < j:
              cm1 = group_cms[group1]
              cm2 = group_cms[group2]
              diff = np.abs(cm1 - cm2)
              score += (diff.sum() / 2) / cm1.sum()

  return accuracy, score

# Per-class metrics means the metrics for each class, and the class is defined by the label_map
def calculate_per_class_metrics(true_labels, pred_labels, label_map, metric='accuracy'):
    unique_labels = sorted(label_map.values())
    metrics = []
        
    if metric == 'accuracy':
        for label in unique_labels:
            label_indices = [i for i, true_label in enumerate(true_labels) if true_label == label]
            true_label_subset = [true_labels[i] for i in label_indices]
            pred_label_subset = [pred_labels[i] for i in label_indices]
            accuracy = accuracy_score(true_label_subset, pred_label_subset)
            metrics.append(accuracy)
    elif metric == 'f1':
        f1_scores = f1_score(true_labels, pred_labels, labels=unique_labels, average=None)
        metrics = f1_scores.tolist()
    else:
        raise ValueError(f"Invalid metric: {metric}")

    return metrics

def generate_fairness_statement(accuracy, fairness_score):
    accuracy_level = "high" if accuracy >= 0.85 else "moderate" if accuracy >= 0.7 else "low"
    fairness_level = "low" if fairness_score <= 0.15 else "moderate" if fairness_score <= 0.3 else "high"

    # statement = f"The model has a {accuracy_level} overall accuracy of {accuracy * 100:.2f}% and a {fairness_level} fairness score of {fairness_score:.2f}. "
    statement = f"Assessment: "
        
    if fairness_level == "low":
        statement += f"The low fairness score ({fairness_score:.2f}) and accuracy ({accuracy * 100:.2f}%) indicate that the model is relatively fair and does not exhibit significant bias across different groups."
    elif fairness_level == "moderate":
        statement += f"The moderate fairness score ({fairness_score:.2f}) and accuracy ({accuracy * 100:.2f}%) suggest that the model may have some bias across different groups, and further investigation is needed to ensure it does not disproportionately affect certain groups."
    else:
        statement += f"The high fairness score ({fairness_score:.2f}) and accuracy ({accuracy * 100:.2f}%) indicate that the model exhibits significant bias across different groups, and it's recommended to address this issue to ensure fair predictions for all groups."

    return statement

def generate_visualization(visualization_type, results, label_map, chart_mode):
    true_labels = [r[1] for r in results]
    pred_labels = [r[2] for r in results]
    
    background_color = "white" if chart_mode == "Light" else "black"
    text_color = "black" if chart_mode == "Light" else "white"

    if visualization_type == "confusion_matrix":
        return generate_report_card(results, label_map, chart_mode)["fig"]
    elif visualization_type == "per_class_accuracy":
        per_class_accuracy = calculate_per_class_metrics(
            true_labels, pred_labels, label_map, metric='accuracy')
            
        colors = px.colors.qualitative.Plotly
        fig = go.Figure()
        for i, label in enumerate(label_map.values()):
            fig.add_trace(go.Bar(
                x=[label],
                y=[per_class_accuracy[i]],
                name=label,
                marker_color=colors[i % len(colors)]
            ))
            
        fig.update_xaxes(showgrid=True, gridwidth=1,
                         gridcolor='LightGray', linecolor='black', linewidth=1)
        fig.update_yaxes(showgrid=True, gridwidth=1,
                         gridcolor='LightGray', linecolor='black', linewidth=1)
        fig.update_layout(plot_bgcolor=background_color,
                          paper_bgcolor=background_color, 
                          font=dict(color=text_color),
                          title='Per-Class Accuracy',
                          xaxis_title='Class', yaxis_title='Accuracy'
                          
                          )
        return fig
    elif visualization_type == "per_class_f1":
        per_class_f1 = calculate_per_class_metrics(
            true_labels, pred_labels, label_map, metric='f1')
            
        colors = px.colors.qualitative.Plotly
        fig = go.Figure()
        for i, label in enumerate(label_map.values()):
            fig.add_trace(go.Bar(
                x=[label],
                y=[per_class_f1[i]],
                name=label,
                marker_color=colors[i % len(colors)]
            ))
            
        fig.update_xaxes(showgrid=True, gridwidth=1,
                         gridcolor='LightGray', linecolor='black', linewidth=1)
        fig.update_yaxes(showgrid=True, gridwidth=1,
                         gridcolor='LightGray', linecolor='black', linewidth=1)
        fig.update_layout(plot_bgcolor=background_color,
                          paper_bgcolor=background_color,
                          font=dict(color=text_color),
                          title='Per-Class F1-Score',
                          xaxis_title='Class', yaxis_title='F1-Score'
                          )
        return fig
    elif visualization_type == "interactive_dashboard":
        return generate_interactive_dashboard(results, label_map, chart_mode)
    else:
        raise ValueError(f"Invalid visualization type: {visualization_type}")

def generate_interactive_dashboard(results, label_map, chart_mode):
    true_labels = [r[1] for r in results]
    pred_labels = [r[2] for r in results]
    
    colors = ['#EF553B', '#00CC96', '#636EFA',   '#AB63FA', '#FFA15A',
              '#19D3F3', '#FF6692', '#B6E880', '#FF97FF', '#FECB52']
    
    background_color = "white" if chart_mode == "Light" else "black"
    text_color = "black" if chart_mode == "Light" else "white"

    # Create confusion matrix
    cm_fig = generate_report_card(results, label_map, chart_mode)["fig"]

    # Create per-class accuracy bar chart
    pca_data = calculate_per_class_metrics(true_labels, pred_labels, label_map, metric='accuracy')
    pca_fig = go.Bar(x=list(label_map.values()), y=pca_data, marker=dict(color=colors))

    # Create per-class F1-score bar chart
    pcf_data = calculate_per_class_metrics(true_labels, pred_labels, label_map, metric='f1')
    pcf_fig = go.Bar(x=list(label_map.values()), y=pcf_data, marker=dict(color=colors))

    # Combine all charts into a mixed subplot
    fig = make_subplots(rows=2, cols=2, shared_xaxes=True, specs=[[{"colspan": 2}, None],
                                               [{}, {}]],
                        print_grid=True,subplot_titles=(
        "Confusion Matrix", "Per-Class Accuracy", "Per-Class F1-Score"))
    fig.add_trace(cm_fig['data'][0], row=1, col=1)
    fig.add_trace(pca_fig, row=2, col=1)
    fig.add_trace(pcf_fig, row=2, col=2)

    fig.update_xaxes(showgrid=True, gridwidth=1,
                     gridcolor='LightGray', linecolor='black', linewidth=1)
    fig.update_yaxes(showgrid=True, gridwidth=1,
                     gridcolor='LightGray', linecolor='black', linewidth=1)
    # Update layout
    fig.update_layout(height=700, width=650,
                      plot_bgcolor=background_color,
                      paper_bgcolor=background_color,
                      font=dict(color=text_color),
                      title="Fairness Report", showlegend=False
                      )

    return fig

def generate_report_card(results, label_map, chart_mode):
  true_labels = [r[1] for r in results]
  pred_labels = [r[2] for r in results]
  
  background_color = "white" if chart_mode == "Light" else "black"
  text_color = "black" if chart_mode == "Light" else "white"

  cm = confusion_matrix(true_labels, pred_labels)
  
  # Normalize the confusion matrix
  cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

  # Create a custom color scale
  custom_color_scale = np.zeros(cm_normalized.shape, dtype='str')
  for i in range(cm_normalized.shape[0]):
        for j in range(cm_normalized.shape[1]):
            custom_color_scale[i, j] = '#EF553B' if i == j else '#00CC96'

  fig = go.Figure(go.Heatmap(z=cm_normalized,
                            x=list(label_map.values()),
                            y=list(label_map.values()),
                            text=cm,
                            hovertemplate='%{text}',
                             colorscale=[[0, '#EF553B'], [
                                 1, '#00CC96']],
                            showscale=False,
                            zmin=0, zmax=1,
                            customdata=custom_color_scale))

  fig.update_xaxes(showgrid=True, gridwidth=1,
                   gridcolor='LightGray', linecolor='black', linewidth=1)
  fig.update_yaxes(showgrid=True, gridwidth=1,
                    gridcolor='LightGray', linecolor='black', linewidth=1)
  fig.update_layout(
      plot_bgcolor=background_color,
      paper_bgcolor=background_color,
      font=dict(color=text_color),
      height=500, width=600,
      title='Confusion Matrix',
      xaxis=dict(title='Predicted Labels'),
      yaxis=dict(title='True Labels')
  )

  # Create the text output
  # accuracy = pd.Series(true_labels) == pd.Series(pred_labels)
  accuracy = accuracy_score(true_labels, pred_labels, normalize=False)
  fairness_score = calculate_fairness_score(results, label_map)

  per_class_accuracy = calculate_per_class_metrics(
      true_labels, pred_labels, label_map, metric='accuracy')
  per_class_f1 = calculate_per_class_metrics(
        true_labels, pred_labels, label_map, metric='f1')

  report_card = {
      "fig": fig,
      "accuracy": accuracy,
      "fairness_score": fairness_score,
      "per_class_accuracy": per_class_accuracy,
      "per_class_f1": per_class_f1
  }
  return report_card

  # return fig, text_output
  

def generate_insights(custom_text, model_name, dataset_name, accuracy, fairness_score, report_card, generator):
    per_class_metrics = {
        'accuracy': report_card.get('per_class_accuracy', []),
        'f1': report_card.get('per_class_f1', [])
    }

    if not per_class_metrics['accuracy'] or not per_class_metrics['f1']:
        input_text = f"{custom_text} The model {model_name} has been evaluated on the {dataset_name} dataset. It has an overall accuracy of {accuracy * 100:.2f}%. The fairness score is {fairness_score:.2f}. Per-class metrics could not be calculated. Please provide some interesting insights about the fairness and bias of the model."
    else:
        input_text = f"{custom_text} The model {model_name} has been evaluated on the {dataset_name} dataset. It has an overall accuracy of {accuracy * 100:.2f}%. The fairness score is {fairness_score:.2f}. The per-class metrics are: {per_class_metrics}. Please provide some interesting insights about the fairness, bias, and per-class performance."


    insights = generator(input_text, max_length=600,
                        do_sample=True, temperature=0.7)
    return insights[0]['generated_text']


def app(model_type: str, model_name_or_path: str, dataset_name: str, config_name: str, dataset_split: str, num_samples: int, visualization_type: str, chart_mode: str):
  tokenizer, model = load_model(
      model_type, model_name_or_path, dataset_name, config_name)

  # Load the dataset
  # Add this line to cast num_samples to an integer
  num_samples = int(num_samples)
  dataset = load_dataset(
      dataset_name, config_name, split=f"{dataset_split}[:{num_samples}]")
  test_data = []

  if dataset_name == "glue":
      test_data = [(item["sentence"], None,
        dataset.features["label"].names[item["label"]]) for item in dataset]
  elif dataset_name == "tweet_eval":
      test_data = [(item["text"], None, dataset.features["label"].names[item["label"]])
        for item in dataset]
  else:
      test_data = [(item["sentence"], None,
        dataset.features["label"].names[item["label"]]) for item in dataset]

    #  if model_type == "text_classification":
      #      for item in dataset:
      #          text = item["sentence"]
      #          context = None
      #          true_label = item["label"]
#          test_data.append((text, context, true_label))
    #  elif model_type == "question_answering":
      #      for item in dataset:
      #          text = item["question"]
      #          context = item["context"]
      #          true_label = None
#          test_data.append((text, context, true_label))
    #  else:
#      raise ValueError(f"Invalid model type: {model_type}")

  label_map = generate_label_map(dataset)

  results = test_model(tokenizer, model, test_data, label_map)
  # fig, text_output = generate_report_card(results, label_map)

  # return fig, text_output

  report_card = generate_report_card(results, label_map, chart_mode)
  visualization = generate_visualization(visualization_type, results, label_map, chart_mode)

  per_class_metrics_str = "\n".join([f"{label}: Acc {acc:.2f}, F1 {f1:.2f}" for label, acc, f1 in zip(
      label_map.values(), report_card['per_class_accuracy'], report_card['per_class_f1'])])
  
  accuracy, fairness_score = calculate_fairness_score(results, label_map)
  fairness_statement = generate_fairness_statement(accuracy, fairness_score)
  
  # Use a GPU if available, otherwise use -1 for CPU.
  generator = pipeline(
      'text-generation', model='gpt2', device=-1)  # Use EleutherAI/gpt-neo-1.3B or EleutherAI/GPT-J-6B for GPT3 for distilgpt2 for GPT2
  per_class_metrics = {
      'accuracy': report_card['per_class_accuracy'],
      'f1': report_card['per_class_f1']
  }
  
  custom_text = fairness_statement
  
  insights = generate_insights(custom_text, model_name_or_path,
                               dataset_name, accuracy, fairness_score, report_card, generator)

  # return report_card["fig"], f"Accuracy: {report_card['accuracy']}, Fairness Score: {report_card['fairness_score'][1]:.2f}"
  # return f"Accuracy: {report_card['accuracy']}, Fairness Score: {report_card['fairness_score'][1]:.2f}", report_card["fig"]
  return (f"{insights}\n\n"
          f"Accuracy: {report_card['accuracy']}, Fairness Score: {report_card['fairness_score'][1]: .2f}\n\n"
          f"Per-Class Metrics:\n{per_class_metrics_str}"), visualization

interface = gr.Interface(
    fn=app,
    inputs=[
        gr.inputs.Radio(["text_classification", "token_classification",
                        "question_answering"], label="Model Type", default="text_classification"),
        gr.inputs.Textbox(lines=1, label="Model Name or Path",
                          placeholder="ex: distilbert-base-uncased-finetuned-sst-2-english", default="distilbert-base-uncased-finetuned-sst-2-english"),
        gr.inputs.Textbox(lines=1, label="Dataset Name",
                          placeholder="ex: glue", default="glue"),
        gr.inputs.Textbox(lines=1, label="Config Name",
                          placeholder="ex: sst2", default="cola"),
        gr.inputs.Dropdown(
            choices=["train", "validation", "test"], label="Dataset Split", default="validation"),
        gr.inputs.Number(default=100, label="Number of Samples"),
        gr.inputs.Dropdown(
            choices=["interactive_dashboard", "confusion_matrix", "per_class_accuracy", "per_class_f1"], label="Visualization Type", default="interactive_dashboard"
        ),
        gr.inputs.Radio(["Light", "Dark"], label="Chart Mode", default="Light"),
    ],
    # outputs=gr.Plot(),
    # outputs=gr.outputs.HTML(),
    # outputs=[gr.outputs.HTML(), gr.Plot()],
    outputs=[
        gr.outputs.Textbox(label="Fairness and Bias Metrics"),
        gr.Plot(label="Graph")
    ],
    title="Fairness and Bias Testing",
    description="Enter a model and dataset to test for fairness and bias.",
)

# Define the label map globally
label_map = {0: "negative", 1: "positive"}

if __name__ == "__main__":
    interface.launch()