import os
# from tree_sitter import Language, Parser
# # import pandas as pd
# import openpyxl
import json
import time
import csv
import pathlib
import difflib
import re
from bleu import _bleu
from fuzzywuzzy import fuzz
import random
import numpy as np
from transformers import RobertaTokenizer
#tokens = nltk.word_tokenize(sentence)
import argparse

parser = argparse.ArgumentParser(description='Test')
parser.add_argument("--task", default=None, type=str, required=True,
                        help="Task Type: statement_level, next_statement" )
args = parser.parse_args()



folder = str(pathlib.Path(__file__).parent.resolve())
isa_type_dir = folder+"/../../../Dataset"
src_dir = folder+f"/../../../Dataset/Code_Completion/{args.task}"
dst_dir = folder

train_lis = []
valid_lis = []
test_lis = []

target_clf = {}
def get_target_clf_list():
    global target_clf
    with open(isa_type_dir+"/comback_isa_type.csv","r",encoding="utf-8") as f:
        reader = csv.reader(f)
        for idx, l in enumerate(reader):
            if l[1].lower() == "arc" or l[1].lower() == "riscv" or l[1].lower() == "nvptx":
                continue
            if l[0] + " " + l[2] not in target_clf.keys():
                target_clf[l[0] + " " + l[2]] = [l[1]]
            else:
                target_clf[l[0] + " " + l[2]] += [l[1]]




def Calculate_Completion():
    get_target_clf_list()
    print("############## Exp 2: Calculate ChatGPT Stmt Completion ################\n")
    
    test_lis = ["nvptx","arc","riscv"]


    chatgpt_gcc_code = {}
    chatgpt_llvm_code = {}

    if args.task == "next_statement":
        dst_file = dst_dir+"/Input/chatgpt_next_output_cleaned.csv"
    else:
        dst_file = dst_dir+"/Input/chatgpt_stmt_output_cleaned.csv"

    

    with open(dst_file, encoding="utf-8") as f:
        reader = csv.reader(f)
        for idx, row in enumerate(reader):
            if row[0] == "GCC":
                chatgpt_gcc_code[row[1] + " " + str(row[2])] = row[3]
            else:
                chatgpt_llvm_code[row[1] + " " + str(row[2])] = row[3]
    avg_accuracy = {}
    for comp_type in ["GCC", "LLVM"]:
        for isa_type in ["GPU", "MPU", "CPU"]:
            test_target_dic = {}
            cnt_idx = 0
            if comp_type == "GCC":
                if isa_type == "CPU":
                    cnt_idx = 0
                    for line in open(src_dir + "/GCC/riscv.jsonl", 'r'):
                        dic = json.loads(line)
                        test_target_dic["riscv" + " " + str(cnt_idx)] = " ".join(dic["ground_truth"])
                        
                        cnt_idx += 1
                    total_EM = 0.0
                    total_ED = 0.0
                    for k in test_target_dic.keys():
                        edit_dis = 0.0
                        EM = 0.0
                        src_code = test_target_dic[k]
                        
                        if k in chatgpt_gcc_code.keys():
                            chat_code = chatgpt_gcc_code[k]
                            if chat_code.replace(" ", "") == src_code.replace(" ", ""):
                                EM = 1
                            edit_dis = fuzz.ratio(chat_code.replace(" ", ""), src_code.replace(" ", ""))
                            total_ED += edit_dis
                            total_EM += EM
                            with open(dst_dir + '/result.csv', 'a', newline='') as file:
                                writer = csv.writer(file)
                                writer.writerow([comp_type, "RISCV", k.split(" ")[1], str(round(EM*100,2)), str(round(float(edit_dis),2))])
                        else:
                            print(k)
                    with open(dst_dir + '/result.csv', 'a', newline='') as file:
                        writer = csv.writer(file)
                        writer.writerow([comp_type, "RISCV", "average", str(round((total_EM / cnt_idx)*100,2)), str(round(float(total_ED / cnt_idx),2))])
                        avg_accuracy[comp_type + " " + "RISCV"] = [str(round((total_EM / cnt_idx)*100,2)), str(round(float(total_ED / cnt_idx),2))]
                    
                if isa_type == "GPU":
                    cnt_idx = 0
                    for line in open(src_dir + "/GCC/nvptx.jsonl", 'r'):
                        dic = json.loads(line)
                        test_target_dic["nvptx" + " " + str(cnt_idx)] = " ".join(dic["ground_truth"])
                        cnt_idx += 1
                    total_EM = 0.0
                    total_ED = 0.0
                    
                    for k in test_target_dic.keys():
                        edit_dis = 0.0
                        EM = 0.0          
                        src_code = test_target_dic[k]
                        if k in chatgpt_gcc_code.keys():
                            chat_code = chatgpt_gcc_code[k]
                            if chat_code.replace(" ", "") == src_code.replace(" ", ""):
                                EM = 1
                            edit_dis = fuzz.ratio(chat_code.replace(" ", ""), src_code.replace(" ", ""))
                            total_ED += edit_dis 
                            total_EM += EM
                            with open(dst_dir + '/result.csv', 'a', newline='') as file:
                                writer = csv.writer(file)
                                writer.writerow([comp_type, "NVPTX", k.split(" ")[1], str(round(EM*100,2)), str(round(float(edit_dis),2))])
                        else:
                            print(k)
                    with open(dst_dir + '/result.csv', 'a', newline='') as file:
                        writer = csv.writer(file)
                        writer.writerow([comp_type, "NVPTX", "average", str(round((total_EM / cnt_idx)*100,2)), str(round(float(total_ED / cnt_idx),2))])
                        avg_accuracy[comp_type + " " + "NVPTX"] = [str(round((total_EM / cnt_idx)*100,2)), str(round(float(total_ED / cnt_idx),2))]
                    
                if isa_type == "MPU":
                    cnt_idx = 0
                    for line in open(src_dir + "/GCC/arc.jsonl", 'r'):
                        dic = json.loads(line)
                        test_target_dic["arc" + " " + str(cnt_idx)] = " ".join(dic["ground_truth"])
                        cnt_idx += 1
                    total_EM = 0.0
                    total_ED = 0.0
                    for k in test_target_dic.keys():
                        edit_dis = 0.0
                        EM = 0.0
                        src_code = test_target_dic[k]
                        if k in chatgpt_gcc_code.keys():
                            chat_code = chatgpt_gcc_code[k]    
                            if chat_code.replace(" ", "") == src_code.replace(" ", ""):
                                EM = 1
                            edit_dis = fuzz.ratio(chat_code.replace(" ", ""), src_code.replace(" ", ""))
                            total_ED += edit_dis
                            total_EM += EM
                            with open(dst_dir + '/result.csv', 'a', newline='') as file:
                                writer = csv.writer(file)
                                writer.writerow([comp_type, "ARC", k.split(" ")[1], str(round(EM*100,2)), str(round(float(edit_dis),2))])
                        else:
                            print(k)
                    with open(dst_dir + '/result.csv', 'a', newline='') as file:
                        writer = csv.writer(file)
                        writer.writerow([comp_type, "ARC", "average", str(round((total_EM / cnt_idx)*100,2)), str(round(float(total_ED / cnt_idx),2))])
                        avg_accuracy[comp_type + " " + "ARC"] = [str(round((total_EM / cnt_idx)*100,2)), str(round(float(total_ED / cnt_idx),2))]
                    

            if comp_type == "LLVM":
                if isa_type == "CPU":
                    cnt_idx = 0
                    for line in open(src_dir + "/LLVM/RISCV.jsonl", 'r'):
                        dic = json.loads(line)
                        test_target_dic["RISCV" + " " + str(cnt_idx)] = " ".join(dic["ground_truth"])
                        cnt_idx += 1
                    total_EM = 0.0
                    total_ED = 0.0
                    for k in test_target_dic.keys():
                        edit_dis = 0.0
                        EM = 0.0
                        src_code = test_target_dic[k]
                        if k in chatgpt_llvm_code.keys():
                            chat_code = chatgpt_llvm_code[k]
                            if chat_code.replace(" ", "") == src_code.replace(" ", ""):
                                EM = 1
                            edit_dis = fuzz.ratio(chat_code.replace(" ", ""), src_code.replace(" ", ""))
                            total_ED += edit_dis
                            total_EM += EM
                            with open(dst_dir + '/result.csv', 'a', newline='') as file:
                                writer = csv.writer(file)
                                writer.writerow([comp_type, "RISCV", k.split(" ")[1], str(round(EM*100,2)), str(round(float(edit_dis),2))])
                        else:
                            print(k)
                    with open(dst_dir + '/result.csv', 'a', newline='') as file:
                        writer = csv.writer(file)
                        writer.writerow([comp_type, "RISCV", "average", str(round((total_EM / cnt_idx)*100,2)), str(round(float(total_ED / cnt_idx),2))])
                        avg_accuracy[comp_type + " " + "RISCV"] = [str(round((total_EM / cnt_idx)*100,2)), str(round(float(total_ED / cnt_idx),2))]
                if isa_type == "GPU":
                    cnt_idx = 0
                    for line in open(src_dir + "/LLVM/NVPTX.jsonl", 'r'):
                        dic = json.loads(line)
                        test_target_dic["NVPTX" + " " + str(cnt_idx)] = " ".join(dic["ground_truth"])
                        cnt_idx += 1
                    total_EM = 0.0
                    total_ED = 0.0
                    for k in test_target_dic.keys():
                        edit_dis = 0.0
                        EM = 0.0
                        src_code = test_target_dic[k]
                        if k in chatgpt_llvm_code.keys():
                            chat_code = chatgpt_llvm_code[k]  
                            if chat_code.replace(" ", "") == src_code.replace(" ", ""):
                                EM = 1
                            edit_dis = fuzz.ratio(chat_code.replace(" ", ""), src_code.replace(" ", ""))
                            total_ED += edit_dis
                            total_EM += EM
                            with open(dst_dir + '/result.csv', 'a', newline='') as file:
                                writer = csv.writer(file)
                                writer.writerow([comp_type, "NVPTX", k.split(" ")[1], str(round(EM*100,2)), str(round(float(edit_dis),2))])
                        else:
                            print(k)
                    with open(dst_dir + '/result.csv', 'a', newline='') as file:
                        writer = csv.writer(file)
                        writer.writerow([comp_type, "NVPTX", "average", str(round((total_EM / cnt_idx)*100,2)), str(round(float(total_ED / cnt_idx),2))])
                        avg_accuracy[comp_type + " " + "NVPTX"] = [str(round((total_EM / cnt_idx)*100,2)), str(round(float(total_ED / cnt_idx),2))]
                if isa_type == "MPU":
                    cnt_idx = 0
                    for line in open(src_dir + "/LLVM/ARC.jsonl", 'r'):
                        dic = json.loads(line)
                        test_target_dic["ARC" + " " + str(cnt_idx)] = " ".join(dic["ground_truth"])
                        cnt_idx += 1
                    total_EM = 0.0
                    total_ED = 0.0
                    
                    for k in test_target_dic.keys():
                        edit_dis = 0.0
                        EM = 0.0
                        src_code = test_target_dic[k]
                        if k in chatgpt_llvm_code.keys():
                            chat_code = chatgpt_llvm_code[k]
                            if chat_code.replace(" ", "") == src_code.replace(" ", ""):
                                EM = 1
                            edit_dis = fuzz.ratio(chat_code.replace(" ", ""), src_code.replace(" ", ""))
                            total_ED += edit_dis
                            total_EM += EM
                            with open(dst_dir + '/result.csv', 'a', newline='') as file:
                                writer = csv.writer(file)
                                writer.writerow([comp_type, "ARC", k.split(" ")[1], str(round(EM*100,2)), str(round(float(edit_dis),2))])
                        else:
                            print(k)
                    with open(dst_dir + '/result.csv', 'a', newline='') as file:
                        writer = csv.writer(file)
                        writer.writerow([comp_type, "ARC", "average", str(round((total_EM / cnt_idx)*100,2)), str(round(float(total_ED / cnt_idx),2))])
                        avg_accuracy[comp_type + " " + "ARC"] = [str(round((total_EM / cnt_idx)*100,2)), str(round(float(total_ED / cnt_idx),2))]

    return avg_accuracy




if __name__ == "__main__":
    with open(dst_dir + '/result.csv', 'w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(["Compiler Type", "Target", "Idx", "Exact Match", "Edit Didtance"])

    avg_dic = Calculate_Completion()

    for k in avg_dic:
        print("########################")
        
        print(k)
        print(" ".join(["Exact Match", "Edit Didtance"]))
        print(" ".join(avg_dic[k]))