from email.utils import parseaddr
from huggingface_hub import HfApi
import os
import datetime
import pandas as pd
import json

import evaluate as nlp_evaluate
import re
import sqlite3
import random
from tqdm import tqdm
import sys
import numpy as np


from get_exact_and_f1_score.ext_services.jsql_parser import JSQLParser
from get_exact_and_f1_score.metrics.partial_match_eval.evaluate import evaluate

random.seed(10001)

bleu = nlp_evaluate.load("bleu")
rouge = nlp_evaluate.load('rouge')


LEADERBOARD_PATH = "Exploration-Lab/BookSQL-Leaderboard"
RESULTS_PATH = "Exploration-Lab/BookSQL-Leaderboard-results"
api = HfApi()
TOKEN = os.environ.get("TOKEN", None)
YEAR_VERSION = "2024"

sqlite_path = "accounting/accounting_for_testing.sqlite"


_jsql_parser = JSQLParser.create()

def format_error(msg):
    return f"<p style='color: red; font-size: 20px; text-align: center;'>{msg}</p>"


def format_warning(msg):
    return f"<p style='color: orange; font-size: 20px; text-align: center;'>{msg}</p>"


def format_log(msg):
    return f"<p style='color: green; font-size: 20px; text-align: center;'>{msg}</p>"


def model_hyperlink(link, model_name):
    return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'


def input_verification(method_name, url, path_to_file, organisation, mail):
    for input in [method_name, url, path_to_file, organisation, mail]:
        if input == "":
            return format_warning("Please fill all the fields.")

    # Very basic email parsing
    _, parsed_mail = parseaddr(mail)
    if not "@" in parsed_mail:
        return format_warning("Please provide a valid email adress.")

    if path_to_file is None:
        return format_warning("Please attach a file.")

    return parsed_mail

def replace_current_date_and_now(_sql, _date):
    _sql = _sql.replace('current_date', "\'"+_date+"\'")
    _sql = _sql.replace(', now', ", \'"+_date+"\'")
    return _sql

def remove_gold_Non_exec(data,df1, sqlite_path):

    con = sqlite3.connect(sqlite_path)
    cur = con.cursor()

    out, non_exec=[], []
    new_df = df1.copy()
    new_df.loc[:, 'Exec/Non-Exec'] = 0
    for i,s in tqdm(enumerate(data)):
        _sql = str(s).replace('"', "'").lower()
        _sql = replace_current_date_and_now(_sql, '2022-06-01')
        _sql = replace_percent_symbol_y(_sql)
        try:
            cur.execute(_sql)
            res = cur.fetchall()
            out.append(i)
        except:
            non_exec.append(i)
            print("_sql: ", _sql)

    new_df.loc[out, 'Exec/Non-Exec'] = 1
    con.close()
    return out, non_exec, new_df

def remove_data_from_index(data, ind_list):
    new_data=[]
    for i in ind_list:
        new_data.append(data[i])
    return new_data

def get_exec_match_acc(gold, pred):
    assert len(gold)==len(pred)
    count=0
    goldd = [re.sub(' +', ' ', str(g).replace("'", '"').lower()) for g in gold]
    predd = [re.sub(' +', ' ', str(p).replace("'", '"').lower()) for p in pred]
    # for g, p in zip(gold, pred):
    #     #extra space, double quotes, lower_case
    #     gg = re.sub(' +', ' ', str(g).replace("'", '"').lower())
    #     gg = re.sub(' +', ' ', str(p).replace("'", '"').lower())
        # if gold==pred:
        #     count+=1

    goldd = _jsql_parser.translate_batch(goldd)
    predd = _jsql_parser.translate_batch(predd)
    pcm_f1_scores = evaluate(goldd, predd)
    pcm_em_scores = evaluate(goldd, predd, exact_match=True)

    _pcm_f1_scores, _pcm_em_scores=[], []
    for f1, em in zip(pcm_f1_scores, pcm_em_scores):
        if type(f1)==float and type(em)==float: 
            _pcm_f1_scores.append(f1)
            _pcm_em_scores.append(em)

    assert len(_pcm_f1_scores) == len(_pcm_em_scores)
    
    jsql_error_count=0 ####JSQLError
    for i, score in enumerate(pcm_f1_scores):
        if type(score)==str:
            jsql_error_count+=1

    print("JSQLError in sql: ", jsql_error_count)

    return sum(_pcm_em_scores) / len(_pcm_em_scores), sum(_pcm_f1_scores) / len(_pcm_f1_scores)

def replace_percent_symbol_y(_sql):
    _sql = _sql.replace('%y', "%Y")
    return _sql


def get_exec_results(sqlite_path, scores, df, flag, gold_sql_map_res={}):

    con = sqlite3.connect(sqlite_path)
    cur = con.cursor()

    i,j,count=0,0,0
    out,non_exec={},{}
    new_df = df.copy()
    responses=[]
    for s in tqdm(scores):
        _sql = str(s).replace('"', "'").lower()
        _sql = replace_current_date_and_now(_sql, '2022-06-01')
        _sql = replace_percent_symbol_y(_sql)
        try:
            cur.execute(_sql)
            res = cur.fetchall()
            out[i] = str(res)
        except Exception as err:
            non_exec[i]=err
        i+=1

    if flag=='g': 
        new_df.loc[list(out.keys()), 'GOLD_res'] = list(out.values())
    # assert len(gold_sql_map_res)==count
    if flag=='p':
        new_df.loc[list(out.keys()), 'PRED_res'] = list(out.values())
    if flag=='d':
        new_df.loc[list(out.keys()), 'DEBUG_res'] = list(out.values())

    con.close()
    return out, non_exec, new_df

def get_scores(gold_dict, pred_dict):
    exec_count, non_exec_count=0, 0
    none_count=0
    correct_sql, incorrect_sql = [], []
    for k, res in pred_dict.items():
        if k in gold_dict:
            if gold_dict[k]==str(None) or str(None) in gold_dict[k]: 
                none_count+=1
                continue
            if res==gold_dict[k]:
                exec_count+=1
                correct_sql.append(k)
            else: 
                non_exec_count+=1
                incorrect_sql.append(k)
                
    return exec_count, non_exec_count, none_count, correct_sql, incorrect_sql

def get_total_gold_none_count(gold_dict):
    none_count, ok_count=0, 0
    for k, res in gold_dict.items():
        if res==str(None) or str(None) in res: 
            none_count+=1
        else: ok_count+=1
    return ok_count, none_count


def evaluate(df):
    # df - [id, pred_sql]
    pred_sql = df['pred_sql'].to_list()
    ids = df['id'].to_list()
    f = open(f"tests/test.json")
    questions_and_ids = json.load(f)
    ts = open(f"tests/test_sql.json")
    gold_sql = json.load(ts)

    gold_sql_list=[]
    pred_sql_list=[]
    questions_list=[]
    for idx, pred in zip(ids, pred_sql):
        ques = questions_and_ids[idx]['Query']
        gd_sql = gold_sql[idx]['SQL']
        gold_sql_list.append(gd_sql)
        pred_sql_list.append(pred_sql_list)
        questions_list.append(ques)
    
    df = pd.DataFrame({'NLQ':questions_list, 'GOLD SQL':gold_sql_list, 'PREDICTED SQL':pred_sql_list})

    test_size = len(df)

    pred_score = df['PREDICTED SQL'].str.lower().values
    # debug_score = df['DEBUGGED SQL'].str.lower().values
    gold_score1 = df['GOLD SQL'].str.lower().values


    print("Checking non-exec Gold sql query")
    gold_exec, gold_not_exec, new_df = remove_gold_Non_exec(gold_score1, df, sqlite_path)
    print("GOLD Total exec SQL query: {}/{}".format(len(gold_exec), test_size))
    print("GOLD Total non-exec SQL query: {}/{}".format(len(gold_not_exec), test_size))


    prev_non_exec_df = new_df[new_df['Exec/Non-Exec'] == 0]
    new_df = new_df[new_df['Exec/Non-Exec']==1]

    prev_non_exec_df.reset_index(inplace=True)
    new_df.reset_index(inplace=True)

    #Removing Non-exec sql from data
    print(f"Removing {len(gold_not_exec)} non-exec sql query from all Gold/Pred/Debug")
    gold_score1 = remove_data_from_index(gold_score1, gold_exec)
    pred_score = remove_data_from_index(pred_score, gold_exec)
    # debug_score = remove_data_from_index(debug_score, gold_exec)
    gold_score = [[x] for x in gold_score1]

    assert len(gold_score) == len(pred_score) #== len(debug_score)

    pred_bleu_score  = bleu.compute(predictions=pred_score, references=gold_score)
    pred_rouge_score  = rouge.compute(predictions=pred_score, references=gold_score)
    pred_exact_match, pred_partial_f1_score = get_exec_match_acc(gold_score1, pred_score)

    print("PREDICTED_vs_GOLD Final bleu_score: ", pred_bleu_score['bleu'])
    print("PREDICTED_vs_GOLD Final rouge_score: ", pred_rouge_score['rougeL'])
    print("PREDICTED_vs_GOLD Exact Match Accuracy: ", pred_exact_match)
    print("PREDICTED_vs_GOLD Partial CM F1 score: ", pred_partial_f1_score)
    print()


    new_df.loc[:, 'GOLD_res'] = str(None)
    new_df.loc[:, 'PRED_res'] = str(None)
    # new_df.loc[:, 'DEBUG_res'] = str(None)

    print("Getting Gold results")
    # gout_res_dict, gnon_exec_err_dict, gold_sql_map_res = get_exec_results(cur, gold_score1, 'g')
    gout_res_dict, gnon_exec_err_dict, new_df = get_exec_results(sqlite_path, gold_score1, new_df, 'g')

    total_gold_ok_count, total_gold_none_count = get_total_gold_none_count(gout_res_dict)
    print("Total Gold None count: ", total_gold_none_count)

    print("Getting Pred results")
    pout_res_dict, pnon_exec_err_dict, new_df = get_exec_results(sqlite_path, pred_score, new_df, 'p')
    # print("Getting Debug results")
    # dout_res_dict, dnon_exec_err_dict = get_exec_results(cur, debug_score, 'd')

    print("GOLD Total exec SQL query: {}/{}".format(len(gold_exec), test_size))
    print("GOLD Total non-exec SQL query: {}/{}".format(len(gold_not_exec), test_size))
    print()
    print("PRED Total exec SQL query: {}/{}".format(len(pout_res_dict), len(pred_score)))
    print("PRED Total non-exec SQL query: {}/{}".format(len(pnon_exec_err_dict), len(pred_score)))
    print()
    # print("DEBUG Total exec SQL query: {}/{}".format(len(dout_res_dict), len(debug_score)))
    # print("DEBUG Total non-exec SQL query: {}/{}".format(len(dnon_exec_err_dict), len(debug_score)))
    # print()
    pred_correct_exec_acc_count, pred_incorrect_exec_acc_count, pred_none_count, pred_correct_sql, pred_incorrect_sql  = get_scores(gout_res_dict, pout_res_dict)
    # debug_correct_exec_acc_count, debug_incorrect_exec_acc_count, debug_none_count, debug_correct_sql, debug_incorrect_sql   = get_scores(gout_res_dict, dout_res_dict)
    # print("PRED_vs_GOLD None_count: ", total_gold_none_count)
    print("PRED_vs_GOLD Correct_Exec_count without None: ", pred_correct_exec_acc_count)
    print("PRED_vs_GOLD Incorrect_Exec_count without None: ", pred_incorrect_exec_acc_count)
    print("PRED_vs_GOLD Exec_Accuracy: ", pred_correct_exec_acc_count/total_gold_ok_count)
    print()

    return pred_exact_match, pred_correct_exec_acc_count/total_gold_ok_count, pred_partial_f1_score, pred_bleu_score['bleu'], pred_rouge_score['rougeL']

def add_new_eval(
    method_name: str,
    url: str,
    path_to_file: str,
    organisation: str,
    mail: str,
):

    parsed_mail = input_verification(
        method_name,
        url,
        path_to_file,
        organisation,
        mail,
    )

    # load the file
    df = pd.read_csv(path_to_file)
    submission_df = pd.read_csv(path_to_file)

    # modify the df to include metadata
    df["Method"] = method_name
    df["url"] = url
    df["organisation"] = organisation
    df["mail"] = parsed_mail
    df["timestamp"] = datetime.datetime.now()

    submission_df = pd.read_csv(path_to_file)
    submission_df["Method"] = method_name
    submission_df["Submitted By"] = organisation
    # upload to spaces using the hf api at

    path_in_repo = f"submissions/{method_name}"
    file_name = f"{method_name}-{organisation}-{datetime.datetime.now().strftime('%Y-%m-%d')}.csv"

    EM, EX, PCM_F1, BLEU, ROUGE = evaluate(submission_df)
    submission_df['EM'] = EM
    submission_df['EX'] = EX
    # submission_df['PCM_F1'] = PCM_F1
    submission_df['BLEU'] = BLEU
    submission_df['ROUGE'] = ROUGE

    # upload the df to spaces
    import io

    buffer = io.BytesIO()
    df.to_csv(buffer, index=False)  # Write the DataFrame to a buffer in CSV format
    buffer.seek(0)  # Rewind the buffer to the beginning

    api.upload_file(
        repo_id=RESULTS_PATH,
        path_in_repo=f"{path_in_repo}/{file_name}",
        path_or_fileobj=buffer,
        token=TOKEN,
        repo_type="dataset",
    )
    # read the leaderboard
    leaderboard_df = pd.read_csv(f"submissions/baseline/baseline.csv")

    # append the new submission_df csv to the leaderboard
    # leaderboard_df = leaderboard_df._append(submission_df)
    leaderboard_df = pd.concat([leaderboard_df, submission_df], ignore_index=True)

    # save the new leaderboard
    # leaderboard_df.to_csv(f"submissions/baseline/baseline.csv", index=False)
    leaderboard_buffer = io.BytesIO()
    leaderboard_df.to_csv(leaderboard_buffer, index=False)
    leaderboard_buffer.seek(0)
    api.upload_file(
        repo_id=LEADERBOARD_PATH,
        path_in_repo=f"submissions/baseline/baseline.csv",
        path_or_fileobj=leaderboard_buffer,
        token=TOKEN,
        repo_type="space",
    )

    return format_log(
        f"Method {method_name} submitted by {organisation} successfully. \nPlease refresh the leaderboard, and wait a bit to see the score displayed"
    )