File size: 4,895 Bytes
6f86460
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from arabert.preprocess import ArabertPreprocessor
import unicodedata
import arabic_reshaper
from bidi.algorithm import get_display
import torch
import random
import re
import gradio as gr

tokenizer1 = AutoTokenizer.from_pretrained("Reham721/Subjective_QG")
tokenizer2 = AutoTokenizer.from_pretrained("google/mt5-base")

model1 = AutoModelForSeq2SeqLM.from_pretrained("Reham721/Subjective_QG")
model2 = AutoModelForSeq2SeqLM.from_pretrained("Reham721/MCQs_QG")

prep = ArabertPreprocessor("aubmindlab/araelectra-base-discriminator")
qa_pipe = pipeline("question-answering", model="wissamantoun/araelectra-base-artydiqa")

def generate_questions(model, tokenizer, input_sequence):
    input_ids = tokenizer.encode(input_sequence, return_tensors='pt')
    outputs = model.generate(
        input_ids=input_ids,
        max_length=200,
        num_beams=3,
        no_repeat_ngram_size=3,
        early_stopping=True,
        temperature=1,
        num_return_sequences=3,
    )
    return [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]

def get_sorted_questions(questions, context):
    dic = {}
    context = prep.preprocess(context)
    for question in questions:
        try:
            result = qa_pipe(question=question, context=context)
            dic[question] = result["score"]
        except:
            dic[question] = 0
    return dict(sorted(dic.items(), key=lambda item: item[1], reverse=True))

def is_arabic(text):
    reshaped_text = arabic_reshaper.reshape(text)
    bidi_text = get_display(reshaped_text)
    for char in bidi_text:
        if char.isalpha() and not unicodedata.name(char).startswith('ARABIC'):
            return False
    return True

def generate_distractors(question, answer, context, num_distractors=3, k=10):
    input_sequence = f'{question} <sep> {answer} <sep> {context}'
    input_ids = tokenizer2.encode(input_sequence, return_tensors='pt')
    outputs = model2.generate(
        input_ids,
        do_sample=True,
        max_length=50,
        top_k=50,
        top_p=0.95,
        num_return_sequences=num_distractors,
        no_repeat_ngram_size=2
    )
    distractors = []
    for output in outputs:
        decoded_output = tokenizer2.decode(output, skip_special_tokens=True)
        elements = [re.sub(r'<[^>]*>', '', e.strip()) for e in re.split(r'(<[^>]*>)|(?:None)', decoded_output) if e]
        elements = [e for e in elements if e and is_arabic(e)]
        distractors.extend(elements)
    unique_distractors = []
    for d in distractors:
        if d not in unique_distractors and d != answer:
            unique_distractors.append(d)
    while len(unique_distractors) < num_distractors:
        outputs = model2.generate(
            input_ids,
            do_sample=True,
            max_length=50,
            top_k=50,
            top_p=0.95,
            num_return_sequences=num_distractors - len(unique_distractors),
            no_repeat_ngram_size=2
        )
        for output in outputs:
            decoded_output = tokenizer2.decode(output, skip_special_tokens=True)
            elements = [re.sub(r'<[^>]*>', '', e.strip()) for e in re.split(r'(<[^>]*>)|(?:None)', decoded_output) if e]
            elements = [e for e in elements if e and is_arabic(e)]
            for e in elements:
                if e not in unique_distractors and e != answer:
                    unique_distractors.append(e)
                if len(unique_distractors) >= num_distractors:
                    break
    if len(unique_distractors) > k:
        unique_distractors = sorted(unique_distractors, key=lambda x: random.random())[:k]
    return random.sample(unique_distractors, num_distractors)

context = gr.Textbox(lines=5, placeholder="أدخل الفقرة هنا", label="النص")
answer = gr.Textbox(lines=3, placeholder="أدخل الإجابة هنا", label="الإجابة")
question_type = gr.Radio(choices=["سؤال مقالي", "سؤال اختيار من متعدد"], label="نوع السؤال")
question = gr.Textbox(type="text", label="السؤال الناتج")

def generate_question(context, answer, question_type):
    article = answer + "<sep>" + context
    output = generate_questions(model1, tokenizer1, article)
    result = get_sorted_questions(output, context)
    best_question = next(iter(result)) if result else "لم يتم توليد سؤال مناسب"
    if question_type == "سؤال مقالي":
        return best_question
    else:
        mcqs = generate_distractors(best_question, answer, context)
        mcqs.append(answer)
        random.shuffle(mcqs)
        return best_question + "\n" + "\n".join("- " + opt for opt in mcqs)

iface = gr.Interface(
    fn=generate_question,
    inputs=[context, answer, question_type],
    outputs=question
)

iface.launch(debug=True, share=False)