import torch
import numpy as np
from bartpho.preprocess import tokenize, normalize


tag_dict = {
    "RESTAURANT#GENERAL": "chung về nhà_hàng",
    "RESTAURANT#PRICES": "giá của nhà_hàng",
    "RESTAURANT#MISCELLANEOUS": "tổng_quát về nhà_hàng",
    "FOOD#PRICES": "giá đồ ăn",
    "FOOD#QUALITY": "chất_lượng đồ ăn",
    "FOOD#STYLE&OPTIONS": "phong_cách và lựa_chọn đồ ăn",
    "DRINKS#PRICES": "giá đồ uống",
    "DRINKS#QUALITY": "chất_lượng đồ uống",
    "DRINKS#STYLE&OPTIONS": "phong_cách và lựa_chọn đồ uống",
    "AMBIENCE#GENERAL": "bầu không_khí",
    "SERVICE#GENERAL": "dịch_vụ",
    "LOCATION#GENERAL": "vị_trí",
}

polarity_dict = {
    "không có": "không có",
    "positive": "tích_cực",
    "neutral": "trung_lập",
    "negative": "tiêu_cực"
}

polarity_list = ["không có", "tích_cực", "trung_lập", "tiêu_cực"]
tags = ["chung về nhà_hàng", "giá của nhà_hàng", "tổng_quát về nhà_hàng", "giá đồ ăn",
        "chất_lượng đồ ăn", "phong_cách và lựa_chọn đồ ăn", "giá đồ uống", "chất_lượng đồ uống",
        "phong_cách và lựa_chọn đồ uống", "bầu không_khí", "dịch_vụ", "vị_trí"]
eng_tags = ["RESTAURANT#GENERAL", "RESTAURANT#PRICES", "RESTAURANT#MISCELLANEOUS", "FOOD#PRICES",
            "FOOD#QUALITY", "FOOD#STYLE&OPTIONS", "DRINKS#PRICES", "DRINKS#QUALITY",
            "DRINKS#STYLE&OPTIONS", "AMBIENCE#GENERAL", "SERVICE#GENERAL", "LOCATION#GENERAL"]
eng_polarity = ["không có", "positive", "neutral", "negative"]
detect_labels = ['không', 'có']
no_polarity = len(polarity_list)
no_tag = len(tags)

def predict(model, text, tokenizer, model_tokenize=None, processed=True, printout=False):
    predicts = []
    device = 'cpu'
    model.to(device)
    model.eval()
    model.config.use_cache = False
    
    if not processed:
        text = normalize(text)
        text = tokenize(text, model_tokenize)
        
    for i in range(no_tag):
        tag = tags[i]
        score_list = []

        input_ids = tokenizer([text] * no_polarity, return_tensors='pt')['input_ids'].to(device)
        target_list = ["Nhận_xét " + tag.lower() + " " + polarity.lower() + " ." for polarity in polarity_list]
        output_ids = tokenizer(target_list, return_tensors='pt', padding=True, truncation=True)['input_ids'].to(device)

        with torch.no_grad():
            output = model(input_ids=input_ids.to(device), decoder_input_ids=output_ids.to(device))[0]
            logits = output.softmax(dim=-1).to('cpu').numpy()
        for m in range(no_polarity):
            score = np.sum(np.log(logits[m][range(len(output_ids[m]) - 2), output_ids[m][1:-1]]))
            score_list.append(score)
        predict = int(np.argmax(score_list))  # Ép kiểu sang int
        predicts.append(predict)
        
    if printout:
        result = {}
        for i in range(no_tag):
            if predicts[i] != 0:  # Bỏ qua các nhãn không có cảm xúc (mặc định 0)
                result[tags[i]] = polarity_list[predicts[i]]  # Ánh xạ nhãn
        # print(result)
    return result

def predict_df(model, df, tokenizer=None, model_tokenize=None, tokenizer_name='vinai/bartpho-word-base', processed=True, printout=True):
    model.eval()
    device = 'cpu'
    model.to(device)
    model.config.use_cache = False
    if not tokenizer:
        tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    
    count_acc = count_detect = f1_detect = f1_absa = pre_detect = rec_detect = pre_absa = rec_absa = 0
    total_f1 = len(df)
    total = len(df) * no_tag
    
    for i in range(total_f1):
        text = df['text'][i]
        labels = [df[x][i] for x in eng_tags]
        predicts = predict(model, text, tokenizer, model_tokenize, device, processed)
        
        labels_detect = [i for i in range(no_tag) if labels[i] != 0]
        predicts_detect = [i for i in range(no_tag) if predicts[i] != 0]
        common_detect = [x for x in labels_detect if x in predicts_detect]
        
        if common_detect:
            precision_detect = len(common_detect) / len(predicts_detect)
            recall_detect = len(common_detect) / len(labels_detect)
            f1_detect += (2 * precision_detect * recall_detect / (precision_detect + recall_detect))
            pre_detect += precision_detect
            rec_detect += recall_detect

            labels_absa = [str(i) + '-' + str(labels[i]) for i in range(no_tag) if labels[i] != 0]
            predicts_absa = [str(i) + '-' + str(predicts[i]) for i in range(no_tag) if predicts[i] != 0]
            common_absa = [x for x in labels_absa if x in predicts_absa]
            
            if common_absa:
                precision_absa = len(common_absa) / len(predicts_absa)
                recall_absa = len(common_absa) / len(labels_absa)
                f1_absa += (2 * precision_absa * recall_absa / (precision_absa + recall_absa))
                pre_absa += precision_absa
                rec_absa += recall_absa
                
        for j in range(no_tag):
            if labels[j] == predicts[j]:
                count_acc += 1
                count_detect += 1
            else:
                if labels[j] != 0 and predicts[j] != 0:
                    count_detect += 1
    
    acc_detect = count_detect / total
    pre_detect = pre_detect / total_f1
    rec_detect = rec_detect / total_f1
    f1_detect = f1_detect / total_f1
    
    acc = count_acc / total
    pre_absa = pre_absa / total_f1
    rec_absa = rec_absa / total_f1
    f1_absa = f1_absa / total_f1
    
    if printout:
        print(f"Detect acc: {acc_detect:.4f}%")
        print(f"Detect precision: {pre_detect:.4f}%")
        print(f"Detect recall: {rec_detect:.4f}%")
        print(f"Detect f1: {f1_detect:.4f}%")
        print()
        print(f"Absa acc: {acc:.4f}%")
        print(f"Absa precision: {pre_absa:.4f}%")
        print(f"Absa recall: {rec_absa:.4f}%")
        print(f"Absa f1: {f1_absa:.4f}%")
    
    return acc_detect, pre_detect, rec_detect, f1_detect, acc, pre_absa, rec_absa, f1_absa

def predict_detect(model, text, tokenizer, model_tokenize=None, processed=True, printout=False):
    detect_predicts = []
    device = 'cpu'
    model.to(device)
    model.eval()
    model.config.use_cache = False
    
    if not processed:
        text = normalize(text)
        text = tokenize(text, model_tokenize)
        
    for i in range(no_tag):
        tag = tags[i]
        detect_score_list = []
        input_ids = tokenizer([text] * 2, return_tensors='pt')['input_ids']
        target_list = [tag.lower() + " " + detect_label.lower() + " được nhận_xét ." for detect_label in detect_labels]
        output_ids = tokenizer(target_list, return_tensors='pt', padding=True, truncation=True)['input_ids']

        with torch.no_grad():
            output = model(input_ids=input_ids.to(device), decoder_input_ids=output_ids.to(device))[0]
            logits = output.softmax(dim=-1).to('cpu').numpy()
        for m in range(2):
            detect_score = 1
            for n in range(logits[m].shape[0] - 2):
                detect_score *= logits[m][n][output_ids[m][n+1]]
            detect_score_list.append(detect_score)
        detect_predict = np.argmax(detect_score_list)
        detect_predicts.append(detect_predict)
        
    predicts = []
    for i in range(no_tag):
        if detect_predicts[i] == 0:
            predicts.append(0)
        else:
            tag = tags[i]
            score_list = []
            input_ids = tokenizer([text] * (no_polarity - 1), return_tensors='pt')['input_ids']
            target_list = ["Nhận_xét " + tag.lower() + " " + polarity.lower() + " ." for polarity in polarity_list if polarity != "không có"]
            output_ids = tokenizer(target_list, return_tensors='pt', padding=True, truncation=True)['input_ids']

            with torch.no_grad():
                output = model(input_ids=input_ids.to(device), decoder_input_ids=output_ids.to(device))[0]
                logits = output.softmax(dim=-1).to('cpu').numpy()
            for m in range(no_polarity - 1):
                score = 1
                for n in range(logits[m].shape[0] - 2):
                    score *= logits[m][n][output_ids[m][n + 1]]
                score_list.append(score)
            predict = np.argmax(score_list) + 1
            predicts.append(predict)

    if printout:
        result = {}
        for i in range(no_tag):
            if predicts[i] != 0:
                result[eng_tags[i]] = eng_polarity[predicts[i]]
        print(result)
    return predicts

def predict_df_detect(model, df, tokenizer=None, model_tokenize=None, tokenizer_name='vinai/bartpho-word-base', printout=True):
    model.eval()
    device = 'cpu'
    model.to(device)
    model.config.use_cache = False
    if not tokenizer:
        tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    count_acc = count_detect = f1_detect = f1_absa = pre_detect = rec_detect = pre_absa = rec_absa = 0
    
    total_f1 = len(df)
    total = len(df) * no_tag
    
    for i in range(total_f1):
        text = df['text'][i]
        labels = [df[x][i] for x in eng_tags]
        predicts = predict(model, text, tokenizer, model_tokenize, processed, device)
        
        labels_detect = [i for i in range(no_tag) if labels[i] != 0]
        predicts_detect = [i for i in range(no_tag) if predicts[i] != 0]
        common_detect = [x for x in labels_detect if x in predicts_detect]
        if common_detect:
            precision_detect = len(common_detect) / len(predicts_detect)
            recall_detect = len(common_detect) / len(labels_detect)
            f1_detect += (2 * precision_detect * recall_detect / (precision_detect + recall_detect))
            pre_detect += precision_detect
            rec_detect += recall_detect
            
            labels_absa = [str(i) + '-' + str(labels[i]) for i in range(no_tag) if labels[i] != 0]
            predicts_absa = [str(i) + '-' + str(predicts[i]) for i in range(no_tag) if predicts[i] != 0]
            common_absa = [x for x in labels_absa if x in predicts_absa]
            if common_absa:
                precision_absa = len(common_absa) / len(predicts_absa)
                recall_absa = len(common_absa) / len(labels_absa)
                f1_absa += (2 * precision_absa * recall_absa / (precision_absa + recall_absa))
                pre_absa += precision_absa
                rec_absa += recall_absa

        for j in range(no_tag):
            if labels[j] == predicts[j]:
                count_acc += 1
                count_detect += 1
            else:
                if labels[j] != 0 and predicts[j] != 0:
                    count_detect += 1

    acc_detect = count_detect / total
    pre_detect = pre_detect / total_f1
    rec_detect = rec_detect / total_f1
    f1_detect = f1_detect / total_f1
    
    acc = count_acc / total
    pre_absa = pre_absa / total_f1
    rec_absa = rec_absa / total_f1
    f1_absa = f1_absa / total_f1
    
    if printout:
        print(f"Detect acc: {acc_detect:.4f}%")
        print(f"Detect precision: {pre_detect:.4f}%")
        print(f"Detect recall: {rec_detect:.4f}%")
        print(f"Detect f1: {f1_detect:.4f}%")
        print()
        print(f"Absa acc: {acc:.4f}%")
        print(f"Absa precision: {pre_absa:.4f}%")
        print(f"Absa recall: {rec_absa:.4f}%")
        print(f"Absa f1: {f1_absa:.4f}%")
    
    return acc_detect, pre_detect, rec_detect, f1_detect, acc, pre_absa, rec_absa, f1_absa