from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from arabert.preprocess import ArabertPreprocessor
import unicodedata
import arabic_reshaper
from bidi.algorithm import get_display
import torch
import random
import re
import gradio as gr

tokenizer1 = AutoTokenizer.from_pretrained("Reham721/Subjective_QG")
tokenizer2 = AutoTokenizer.from_pretrained("google/mt5-base")

model1 = AutoModelForSeq2SeqLM.from_pretrained("Reham721/Subjective_QG")
model2 = AutoModelForSeq2SeqLM.from_pretrained("Reham721/MCQs_QG")

prep = ArabertPreprocessor("aubmindlab/araelectra-base-discriminator")
qa_pipe = pipeline("question-answering", model="wissamantoun/araelectra-base-artydiqa")

def generate_questions(model, tokenizer, input_sequence):
    input_ids = tokenizer.encode(input_sequence, return_tensors='pt')
    outputs = model.generate(
        input_ids=input_ids,
        max_length=200,
        num_beams=3,
        no_repeat_ngram_size=3,
        early_stopping=True,
        temperature=1,
        num_return_sequences=3,
    )
    return [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]

def get_sorted_questions(questions, context):
    dic = {}
    context = prep.preprocess(context)
    for question in questions:
        try:
            result = qa_pipe(question=question, context=context)
            dic[question] = result["score"]
        except:
            dic[question] = 0
    return dict(sorted(dic.items(), key=lambda item: item[1], reverse=True))

def is_arabic(text):
    reshaped_text = arabic_reshaper.reshape(text)
    bidi_text = get_display(reshaped_text)
    for char in bidi_text:
        if char.isalpha() and not unicodedata.name(char).startswith('ARABIC'):
            return False
    return True

def generate_distractors(question, answer, context, num_distractors=3, k=10):
    input_sequence = f'{question} <sep> {answer} <sep> {context}'
    input_ids = tokenizer2.encode(input_sequence, return_tensors='pt')
    outputs = model2.generate(
        input_ids,
        do_sample=True,
        max_length=50,
        top_k=50,
        top_p=0.95,
        num_return_sequences=num_distractors,
        no_repeat_ngram_size=2
    )
    distractors = []
    for output in outputs:
        decoded_output = tokenizer2.decode(output, skip_special_tokens=True)
        elements = [re.sub(r'<[^>]*>', '', e.strip()) for e in re.split(r'(<[^>]*>)|(?:None)', decoded_output) if e]
        elements = [e for e in elements if e and is_arabic(e)]
        distractors.extend(elements)
    unique_distractors = []
    for d in distractors:
        if d not in unique_distractors and d != answer:
            unique_distractors.append(d)
    while len(unique_distractors) < num_distractors:
        outputs = model2.generate(
            input_ids,
            do_sample=True,
            max_length=50,
            top_k=50,
            top_p=0.95,
            num_return_sequences=num_distractors - len(unique_distractors),
            no_repeat_ngram_size=2
        )
        for output in outputs:
            decoded_output = tokenizer2.decode(output, skip_special_tokens=True)
            elements = [re.sub(r'<[^>]*>', '', e.strip()) for e in re.split(r'(<[^>]*>)|(?:None)', decoded_output) if e]
            elements = [e for e in elements if e and is_arabic(e)]
            for e in elements:
                if e not in unique_distractors and e != answer:
                    unique_distractors.append(e)
                if len(unique_distractors) >= num_distractors:
                    break
    if len(unique_distractors) > k:
        unique_distractors = sorted(unique_distractors, key=lambda x: random.random())[:k]
    return random.sample(unique_distractors, num_distractors)

context = gr.Textbox(lines=5, placeholder="أدخل الفقرة هنا", label="النص")
answer = gr.Textbox(lines=3, placeholder="أدخل الإجابة هنا", label="الإجابة")
question_type = gr.Radio(choices=["سؤال مقالي", "سؤال اختيار من متعدد"], label="نوع السؤال")
question = gr.Textbox(type="text", label="السؤال الناتج")

def generate_question(context, answer, question_type):
    article = answer + "<sep>" + context
    output = generate_questions(model1, tokenizer1, article)
    result = get_sorted_questions(output, context)
    best_question = next(iter(result)) if result else "لم يتم توليد سؤال مناسب"
    if question_type == "سؤال مقالي":
        return best_question
    else:
        mcqs = generate_distractors(best_question, answer, context)
        mcqs.append(answer)
        random.shuffle(mcqs)
        return best_question + "\n" + "\n".join("- " + opt for opt in mcqs)

iface = gr.Interface(
    fn=generate_question,
    inputs=[context, answer, question_type],
    outputs=question
)

iface.launch(debug=True, share=False)