import os
import sys

import gradio as gr
import html
import torch

from transformers import MBartForConditionalGeneration, AutoTokenizer, AutoModel, AutoModelForQuestionAnswering, AutoModelForTokenClassification, pipeline

from torch import nn
import torch.nn.functional as F
from underthesea import word_tokenize
 
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Load multi task model
bartpho_mt_base = MBartForConditionalGeneration.from_pretrained("mc0c0z/BARTPho-multi-task")
bartpho_mt_base_tokenizer = AutoTokenizer.from_pretrained("mc0c0z/BARTPho-multi-task")
bartpho_mt_base.to(device)

bartpho_mt = MBartForConditionalGeneration.from_pretrained("mc0c0z/BARTPho-Large-multi-task")
bartpho_mt_tokenizer = AutoTokenizer.from_pretrained("mc0c0z/BARTPho-Large-multi-task")
bartpho_mt.to(device)

def segmenter(text):
    text = html.unescape(text)
    tokens = word_tokenize(text)
    result = []
    for token in tokens:
        if ' ' in token:
            result.append(token.replace(' ', '_'))
        else:
            result.append(token)
    return result

class MultiTaskModel:
    def __init__(self, model, tokenizer, device):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
    
    def get_prompt(self, task):
        if task == 'sa':
            return "Classify the sentiment: "
        elif task == 'mt-en-vi':
            return "Translate English to Vietnamese: "
        elif task == 'mt-vi-en':
            return "Translate Vietnamese to English: "
        else:
            return "" 
        
    def inference(self, task, sentence, device):
        # Tiền xử lý câu đầu vào tương tự như trong CustomDataset
        tokenized_text = segmenter(sentence)
        source = self.get_prompt(task) + " ".join(tokenized_text)
        
        # Tokenize input
        inputs = self.tokenizer(source, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
        
        # Di chuyển input sang device
        input_ids = inputs["input_ids"].to(device)
        attention_mask = inputs["attention_mask"].to(device)
        
        # Sinh dự đoán
        self.model.eval()
        with torch.no_grad():
            generated_output = self.model.generate(input_ids, attention_mask=attention_mask, max_length=128)
        
        # Giải mã dự đoán
        prediction = self.tokenizer.decode(generated_output[0], skip_special_tokens=True)

        if task == 'sa':
            class_names = ["Negative", "Positive"]
            return class_names[int(prediction[0])]
        return html.unescape(prediction)
    
#Load SA model
class CustomModel(nn.Module):
    def __init__(self, bert_model):
        super(CustomModel, self).__init__()
        self.bert = bert_model
        self.mlp = nn.Sequential(
            nn.Linear(768 * 5, 512),  # 768*5 cho BERT
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 3)  # num_classes là số lượng lớp trong bài toán
        )
    
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        
        # Lấy 5 lớp ẩn cuối cùng của token [CLS]
        last_hidden_states = outputs.hidden_states[-5:]
        cls_embeddings = torch.cat([state[:, 0, :] for state in last_hidden_states], dim=1)

        # Đưa qua MLP
        logits = self.mlp(cls_embeddings)
        return logits
    
## PhoBERT
phobert_sa = AutoModel.from_pretrained("vinai/phobert-base", output_hidden_states=True)
phobert_sa_tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base")
phobert_sa = CustomModel(phobert_sa)
phobert_sa.load_state_dict(torch.load('phobert_sentiment_analysis.pth', map_location=device))
phobert_sa.to(device)

## PhoBERTv2
phobertv2_sa = AutoModel.from_pretrained("vinai/phobert-base-v2", output_hidden_states=True)
phobertv2_sa_tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base-v2")
phobertv2_sa = CustomModel(phobertv2_sa)
phobertv2_sa.load_state_dict(torch.load('phobertv2_sentiment_analysis.pth', map_location=device))
phobertv2_sa.to(device)

## Multilingual BERT
m_bert_sa = AutoModel.from_pretrained("google-bert/bert-base-multilingual-cased", output_hidden_states=True)
m_bert_sa_tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-multilingual-cased")
m_bert_sa = CustomModel(m_bert_sa)
m_bert_sa.load_state_dict(torch.load('bert_model_sentiment_analysis.pth', map_location=device))
m_bert_sa.to(device)

# Load Q&A model

## XLM-RoBERTa-Large
roberta_large_qa = AutoModelForQuestionAnswering.from_pretrained("HungLV2512/Vietnamese-QA-fine-tuned")
roberta_large_qa_tokenizer = AutoTokenizer.from_pretrained("HungLV2512/Vietnamese-QA-fine-tuned")
roberta_large_qa.to(device)

## XLM-RoBERTa-Base
roberta_base_qa = AutoModelForQuestionAnswering.from_pretrained("HungLV2512/xlm-roberta-base-fine-tuned-qa-vietnamese", output_hidden_states=True)
roberta_base_qa_tokenizer = AutoTokenizer.from_pretrained("HungLV2512/xlm-roberta-base-fine-tuned-qa-vietnamese")
roberta_base_qa.to(device)

## Multilingual BERT
m_bert_qa = AutoModelForQuestionAnswering.from_pretrained("HungLV2512/bert-base-multilingual-cased-fine-tuned-qa-vietnamese")
m_bert_qa_tokenizer = AutoTokenizer.from_pretrained("HungLV2512/bert-base-multilingual-cased-fine-tuned-qa-vietnamese")
m_bert_qa.to(device)

# Load NER model
label_map = {
    'B-LOC': 0,
    'B-MISC': 1,
    'B-ORG': 2,
    'B-PER': 3,
    'I-LOC': 4,
    'I-MISC': 5,
    'I-ORG': 6,
    'I-PER': 7,
    'O': 8
}

## PhoBERT
phobert_ner = AutoModelForTokenClassification.from_pretrained("DrRinS/NER-PhoBERT", num_labels=len(label_map))
phobert_ner_tokenizer = AutoTokenizer.from_pretrained("DrRinS/NER-PhoBERT")
phobert_ner.to(device)

## PhoBERTv2
phobertv2_ner = AutoModelForTokenClassification.from_pretrained("DrRinS/NER-PhoBERTv2", num_labels=len(label_map))
phobertv2_ner_tokenizer = AutoTokenizer.from_pretrained("DrRinS/NER-PhoBERTv2")
phobertv2_ner.to(device)

## Multilingual BERT
m_bert_ner = AutoModelForTokenClassification.from_pretrained("DrRinS/NER_MultilingualBERT", num_labels=len(label_map))
m_bert_ner_tokenizer = AutoTokenizer.from_pretrained("DrRinS/NER_MultilingualBERT")
m_bert_ner.to(device)

# Inference function
def sentiment_inference(model, tokenizer, text, device):
    # Segment the input text
    text = " ".join(segmenter(text))
    
    # Tokenize the segmented text
    inputs = tokenizer(
        text,
        padding='max_length',
        truncation=True,
        max_length=128,
        return_tensors='pt'
    )
    
    # Move inputs to the correct device
    input_ids = inputs['input_ids'].to(device)
    attention_mask = inputs['attention_mask'].to(device)
    
    # Ensure inputs have the correct shape
    input_ids = input_ids.unsqueeze(0) if input_ids.dim() == 1 else input_ids
    attention_mask = attention_mask.unsqueeze(0) if attention_mask.dim() == 1 else attention_mask
    
    # Perform inference
    model.eval()
    with torch.no_grad():
        outputs = model(input_ids, attention_mask)
        _, preds = torch.max(outputs, dim=1)
    
    # Map predictions to class names
    class_names = ["Negative", "Positive", "Neutral"]
    return class_names[preds.cpu().item()]

def multitask_inference(model, tokenizer, text, task, device):
    multitask_model = MultiTaskModel(model, tokenizer, device)
    return multitask_model.inference(task, text, device)

def qa_inference(model, tokenizer, question, context, device):
    qa_pipeline = pipeline('question-answering', model=model, tokenizer=tokenizer)
    res = qa_pipeline(question=question, context=context)
    return res['answer']

def ner_inference(model, tokenizer, text, device):   
    predictions = []
    # Tokenize the segmented text
    inputs = tokenizer(
        text,
        padding='max_length',
        truncation=True,
        max_length=128,
        return_tensors='pt'
    )
    
    # Move inputs to the correct device
    input_ids = inputs['input_ids'].to(device)
    attention_mask = inputs['attention_mask'].to(device)
    
    # Perform inference
    model.eval()
    with torch.no_grad():
        outputs = model(input_ids, attention_mask)
        _, preds = torch.max(outputs.logits, dim=2)
    
    # Convert predictions to labels
    id_to_label = {v: k for k, v in label_map.items()}
    predictions = preds[attention_mask.bool()].cpu().numpy().flatten()
    labels = [id_to_label[p] for p in predictions]
    
    # Decode the input ids to tokens
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0], skip_special_tokens=True)
    
    labels = labels[1:-1]
    # Combine tokens with their NER labels
    ner_tags = list(zip(tokens, labels))
    
    return ner_tags

def process_input(input_text, context, task):
    results = {}
    
    if task == "Sentiment Analysis":
        results["PhoBERT"] = sentiment_inference(phobert_sa, phobert_sa_tokenizer, input_text, device)
        results["PhoBERTv2"] = sentiment_inference(phobertv2_sa, phobertv2_sa_tokenizer, input_text, device)
        results["Multilingual BERT"] = sentiment_inference(m_bert_sa, m_bert_sa_tokenizer, input_text, device)
        results["BARTPho Base"] = multitask_inference(bartpho_mt_base, bartpho_mt_base_tokenizer, input_text, "sa", device)
        results["BARTPho Large"] = multitask_inference(bartpho_mt, bartpho_mt_tokenizer, input_text, "sa", device)
    elif task == "English to Vietnamese":
        results["BARTPho Base"] = multitask_inference(bartpho_mt_base, bartpho_mt_base_tokenizer, input_text, "mt-en-vi", device)
        results["BARTPho Large"] = multitask_inference(bartpho_mt, bartpho_mt_tokenizer, input_text, "mt-en-vi", device)
    elif task == "Vietnamese to English":
        results["BARTPho Base"] = multitask_inference(bartpho_mt_base, bartpho_mt_base_tokenizer, input_text, "mt-vi-en", device)
        results["BARTPho Large"] = multitask_inference(bartpho_mt, bartpho_mt_tokenizer, input_text, "mt-vi-en", device)
    elif task == "Question Answering":
        results["RoBERTa Base"] = qa_inference(roberta_base_qa, roberta_base_qa_tokenizer, input_text, context, device)
        results["RoBERTa Large"] = qa_inference(roberta_large_qa, roberta_large_qa_tokenizer, input_text, context, device)
        results["Multilingual BERT"] = qa_inference(m_bert_qa, m_bert_qa_tokenizer, input_text, context, device)
    elif task == "Named Entity Recognition":
        results["PhoBERT"] = ner_inference(phobert_ner, phobert_ner_tokenizer, input_text, device)
        results["PhoBERTv2"] = ner_inference(phobertv2_ner, phobertv2_ner_tokenizer, input_text, device)
        results["Multilingual BERT"] = ner_inference(m_bert_ner, m_bert_ner_tokenizer, input_text, device)
    return results

with gr.Blocks() as iface:
    gr.Markdown("# Multi-task NLP Demo")
    gr.Markdown("Perform sentiment analysis, machine translation, question answering, or named entity recognition using various models.")
    
    with gr.Row():
        task = gr.Radio(["Sentiment Analysis", "Question Answering", "Named Entity Recognition", "English to Vietnamese", "Vietnamese to English"], label="Task")
    
    with gr.Row():
        input_text = gr.Textbox(label="Input Text")
        context = gr.Textbox(label="Context", visible=False)
    
    output = gr.JSON(label="Results")
    
    submit = gr.Button("Submit")
    
    def on_task_change(task):
        if task == "Question Answering":
            return {
                input_text: gr.update(label="Question", visible=True),
                context: gr.update(visible=True)
            }
        else:
            return {
                input_text: gr.update(label="Input Text", visible=True),
                context: gr.update(visible=False)
            }
    
    task.change(on_task_change, task, [input_text, context])
    
    submit.click(
        process_input,
        inputs=[input_text, context, task],
        outputs=output
    )

if __name__ == "__main__":
    iface.launch(share=True)