from tenacity import retry, stop_after_attempt, wait_random_exponential
from tqdm import tqdm
import time
import sys

# MODEL_NAME = str(sys.argv[1])
# num_shots = int(sys.argv[2])
# method = str(sys.argv[3]) #['fixed', 'random', 'bm25']

# ADDED K-SHOT SETTING, WHERE K IS VARIABLE

# import openai
import time
# import pandas as pd
import random
random.seed(1)

import csv
import os
import pickle
import json
import nltk
nltk.download('punkt')
nltk.download('stopwords')
from nltk.tokenize import sent_tokenize
from nltk.corpus import stopwords
import string

from langchain.chat_models import AzureChatOpenAI
from langchain.schema import HumanMessage, SystemMessage
from langchain.callbacks import get_openai_callback
from langchain.llms import OpenAI
import tiktoken

import re
from nltk.tokenize import sent_tokenize
from collections import defaultdict


import nltk
from nltk.tokenize import sent_tokenize
from nltk.tokenize import word_tokenize
import numpy as np

# Get the parent directory
# parent_dir = "/home/abnandy/sensei-fs-link"#os.path.abspath(os.path.join(os.getcwd(), os.pardir))
# Add the parent directory to the system path
# sys.path.append(parent_dir)

from utils import AzureModels, write_to_file, read_from_file
# from utils_open import OpenModels

def remove_stopwords_and_punctuation(text):
    # Get the list of stopwords
    stop_words = set(stopwords.words('english'))
    
    # Remove punctuation from text
    text = text.translate(str.maketrans('', '', string.punctuation.replace('_', '').replace('@', '')))
    
    # Split the text into words
    words = text.split()
    
    # Remove stopwords
    filtered_words = [word for word in words if word.lower() not in stop_words]
    
    # Join the words back into a single string
    filtered_text = ' '.join(filtered_words)
    
    return filtered_text

def get_key(list_):
    tmp_str = '@cite'
    for item in list_:
        tmp_str+=item.replace('@cite', '')
    return tmp_str

def group_citations(key):
    list_ = ["@cite_" + item for item in key.replace("@cite_", "").split("_")]
    return ", ".join(list_)

def code_to_extra_info(code_str):
    citation_bracket_keys = []
    sentence_keys = []
    code_lines = code_str.split("\n")
    for line in code_lines:
        if "citation_bracket[" in line.split("=")[0]:
            citation_bracket_keys.append(line.split("=")[0].split('citation_bracket["')[-1].split('"]')[0])
        if "sentence[" in line.split("=")[0]:
            sentence_keys.append(line.split("=")[0].split('sentence["')[-1].split('"]')[0])

    cb_template = "{} are in the same citation bracket (i.e., they are right next to each other) within the section of the Wikipedia Article."
    sent_template = "{} are in the same sentence within the section of the Wikipedia Article."
    
    cb_list = [cb_template.format(group_citations(key)) for key in citation_bracket_keys if key.count("_")>1]
    sent_list = [sent_template.format(group_citations(key)) for key in sentence_keys if key.count("_")>1]

    if len(cb_list) + len(sent_list) == 0:
        return ""
    return_str = "\n\nNOTE THAT -\n\n" + "\n".join(cb_list) + "\n\n" + "\n".join(sent_list)

    return return_str

def get_code_str(related_work, reference_dict):
    # print(reference_dict.keys())
    citation_bracket_code_lines = []
    sentence_code_lines = []
    
    # Tokenize the related work into sentences
    sentences = sent_tokenize(related_work)
    
    # Get all citation tags from the reference_dict
    citation_tags = list(reference_dict.keys())

    
    for sentence in sentences:
        tmp_sentence_list = []
        parts = remove_stopwords_and_punctuation(sentence).split(' ')
        cb_list = []
        str_cb_list = []
        
        # print(parts)
        # print(reference_dict.keys())
        # print(1/0)
        
        for word in parts:
            if word in reference_dict:
                cb_list.append(word)
                str_cb_list.append('"' + word + '"')
            else:
                if len(cb_list)>0:
                    # print(cb_list)
                    citation_bracket_code_lines.append('citation_bracket["{}"] = {}'.format(get_key(cb_list), str(str_cb_list)))
                    tmp_sentence_list.append(get_key(cb_list))
                    cb_list = []
                    str_cb_list = []

        if len(cb_list) > 0:
            citation_bracket_code_lines.append('citation_bracket["{}"] = {}'.format(get_key(cb_list), str(str_cb_list)))
            tmp_sentence_list.append(get_key(cb_list))
            cb_list = []
            str_cb_list = []

        tmp_values = []
        for key in tmp_sentence_list:
            tmp_values.append('citation_bracket["{}"]'.format(key))
        if len(tmp_values) > 0:
            sentence_code_lines.append('sentence["{}"] = {}'.format(get_key(tmp_sentence_list), str(tmp_values)))

    return "        " + "\n        ".join(citation_bracket_code_lines).replace("'", "") + "\n\n        " + "\n        ".join(sentence_code_lines).replace("'", "")

def get_prompt(list_, i, prompt_template):
    gt_summary = list_[i]['related_work'].strip()
    inp_intent = list_[i]['abstract'].strip()

    input_code_str = "    "
    input_code_list = []
    
    # print(sent_tokenize(gt_summary))
    # print()
    
    # print(1/0)
    tmp_list = list_[i]['ref_abstract']
    # abstract_list = []
    # cite_tags = []
    abstract_dict = {}
    # write_to_file("dummy.json", tmp_list)
    for key in tmp_list:
        abstract_dict[key] = tmp_list[key]['abstract'].strip()
    for key in abstract_dict:
        input_code_list.append('reference_articles["{}"] = "{}"'.format(key, abstract_dict[key]))
    input_code_list.append('intent = "{}"'.format(inp_intent))
    input_code_str += "\n    ".join(input_code_list)
    code_str = get_code_str(gt_summary, tmp_list)
    prompt = prompt_template.format(input_code_str)    
    return gt_summary, prompt, code_str

def preprocess_retrieved_out(tmp_keys, out):
    new_dict = {}
    for key in tmp_keys:
        for line in out.split("\n"):
            if key in line:
                summ_doc = line.split(":", 1)[-1].strip()
                new_dict[key] = {"abstract": summ_doc}
                print(key)
                print(summ_doc)
                print()
                break
    return new_dict

def get_slide(topic, text):
    slide_prompt = '''Convert this text into more structured text (in markdown) that can be put into the content of a slide in a presentation (e.g. use bullet points, numbered points, proper layout, etc.). Also, the include the topic "{}" of the slide. -
    
{}'''
    azure_models = AzureModels("gpt4o")
    slide_prompt = slide_prompt.format(topic, text)
    out_ = azure_models.get_completion(slide_prompt, 100)
    time.sleep(2)
    return out_

def get_retrieved_results(MODEL_NAME, num_shots, method, train_list, test_list, code=False, organize_out=None):    
    response_template = ''
    instruction_template = ''

    final_dict = {}

    pred_dict = {}
    start_idx = 0

    icl_extra_info = ""
    test_extra_info = ""
    
    if 'gpt4' in MODEL_NAME:
        azure_models = AzureModels(MODEL_NAME)
    else:
        if code:
            instruction_template = '''Below is an instruction that describes a task. Write a response that appropriately completes the request.
        
        ### Instruction:
        '''
            response_template = '### Response:\n'            
        else:
            response_template = '### Assistant: '
        if MODEL_NAME=='gemma2b':
            model_id = "google/gemma-2b-it"
        elif MODEL_NAME=='gemma7b':
            model_id = "google/gemma-7b-it"
        elif MODEL_NAME=='mistral7b':
            model_id = "mistralai/Mistral-7B-Instruct-v0.3"
        elif MODEL_NAME=="llama7b":
            model_id = "meta-llama/Llama-2-7b-chat-hf"
        elif MODEL_NAME=="llama13b":
            model_id = "meta-llama/Llama-2-13b-chat-hf"
        elif MODEL_NAME=="llama3":
            model_id="meta-llama/Meta-Llama-3-8B-Instruct"        
        elif MODEL_NAME=="galactica7b":
            model_id = "facebook/galactica-6.7b"
        open_models = OpenModels(model_id)
    
    prompt_template = '''Given are a set of articles referenced in a Wikipedia Article, and the intent -
    
    Reference Articles:
    {}
    
    Intent:
    {}
    
    Summarize each reference article (generate in the format "@cite_K : <SUMMARIZED CONTENT CORREPONDING TO @cite_K>", each in a new line, where @cite_K represents each of the following citation/reference tags - {}, given in Reference Articles), given the reference articles as documents, and the intent.{}
    
    {}Answer: '''

    if organize_out!=None:
        prompt_template = '''Given are a set of articles referenced in a Wikipedia Article, and the intent -
        
        Reference Articles:
        {}
        
        Intent:
        {}
        
        Generate the wikipedia article section in 100-200 words based on the intent as an intent-based multi-document summary, given the reference articles as documents, and the intent.{}
        
        {}Answer: '''        

    if code:
        prompt_template = '''def main():
            # Given is a dictionary of articles that are referenced in a section of the Wikipedia Article, and the intent -
        
            reference_articles = dict()
        
        {}'''        
    
    if method == 'bm25':
        retrieve_dict = read_from_file("bm25_10_icl_samples_50_holdout_samples.json")
    elif method == "gat":
        retrieve_dict = read_from_file("gat_20_icl_samples_50_holdout_samples.json")            
    
    #len(test_list))):
    
    icl_train_indices = [0,1]


    if code:
        for i in tqdm(range(start_idx, len(test_list))):#start_idx, len(test_list))):
            if len(test_list[i]['ref_abstract']) > 1:
        
                full_icl_prompt = ""
        
                hier_cluster_prompt = "\n    def hierarchical_clustering():\n        # Hierarchical Clustering of references within a section of the Wikipedia Article, based on the reference articles and the intent\n        citation_bracket = {} # This dictionary contains lists as values that shows how references are grouped within the same citation bracket in the section of the Wikipedia Article\n        sentence = {} # This dictionary contains lists, where each list contains references in a sentence in the section of the Wikipedia Article\n\n"        
        
                if num_shots > 0:
        
                    if method == "random":
                        icl_train_indices = random.sample(holdout_indices, num_shots)#random.sample(np.arange(len(train_list)).tolist())
                    elif (method == "bm25") or (method == "gat"):
                        icl_train_indices = [int(retrieve_dict[str(i)][j]) for j in range(num_shots)]
                    elif method == 'fixed':
                        icl_train_indices = icl_train_indices[:num_shots]
        
                    for enum_idx, icl_train_idx in enumerate(icl_train_indices):
                        
                        # Fixed ICL Sample
                        icl_gt_summary, icl_prompt, icl_code_str = get_prompt(train_list, icl_train_idx, prompt_template) # this particular example has 6 citations
                        # icl_gt_summary_2, icl_prompt_2, icl_code_str_2 = get_prompt(train_list, 85) # this particular example has 12 citations, 4 of which are missing
        
                        full_icl_prompt += "##Example {}:\n\n".format(enum_idx + 1) + instruction_template + icl_prompt + hier_cluster_prompt + icl_code_str + "\n\n"
        
                    full_icl_prompt += "##Example {}:\n\n".format(num_shots+1)
                
                gt_summary, prompt, code_str = get_prompt(test_list, i, prompt_template)
        
                
                
                
                # full_icl_prompt_2 = "##Example 2:\n\n" + icl_prompt_2 + hier_cluster_prompt + icl_code_str_2
                
                final_prompt = full_icl_prompt + instruction_template + prompt + hier_cluster_prompt + "        # only generate the code that comes after this, as if you are on autocomplete mode\n" + response_template
                
                # final_prompt = full_icl_prompt + "\n\n" + full_icl_prompt_2 + "\n\n" + prompt
                
                # final_prompt = full_icl_prompt + "\n\n" + prompt
                
                # print(get_num_inp_tokens(final_prompt))
                # print(gt_summary)
                # print("---------")
                # print(final_prompt)
                # print("---------")
                # print("GT:")
                # print(code_str)
                # print("---------")
        
                max_tokens = 500
        
                if 'gpt4' in MODEL_NAME:
                    out_ = azure_models.get_completion(final_prompt, max_tokens)
                    time.sleep(2)
                else:
                    out_ = open_models.open_completion(final_prompt, max_tokens, stop_token="##Example {}".format(num_shots + 2))
        
                # print("Predicted:")
                # print(out_)
        
                final_dict[i] = out_
        
        return final_dict    
        
                # write_to_file(save_filepath, final_dict)

    
    else:
        if organize_out==None:
            tmp_max_tok_len=1000
        else:
            tmp_max_tok_len=300
        
    
        for i in tqdm(range(start_idx, len(test_list))):#len(test_list))):
            if len(test_list[i]['ref_abstract']) > 1:
                
                icl_prompt = ""
        
                if num_shots > 0:
        
                    if method == "random":
                        icl_train_indices = random.sample(holdout_indices, num_shots)#random.sample(np.arange(len(train_list)).tolist())
                    elif method == "bm25":
                        icl_train_indices = [int(retrieve_dict[str(i)][j]) for j in range(num_shots)]
                    elif method == 'fixed':
                        icl_train_indices = icl_train_indices[:num_shots]        
            
                    for enum_idx, icl_train_idx in enumerate(icl_train_indices):
                        icl_tmp_list = train_list[icl_train_idx]['ref_abstract']
                        icl_inp_intent = train_list[icl_train_idx]['abstract']
                        icl_gt_summary = train_list[icl_train_idx]['related_work']

                        if organize_out!=None:
                            icl_code_str = get_code_str(icl_gt_summary, icl_tmp_list)
                            icl_extra_info = code_to_extra_info(icl_code_str)                        
                        
                        icl_abstract_dict = {}
                        
                        for key in icl_tmp_list:
                            if organize_out==None:
                                icl_abstract_dict[key] = icl_tmp_list[key]#['abstract']      
                            else:
                                icl_abstract_dict[key] = icl_tmp_list[key]['abstract']      
                            
                        icl_abstract_list = [key + " : " + icl_abstract_dict[key] for key in icl_abstract_dict]
                                
                        icl_paper_abstracts = "\n".join(icl_abstract_list)
                        
                        icl_prompt += "##Example {}:\n\n".format(enum_idx + 1) + prompt_template.format(icl_paper_abstracts, icl_inp_intent, " ".join(list(icl_tmp_list.keys())), icl_extra_info, response_template) + icl_gt_summary.strip() + "\n\n"
            
                    icl_prompt += "##Example {}:\n\n".format(num_shots+1)
                
                gt_summary = test_list[i]['related_work']
                inp_intent = test_list[i]['abstract']
                if organize_out!=None:
                    test_code_str = organize_out[str(i)]
                    test_extra_info = code_to_extra_info(test_code_str)                
                
                # print(sent_tokenize(gt_summary))
                # print()
                
                # print(1/0)
                tmp_list = test_list[i]['ref_abstract']
                # abstract_list = []
                # cite_tags = []
                abstract_dict = {}
                for key in tmp_list:
                    if organize_out==None:
                        abstract_dict[key] = tmp_list[key]#['abstract']      
                    else:
                        abstract_dict[key] = tmp_list[key]['abstract']      
                    
                abstract_list = [key + " : " + abstract_dict[key] for key in abstract_dict]
                        
                paper_abstracts = "\n".join(abstract_list)
                        
                prompt = prompt_template.format(paper_abstracts, inp_intent, " ".join(list(tmp_list.keys())), test_extra_info, response_template)
        
                # if num_shots == 1:
                prompt = icl_prompt + prompt
                
                # print(prompt)
                # print("-----------")
                
                if 'gpt4' in MODEL_NAME:
                    out_ = azure_models.get_completion(prompt, tmp_max_tok_len)
                    time.sleep(2)
                else:
                    # try:
                    out_ = open_models.open_completion(prompt, tmp_max_tok_len, temperature=0.7)

                if organize_out==None:
                    test_list[i]["ref_abstract"] = preprocess_retrieved_out(tmp_list, out_)
                else:
                    pred_dict[i] = out_
                
        
        # return pred_dict
        # write_to_file("retrieved_docs.json", test_list)
        if organize_out==None:
            return test_list
        else:
            return pred_dict