import sklearn
from tenacity import retry, stop_after_attempt, wait_random_exponential
from tqdm import tqdm
import sys
# import openai
import time
# import pandas as pd
import random
import csv
import os
import pickle
import json

from langchain.chat_models import AzureChatOpenAI
from langchain.schema import HumanMessage, SystemMessage
from langchain.callbacks import get_openai_callback
from langchain.llms import OpenAI        

import tiktoken

from sklearn.feature_extraction.text import CountVectorizer
from collections import Counter
import math

import io
import contextlib

# os.system('pip install pandas reportlab')
# os.system('pip install openai==0.27.2')
# os.system('pip install tenacity')

import requests
from bs4 import BeautifulSoup
import ast

import nltk
nltk.download('punkt')
nltk.download('stopwords')
from nltk.tokenize import sent_tokenize
from nltk.corpus import stopwords
import string
from nltk.tokenize import sent_tokenize
from nltk.tokenize import word_tokenize
import numpy as np
import evaluate

def tree_edit_distance(tree1, tree2):
    def cost(node1, node2):
        """ Cost to transform node1 to node2 """
        if node1 == node2:
            return 0
        return 1

    def tree_size(tree):
        """ Calculate the size of the tree """
        if not isinstance(tree, list) or not tree:
            return 1
        return 1 + sum(tree_size(child) for child in tree)
    
    def ted(tree1, tree2):
        """ Compute tree edit distance between two trees """
        if not isinstance(tree1, list) and not isinstance(tree2, list):
            return cost(tree1, tree2)
        if not isinstance(tree1, list):
            return tree_size(tree2)
        if not isinstance(tree2, list):
            return tree_size(tree1)
        if not tree1 and not tree2:
            return 0
        if not tree1:
            return sum(tree_size(child) for child in tree2)
        if not tree2:
            return sum(tree_size(child) for child in tree1)

        dp = [[0] * (len(tree2) + 1) for _ in range(len(tree1) + 1)]

        for i in range(1, len(tree1) + 1):
            dp[i][0] = dp[i-1][0] + tree_size(tree1[i-1])
        for j in range(1, len(tree2) + 1):
            dp[0][j] = dp[0][j-1] + tree_size(tree2[j-1])

        for i in range(1, len(tree1) + 1):
            for j in range(1, len(tree2) + 1):
                dp[i][j] = min(dp[i-1][j] + tree_size(tree1[i-1]),
                               dp[i][j-1] + tree_size(tree2[j-1]),
                               dp[i-1][j-1] + ted(tree1[i-1], tree2[j-1]))

        return dp[len(tree1)][len(tree2)]

    return ted(tree1, tree2)

def preprocess_code_str(code_str):
    prefix = "citation_bracket = {}\nsentence = {}\n"
    code_str = code_str.replace("        ", "")
    code_lines = code_str.split("\n")
    code_line_list = []
    for line in code_lines:
        if "citation_bracket[" in line.split("=")[0]:
            code_line_list.append(line)
        if "sentence[" in line.split("=")[0]:
            code_line_list.append(line)

    return prefix + "\n".join(code_line_list) + "\nprint(sentence)"

def run_code(code_str):
    # Redirect stdout to capture print statements
    f = io.StringIO()
    with contextlib.redirect_stdout(f):
        exec(preprocess_code_str(code_str))
    
    # Get the standard output
    output = f.getvalue()
    return ast.literal_eval(output)

def replace_with_char(input_list, char='a'):
    def replace_in_nested_list(nested_list):
        if isinstance(nested_list, list):
            return [replace_in_nested_list(item) for item in nested_list]
        else:
            return char

    return replace_in_nested_list(input_list)

def top_k_keys(input_dict, k):
    # Sort the dictionary items by value in descending order and extract the keys
    sorted_keys = sorted(input_dict, key=input_dict.get, reverse=True)
    # Return the top-k keys
    return sorted_keys[:k]



def keys_with_least_k_values(d, k):
    if k <= 0:
        return []

    # Get the sorted list of (key, value) tuples based on the values
    sorted_items = sorted(d.items(), key=lambda item: item[1])
    
    # Extract the keys of the first k items
    least_k_keys = [item[0] for item in sorted_items[:k]]
    
    return least_k_keys

def edit_distance_code_str(code1, code2, just_tree_structure=False):
    
    # code1 = preprocess_code_str(code1)
    # code2 = preprocess_code_str(code2)
    sentence1 = run_code(code1)
    list_1 = [sentence1[key] for key in sentence1]
    sentence2 = run_code(code2)
    list_2 = [sentence2[key] for key in sentence2]

    if just_tree_structure:
        list_1 = replace_with_char(list_1)
        list_2 = replace_with_char(list_2)

    return tree_edit_distance(list_1, list_2)

class eval_metrics:
    def __init__(self):
        pass
        # if is_bertscore:
        #     pass

    def get_rouge_l(self, pred, refs):
        rouge = evaluate.load('rouge')
        results = rouge.compute(predictions=pred, references=refs)
        return results['rougeL']

    def get_bleu(self, pred, refs):
        bleu = evaluate.load('bleu')
        tmp_refs = [[item] for item in refs]
        results = bleu.compute(predictions=pred, references=tmp_refs)
        return results['bleu']        

    def get_meteor(self, pred, refs):
        meteor = evaluate.load('meteor')
        results = meteor.compute(predictions=pred, references=refs)
        return results['meteor']

    def get_bertscore(self, pred, refs):
        bertscore = evaluate.load('bertscore')
        results = bertscore.compute(predictions=pred, references=refs, lang = "en")
        return np.mean(results['f1'])

    def get_bleurt(self, pred, refs):
        bleurt = evaluate.load('bleurt', module_type="metric")
        # tmp_refs = [[item] for item in refs]
        results = bleurt.compute(predictions=pred, references=refs)
        return np.mean(results['scores'])        
    
class BM25:
    def __init__(self, documents, k1=1.5, b=0.75):
        self.documents = documents
        self.k1 = k1
        self.b = b
        self.vectorizer = CountVectorizer().fit(documents)
        self.doc_term_matrix = self.vectorizer.transform(documents)
        self.doc_lengths = np.array(self.doc_term_matrix.sum(axis=1)).flatten()
        self.avg_doc_length = np.mean(self.doc_lengths)
        self.df = np.diff(self.doc_term_matrix.tocsc().indptr)
        self.idf = self.compute_idf()
    
    def compute_idf(self):
        N = len(self.documents)
        idf = np.log((N - self.df + 0.5) / (self.df + 0.5) + 1)
        return idf
    
    def compute_bm25(self, query):
        query_vec = self.vectorizer.transform([query])
        scores = []
        for doc_idx in range(self.doc_term_matrix.shape[0]):
            score = 0
            for term_idx in query_vec.indices:
                if term_idx in self.doc_term_matrix[doc_idx].indices:
                    tf = self.doc_term_matrix[doc_idx, term_idx]
                    idf = self.idf[term_idx]
                    numerator = tf * (self.k1 + 1)
                    denominator = tf + self.k1 * (1 - self.b + self.b * (self.doc_lengths[doc_idx] / self.avg_doc_length))
                    score += idf * numerator / denominator
            scores.append(score)
        return scores
    
    def get_top_k(self, query, k=5):
        scores = self.compute_bm25(query)
        top_k_indices = np.argsort(scores)[::-1][:k]
        top_k_docs = [self.documents[i] for i in top_k_indices]
        return top_k_docs, top_k_indices

def get_nmis(true_dict, pred_dict):
    labels_true = []
    labels_pred = []
    
    # print(true_dict.keys())
    # print(pred_dict.keys())
    # print()
    
    for key in true_dict:
        labels_true.append(true_dict[key])
        if key not in pred_dict:
            labels_pred.append(-1)
        else:
            labels_pred.append(pred_dict[key])
    if len(labels_pred) == 0:
        max_label_pred = 0
    else:
        max_label_pred = np.max(labels_pred) + 1
    for label_idx, item in enumerate(labels_pred):
        if item==-1:
            labels_pred[label_idx] = max_label_pred
            max_label_pred+=1
            
    return sklearn.metrics.normalized_mutual_info_score(labels_true=labels_true, labels_pred=labels_pred), sklearn.metrics.adjusted_mutual_info_score(labels_true=labels_true, labels_pred=labels_pred)

def calculate_precision_recall_f1(predicted, ground_truth):

    # print(predicted)
    # print()
    # print(ground_truth)
    # print("-------------")
    
    # Convert lists to sets to handle duplicates and perform set operations
    predicted_set = set(predicted)
    ground_truth_set = set(ground_truth)
    
    # Calculate true positives (intersection of predicted and ground truth)
    true_positives = predicted_set.intersection(ground_truth_set)
    
    # Calculate precision
    precision = len(true_positives) / len(predicted_set) if predicted_set else 0
    
    # Calculate recall
    recall = len(true_positives) / len(ground_truth_set) if ground_truth_set else 0
    
    # Calculate F1-score
    if precision + recall == 0:
        f1_score = 0
    else:
        f1_score = 2 * (precision * recall) / (precision + recall)
    
    return precision, recall, f1_score

def get_introduction(arxiv_id):
    # Step 1: Construct the URL
    url = f"https://ar5iv.org/html/{arxiv_id}"
    
    # Step 2: Fetch the HTML content of the page
    response = requests.get(url)
    if response.status_code != 200:
        raise Exception(f"Failed to fetch the page: Status code {response.status_code}")
    
    # Step 3: Parse the HTML content
    soup = BeautifulSoup(response.content, 'html.parser')
    
    # Step 4: Locate the introduction section
    # We assume the introduction is the first section after the abstract
    # This heuristic might need adjustment depending on the exact structure of the paper
    introduction_text = ""
    found_introduction = False
    
    # Look for h2 tags which usually denote sections
    for tag in soup.find_all(['h2', 'h3']):
        # print(tag.text.lower())
        if 'introduction' in tag.text.lower():
            # print(tag.text)
            introduction_text += tag.text.strip() + "\n\n"
            next_node = tag.find_next_sibling()
            while next_node and next_node.name not in ['h2', 'h3']:
                introduction_text += next_node.get_text().strip() + "\n\n"
                next_node = next_node.find_next_sibling()
            break
    
    return introduction_text

def write_to_file(filepath, content):
    if '.txt' in filepath:
        with open(filepath, 'w') as fw:
            fw.write(content)
    elif '.json' in filepath:
        with open(filepath, 'w') as fw:
            json.dump(content, fw)
    elif '.pickle' in filepath or '.pkl' in filepath:
        with open(filepath, 'wb') as fw:
            pickle.dump(content, fw)
    elif '.npy' in filepath:
        np.save(filepath, content)

def read_from_file(filepath):
    if '.txt' in filepath:
        with open(filepath, 'r') as fr:
            return fr.read()
    elif '.json' in filepath:
        with open(filepath, 'r') as fr:
            return json.load(fr)
    elif '.pickle' in filepath or '.pkl' in filepath:
        with open(filepath, 'rb') as fr:
            return pickle.load(fr)
    elif '.npy' in filepath:
        return np.load(filepath)

def remove_stopwords_and_punctuation(text):
    # Get the list of stopwords
    stop_words = set(stopwords.words('english'))
    
    # Remove punctuation from text
    text = text.translate(str.maketrans('', '', string.punctuation.replace('_', '').replace('@', '')))
    
    # Split the text into words
    words = text.split()
    
    # Remove stopwords
    filtered_words = [word for word in words if word.lower() not in stop_words]
    
    # Join the words back into a single string
    filtered_text = ' '.join(filtered_words)
    
    return filtered_text

class AzureModels:
    
    def __init__(self, model_name):


        
        if model_name == "gpt4":
            DEPLOYMENT_NAME = "gentech-gpt4-research"
            BASE_URL = "https://gentechworkbench-stage.openai.azure.com/"
            API_KEY = "f074d7f2bfdf486783db5f4605b263a6"
            
            
            self.model = AzureChatOpenAI(
                openai_api_base=BASE_URL,
                openai_api_version="2023-03-15-preview",
                deployment_name=DEPLOYMENT_NAME,
                openai_api_key=API_KEY,
                openai_api_type="azure",
            )
            
            self.enc = tiktoken.encoding_for_model("gpt-4-0314")
        elif model_name == "gpt4o":
            DEPLOYMENT_NAME = "gpt-4o"
            BASE_URL = "https://docexpresearch.openai.azure.com/"
            API_KEY = "2d6dc256edd94e65a2fa4b5658651377"
            
            
            self.model = AzureChatOpenAI(
                openai_api_base=BASE_URL,
                openai_api_version="2023-07-01-preview",
                deployment_name=DEPLOYMENT_NAME,
                openai_api_key=API_KEY,
                openai_api_type="azure",
            )
            
            self.enc = tiktoken.encoding_for_model("gpt-4o")            


    @retry(wait=wait_random_exponential(min=30, max=80), stop=stop_after_attempt(5))
    def get_completion(self, question, max_tokens, stop=None):

        gpt_answer = self.model(
                [
                    HumanMessage(
                        content=question
                    )
                ], max_tokens = max_tokens, stop=stop
            )
        gpt_answer_content = gpt_answer.content  # Access the content attribute
    
        # Convert the answer_content to string datatype
        if isinstance(gpt_answer_content, str):
            gpt_answer_string = gpt_answer_content  # If the content is already a string, use it directly
        else:
            gpt_answer_string = str(gpt_answer_content)  # Convert to string if it's not already a string
    
        return gpt_answer_string
    
    
    
    def get_num_inp_tokens(self, inp):
        tokens = self.enc.encode(inp)
        return len(tokens)